Skip to content

Commit adf6a68

Browse files
Fix!: Dont normalize aliases in merge and when matched (#5014)
1 parent eb4c0b4 commit adf6a68

File tree

6 files changed

+184
-83
lines changed

6 files changed

+184
-83
lines changed

sqlmesh/core/dialect.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1421,18 +1421,10 @@ def replace_merge_table_aliases(
14211421
"""
14221422
from sqlmesh.core.engine_adapter.base import MERGE_SOURCE_ALIAS, MERGE_TARGET_ALIAS
14231423

1424-
normalized_merge_source_alias = quote_identifiers(
1425-
normalize_identifiers(exp.to_identifier(MERGE_SOURCE_ALIAS), dialect), dialect=dialect
1426-
)
1427-
1428-
normalized_merge_target_alias = quote_identifiers(
1429-
normalize_identifiers(exp.to_identifier(MERGE_TARGET_ALIAS), dialect), dialect=dialect
1430-
)
1431-
14321424
if isinstance(expression, exp.Column) and (first_part := expression.parts[0]):
14331425
if first_part.this.lower() in ("target", "dbt_internal_dest", "__merge_target__"):
1434-
first_part.replace(normalized_merge_target_alias)
1426+
first_part.replace(exp.to_identifier(MERGE_TARGET_ALIAS, quoted=True))
14351427
elif first_part.this.lower() in ("source", "dbt_internal_source", "__merge_source__"):
1436-
first_part.replace(normalized_merge_source_alias)
1428+
first_part.replace(exp.to_identifier(MERGE_SOURCE_ALIAS, quoted=True))
14371429

14381430
return expression

sqlmesh/core/model/kind.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -478,10 +478,9 @@ def _when_matched_validator(
478478
v = v[1:-1]
479479

480480
v = t.cast(exp.Whens, d.parse_one(v, into=exp.Whens, dialect=dialect))
481-
else:
482-
v = t.cast(exp.Whens, v.transform(d.replace_merge_table_aliases, dialect=dialect))
483481

484-
return validate_expression(v, dialect=dialect)
482+
v = validate_expression(v, dialect=dialect)
483+
return t.cast(exp.Whens, v.transform(d.replace_merge_table_aliases, dialect=dialect))
485484

486485
@field_validator("merge_filter", mode="before")
487486
def _merge_filter_validator(
@@ -497,10 +496,9 @@ def _merge_filter_validator(
497496
if isinstance(v, str):
498497
v = v.strip()
499498
v = d.parse_one(v, dialect=dialect)
500-
else:
501-
v = v.transform(d.replace_merge_table_aliases, dialect=dialect)
502499

503-
return validate_expression(v, dialect=dialect)
500+
v = validate_expression(v, dialect=dialect)
501+
return v.transform(d.replace_merge_table_aliases, dialect=dialect)
504502

505503
@property
506504
def data_hash_values(self) -> t.List[t.Optional[str]]:

tests/core/engine_adapter/integration/test_integration.py

Lines changed: 128 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from sqlmesh.core.dialect import select_from_values
2525
from sqlmesh.core.model import Model, load_sql_based_model
2626
from sqlmesh.core.engine_adapter.shared import DataObject, DataObjectType
27-
from sqlmesh.core.engine_adapter.mixins import RowDiffMixin
27+
from sqlmesh.core.engine_adapter.mixins import RowDiffMixin, LogicalMergeMixin
2828
from sqlmesh.core.model.definition import create_sql_model
2929
from sqlmesh.core.plan import Plan
3030
from sqlmesh.core.state_sync.db import EngineAdapterStateSync
@@ -1897,6 +1897,133 @@ def _mutate_config(current_gateway_name: str, config: Config):
18971897
ctx.cleanup(context)
18981898

18991899

1900+
def test_incremental_by_unique_key_model_when_matched(ctx: TestContext):
1901+
if not ctx.supports_merge:
1902+
pytest.skip(f"{ctx.dialect} on {ctx.gateway} doesnt support merge")
1903+
1904+
# DuckDB and some other engines use logical_merge which doesn't support when_matched
1905+
if isinstance(ctx.engine_adapter, LogicalMergeMixin):
1906+
pytest.skip(
1907+
f"{ctx.dialect} on {ctx.gateway} uses logical merge which doesn't support when_matched"
1908+
)
1909+
1910+
def _mutate_config(current_gateway_name: str, config: Config):
1911+
connection = config.gateways[current_gateway_name].connection
1912+
connection.concurrent_tasks = 1
1913+
if current_gateway_name == "inttest_redshift":
1914+
connection.enable_merge = True
1915+
1916+
context = ctx.create_context(_mutate_config)
1917+
schema = ctx.schema(TEST_SCHEMA)
1918+
1919+
# Create seed data with multiple days
1920+
seed_query = ctx.input_data(
1921+
pd.DataFrame(
1922+
[
1923+
[1, "item_a", 100, "2020-01-01"],
1924+
[2, "item_b", 200, "2020-01-01"],
1925+
[1, "item_a_changed", 150, "2020-01-02"], # Same item_id, different name and value
1926+
[2, "item_b_changed", 250, "2020-01-02"], # Same item_id, different name and value
1927+
[3, "item_c", 300, "2020-01-02"], # New item on day 2
1928+
],
1929+
columns=["item_id", "name", "value", "event_date"],
1930+
),
1931+
columns_to_types={
1932+
"item_id": exp.DataType.build("integer"),
1933+
"name": exp.DataType.build("text"),
1934+
"value": exp.DataType.build("integer"),
1935+
"event_date": exp.DataType.build("date"),
1936+
},
1937+
)
1938+
context.upsert_model(
1939+
create_sql_model(name=f"{schema}.seed_model", query=seed_query, kind="FULL")
1940+
)
1941+
1942+
table_format = ""
1943+
if ctx.dialect == "athena":
1944+
# INCREMENTAL_BY_UNIQUE_KEY uses MERGE which is only supported in Athena on Iceberg tables
1945+
table_format = "table_format iceberg,"
1946+
1947+
# Create model with when_matched clause that only updates the value column
1948+
# BUT keeps the existing name column unchanged
1949+
# batch_size=1 is so that we trigger merge on second batch and verify behaviour of when_matched
1950+
context.upsert_model(
1951+
load_sql_based_model(
1952+
d.parse(
1953+
f"""MODEL (
1954+
name {schema}.test_model_when_matched,
1955+
kind INCREMENTAL_BY_UNIQUE_KEY (
1956+
unique_key item_id,
1957+
batch_size 1,
1958+
merge_filter source.event_date > target.event_date,
1959+
when_matched WHEN MATCHED THEN UPDATE SET target.value = source.value, target.event_date = source.event_date
1960+
),
1961+
{table_format}
1962+
start '2020-01-01',
1963+
end '2020-01-02',
1964+
cron '@daily'
1965+
);
1966+
1967+
select item_id, name, value, event_date
1968+
from {schema}.seed_model
1969+
where event_date between @start_date and @end_date""",
1970+
)
1971+
)
1972+
)
1973+
1974+
try:
1975+
# Initial plan to create the model and run it
1976+
context.plan(auto_apply=True, no_prompts=True)
1977+
1978+
test_model = context.get_model(f"{schema}.test_model_when_matched")
1979+
1980+
# Verify that the model has the when_matched clause and merge_filter
1981+
assert test_model.kind.when_matched is not None
1982+
assert (
1983+
test_model.kind.when_matched.sql()
1984+
== '(WHEN MATCHED THEN UPDATE SET "__MERGE_TARGET__"."value" = "__MERGE_SOURCE__"."value", "__MERGE_TARGET__"."event_date" = "__MERGE_SOURCE__"."event_date")'
1985+
)
1986+
assert test_model.merge_filter is not None
1987+
assert (
1988+
test_model.merge_filter.sql()
1989+
== '"__MERGE_SOURCE__"."event_date" > "__MERGE_TARGET__"."event_date"'
1990+
)
1991+
1992+
actual_df = (
1993+
ctx.get_current_data(test_model.fqn).sort_values(by="item_id").reset_index(drop=True)
1994+
)
1995+
1996+
# Expected results after batch processing:
1997+
# - Day 1: Items 1 and 2 are inserted (first insert)
1998+
# - Day 2: Items 1 and 2 are merged (when_matched clause preserves names but updates values/dates)
1999+
# Item 3 is inserted as new
2000+
expected_df = (
2001+
pd.DataFrame(
2002+
[
2003+
[1, "item_a", 150, "2020-01-02"], # name from day 1, value and date from day 2
2004+
[2, "item_b", 250, "2020-01-02"], # name from day 1, value and date from day 2
2005+
[3, "item_c", 300, "2020-01-02"], # new item from day 2
2006+
],
2007+
columns=["item_id", "name", "value", "event_date"],
2008+
)
2009+
.sort_values(by="item_id")
2010+
.reset_index(drop=True)
2011+
)
2012+
2013+
# Convert date columns to string for comparison
2014+
actual_df["event_date"] = actual_df["event_date"].astype(str)
2015+
expected_df["event_date"] = expected_df["event_date"].astype(str)
2016+
2017+
pd.testing.assert_frame_equal(
2018+
actual_df,
2019+
expected_df,
2020+
check_dtype=False,
2021+
)
2022+
2023+
finally:
2024+
ctx.cleanup(context)
2025+
2026+
19002027
def test_managed_model_upstream_forward_only(ctx: TestContext):
19012028
"""
19022029
This scenario goes as follows:

tests/core/test_model.py

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -5480,7 +5480,7 @@ def test_when_matched():
54805480
"""
54815481
)
54825482

5483-
expected_when_matched = "(WHEN MATCHED THEN UPDATE SET `__merge_target__`.`salary` = COALESCE(`__merge_source__`.`salary`, `__merge_target__`.`salary`))"
5483+
expected_when_matched = "(WHEN MATCHED THEN UPDATE SET `__MERGE_TARGET__`.`salary` = COALESCE(`__MERGE_SOURCE__`.`salary`, `__MERGE_TARGET__`.`salary`))"
54845484

54855485
model = load_sql_based_model(expressions, dialect="hive")
54865486
assert model.kind.when_matched.sql(dialect="hive") == expected_when_matched
@@ -5514,9 +5514,9 @@ def test_when_matched():
55145514
kind INCREMENTAL_BY_UNIQUE_KEY (
55155515
unique_key ("purchase_order_id"),
55165516
when_matched (
5517-
WHEN MATCHED AND "__merge_source__"."_operation" = 1 THEN DELETE
5518-
WHEN MATCHED AND "__merge_source__"."_operation" <> 1 THEN UPDATE SET
5519-
"__merge_target__"."purchase_order_id" = 1
5517+
WHEN MATCHED AND "__MERGE_SOURCE__"."_operation" = 1 THEN DELETE
5518+
WHEN MATCHED AND "__MERGE_SOURCE__"."_operation" <> 1 THEN UPDATE SET
5519+
"__MERGE_TARGET__"."purchase_order_id" = 1
55205520
),
55215521
batch_concurrency 1,
55225522
forward_only FALSE,
@@ -5567,7 +5567,7 @@ def fingerprint_merge(
55675567
kind INCREMENTAL_BY_UNIQUE_KEY (
55685568
unique_key ("purchase_order_id"),
55695569
when_matched (
5570-
WHEN MATCHED AND "__merge_source__"."salary" <> "__merge_target__"."salary" THEN UPDATE SET
5570+
WHEN MATCHED AND "__MERGE_SOURCE__"."salary" <> "__MERGE_TARGET__"."salary" THEN UPDATE SET
55715571
ARRAY('target.update_datetime = source.update_datetime', 'target.salary = source.salary')
55725572
),
55735573
batch_concurrency 1,
@@ -5601,8 +5601,8 @@ def test_when_matched_multiple():
56015601
)
56025602

56035603
expected_when_matched = [
5604-
"WHEN MATCHED AND `__merge_source__`.`x` = 1 THEN UPDATE SET `__merge_target__`.`salary` = COALESCE(`__merge_source__`.`salary`, `__merge_target__`.`salary`)",
5605-
"WHEN MATCHED THEN UPDATE SET `__merge_target__`.`salary` = COALESCE(`__merge_source__`.`salary`, `__merge_target__`.`salary`)",
5604+
"WHEN MATCHED AND `__MERGE_SOURCE__`.`x` = 1 THEN UPDATE SET `__MERGE_TARGET__`.`salary` = COALESCE(`__MERGE_SOURCE__`.`salary`, `__MERGE_TARGET__`.`salary`)",
5605+
"WHEN MATCHED THEN UPDATE SET `__MERGE_TARGET__`.`salary` = COALESCE(`__MERGE_SOURCE__`.`salary`, `__MERGE_TARGET__`.`salary`)",
56065606
]
56075607

56085608
model = load_sql_based_model(expressions, dialect="hive", variables={"schema": "db"})
@@ -5643,13 +5643,13 @@ def test_when_matched_merge_filter_multi_part_columns():
56435643
)
56445644

56455645
expected_when_matched = [
5646-
"WHEN MATCHED AND `__merge_source__`.`record`.`nested_record`.`field` = 1 THEN UPDATE SET `__merge_target__`.`repeated_record`.`sub_repeated_record`.`sub_field` = COALESCE(`__merge_source__`.`repeated_record`.`sub_repeated_record`.`sub_field`, `__merge_target__`.`repeated_record`.`sub_repeated_record`.`sub_field`)",
5647-
"WHEN MATCHED THEN UPDATE SET `__merge_target__`.`repeated_record`.`sub_repeated_record`.`sub_field` = COALESCE(`__merge_source__`.`repeated_record`.`sub_repeated_record`.`sub_field`, `__merge_target__`.`repeated_record`.`sub_repeated_record`.`sub_field`)",
5646+
"WHEN MATCHED AND `__MERGE_SOURCE__`.`record`.`nested_record`.`field` = 1 THEN UPDATE SET `__MERGE_TARGET__`.`repeated_record`.`sub_repeated_record`.`sub_field` = COALESCE(`__MERGE_SOURCE__`.`repeated_record`.`sub_repeated_record`.`sub_field`, `__MERGE_TARGET__`.`repeated_record`.`sub_repeated_record`.`sub_field`)",
5647+
"WHEN MATCHED THEN UPDATE SET `__MERGE_TARGET__`.`repeated_record`.`sub_repeated_record`.`sub_field` = COALESCE(`__MERGE_SOURCE__`.`repeated_record`.`sub_repeated_record`.`sub_field`, `__MERGE_TARGET__`.`repeated_record`.`sub_repeated_record`.`sub_field`)",
56485648
]
56495649

56505650
expected_merge_filter = (
5651-
"`__merge_source__`.`record`.`nested_record`.`field` < `__merge_target__`.`record`.`nested_record`.`field` AND "
5652-
"`__merge_target__`.`repeated_record`.`sub_repeated_record`.`sub_field` > `__merge_source__`.`repeated_record`.`sub_repeated_record`.`sub_field`"
5651+
"`__MERGE_SOURCE__`.`record`.`nested_record`.`field` < `__MERGE_TARGET__`.`record`.`nested_record`.`field` AND "
5652+
"`__MERGE_TARGET__`.`repeated_record`.`sub_repeated_record`.`sub_field` > `__MERGE_SOURCE__`.`repeated_record`.`sub_repeated_record`.`sub_field`"
56535653
)
56545654

56555655
model = load_sql_based_model(expressions, dialect="bigquery", variables={"schema": "db"})
@@ -6679,7 +6679,7 @@ def test_unrendered_macros_sql_model(mocker: MockerFixture) -> None:
66796679
assert model.unique_key[0] == exp.column("a", quoted=True)
66806680
assert (
66816681
t.cast(exp.Expression, model.merge_filter).sql()
6682-
== '"__merge_source__"."id" > 0 AND "__merge_target__"."updated_at" < @end_ds AND "__merge_source__"."updated_at" > @start_ds AND @merge_filter_var'
6682+
== '"__MERGE_SOURCE__"."id" > 0 AND "__MERGE_TARGET__"."updated_at" < @end_ds AND "__MERGE_SOURCE__"."updated_at" > @start_ds AND @merge_filter_var'
66836683
)
66846684

66856685

@@ -6775,7 +6775,7 @@ def model_with_macros(evaluator, **kwargs):
67756775
assert python_sql_model.unique_key[0] == exp.column("a", quoted=True)
67766776
assert (
67776777
python_sql_model.merge_filter.sql()
6778-
== '"__merge_source__"."id" > 0 AND "__merge_target__"."updated_at" < @end_ds AND "__merge_source__"."updated_at" > @start_ds AND @merge_filter_var'
6778+
== '"__MERGE_SOURCE__"."id" > 0 AND "__MERGE_TARGET__"."updated_at" < @end_ds AND "__MERGE_SOURCE__"."updated_at" > @start_ds AND @merge_filter_var'
67796779
)
67806780

67816781

@@ -7862,7 +7862,7 @@ def test_model_kind_to_expression():
78627862
.sql()
78637863
== """INCREMENTAL_BY_UNIQUE_KEY (
78647864
unique_key ("a"),
7865-
when_matched (WHEN MATCHED THEN UPDATE SET "__merge_target__"."b" = COALESCE("__merge_source__"."b", "__merge_target__"."b")),
7865+
when_matched (WHEN MATCHED THEN UPDATE SET "__MERGE_TARGET__"."b" = COALESCE("__MERGE_SOURCE__"."b", "__MERGE_TARGET__"."b")),
78667866
batch_concurrency 1,
78677867
forward_only FALSE,
78687868
disable_restatement FALSE,
@@ -7890,7 +7890,7 @@ def test_model_kind_to_expression():
78907890
.sql()
78917891
== """INCREMENTAL_BY_UNIQUE_KEY (
78927892
unique_key ("a"),
7893-
when_matched (WHEN MATCHED AND "__merge_source__"."x" = 1 THEN UPDATE SET "__merge_target__"."b" = COALESCE("__merge_source__"."b", "__merge_target__"."b") WHEN MATCHED THEN UPDATE SET "__merge_target__"."b" = COALESCE("__merge_source__"."b", "__merge_target__"."b")),
7893+
when_matched (WHEN MATCHED AND "__MERGE_SOURCE__"."x" = 1 THEN UPDATE SET "__MERGE_TARGET__"."b" = COALESCE("__MERGE_SOURCE__"."b", "__MERGE_TARGET__"."b") WHEN MATCHED THEN UPDATE SET "__MERGE_TARGET__"."b" = COALESCE("__MERGE_SOURCE__"."b", "__MERGE_TARGET__"."b")),
78947894
batch_concurrency 1,
78957895
forward_only FALSE,
78967896
disable_restatement FALSE,
@@ -8151,7 +8151,7 @@ def test_merge_filter():
81518151
"""
81528152
)
81538153

8154-
expected_incremental_predicate = f"`{MERGE_SOURCE_ALIAS.lower()}`.`salary` > 0"
8154+
expected_incremental_predicate = f"`{MERGE_SOURCE_ALIAS}`.`salary` > 0"
81558155

81568156
model = load_sql_based_model(expressions, dialect="hive")
81578157
assert model.kind.merge_filter.sql(dialect="hive") == expected_incremental_predicate
@@ -8194,19 +8194,19 @@ def test_merge_filter():
81948194
kind INCREMENTAL_BY_UNIQUE_KEY (
81958195
unique_key ("purchase_order_id"),
81968196
when_matched (
8197-
WHEN MATCHED AND "{MERGE_SOURCE_ALIAS.lower()}"."_operation" = 1 THEN DELETE
8198-
WHEN MATCHED AND "{MERGE_SOURCE_ALIAS.lower()}"."_operation" <> 1 THEN UPDATE SET
8199-
"{MERGE_TARGET_ALIAS.lower()}"."purchase_order_id" = 1
8197+
WHEN MATCHED AND "{MERGE_SOURCE_ALIAS}"."_operation" = 1 THEN DELETE
8198+
WHEN MATCHED AND "{MERGE_SOURCE_ALIAS}"."_operation" <> 1 THEN UPDATE SET
8199+
"{MERGE_TARGET_ALIAS}"."purchase_order_id" = 1
82008200
),
82018201
merge_filter (
8202-
"{MERGE_SOURCE_ALIAS.lower()}"."ds" > (
8202+
"{MERGE_SOURCE_ALIAS}"."ds" > (
82038203
SELECT
82048204
MAX("ds")
82058205
FROM "db"."test"
82068206
)
8207-
AND "{MERGE_SOURCE_ALIAS.lower()}"."ds" > @start_ds
8208-
AND "{MERGE_SOURCE_ALIAS.lower()}"."_operation" <> 1
8209-
AND "{MERGE_TARGET_ALIAS.lower()}"."start_date" > CURRENT_DATE + INTERVAL '7' DAY
8207+
AND "{MERGE_SOURCE_ALIAS}"."ds" > @start_ds
8208+
AND "{MERGE_SOURCE_ALIAS}"."_operation" <> 1
8209+
AND "{MERGE_TARGET_ALIAS}"."start_date" > CURRENT_DATE + INTERVAL '7' DAY
82108210
),
82118211
batch_concurrency 1,
82128212
forward_only FALSE,
@@ -8224,7 +8224,7 @@ def test_merge_filter():
82248224
rendered_merge_filters = model.render_merge_filter(start="2023-01-01", end="2023-01-02")
82258225
assert (
82268226
rendered_merge_filters.sql(dialect="hive")
8227-
== "(`__merge_source__`.`ds` > (SELECT MAX(`ds`) FROM `db`.`test`) AND `__merge_source__`.`ds` > '2023-01-01' AND `__merge_source__`.`_operation` <> 1 AND `__merge_target__`.`start_date` > CURRENT_DATE + INTERVAL '7' DAY)"
8227+
== "(`__MERGE_SOURCE__`.`ds` > (SELECT MAX(`ds`) FROM `db`.`test`) AND `__MERGE_SOURCE__`.`ds` > '2023-01-01' AND `__MERGE_SOURCE__`.`_operation` <> 1 AND `__MERGE_TARGET__`.`start_date` > CURRENT_DATE + INTERVAL '7' DAY)"
82288228
)
82298229

82308230

0 commit comments

Comments
 (0)