Skip to content

Commit 210ceb6

Browse files
authored
Merge branch 'main' into fix/invalidate-nonexistent-environment
2 parents 93d0d8a + 479b5ba commit 210ceb6

8 files changed

Lines changed: 124 additions & 4 deletions

File tree

sqlmesh/core/dialect.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -556,6 +556,10 @@ def _parse_if(self: Parser) -> t.Optional[exp.Expr]:
556556
# to parse a statement / command to support the macro @IF(condition, statement)
557557
index = self._index
558558
try:
559+
if self.dialect == "tsql":
560+
if not (self._index >= 2 and self._tokens[self._index - 2].text == "@"):
561+
return self.__parse_if() # type: ignore
562+
return Parser.__parse_if(self) # type: ignore
559563
return self.__parse_if() # type: ignore
560564
except ParseError:
561565
self._retreat(index)
@@ -1133,8 +1137,8 @@ def extend_sqlglot() -> None:
11331137
_override(Parser, _parse_value)
11341138
_override(Parser, _parse_lambda)
11351139
_override(Parser, _parse_types)
1136-
_override(TSQL.Parser, Parser._parse_if)
11371140
_override(Parser, _parse_if)
1141+
_override(TSQL.Parser, Parser._parse_if)
11381142
_override(Parser, _parse_id_var)
11391143
_override(Parser, _parse_interval_span)
11401144
_override(Parser, _warn_unsupported)

sqlmesh/core/engine_adapter/snowflake.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -407,7 +407,7 @@ def query_factory() -> Query:
407407
elif isinstance(df, pd.DataFrame):
408408
from snowflake.connector.pandas_tools import write_pandas
409409

410-
ordered_df = df[list(source_columns_to_types)]
410+
ordered_df = df[list(source_columns_to_types)].reset_index(drop=True)
411411

412412
# Workaround for https://github.com/snowflakedb/snowflake-connector-python/issues/1034
413413
# The above issue has already been fixed upstream, but we keep the following

sqlmesh/core/model/seed.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def read(self, batch_size: t.Optional[int] = None) -> t.Generator[pd.DataFrame,
113113
batch_size = batch_size or df.size
114114
batch_start = 0
115115
while batch_start < df.shape[0]:
116-
yield df.iloc[batch_start : batch_start + batch_size, :]
116+
yield df.iloc[batch_start : batch_start + batch_size, :].copy()
117117
batch_start += batch_size
118118

119119
def _get_df(self) -> pd.DataFrame:

sqlmesh/dbt/source.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,10 +82,12 @@ def canonical_name(self, context: DbtContext) -> str:
8282
f"'source' macro failed for '{self.config_name}' with exception '{e}'."
8383
)
8484

85+
identifier = relation.identifier or ""
86+
needs_identifier_quoting = "." in identifier or " " in identifier
8587
relation = relation.quote(
8688
database=False,
8789
schema=False,
88-
identifier=False,
90+
identifier=needs_identifier_quoting,
8991
)
9092
if relation.database == context.target.database:
9193
relation = relation.include(database=False)

tests/core/engine_adapter/test_snowflake.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -469,6 +469,26 @@ def test_df_to_source_queries_use_schema(
469469
assert 'USE SCHEMA "other_catalog"."other_db"' in to_sql_calls(adapter)
470470

471471

472+
def test_df_to_source_queries_reset_non_default_index(
473+
make_mocked_engine_adapter: t.Callable, mocker: MockerFixture
474+
):
475+
mocker.patch(
476+
"sqlmesh.core.engine_adapter.snowflake.SnowflakeEngineAdapter.table_exists",
477+
return_value=False,
478+
)
479+
write_pandas = mocker.patch("snowflake.connector.pandas_tools.write_pandas", return_value=None)
480+
adapter = make_mocked_engine_adapter(SnowflakeEngineAdapter)
481+
482+
df = pd.DataFrame({"a": [2, 3], "b": [5, 6]}, index=[1, 2])
483+
adapter.replace_query(
484+
"other_db.test_table", df, {"a": exp.DataType.build("INT"), "b": exp.DataType.build("INT")}
485+
)
486+
487+
uploaded_df = write_pandas.call_args.args[1]
488+
assert uploaded_df.index.equals(pd.RangeIndex(start=0, stop=2, step=1))
489+
assert uploaded_df.to_dict("list") == {"a": [2, 3], "b": [5, 6]}
490+
491+
472492
def test_create_managed_table(make_mocked_engine_adapter: t.Callable, mocker: MockerFixture):
473493
adapter = make_mocked_engine_adapter(SnowflakeEngineAdapter)
474494

tests/core/test_dialect.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -709,6 +709,20 @@ def test_conditional_statement():
709709
q = parse_one("@IF(cond, VACUUM ANALYZE);", read="postgres")
710710
assert q.sql(dialect="postgres") == "@IF(cond, VACUUM ANALYZE)"
711711

712+
# Verify that the original error case from issue #5823 (Required keyword: 'true' missing) is resolved.
713+
# It must be parsed as a macro function containing an Anonymous expression rather than exp.If.
714+
q = parse_one("@IF(1 = 1, ALTER TABLE x ADD y INT);", read="tsql")
715+
assert q.sql(dialect="tsql") == "@IF(1 = 1, ALTER TABLE x ADD y INTEGER)"
716+
assert isinstance(q.this, exp.Anonymous)
717+
assert q.this.name == "IF"
718+
719+
# Note: SQLGlot's fallback Command parser strips quotes from string literal tokens when parsing unparsed commands
720+
q = parse_one("@IF(cond, PRINT 'hello');", read="tsql")
721+
assert q.sql(dialect="tsql") == "@IF(cond, PRINT hello)"
722+
723+
q = parse_one("@IF(@runtime_stage = 'evaluating', SELECT 1);", read="tsql")
724+
assert q.sql(dialect="tsql") == "@IF(@runtime_stage = 'evaluating', SELECT 1)"
725+
712726

713727
def test_model_name_cannot_be_string():
714728
with pytest.raises(ParseError) as parse_error:

tests/core/test_seed.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,21 @@ def test_read_custom_settings():
5858
pd.testing.assert_frame_equal(next(dfs), expected_df)
5959

6060

61+
def test_read_returns_independent_batches():
62+
content = """key,value
63+
1,one
64+
2,two
65+
"""
66+
seed = Seed(content=content)
67+
seed_reader = seed.reader()
68+
69+
batches = list(seed_reader.read(batch_size=1))
70+
batches[0].at[0, "value"] = "changed"
71+
72+
assert [df["value"].tolist() for df in batches] == [["changed"], ["two"]]
73+
assert next(seed_reader.read())["value"].tolist() == ["one", "two"]
74+
75+
6176
def test_column_hashes():
6277
content = """key,value,ds
6378
1,one,2022-01-01

tests/dbt/test_config.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -521,6 +521,71 @@ def test_quoting():
521521
assert str(BaseRelation.create(**source.relation_info)) == 'foo."bar"'
522522

523523

524+
def test_source_canonical_name_with_dots_and_spaces(mocker):
525+
from dbt.adapters.base import BaseRelation
526+
527+
mock_context = mocker.Mock()
528+
mock_context.target.database = "target_db"
529+
530+
def mock_source_macro(source_name, table_name):
531+
if table_name == "my_table_dot":
532+
identifier = "FILENAME.CSV"
533+
elif table_name == "my_table_space":
534+
identifier = "my table space"
535+
else:
536+
identifier = "my_table_std"
537+
return BaseRelation.create(
538+
database="RAW_DEV",
539+
schema="raw_schema",
540+
identifier=identifier,
541+
)
542+
543+
mock_context.get_callable_macro.return_value = mock_source_macro
544+
545+
# 1. Identifier with a dot
546+
source_dot = SourceConfig(
547+
name="my_table_dot",
548+
source_name="my_source",
549+
identifier="FILENAME.CSV",
550+
)
551+
assert source_dot.canonical_name(mock_context) == 'RAW_DEV.raw_schema."FILENAME.CSV"'
552+
553+
# 2. Identifier with a space
554+
source_space = SourceConfig(
555+
name="my_table_space",
556+
source_name="my_source",
557+
identifier="my table space",
558+
)
559+
assert source_space.canonical_name(mock_context) == 'RAW_DEV.raw_schema."my table space"'
560+
561+
# 3. Standard identifier (without dots or spaces) should not be quoted
562+
source_std = SourceConfig(
563+
name="my_table_std",
564+
source_name="my_source",
565+
identifier="my_table_std",
566+
)
567+
assert source_std.canonical_name(mock_context) == "RAW_DEV.raw_schema.my_table_std"
568+
569+
# 4. Standard identifier, but with database matching target database (to test database omission)
570+
mock_context_target_db = mocker.Mock()
571+
mock_context_target_db.target.database = "RAW_DEV"
572+
mock_context_target_db.get_callable_macro.return_value = mock_source_macro
573+
574+
source_dot_target = SourceConfig(
575+
name="my_table_dot",
576+
source_name="my_source",
577+
identifier="FILENAME.CSV",
578+
)
579+
source_std_target = SourceConfig(
580+
name="my_table_std",
581+
source_name="my_source",
582+
identifier="my_table_std",
583+
)
584+
585+
assert source_dot_target.canonical_name(mock_context_target_db) == 'raw_schema."FILENAME.CSV"'
586+
assert source_std_target.canonical_name(mock_context_target_db) == "raw_schema.my_table_std"
587+
588+
524589
def _test_warehouse_config(
525590
config_yaml: str, target_class: t.Type[TargetConfig], *params_path: str
526591
) -> TargetConfig:

0 commit comments

Comments
 (0)