diff --git a/sqlmesh/core/dialect.py b/sqlmesh/core/dialect.py index db77f0c461..00da1852ce 100644 --- a/sqlmesh/core/dialect.py +++ b/sqlmesh/core/dialect.py @@ -1388,17 +1388,27 @@ def is_meta_expression(v: t.Any) -> bool: return isinstance(v, (Audit, Metric, Model)) -def replace_merge_table_aliases(expression: exp.Expression) -> exp.Expression: +def replace_merge_table_aliases( + expression: exp.Expression, dialect: t.Optional[str] = None +) -> exp.Expression: """ Resolves references from the "source" and "target" tables (or their DBT equivalents) with the corresponding SQLMesh merge aliases (MERGE_SOURCE_ALIAS and MERGE_TARGET_ALIAS) """ from sqlmesh.core.engine_adapter.base import MERGE_SOURCE_ALIAS, MERGE_TARGET_ALIAS + normalized_merge_source_alias = quote_identifiers( + normalize_identifiers(exp.to_identifier(MERGE_SOURCE_ALIAS), dialect), dialect=dialect + ) + + normalized_merge_target_alias = quote_identifiers( + normalize_identifiers(exp.to_identifier(MERGE_TARGET_ALIAS), dialect), dialect=dialect + ) + if isinstance(expression, exp.Column) and (first_part := expression.parts[0]): if first_part.this.lower() in ("target", "dbt_internal_dest", "__merge_target__"): - first_part.replace(exp.to_identifier(MERGE_TARGET_ALIAS)) + first_part.replace(normalized_merge_target_alias) elif first_part.this.lower() in ("source", "dbt_internal_source", "__merge_source__"): - first_part.replace(exp.to_identifier(MERGE_SOURCE_ALIAS)) + first_part.replace(normalized_merge_source_alias) return expression diff --git a/sqlmesh/core/model/definition.py b/sqlmesh/core/model/definition.py index 6ccd0927cd..c8d6a0f836 100644 --- a/sqlmesh/core/model/definition.py +++ b/sqlmesh/core/model/definition.py @@ -653,7 +653,7 @@ def render_merge_filter( ) if len(rendered_exprs) != 1: raise SQLMeshError(f"Expected one expression but got {len(rendered_exprs)}") - return rendered_exprs[0].transform(d.replace_merge_table_aliases) + return rendered_exprs[0].transform(d.replace_merge_table_aliases, dialect=self.dialect) def _render_properties( self, properties: t.Dict[str, exp.Expression] | SessionProperties, **render_kwargs: t.Any diff --git a/sqlmesh/core/model/kind.py b/sqlmesh/core/model/kind.py index 4a15023f2f..9dc54f4b83 100644 --- a/sqlmesh/core/model/kind.py +++ b/sqlmesh/core/model/kind.py @@ -34,6 +34,7 @@ get_dialect, validate_string, positive_int_validator, + validate_expression, ) @@ -467,15 +468,20 @@ def _when_matched_validator( return v if isinstance(v, list): v = " ".join(v) + + dialect = get_dialect(info.data) + if isinstance(v, str): # Whens wrap the WHEN clauses, but the parentheses aren't parsed by sqlglot v = v.strip() if v.startswith("("): v = v[1:-1] - return t.cast(exp.Whens, d.parse_one(v, into=exp.Whens, dialect=get_dialect(info.data))) + v = t.cast(exp.Whens, d.parse_one(v, into=exp.Whens, dialect=dialect)) + else: + v = t.cast(exp.Whens, v.transform(d.replace_merge_table_aliases, dialect=dialect)) - return t.cast(exp.Whens, v.transform(d.replace_merge_table_aliases)) + return validate_expression(v, dialect=dialect) @field_validator("merge_filter", mode="before") def _merge_filter_validator( @@ -485,11 +491,16 @@ def _merge_filter_validator( ) -> t.Optional[exp.Expression]: if v is None: return v + + dialect = get_dialect(info.data) + if isinstance(v, str): v = v.strip() - return d.parse_one(v, dialect=get_dialect(info.data)) + v = d.parse_one(v, dialect=dialect) + else: + v = v.transform(d.replace_merge_table_aliases, dialect=dialect) - return v.transform(d.replace_merge_table_aliases) + return validate_expression(v, dialect=dialect) @property def data_hash_values(self) -> t.List[t.Optional[str]]: diff --git a/sqlmesh/migrations/v0084_normalize_quote_when_matched_and_merge_filter.py b/sqlmesh/migrations/v0084_normalize_quote_when_matched_and_merge_filter.py new file mode 100644 index 0000000000..24a6db9384 --- /dev/null +++ b/sqlmesh/migrations/v0084_normalize_quote_when_matched_and_merge_filter.py @@ -0,0 +1,9 @@ +""" +Normalize and quote the when_matched and merge_filter properties of IncrementalByUniqueKeyKind +to match how other properties (such as time_column and partitioned_by) are handled and to +prevent un-normalized identifiers being quoted at the EngineAdapter level +""" + + +def migrate(state_sync, **kwargs): # type: ignore + pass diff --git a/sqlmesh/utils/pydantic.py b/sqlmesh/utils/pydantic.py index 3de15773a3..317e873aeb 100644 --- a/sqlmesh/utils/pydantic.py +++ b/sqlmesh/utils/pydantic.py @@ -16,6 +16,8 @@ from sqlmesh.utils import str_to_bool if t.TYPE_CHECKING: + from sqlglot._typing import E + Model = t.TypeVar("Model", bound="PydanticModel") @@ -193,6 +195,12 @@ def validate_string(v: t.Any) -> str: return str(v) +def validate_expression(expression: E, dialect: str) -> E: + # this normalizes and quotes identifiers in the given expression according the specified dialect + # it also sets expression.meta["dialect"] so that when we serialize for state, the expression is serialized in the correct dialect + return _get_field(expression, {"dialect": dialect}) # type: ignore + + def bool_validator(v: t.Any) -> bool: if isinstance(v, exp.Boolean): return v.this diff --git a/tests/core/test_model.py b/tests/core/test_model.py index 4fc82875d7..a8c4d688a1 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -5366,13 +5366,13 @@ def test_when_matched(): """ ) - expected_when_matched = "(WHEN MATCHED THEN UPDATE SET __MERGE_TARGET__.salary = COALESCE(__MERGE_SOURCE__.salary, __MERGE_TARGET__.salary))" + expected_when_matched = "(WHEN MATCHED THEN UPDATE SET `__merge_target__`.`salary` = COALESCE(`__merge_source__`.`salary`, `__merge_target__`.`salary`))" model = load_sql_based_model(expressions, dialect="hive") - assert model.kind.when_matched.sql() == expected_when_matched + assert model.kind.when_matched.sql(dialect="hive") == expected_when_matched model = SqlModel.parse_raw(model.json()) - assert model.kind.when_matched.sql() == expected_when_matched + assert model.kind.when_matched.sql(dialect="hive") == expected_when_matched expressions = d.parse( """ @@ -5400,9 +5400,9 @@ def test_when_matched(): kind INCREMENTAL_BY_UNIQUE_KEY ( unique_key ("purchase_order_id"), when_matched ( - WHEN MATCHED AND __MERGE_SOURCE__._operation = 1 THEN DELETE - WHEN MATCHED AND __MERGE_SOURCE__._operation <> 1 THEN UPDATE SET - __MERGE_TARGET__.purchase_order_id = 1 + WHEN MATCHED AND "__merge_source__"."_operation" = 1 THEN DELETE + WHEN MATCHED AND "__merge_source__"."_operation" <> 1 THEN UPDATE SET + "__merge_target__"."purchase_order_id" = 1 ), batch_concurrency 1, forward_only FALSE, @@ -5453,7 +5453,7 @@ def fingerprint_merge( kind INCREMENTAL_BY_UNIQUE_KEY ( unique_key ("purchase_order_id"), when_matched ( - WHEN MATCHED AND __MERGE_SOURCE__.salary <> __MERGE_TARGET__.salary THEN UPDATE SET + WHEN MATCHED AND "__merge_source__"."salary" <> "__merge_target__"."salary" THEN UPDATE SET ARRAY('target.update_datetime = source.update_datetime', 'target.salary = source.salary') ), batch_concurrency 1, @@ -5487,21 +5487,21 @@ def test_when_matched_multiple(): ) expected_when_matched = [ - "WHEN MATCHED AND __MERGE_SOURCE__.x = 1 THEN UPDATE SET __MERGE_TARGET__.salary = COALESCE(__MERGE_SOURCE__.salary, __MERGE_TARGET__.salary)", - "WHEN MATCHED THEN UPDATE SET __MERGE_TARGET__.salary = COALESCE(__MERGE_SOURCE__.salary, __MERGE_TARGET__.salary)", + "WHEN MATCHED AND `__merge_source__`.`x` = 1 THEN UPDATE SET `__merge_target__`.`salary` = COALESCE(`__merge_source__`.`salary`, `__merge_target__`.`salary`)", + "WHEN MATCHED THEN UPDATE SET `__merge_target__`.`salary` = COALESCE(`__merge_source__`.`salary`, `__merge_target__`.`salary`)", ] model = load_sql_based_model(expressions, dialect="hive", variables={"schema": "db"}) whens = model.kind.when_matched assert len(whens.expressions) == 2 - assert whens.expressions[0].sql() == expected_when_matched[0] - assert whens.expressions[1].sql() == expected_when_matched[1] + assert whens.expressions[0].sql(dialect="hive") == expected_when_matched[0] + assert whens.expressions[1].sql(dialect="hive") == expected_when_matched[1] model = SqlModel.parse_raw(model.json()) whens = model.kind.when_matched assert len(whens.expressions) == 2 - assert whens.expressions[0].sql() == expected_when_matched[0] - assert whens.expressions[1].sql() == expected_when_matched[1] + assert whens.expressions[0].sql(dialect="hive") == expected_when_matched[0] + assert whens.expressions[1].sql(dialect="hive") == expected_when_matched[1] def test_when_matched_merge_filter_multi_part_columns(): @@ -5529,28 +5529,86 @@ def test_when_matched_merge_filter_multi_part_columns(): ) expected_when_matched = [ - "WHEN MATCHED AND __MERGE_SOURCE__.record.nested_record.field = 1 THEN UPDATE SET __MERGE_TARGET__.repeated_record.sub_repeated_record.sub_field = COALESCE(__MERGE_SOURCE__.repeated_record.sub_repeated_record.sub_field, __MERGE_TARGET__.repeated_record.sub_repeated_record.sub_field)", - "WHEN MATCHED THEN UPDATE SET __MERGE_TARGET__.repeated_record.sub_repeated_record.sub_field = COALESCE(__MERGE_SOURCE__.repeated_record.sub_repeated_record.sub_field, __MERGE_TARGET__.repeated_record.sub_repeated_record.sub_field)", + "WHEN MATCHED AND `__merge_source__`.`record`.`nested_record`.`field` = 1 THEN UPDATE SET `__merge_target__`.`repeated_record`.`sub_repeated_record`.`sub_field` = COALESCE(`__merge_source__`.`repeated_record`.`sub_repeated_record`.`sub_field`, `__merge_target__`.`repeated_record`.`sub_repeated_record`.`sub_field`)", + "WHEN MATCHED THEN UPDATE SET `__merge_target__`.`repeated_record`.`sub_repeated_record`.`sub_field` = COALESCE(`__merge_source__`.`repeated_record`.`sub_repeated_record`.`sub_field`, `__merge_target__`.`repeated_record`.`sub_repeated_record`.`sub_field`)", ] expected_merge_filter = ( - "__MERGE_SOURCE__.record.nested_record.field < __MERGE_TARGET__.record.nested_record.field AND " - "__MERGE_TARGET__.repeated_record.sub_repeated_record.sub_field > __MERGE_SOURCE__.repeated_record.sub_repeated_record.sub_field" + "`__merge_source__`.`record`.`nested_record`.`field` < `__merge_target__`.`record`.`nested_record`.`field` AND " + "`__merge_target__`.`repeated_record`.`sub_repeated_record`.`sub_field` > `__merge_source__`.`repeated_record`.`sub_repeated_record`.`sub_field`" ) model = load_sql_based_model(expressions, dialect="bigquery", variables={"schema": "db"}) whens = model.kind.when_matched assert len(whens.expressions) == 2 - assert whens.expressions[0].sql() == expected_when_matched[0] - assert whens.expressions[1].sql() == expected_when_matched[1] - assert model.merge_filter.sql() == expected_merge_filter + assert whens.expressions[0].sql(dialect="bigquery") == expected_when_matched[0] + assert whens.expressions[1].sql(dialect="bigquery") == expected_when_matched[1] + assert model.merge_filter.sql(dialect="bigquery") == expected_merge_filter model = SqlModel.parse_raw(model.json()) whens = model.kind.when_matched assert len(whens.expressions) == 2 - assert whens.expressions[0].sql() == expected_when_matched[0] - assert whens.expressions[1].sql() == expected_when_matched[1] - assert model.merge_filter.sql() == expected_merge_filter + assert whens.expressions[0].sql(dialect="bigquery") == expected_when_matched[0] + assert whens.expressions[1].sql(dialect="bigquery") == expected_when_matched[1] + assert model.merge_filter.sql(dialect="bigquery") == expected_merge_filter + + +def test_when_matched_normalization() -> None: + # unquoted should be normalized and quoted + expressions = d.parse( + """ + MODEL ( + name test.employees, + kind INCREMENTAL_BY_UNIQUE_KEY ( + unique_key name, + when_matched ( + WHEN MATCHED THEN UPDATE SET + target.key_a = source.key_a, + target.key_b = source.key_b, + ) + ) + ); + SELECT 'name' AS name, 1 AS key_a, 2 AS key_b; + """ + ) + model = load_sql_based_model(expressions, dialect="snowflake") + + assert isinstance(model.kind, IncrementalByUniqueKeyKind) + assert isinstance(model.kind.when_matched, exp.Whens) + first_expression = model.kind.when_matched.expressions[0] + assert isinstance(first_expression, exp.Expression) + assert ( + first_expression.sql(dialect="snowflake") + == 'WHEN MATCHED THEN UPDATE SET "__MERGE_TARGET__"."KEY_A" = "__MERGE_SOURCE__"."KEY_A", "__MERGE_TARGET__"."KEY_B" = "__MERGE_SOURCE__"."KEY_B"' + ) + + # quoted should be preserved + expressions = d.parse( + """ + MODEL ( + name test.employees, + kind INCREMENTAL_BY_UNIQUE_KEY ( + unique_key name, + when_matched ( + WHEN MATCHED THEN UPDATE SET + target."kEy_A" = source."kEy_A", + target."kEY_b" = source.key_b, + ) + ) + ); + SELECT 'name' AS name, 1 AS "kEy_A", 2 AS "kEY_b"; + """ + ) + model = load_sql_based_model(expressions, dialect="snowflake") + + assert isinstance(model.kind, IncrementalByUniqueKeyKind) + assert isinstance(model.kind.when_matched, exp.Whens) + first_expression = model.kind.when_matched.expressions[0] + assert isinstance(first_expression, exp.Expression) + assert ( + first_expression.sql(dialect="snowflake") + == 'WHEN MATCHED THEN UPDATE SET "__MERGE_TARGET__"."kEy_A" = "__MERGE_SOURCE__"."kEy_A", "__MERGE_TARGET__"."kEY_b" = "__MERGE_SOURCE__"."KEY_B"' + ) def test_default_catalog_sql(assert_exp_eq): @@ -6492,11 +6550,11 @@ def model_with_macros(evaluator, **kwargs): == "@IF(@gateway = 'dev', @'hdfs://@{catalog_name}/@{schema_name}/dev/@{table_name}', @'s3://prod/@{table_name}')" ) - # Merge_filter will stay unrendered as well + # merge_filter will stay unrendered as well assert python_sql_model.unique_key[0] == exp.column("a", quoted=True) assert ( python_sql_model.merge_filter.sql() - == "source.id > 0 AND target.updated_at < @end_ds AND source.updated_at > @start_ds" + == '"source"."id" > 0 AND "target"."updated_at" < @end_ds AND "source"."updated_at" > @start_ds' ) @@ -7583,7 +7641,7 @@ def test_model_kind_to_expression(): .sql() == """INCREMENTAL_BY_UNIQUE_KEY ( unique_key ("a"), -when_matched (WHEN MATCHED THEN UPDATE SET __MERGE_TARGET__.b = COALESCE(__MERGE_SOURCE__.b, __MERGE_TARGET__.b)), +when_matched (WHEN MATCHED THEN UPDATE SET "__merge_target__"."b" = COALESCE("__merge_source__"."b", "__merge_target__"."b")), batch_concurrency 1, forward_only FALSE, disable_restatement FALSE, @@ -7611,7 +7669,7 @@ def test_model_kind_to_expression(): .sql() == """INCREMENTAL_BY_UNIQUE_KEY ( unique_key ("a"), -when_matched (WHEN MATCHED AND __MERGE_SOURCE__.x = 1 THEN UPDATE SET __MERGE_TARGET__.b = COALESCE(__MERGE_SOURCE__.b, __MERGE_TARGET__.b) WHEN MATCHED THEN UPDATE SET __MERGE_TARGET__.b = COALESCE(__MERGE_SOURCE__.b, __MERGE_TARGET__.b)), +when_matched (WHEN MATCHED AND "__merge_source__"."x" = 1 THEN UPDATE SET "__merge_target__"."b" = COALESCE("__merge_source__"."b", "__merge_target__"."b") WHEN MATCHED THEN UPDATE SET "__merge_target__"."b" = COALESCE("__merge_source__"."b", "__merge_target__"."b")), batch_concurrency 1, forward_only FALSE, disable_restatement FALSE, @@ -7872,13 +7930,14 @@ def test_merge_filter(): """ ) - expected_incremental_predicate = f"{MERGE_SOURCE_ALIAS}.salary > 0" + expected_incremental_predicate = f"`{MERGE_SOURCE_ALIAS.lower()}`.`salary` > 0" model = load_sql_based_model(expressions, dialect="hive") - assert model.kind.merge_filter.sql() == expected_incremental_predicate + assert model.kind.merge_filter.sql(dialect="hive") == expected_incremental_predicate model = SqlModel.parse_raw(model.json()) - assert model.kind.merge_filter.sql() == expected_incremental_predicate + assert model.kind.merge_filter.sql(dialect="hive") == expected_incremental_predicate + assert model.dialect == "hive" expressions = d.parse( """ @@ -7894,7 +7953,7 @@ def test_merge_filter(): source.ds > (SELECT MAX(ds) FROM db.test) AND source.ds > @start_ds AND source._operation <> 1 AND - target.start_date > dateadd(day, -7, current_date) + target.start_date > date_add(current_date, interval 7 day) ) ) ); @@ -7906,26 +7965,27 @@ def test_merge_filter(): """ ) - model = SqlModel.parse_raw(load_sql_based_model(expressions).json()) - assert d.format_model_expressions(model.render_definition()) == ( + model = SqlModel.parse_raw(load_sql_based_model(expressions, dialect="duckdb").json()) + assert d.format_model_expressions(model.render_definition(), dialect=model.dialect) == ( f"""MODEL ( name db.test, + dialect duckdb, kind INCREMENTAL_BY_UNIQUE_KEY ( unique_key ("purchase_order_id"), when_matched ( - WHEN MATCHED AND {MERGE_SOURCE_ALIAS}._operation = 1 THEN DELETE - WHEN MATCHED AND {MERGE_SOURCE_ALIAS}._operation <> 1 THEN UPDATE SET - {MERGE_TARGET_ALIAS}.purchase_order_id = 1 + WHEN MATCHED AND "{MERGE_SOURCE_ALIAS.lower()}"."_operation" = 1 THEN DELETE + WHEN MATCHED AND "{MERGE_SOURCE_ALIAS.lower()}"."_operation" <> 1 THEN UPDATE SET + "{MERGE_TARGET_ALIAS.lower()}"."purchase_order_id" = 1 ), merge_filter ( - {MERGE_SOURCE_ALIAS}.ds > ( + "{MERGE_SOURCE_ALIAS.lower()}"."ds" > ( SELECT - MAX(ds) - FROM db.test + MAX("ds") + FROM "db"."test" ) - AND {MERGE_SOURCE_ALIAS}.ds > @start_ds - AND {MERGE_SOURCE_ALIAS}._operation <> 1 - AND {MERGE_TARGET_ALIAS}.start_date > DATEADD(day, -7, CURRENT_DATE) + AND "{MERGE_SOURCE_ALIAS.lower()}"."ds" > @start_ds + AND "{MERGE_SOURCE_ALIAS.lower()}"."_operation" <> 1 + AND "{MERGE_TARGET_ALIAS.lower()}"."start_date" > CURRENT_DATE + INTERVAL '7' DAY ), batch_concurrency 1, forward_only FALSE, @@ -7942,10 +8002,46 @@ def test_merge_filter(): rendered_merge_filters = model.render_merge_filter(start="2023-01-01", end="2023-01-02") assert ( - rendered_merge_filters.sql() - == "(__MERGE_SOURCE__.ds > (SELECT MAX(ds) FROM db.test) AND __MERGE_SOURCE__.ds > '2023-01-01' AND __MERGE_SOURCE__._operation <> 1 AND __MERGE_TARGET__.start_date > DATEADD(day, -7, CURRENT_DATE))" + rendered_merge_filters.sql(dialect="hive") + == "(`__merge_source__`.`ds` > (SELECT MAX(`ds`) FROM `db`.`test`) AND `__merge_source__`.`ds` > '2023-01-01' AND `__merge_source__`.`_operation` <> 1 AND `__merge_target__`.`start_date` > CURRENT_DATE + INTERVAL '7' DAY)" + ) + + +def test_merge_filter_normalization(): + # unquoted gets normalized and quoted + expressions = d.parse( + """ + MODEL ( + name db.employees, + kind INCREMENTAL_BY_UNIQUE_KEY ( + unique_key name, + merge_filter source.salary > 0 + ) + ); + SELECT 'name' AS name, 1 AS salary; + """ ) + model = load_sql_based_model(expressions, dialect="snowflake") + assert model.merge_filter.sql(dialect="snowflake") == '"__MERGE_SOURCE__"."SALARY" > 0' + + # quoted gets preserved + expressions = d.parse( + """ + MODEL ( + name db.employees, + kind INCREMENTAL_BY_UNIQUE_KEY ( + unique_key name, + merge_filter source."SaLArY" > 0 + ) + ); + SELECT 'name' AS name, 1 AS "SaLArY"; + """ + ) + + model = load_sql_based_model(expressions, dialect="snowflake") + assert model.merge_filter.sql(dialect="snowflake") == '"__MERGE_SOURCE__"."SaLArY" > 0' + def test_merge_filter_macro(): @macro() @@ -7969,19 +8065,20 @@ def predicate( """ ) - unrendered_merge_filter = ( - f"@predicate(update_datetime) AND {MERGE_TARGET_ALIAS}.update_datetime > @start_dt" + unrendered_merge_filter = f"""@predicate("UPDATE_DATETIME") AND "{MERGE_TARGET_ALIAS}"."UPDATE_DATETIME" > @start_dt""" + expected_merge_filter = ( + f"""\"{MERGE_SOURCE_ALIAS}"."UPDATE_DATETIME" > DATEADD(DAY, -7, "{MERGE_TARGET_ALIAS}"."UPDATE_DATETIME") """ + f"""AND "{MERGE_TARGET_ALIAS}"."UPDATE_DATETIME" > CAST('2023-01-01 15:00:00+00:00' AS TIMESTAMPTZ)""" ) - expected_merge_filter = f"{MERGE_SOURCE_ALIAS}.UPDATE_DATETIME > DATE_ADD({MERGE_TARGET_ALIAS}.UPDATE_DATETIME, -7, 'DAY') AND {MERGE_TARGET_ALIAS}.UPDATE_DATETIME > CAST('2023-01-01 15:00:00+00:00' AS TIMESTAMPTZ)" model = load_sql_based_model(expressions, dialect="snowflake") - assert model.kind.merge_filter.sql() == unrendered_merge_filter + assert model.kind.merge_filter.sql(dialect=model.dialect) == unrendered_merge_filter model = SqlModel.parse_raw(model.json()) - assert model.kind.merge_filter.sql() == unrendered_merge_filter + assert model.kind.merge_filter.sql(dialect=model.dialect) == unrendered_merge_filter rendered_merge_filters = model.render_merge_filter(start="2023-01-01 15:00:00") - assert rendered_merge_filters.sql() == expected_merge_filter + assert rendered_merge_filters.sql(dialect=model.dialect) == expected_merge_filter @pytest.mark.parametrize( diff --git a/tests/core/test_snapshot_evaluator.py b/tests/core/test_snapshot_evaluator.py index 3704c192bd..e4de741cdb 100644 --- a/tests/core/test_snapshot_evaluator.py +++ b/tests/core/test_snapshot_evaluator.py @@ -2211,13 +2211,19 @@ def test_create_incremental_by_unique_key_updated_at_exp(adapter_mock, make_snap source=False, then=exp.Update( expressions=[ - exp.column("name", MERGE_TARGET_ALIAS).eq( - exp.column("name", MERGE_SOURCE_ALIAS) + exp.column("name", MERGE_TARGET_ALIAS.lower(), quoted=True).eq( + exp.column("name", MERGE_SOURCE_ALIAS.lower(), quoted=True) ), - exp.column("updated_at", MERGE_TARGET_ALIAS).eq( + exp.column("updated_at", MERGE_TARGET_ALIAS.lower(), quoted=True).eq( exp.Coalesce( - this=exp.column("updated_at", MERGE_SOURCE_ALIAS), - expressions=[exp.column("updated_at", MERGE_TARGET_ALIAS)], + this=exp.column( + "updated_at", MERGE_SOURCE_ALIAS.lower(), quoted=True + ), + expressions=[ + exp.column( + "updated_at", MERGE_TARGET_ALIAS.lower(), quoted=True + ) + ], ) ), ], @@ -2273,16 +2279,24 @@ def test_create_incremental_by_unique_key_multiple_updated_at_exp(adapter_mock, expressions=[ exp.When( matched=True, - condition=exp.column("id", MERGE_SOURCE_ALIAS).eq(exp.Literal.number(1)), + condition=exp.column("id", MERGE_SOURCE_ALIAS.lower(), quoted=True).eq( + exp.Literal.number(1) + ), then=exp.Update( expressions=[ - exp.column("name", MERGE_TARGET_ALIAS).eq( - exp.column("name", MERGE_SOURCE_ALIAS) + exp.column("name", MERGE_TARGET_ALIAS.lower(), quoted=True).eq( + exp.column("name", MERGE_SOURCE_ALIAS.lower(), quoted=True) ), - exp.column("updated_at", MERGE_TARGET_ALIAS).eq( + exp.column("updated_at", MERGE_TARGET_ALIAS.lower(), quoted=True).eq( exp.Coalesce( - this=exp.column("updated_at", MERGE_SOURCE_ALIAS), - expressions=[exp.column("updated_at", MERGE_TARGET_ALIAS)], + this=exp.column( + "updated_at", MERGE_SOURCE_ALIAS.lower(), quoted=True + ), + expressions=[ + exp.column( + "updated_at", MERGE_TARGET_ALIAS.lower(), quoted=True + ) + ], ) ), ], @@ -2293,13 +2307,19 @@ def test_create_incremental_by_unique_key_multiple_updated_at_exp(adapter_mock, source=False, then=exp.Update( expressions=[ - exp.column("name", MERGE_TARGET_ALIAS).eq( - exp.column("name", MERGE_SOURCE_ALIAS) + exp.column("name", MERGE_TARGET_ALIAS.lower(), quoted=True).eq( + exp.column("name", MERGE_SOURCE_ALIAS.lower(), quoted=True) ), - exp.column("updated_at", MERGE_TARGET_ALIAS).eq( + exp.column("updated_at", MERGE_TARGET_ALIAS.lower(), quoted=True).eq( exp.Coalesce( - this=exp.column("updated_at", MERGE_SOURCE_ALIAS), - expressions=[exp.column("updated_at", MERGE_TARGET_ALIAS)], + this=exp.column( + "updated_at", MERGE_SOURCE_ALIAS.lower(), quoted=True + ), + expressions=[ + exp.column( + "updated_at", MERGE_TARGET_ALIAS.lower(), quoted=True + ) + ], ) ), ], @@ -2384,16 +2404,16 @@ def test_create_incremental_by_unique_key_merge_filter(adapter_mock, make_snapsh assert model.merge_filter == exp.And( this=exp.And( this=exp.GT( - this=exp.column("id", MERGE_SOURCE_ALIAS), + this=exp.column("id", MERGE_SOURCE_ALIAS.lower(), quoted=True), expression=exp.Literal(this="0", is_string=False), ), expression=exp.LT( - this=exp.column("updated_at", MERGE_TARGET_ALIAS), + this=exp.column("updated_at", MERGE_TARGET_ALIAS.lower(), quoted=True), expression=d.MacroVar(this="end_ds"), ), ), expression=exp.GT( - this=exp.column("updated_at", MERGE_SOURCE_ALIAS), + this=exp.column("updated_at", MERGE_SOURCE_ALIAS.lower(), quoted=True), expression=d.MacroVar(this="start_ds"), ), ) @@ -2425,10 +2445,16 @@ def test_create_incremental_by_unique_key_merge_filter(adapter_mock, make_snapsh matched=True, then=exp.Update( expressions=[ - exp.column("updated_at", MERGE_TARGET_ALIAS).eq( + exp.column("updated_at", MERGE_TARGET_ALIAS.lower(), quoted=True).eq( exp.Coalesce( - this=exp.column("updated_at", MERGE_SOURCE_ALIAS), - expressions=[exp.column("updated_at", MERGE_TARGET_ALIAS)], + this=exp.column( + "updated_at", MERGE_SOURCE_ALIAS.lower(), quoted=True + ), + expressions=[ + exp.column( + "updated_at", MERGE_TARGET_ALIAS.lower(), quoted=True + ) + ], ) ), ], @@ -2439,16 +2465,16 @@ def test_create_incremental_by_unique_key_merge_filter(adapter_mock, make_snapsh merge_filter=exp.And( this=exp.And( this=exp.GT( - this=exp.column("id", MERGE_SOURCE_ALIAS), + this=exp.column("id", MERGE_SOURCE_ALIAS.lower(), quoted=True), expression=exp.Literal(this="0", is_string=False), ), expression=exp.LT( - this=exp.column("updated_at", MERGE_TARGET_ALIAS), + this=exp.column("updated_at", MERGE_TARGET_ALIAS.lower(), quoted=True), expression=exp.Literal(this="2020-01-02", is_string=True), ), ), expression=exp.GT( - this=exp.column("updated_at", MERGE_SOURCE_ALIAS), + this=exp.column("updated_at", MERGE_SOURCE_ALIAS.lower(), quoted=True), expression=exp.Literal(this="2020-01-01", is_string=True), ), ), diff --git a/tests/dbt/test_config.py b/tests/dbt/test_config.py index bb8806c657..42172f53e4 100644 --- a/tests/dbt/test_config.py +++ b/tests/dbt/test_config.py @@ -97,7 +97,7 @@ def test_model_to_sqlmesh_fields(): cluster_by=["a", '"b"'], incremental_predicates=[ "55 > DBT_INTERNAL_SOURCE.b", - "DBT_INTERNAL_DEST.session_start > dateadd(day, -7, current_date)", + "DBT_INTERNAL_DEST.session_start > date_add(current_date, interval 7 day)", ], cron="@hourly", interval_unit="FIVE_MINUTE", @@ -135,8 +135,8 @@ def test_model_to_sqlmesh_fields(): assert kind.lookback == 3 assert kind.on_destructive_change == OnDestructiveChange.ALLOW assert ( - kind.merge_filter.sql() - == "55 > __MERGE_SOURCE__.b AND __MERGE_TARGET__.session_start > DATEADD(day, -7, CURRENT_DATE)" + kind.merge_filter.sql(dialect=model.dialect) + == """55 > "__merge_source__"."b" AND "__merge_target__"."session_start" > CURRENT_DATE + INTERVAL '7' DAY""" ) model = model_config.update_with({"dialect": "snowflake"}).to_sqlmesh(context)