Skip to content

Commit 9c5f889

Browse files
committed
Fix: Normalize when_matched and merge_filter expressions to the source dialect
1 parent f43a4c3 commit 9c5f889

File tree

6 files changed

+186
-57
lines changed

6 files changed

+186
-57
lines changed

sqlmesh/core/dialect.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1388,17 +1388,26 @@ def is_meta_expression(v: t.Any) -> bool:
13881388
return isinstance(v, (Audit, Metric, Model))
13891389

13901390

1391-
def replace_merge_table_aliases(expression: exp.Expression) -> exp.Expression:
1391+
def replace_merge_table_aliases(
1392+
expression: exp.Expression, dialect: t.Optional[str] = None
1393+
) -> exp.Expression:
13921394
"""
13931395
Resolves references from the "source" and "target" tables (or their DBT equivalents)
13941396
with the corresponding SQLMesh merge aliases (MERGE_SOURCE_ALIAS and MERGE_TARGET_ALIAS)
13951397
"""
13961398
from sqlmesh.core.engine_adapter.base import MERGE_SOURCE_ALIAS, MERGE_TARGET_ALIAS
13971399

1400+
normalized_merge_source_alias = normalize_identifiers(
1401+
exp.to_identifier(MERGE_SOURCE_ALIAS), dialect
1402+
)
1403+
normalized_merge_target_alias = normalize_identifiers(
1404+
exp.to_identifier(MERGE_TARGET_ALIAS), dialect
1405+
)
1406+
13981407
if isinstance(expression, exp.Column) and (first_part := expression.parts[0]):
13991408
if first_part.this.lower() in ("target", "dbt_internal_dest", "__merge_target__"):
1400-
first_part.replace(exp.to_identifier(MERGE_TARGET_ALIAS))
1409+
first_part.replace(normalized_merge_target_alias)
14011410
elif first_part.this.lower() in ("source", "dbt_internal_source", "__merge_source__"):
1402-
first_part.replace(exp.to_identifier(MERGE_SOURCE_ALIAS))
1411+
first_part.replace(normalized_merge_source_alias)
14031412

14041413
return expression

sqlmesh/core/model/definition.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -653,7 +653,7 @@ def render_merge_filter(
653653
)
654654
if len(rendered_exprs) != 1:
655655
raise SQLMeshError(f"Expected one expression but got {len(rendered_exprs)}")
656-
return rendered_exprs[0].transform(d.replace_merge_table_aliases)
656+
return rendered_exprs[0].transform(d.replace_merge_table_aliases, dialect=self.dialect)
657657

658658
def _render_properties(
659659
self, properties: t.Dict[str, exp.Expression] | SessionProperties, **render_kwargs: t.Any

sqlmesh/core/model/kind.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -468,15 +468,20 @@ def _when_matched_validator(
468468
return v
469469
if isinstance(v, list):
470470
v = " ".join(v)
471+
472+
dialect = get_dialect(info.data)
473+
471474
if isinstance(v, str):
472475
# Whens wrap the WHEN clauses, but the parentheses aren't parsed by sqlglot
473476
v = v.strip()
474477
if v.startswith("("):
475478
v = v[1:-1]
476479

477-
return t.cast(exp.Whens, d.parse_one(v, into=exp.Whens, dialect=get_dialect(info.data)))
480+
v = t.cast(exp.Whens, d.parse_one(v, into=exp.Whens, dialect=dialect))
481+
else:
482+
v = t.cast(exp.Whens, v.transform(d.replace_merge_table_aliases, dialect=dialect))
478483

479-
return t.cast(exp.Whens, v.transform(d.replace_merge_table_aliases))
484+
return normalize_identifiers(v, dialect=dialect)
480485

481486
@field_validator("merge_filter", mode="before")
482487
def _merge_filter_validator(
@@ -486,11 +491,16 @@ def _merge_filter_validator(
486491
) -> t.Optional[exp.Expression]:
487492
if v is None:
488493
return v
494+
495+
dialect = get_dialect(info.data)
496+
489497
if isinstance(v, str):
490498
v = v.strip()
491-
return d.parse_one(v, dialect=get_dialect(info.data))
499+
v = d.parse_one(v, dialect=dialect)
500+
else:
501+
v = v.transform(d.replace_merge_table_aliases, dialect=dialect)
492502

493-
return v.transform(d.replace_merge_table_aliases)
503+
return normalize_identifiers(v, dialect=dialect)
494504

495505
@property
496506
def data_hash_values(self) -> t.List[t.Optional[str]]:

tests/core/test_model.py

Lines changed: 123 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -5366,7 +5366,7 @@ def test_when_matched():
53665366
"""
53675367
)
53685368

5369-
expected_when_matched = "(WHEN MATCHED THEN UPDATE SET __MERGE_TARGET__.salary = COALESCE(__MERGE_SOURCE__.salary, __MERGE_TARGET__.salary))"
5369+
expected_when_matched = "(WHEN MATCHED THEN UPDATE SET __merge_target__.salary = COALESCE(__merge_source__.salary, __merge_target__.salary))"
53705370

53715371
model = load_sql_based_model(expressions, dialect="hive")
53725372
assert model.kind.when_matched.sql() == expected_when_matched
@@ -5400,9 +5400,9 @@ def test_when_matched():
54005400
kind INCREMENTAL_BY_UNIQUE_KEY (
54015401
unique_key ("purchase_order_id"),
54025402
when_matched (
5403-
WHEN MATCHED AND __MERGE_SOURCE__._operation = 1 THEN DELETE
5404-
WHEN MATCHED AND __MERGE_SOURCE__._operation <> 1 THEN UPDATE SET
5405-
__MERGE_TARGET__.purchase_order_id = 1
5403+
WHEN MATCHED AND __merge_source__._operation = 1 THEN DELETE
5404+
WHEN MATCHED AND __merge_source__._operation <> 1 THEN UPDATE SET
5405+
__merge_target__.purchase_order_id = 1
54065406
),
54075407
batch_concurrency 1,
54085408
forward_only FALSE,
@@ -5453,7 +5453,7 @@ def fingerprint_merge(
54535453
kind INCREMENTAL_BY_UNIQUE_KEY (
54545454
unique_key ("purchase_order_id"),
54555455
when_matched (
5456-
WHEN MATCHED AND __MERGE_SOURCE__.salary <> __MERGE_TARGET__.salary THEN UPDATE SET
5456+
WHEN MATCHED AND __merge_source__.salary <> __merge_target__.salary THEN UPDATE SET
54575457
ARRAY('target.update_datetime = source.update_datetime', 'target.salary = source.salary')
54585458
),
54595459
batch_concurrency 1,
@@ -5487,8 +5487,8 @@ def test_when_matched_multiple():
54875487
)
54885488

54895489
expected_when_matched = [
5490-
"WHEN MATCHED AND __MERGE_SOURCE__.x = 1 THEN UPDATE SET __MERGE_TARGET__.salary = COALESCE(__MERGE_SOURCE__.salary, __MERGE_TARGET__.salary)",
5491-
"WHEN MATCHED THEN UPDATE SET __MERGE_TARGET__.salary = COALESCE(__MERGE_SOURCE__.salary, __MERGE_TARGET__.salary)",
5490+
"WHEN MATCHED AND __merge_source__.x = 1 THEN UPDATE SET __merge_target__.salary = COALESCE(__merge_source__.salary, __merge_target__.salary)",
5491+
"WHEN MATCHED THEN UPDATE SET __merge_target__.salary = COALESCE(__merge_source__.salary, __merge_target__.salary)",
54925492
]
54935493

54945494
model = load_sql_based_model(expressions, dialect="hive", variables={"schema": "db"})
@@ -5529,13 +5529,13 @@ def test_when_matched_merge_filter_multi_part_columns():
55295529
)
55305530

55315531
expected_when_matched = [
5532-
"WHEN MATCHED AND __MERGE_SOURCE__.record.nested_record.field = 1 THEN UPDATE SET __MERGE_TARGET__.repeated_record.sub_repeated_record.sub_field = COALESCE(__MERGE_SOURCE__.repeated_record.sub_repeated_record.sub_field, __MERGE_TARGET__.repeated_record.sub_repeated_record.sub_field)",
5533-
"WHEN MATCHED THEN UPDATE SET __MERGE_TARGET__.repeated_record.sub_repeated_record.sub_field = COALESCE(__MERGE_SOURCE__.repeated_record.sub_repeated_record.sub_field, __MERGE_TARGET__.repeated_record.sub_repeated_record.sub_field)",
5532+
"WHEN MATCHED AND __merge_source__.record.nested_record.field = 1 THEN UPDATE SET __merge_target__.repeated_record.sub_repeated_record.sub_field = COALESCE(__merge_source__.repeated_record.sub_repeated_record.sub_field, __merge_target__.repeated_record.sub_repeated_record.sub_field)",
5533+
"WHEN MATCHED THEN UPDATE SET __merge_target__.repeated_record.sub_repeated_record.sub_field = COALESCE(__merge_source__.repeated_record.sub_repeated_record.sub_field, __merge_target__.repeated_record.sub_repeated_record.sub_field)",
55345534
]
55355535

55365536
expected_merge_filter = (
5537-
"__MERGE_SOURCE__.record.nested_record.field < __MERGE_TARGET__.record.nested_record.field AND "
5538-
"__MERGE_TARGET__.repeated_record.sub_repeated_record.sub_field > __MERGE_SOURCE__.repeated_record.sub_repeated_record.sub_field"
5537+
"__merge_source__.record.nested_record.field < __merge_target__.record.nested_record.field AND "
5538+
"__merge_target__.repeated_record.sub_repeated_record.sub_field > __merge_source__.repeated_record.sub_repeated_record.sub_field"
55395539
)
55405540

55415541
model = load_sql_based_model(expressions, dialect="bigquery", variables={"schema": "db"})
@@ -5553,6 +5553,64 @@ def test_when_matched_merge_filter_multi_part_columns():
55535553
assert model.merge_filter.sql() == expected_merge_filter
55545554

55555555

5556+
def test_when_matched_normalization() -> None:
5557+
# unquoted should be normalized
5558+
expressions = d.parse(
5559+
"""
5560+
MODEL (
5561+
name test.employees,
5562+
kind INCREMENTAL_BY_UNIQUE_KEY (
5563+
unique_key name,
5564+
when_matched (
5565+
WHEN MATCHED THEN UPDATE SET
5566+
target.key_a = source.key_a,
5567+
target.key_b = source.key_b,
5568+
)
5569+
)
5570+
);
5571+
SELECT 'name' AS name, 1 AS key_a, 2 AS key_b;
5572+
"""
5573+
)
5574+
model = load_sql_based_model(expressions, dialect="snowflake")
5575+
5576+
assert isinstance(model.kind, IncrementalByUniqueKeyKind)
5577+
assert isinstance(model.kind.when_matched, exp.Whens)
5578+
first_expression = model.kind.when_matched.expressions[0]
5579+
assert isinstance(first_expression, exp.Expression)
5580+
assert (
5581+
first_expression.sql(dialect="snowflake", identify=True)
5582+
== 'WHEN MATCHED THEN UPDATE SET "__MERGE_TARGET__"."KEY_A" = "__MERGE_SOURCE__"."KEY_A", "__MERGE_TARGET__"."KEY_B" = "__MERGE_SOURCE__"."KEY_B"'
5583+
)
5584+
5585+
# quoted should be preserved
5586+
expressions = d.parse(
5587+
"""
5588+
MODEL (
5589+
name test.employees,
5590+
kind INCREMENTAL_BY_UNIQUE_KEY (
5591+
unique_key name,
5592+
when_matched (
5593+
WHEN MATCHED THEN UPDATE SET
5594+
target."kEy_A" = source."kEy_A",
5595+
target."kEY_b" = source.key_b,
5596+
)
5597+
)
5598+
);
5599+
SELECT 'name' AS name, 1 AS "kEy_A", 2 AS "kEY_b";
5600+
"""
5601+
)
5602+
model = load_sql_based_model(expressions, dialect="snowflake")
5603+
5604+
assert isinstance(model.kind, IncrementalByUniqueKeyKind)
5605+
assert isinstance(model.kind.when_matched, exp.Whens)
5606+
first_expression = model.kind.when_matched.expressions[0]
5607+
assert isinstance(first_expression, exp.Expression)
5608+
assert (
5609+
first_expression.sql(dialect="snowflake", identify=True)
5610+
== 'WHEN MATCHED THEN UPDATE SET "__MERGE_TARGET__"."kEy_A" = "__MERGE_SOURCE__"."kEy_A", "__MERGE_TARGET__"."kEY_b" = "__MERGE_SOURCE__"."KEY_B"'
5611+
)
5612+
5613+
55565614
def test_default_catalog_sql(assert_exp_eq):
55575615
"""
55585616
This test validates the hashing behavior of the system as it relates to the default catalog.
@@ -7583,7 +7641,7 @@ def test_model_kind_to_expression():
75837641
.sql()
75847642
== """INCREMENTAL_BY_UNIQUE_KEY (
75857643
unique_key ("a"),
7586-
when_matched (WHEN MATCHED THEN UPDATE SET __MERGE_TARGET__.b = COALESCE(__MERGE_SOURCE__.b, __MERGE_TARGET__.b)),
7644+
when_matched (WHEN MATCHED THEN UPDATE SET __merge_target__.b = COALESCE(__merge_source__.b, __merge_target__.b)),
75877645
batch_concurrency 1,
75887646
forward_only FALSE,
75897647
disable_restatement FALSE,
@@ -7611,7 +7669,7 @@ def test_model_kind_to_expression():
76117669
.sql()
76127670
== """INCREMENTAL_BY_UNIQUE_KEY (
76137671
unique_key ("a"),
7614-
when_matched (WHEN MATCHED AND __MERGE_SOURCE__.x = 1 THEN UPDATE SET __MERGE_TARGET__.b = COALESCE(__MERGE_SOURCE__.b, __MERGE_TARGET__.b) WHEN MATCHED THEN UPDATE SET __MERGE_TARGET__.b = COALESCE(__MERGE_SOURCE__.b, __MERGE_TARGET__.b)),
7672+
when_matched (WHEN MATCHED AND __merge_source__.x = 1 THEN UPDATE SET __merge_target__.b = COALESCE(__merge_source__.b, __merge_target__.b) WHEN MATCHED THEN UPDATE SET __merge_target__.b = COALESCE(__merge_source__.b, __merge_target__.b)),
76157673
batch_concurrency 1,
76167674
forward_only FALSE,
76177675
disable_restatement FALSE,
@@ -7872,7 +7930,7 @@ def test_merge_filter():
78727930
"""
78737931
)
78747932

7875-
expected_incremental_predicate = f"{MERGE_SOURCE_ALIAS}.salary > 0"
7933+
expected_incremental_predicate = f"{MERGE_SOURCE_ALIAS.lower()}.salary > 0"
78767934

78777935
model = load_sql_based_model(expressions, dialect="hive")
78787936
assert model.kind.merge_filter.sql() == expected_incremental_predicate
@@ -7913,19 +7971,19 @@ def test_merge_filter():
79137971
kind INCREMENTAL_BY_UNIQUE_KEY (
79147972
unique_key ("purchase_order_id"),
79157973
when_matched (
7916-
WHEN MATCHED AND {MERGE_SOURCE_ALIAS}._operation = 1 THEN DELETE
7917-
WHEN MATCHED AND {MERGE_SOURCE_ALIAS}._operation <> 1 THEN UPDATE SET
7918-
{MERGE_TARGET_ALIAS}.purchase_order_id = 1
7974+
WHEN MATCHED AND {MERGE_SOURCE_ALIAS.lower()}._operation = 1 THEN DELETE
7975+
WHEN MATCHED AND {MERGE_SOURCE_ALIAS.lower()}._operation <> 1 THEN UPDATE SET
7976+
{MERGE_TARGET_ALIAS.lower()}.purchase_order_id = 1
79197977
),
79207978
merge_filter (
7921-
{MERGE_SOURCE_ALIAS}.ds > (
7979+
{MERGE_SOURCE_ALIAS.lower()}.ds > (
79227980
SELECT
79237981
MAX(ds)
79247982
FROM db.test
79257983
)
7926-
AND {MERGE_SOURCE_ALIAS}.ds > @start_ds
7927-
AND {MERGE_SOURCE_ALIAS}._operation <> 1
7928-
AND {MERGE_TARGET_ALIAS}.start_date > DATEADD(day, -7, CURRENT_DATE)
7984+
AND {MERGE_SOURCE_ALIAS.lower()}.ds > @start_ds
7985+
AND {MERGE_SOURCE_ALIAS.lower()}._operation <> 1
7986+
AND {MERGE_TARGET_ALIAS.lower()}.start_date > DATEADD(day, -7, CURRENT_DATE)
79297987
),
79307988
batch_concurrency 1,
79317989
forward_only FALSE,
@@ -7943,7 +8001,49 @@ def test_merge_filter():
79438001
rendered_merge_filters = model.render_merge_filter(start="2023-01-01", end="2023-01-02")
79448002
assert (
79458003
rendered_merge_filters.sql()
7946-
== "(__MERGE_SOURCE__.ds > (SELECT MAX(ds) FROM db.test) AND __MERGE_SOURCE__.ds > '2023-01-01' AND __MERGE_SOURCE__._operation <> 1 AND __MERGE_TARGET__.start_date > DATEADD(day, -7, CURRENT_DATE))"
8004+
== "(__merge_source__.ds > (SELECT MAX(ds) FROM db.test) AND __merge_source__.ds > '2023-01-01' AND __merge_source__._operation <> 1 AND __merge_target__.start_date > DATEADD(day, -7, CURRENT_DATE))"
8005+
)
8006+
8007+
8008+
def test_merge_filter_normalization():
8009+
# unquoted gets normalized
8010+
expressions = d.parse(
8011+
"""
8012+
MODEL (
8013+
name db.employees,
8014+
kind INCREMENTAL_BY_UNIQUE_KEY (
8015+
unique_key name,
8016+
merge_filter source.salary > 0
8017+
)
8018+
);
8019+
SELECT 'name' AS name, 1 AS salary;
8020+
"""
8021+
)
8022+
8023+
model = load_sql_based_model(expressions, dialect="snowflake")
8024+
assert (
8025+
model.merge_filter.sql(dialect="snowflake", identify=True)
8026+
== '"__MERGE_SOURCE__"."SALARY" > 0'
8027+
)
8028+
8029+
# quoted gets preserved
8030+
expressions = d.parse(
8031+
"""
8032+
MODEL (
8033+
name db.employees,
8034+
kind INCREMENTAL_BY_UNIQUE_KEY (
8035+
unique_key name,
8036+
merge_filter source."SaLArY" > 0
8037+
)
8038+
);
8039+
SELECT 'name' AS name, 1 AS "SaLArY";
8040+
"""
8041+
)
8042+
8043+
model = load_sql_based_model(expressions, dialect="snowflake")
8044+
assert (
8045+
model.merge_filter.sql(dialect="snowflake", identify=True)
8046+
== '"__MERGE_SOURCE__"."SaLArY" > 0'
79478047
)
79488048

79498049

@@ -7970,7 +8070,7 @@ def predicate(
79708070
)
79718071

79728072
unrendered_merge_filter = (
7973-
f"@predicate(update_datetime) AND {MERGE_TARGET_ALIAS}.update_datetime > @start_dt"
8073+
f"@predicate(UPDATE_DATETIME) AND {MERGE_TARGET_ALIAS}.UPDATE_DATETIME > @start_dt"
79748074
)
79758075
expected_merge_filter = f"{MERGE_SOURCE_ALIAS}.UPDATE_DATETIME > DATE_ADD({MERGE_TARGET_ALIAS}.UPDATE_DATETIME, -7, 'DAY') AND {MERGE_TARGET_ALIAS}.UPDATE_DATETIME > CAST('2023-01-01 15:00:00+00:00' AS TIMESTAMPTZ)"
79768076

0 commit comments

Comments
 (0)