Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions fakesnow/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import fakesnow.checks as checks
import fakesnow.expr as expr
import fakesnow.info_schema as info_schema
import fakesnow.macros as macros
import fakesnow.transforms as transforms
from fakesnow import logger
from fakesnow.copy_into import copy_into
Expand Down Expand Up @@ -446,6 +447,7 @@ def _execute(self, transformed: Expr, params: MutableParams | None = None) -> No
elif create_db_name := transformed.args.get("create_db_name"):
# we created a new database, so create the info schema extensions
self._duck_conn.execute(info_schema.per_db_creation_sql(create_db_name))
self._duck_conn.execute(macros.creation_sql(create_db_name))
result_sql = SQL_CREATED_DATABASE.substitute(name=create_db_name)

elif stage_name := transformed.args.get("create_stage_name"):
Expand Down
20 changes: 20 additions & 0 deletions tests/test_connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,26 @@ def test_connect_db_path_can_create_database() -> None:
cursor.execute("CREATE DATABASE db2")


def test_create_database_registers_macros() -> None:
# _fs_* macros must be registered for databases created via CREATE DATABASE,
# not only for databases passed at connect time (issue #347)
with fakesnow.patch():
with snowflake.connector.connect() as conn:
cur = conn.cursor()
cur.execute("CREATE DATABASE foo")
cur.execute("USE DATABASE foo")
cur.execute("CREATE SCHEMA foo.bar")
cur.execute("USE SCHEMA foo.bar")
# TO_TIMESTAMP_NTZ relies on _fs_to_timestamp macro
result = cur.execute("SELECT TO_TIMESTAMP_NTZ('1672531200', 0)").fetchone()
assert result is not None
# LATERAL FLATTEN relies on _fs_flatten macro
cur.execute("CREATE TABLE foo.bar.t (arr ARRAY)")
cur.execute("INSERT INTO foo.bar.t SELECT ARRAY_CONSTRUCT(1, 2, 3)")
rows = cur.execute("SELECT value FROM foo.bar.t, LATERAL FLATTEN(INPUT => arr)").fetchall()
assert len(rows) == 3


def test_connect_db_path_reuse():
with tempfile.TemporaryDirectory(prefix="fakesnow-test") as db_path:
with (
Expand Down