Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 2 additions & 10 deletions sqlmesh/core/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -1421,18 +1421,10 @@ def replace_merge_table_aliases(
"""
from sqlmesh.core.engine_adapter.base import MERGE_SOURCE_ALIAS, MERGE_TARGET_ALIAS

normalized_merge_source_alias = quote_identifiers(
normalize_identifiers(exp.to_identifier(MERGE_SOURCE_ALIAS), dialect), dialect=dialect
)

normalized_merge_target_alias = quote_identifiers(
normalize_identifiers(exp.to_identifier(MERGE_TARGET_ALIAS), dialect), dialect=dialect
)

if isinstance(expression, exp.Column) and (first_part := expression.parts[0]):
if first_part.this.lower() in ("target", "dbt_internal_dest", "__merge_target__"):
first_part.replace(normalized_merge_target_alias)
first_part.replace(exp.to_identifier(MERGE_TARGET_ALIAS, quoted=True))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It took me some time to grok this but I see why it works. quoted=True is the key part to prevent it being un-done later.

This essentially ensures that everything is uppercase regardless of the engine's normalization strategy, which lines up with what EngineAdapter._merge uses.

EngineAdapter._merge produces an unquoted uppercase value, which gets quoted by default during EngineAdapter.execute, which is why it needs to either be uppercase here or normalized in EngineAdapter._merge

elif first_part.this.lower() in ("source", "dbt_internal_source", "__merge_source__"):
first_part.replace(normalized_merge_source_alias)
first_part.replace(exp.to_identifier(MERGE_SOURCE_ALIAS, quoted=True))

return expression
10 changes: 4 additions & 6 deletions sqlmesh/core/model/kind.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,10 +478,9 @@ def _when_matched_validator(
v = v[1:-1]

v = t.cast(exp.Whens, d.parse_one(v, into=exp.Whens, dialect=dialect))
else:
v = t.cast(exp.Whens, v.transform(d.replace_merge_table_aliases, dialect=dialect))

return validate_expression(v, dialect=dialect)
v = validate_expression(v, dialect=dialect)
return t.cast(exp.Whens, v.transform(d.replace_merge_table_aliases, dialect=dialect))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously we only ran the replace_merge_table_aliases transform when loading from disk. When reading back from state, we assumed it had already been applied so didn't apply it again.

However, applying it regardless like you do here will transparently "fix" what was in state, right? So that when people upgrade SQLMesh, everything should still match

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes that was my thinking behind it to foolproof it


@field_validator("merge_filter", mode="before")
def _merge_filter_validator(
Expand All @@ -497,10 +496,9 @@ def _merge_filter_validator(
if isinstance(v, str):
v = v.strip()
v = d.parse_one(v, dialect=dialect)
else:
v = v.transform(d.replace_merge_table_aliases, dialect=dialect)

return validate_expression(v, dialect=dialect)
v = validate_expression(v, dialect=dialect)
return v.transform(d.replace_merge_table_aliases, dialect=dialect)

@property
def data_hash_values(self) -> t.List[t.Optional[str]]:
Expand Down
129 changes: 128 additions & 1 deletion tests/core/engine_adapter/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from sqlmesh.core.dialect import select_from_values
from sqlmesh.core.model import Model, load_sql_based_model
from sqlmesh.core.engine_adapter.shared import DataObject, DataObjectType
from sqlmesh.core.engine_adapter.mixins import RowDiffMixin
from sqlmesh.core.engine_adapter.mixins import RowDiffMixin, LogicalMergeMixin
from sqlmesh.core.model.definition import create_sql_model
from sqlmesh.core.plan import Plan
from sqlmesh.core.state_sync.db import EngineAdapterStateSync
Expand Down Expand Up @@ -1897,6 +1897,133 @@ def _mutate_config(current_gateway_name: str, config: Config):
ctx.cleanup(context)


def test_incremental_by_unique_key_model_when_matched(ctx: TestContext):
if not ctx.supports_merge:
pytest.skip(f"{ctx.dialect} on {ctx.gateway} doesnt support merge")

# DuckDB and some other engines use logical_merge which doesn't support when_matched
if isinstance(ctx.engine_adapter, LogicalMergeMixin):
pytest.skip(
f"{ctx.dialect} on {ctx.gateway} uses logical merge which doesn't support when_matched"
)

def _mutate_config(current_gateway_name: str, config: Config):
connection = config.gateways[current_gateway_name].connection
connection.concurrent_tasks = 1
if current_gateway_name == "inttest_redshift":
connection.enable_merge = True

context = ctx.create_context(_mutate_config)
schema = ctx.schema(TEST_SCHEMA)

# Create seed data with multiple days
seed_query = ctx.input_data(
pd.DataFrame(
[
[1, "item_a", 100, "2020-01-01"],
[2, "item_b", 200, "2020-01-01"],
[1, "item_a_changed", 150, "2020-01-02"], # Same item_id, different name and value
[2, "item_b_changed", 250, "2020-01-02"], # Same item_id, different name and value
[3, "item_c", 300, "2020-01-02"], # New item on day 2
],
columns=["item_id", "name", "value", "event_date"],
),
columns_to_types={
"item_id": exp.DataType.build("integer"),
"name": exp.DataType.build("text"),
"value": exp.DataType.build("integer"),
"event_date": exp.DataType.build("date"),
},
)
context.upsert_model(
create_sql_model(name=f"{schema}.seed_model", query=seed_query, kind="FULL")
)

table_format = ""
if ctx.dialect == "athena":
# INCREMENTAL_BY_UNIQUE_KEY uses MERGE which is only supported in Athena on Iceberg tables
table_format = "table_format iceberg,"

# Create model with when_matched clause that only updates the value column
# BUT keeps the existing name column unchanged
# batch_size=1 is so that we trigger merge on second batch and verify behaviour of when_matched
context.upsert_model(
load_sql_based_model(
d.parse(
f"""MODEL (
name {schema}.test_model_when_matched,
kind INCREMENTAL_BY_UNIQUE_KEY (
unique_key item_id,
batch_size 1,
merge_filter source.event_date > target.event_date,
when_matched WHEN MATCHED THEN UPDATE SET target.value = source.value, target.event_date = source.event_date
),
{table_format}
start '2020-01-01',
end '2020-01-02',
cron '@daily'
);

select item_id, name, value, event_date
from {schema}.seed_model
where event_date between @start_date and @end_date""",
)
)
)

try:
# Initial plan to create the model and run it
context.plan(auto_apply=True, no_prompts=True)

test_model = context.get_model(f"{schema}.test_model_when_matched")

# Verify that the model has the when_matched clause and merge_filter
assert test_model.kind.when_matched is not None
assert (
test_model.kind.when_matched.sql()
== '(WHEN MATCHED THEN UPDATE SET "__MERGE_TARGET__"."value" = "__MERGE_SOURCE__"."value", "__MERGE_TARGET__"."event_date" = "__MERGE_SOURCE__"."event_date")'
)
assert test_model.merge_filter is not None
assert (
test_model.merge_filter.sql()
== '"__MERGE_SOURCE__"."event_date" > "__MERGE_TARGET__"."event_date"'
)

actual_df = (
ctx.get_current_data(test_model.fqn).sort_values(by="item_id").reset_index(drop=True)
)

# Expected results after batch processing:
# - Day 1: Items 1 and 2 are inserted (first insert)
# - Day 2: Items 1 and 2 are merged (when_matched clause preserves names but updates values/dates)
# Item 3 is inserted as new
expected_df = (
pd.DataFrame(
[
[1, "item_a", 150, "2020-01-02"], # name from day 1, value and date from day 2
[2, "item_b", 250, "2020-01-02"], # name from day 1, value and date from day 2
[3, "item_c", 300, "2020-01-02"], # new item from day 2
],
columns=["item_id", "name", "value", "event_date"],
)
.sort_values(by="item_id")
.reset_index(drop=True)
)

# Convert date columns to string for comparison
actual_df["event_date"] = actual_df["event_date"].astype(str)
expected_df["event_date"] = expected_df["event_date"].astype(str)

pd.testing.assert_frame_equal(
actual_df,
expected_df,
check_dtype=False,
)

finally:
ctx.cleanup(context)


def test_managed_model_upstream_forward_only(ctx: TestContext):
"""
This scenario goes as follows:
Expand Down
48 changes: 24 additions & 24 deletions tests/core/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5480,7 +5480,7 @@ def test_when_matched():
"""
)

expected_when_matched = "(WHEN MATCHED THEN UPDATE SET `__merge_target__`.`salary` = COALESCE(`__merge_source__`.`salary`, `__merge_target__`.`salary`))"
expected_when_matched = "(WHEN MATCHED THEN UPDATE SET `__MERGE_TARGET__`.`salary` = COALESCE(`__MERGE_SOURCE__`.`salary`, `__MERGE_TARGET__`.`salary`))"

model = load_sql_based_model(expressions, dialect="hive")
assert model.kind.when_matched.sql(dialect="hive") == expected_when_matched
Expand Down Expand Up @@ -5514,9 +5514,9 @@ def test_when_matched():
kind INCREMENTAL_BY_UNIQUE_KEY (
unique_key ("purchase_order_id"),
when_matched (
WHEN MATCHED AND "__merge_source__"."_operation" = 1 THEN DELETE
WHEN MATCHED AND "__merge_source__"."_operation" <> 1 THEN UPDATE SET
"__merge_target__"."purchase_order_id" = 1
WHEN MATCHED AND "__MERGE_SOURCE__"."_operation" = 1 THEN DELETE
WHEN MATCHED AND "__MERGE_SOURCE__"."_operation" <> 1 THEN UPDATE SET
"__MERGE_TARGET__"."purchase_order_id" = 1
),
batch_concurrency 1,
forward_only FALSE,
Expand Down Expand Up @@ -5567,7 +5567,7 @@ def fingerprint_merge(
kind INCREMENTAL_BY_UNIQUE_KEY (
unique_key ("purchase_order_id"),
when_matched (
WHEN MATCHED AND "__merge_source__"."salary" <> "__merge_target__"."salary" THEN UPDATE SET
WHEN MATCHED AND "__MERGE_SOURCE__"."salary" <> "__MERGE_TARGET__"."salary" THEN UPDATE SET
ARRAY('target.update_datetime = source.update_datetime', 'target.salary = source.salary')
),
batch_concurrency 1,
Expand Down Expand Up @@ -5601,8 +5601,8 @@ def test_when_matched_multiple():
)

expected_when_matched = [
"WHEN MATCHED AND `__merge_source__`.`x` = 1 THEN UPDATE SET `__merge_target__`.`salary` = COALESCE(`__merge_source__`.`salary`, `__merge_target__`.`salary`)",
"WHEN MATCHED THEN UPDATE SET `__merge_target__`.`salary` = COALESCE(`__merge_source__`.`salary`, `__merge_target__`.`salary`)",
"WHEN MATCHED AND `__MERGE_SOURCE__`.`x` = 1 THEN UPDATE SET `__MERGE_TARGET__`.`salary` = COALESCE(`__MERGE_SOURCE__`.`salary`, `__MERGE_TARGET__`.`salary`)",
"WHEN MATCHED THEN UPDATE SET `__MERGE_TARGET__`.`salary` = COALESCE(`__MERGE_SOURCE__`.`salary`, `__MERGE_TARGET__`.`salary`)",
]

model = load_sql_based_model(expressions, dialect="hive", variables={"schema": "db"})
Expand Down Expand Up @@ -5643,13 +5643,13 @@ def test_when_matched_merge_filter_multi_part_columns():
)

expected_when_matched = [
"WHEN MATCHED AND `__merge_source__`.`record`.`nested_record`.`field` = 1 THEN UPDATE SET `__merge_target__`.`repeated_record`.`sub_repeated_record`.`sub_field` = COALESCE(`__merge_source__`.`repeated_record`.`sub_repeated_record`.`sub_field`, `__merge_target__`.`repeated_record`.`sub_repeated_record`.`sub_field`)",
"WHEN MATCHED THEN UPDATE SET `__merge_target__`.`repeated_record`.`sub_repeated_record`.`sub_field` = COALESCE(`__merge_source__`.`repeated_record`.`sub_repeated_record`.`sub_field`, `__merge_target__`.`repeated_record`.`sub_repeated_record`.`sub_field`)",
"WHEN MATCHED AND `__MERGE_SOURCE__`.`record`.`nested_record`.`field` = 1 THEN UPDATE SET `__MERGE_TARGET__`.`repeated_record`.`sub_repeated_record`.`sub_field` = COALESCE(`__MERGE_SOURCE__`.`repeated_record`.`sub_repeated_record`.`sub_field`, `__MERGE_TARGET__`.`repeated_record`.`sub_repeated_record`.`sub_field`)",
"WHEN MATCHED THEN UPDATE SET `__MERGE_TARGET__`.`repeated_record`.`sub_repeated_record`.`sub_field` = COALESCE(`__MERGE_SOURCE__`.`repeated_record`.`sub_repeated_record`.`sub_field`, `__MERGE_TARGET__`.`repeated_record`.`sub_repeated_record`.`sub_field`)",
]

expected_merge_filter = (
"`__merge_source__`.`record`.`nested_record`.`field` < `__merge_target__`.`record`.`nested_record`.`field` AND "
"`__merge_target__`.`repeated_record`.`sub_repeated_record`.`sub_field` > `__merge_source__`.`repeated_record`.`sub_repeated_record`.`sub_field`"
"`__MERGE_SOURCE__`.`record`.`nested_record`.`field` < `__MERGE_TARGET__`.`record`.`nested_record`.`field` AND "
"`__MERGE_TARGET__`.`repeated_record`.`sub_repeated_record`.`sub_field` > `__MERGE_SOURCE__`.`repeated_record`.`sub_repeated_record`.`sub_field`"
)

model = load_sql_based_model(expressions, dialect="bigquery", variables={"schema": "db"})
Expand Down Expand Up @@ -6679,7 +6679,7 @@ def test_unrendered_macros_sql_model(mocker: MockerFixture) -> None:
assert model.unique_key[0] == exp.column("a", quoted=True)
assert (
t.cast(exp.Expression, model.merge_filter).sql()
== '"__merge_source__"."id" > 0 AND "__merge_target__"."updated_at" < @end_ds AND "__merge_source__"."updated_at" > @start_ds AND @merge_filter_var'
== '"__MERGE_SOURCE__"."id" > 0 AND "__MERGE_TARGET__"."updated_at" < @end_ds AND "__MERGE_SOURCE__"."updated_at" > @start_ds AND @merge_filter_var'
)


Expand Down Expand Up @@ -6775,7 +6775,7 @@ def model_with_macros(evaluator, **kwargs):
assert python_sql_model.unique_key[0] == exp.column("a", quoted=True)
assert (
python_sql_model.merge_filter.sql()
== '"__merge_source__"."id" > 0 AND "__merge_target__"."updated_at" < @end_ds AND "__merge_source__"."updated_at" > @start_ds AND @merge_filter_var'
== '"__MERGE_SOURCE__"."id" > 0 AND "__MERGE_TARGET__"."updated_at" < @end_ds AND "__MERGE_SOURCE__"."updated_at" > @start_ds AND @merge_filter_var'
)


Expand Down Expand Up @@ -7862,7 +7862,7 @@ def test_model_kind_to_expression():
.sql()
== """INCREMENTAL_BY_UNIQUE_KEY (
unique_key ("a"),
when_matched (WHEN MATCHED THEN UPDATE SET "__merge_target__"."b" = COALESCE("__merge_source__"."b", "__merge_target__"."b")),
when_matched (WHEN MATCHED THEN UPDATE SET "__MERGE_TARGET__"."b" = COALESCE("__MERGE_SOURCE__"."b", "__MERGE_TARGET__"."b")),
batch_concurrency 1,
forward_only FALSE,
disable_restatement FALSE,
Expand Down Expand Up @@ -7890,7 +7890,7 @@ def test_model_kind_to_expression():
.sql()
== """INCREMENTAL_BY_UNIQUE_KEY (
unique_key ("a"),
when_matched (WHEN MATCHED AND "__merge_source__"."x" = 1 THEN UPDATE SET "__merge_target__"."b" = COALESCE("__merge_source__"."b", "__merge_target__"."b") WHEN MATCHED THEN UPDATE SET "__merge_target__"."b" = COALESCE("__merge_source__"."b", "__merge_target__"."b")),
when_matched (WHEN MATCHED AND "__MERGE_SOURCE__"."x" = 1 THEN UPDATE SET "__MERGE_TARGET__"."b" = COALESCE("__MERGE_SOURCE__"."b", "__MERGE_TARGET__"."b") WHEN MATCHED THEN UPDATE SET "__MERGE_TARGET__"."b" = COALESCE("__MERGE_SOURCE__"."b", "__MERGE_TARGET__"."b")),
batch_concurrency 1,
forward_only FALSE,
disable_restatement FALSE,
Expand Down Expand Up @@ -8151,7 +8151,7 @@ def test_merge_filter():
"""
)

expected_incremental_predicate = f"`{MERGE_SOURCE_ALIAS.lower()}`.`salary` > 0"
expected_incremental_predicate = f"`{MERGE_SOURCE_ALIAS}`.`salary` > 0"

model = load_sql_based_model(expressions, dialect="hive")
assert model.kind.merge_filter.sql(dialect="hive") == expected_incremental_predicate
Expand Down Expand Up @@ -8194,19 +8194,19 @@ def test_merge_filter():
kind INCREMENTAL_BY_UNIQUE_KEY (
unique_key ("purchase_order_id"),
when_matched (
WHEN MATCHED AND "{MERGE_SOURCE_ALIAS.lower()}"."_operation" = 1 THEN DELETE
WHEN MATCHED AND "{MERGE_SOURCE_ALIAS.lower()}"."_operation" <> 1 THEN UPDATE SET
"{MERGE_TARGET_ALIAS.lower()}"."purchase_order_id" = 1
WHEN MATCHED AND "{MERGE_SOURCE_ALIAS}"."_operation" = 1 THEN DELETE
WHEN MATCHED AND "{MERGE_SOURCE_ALIAS}"."_operation" <> 1 THEN UPDATE SET
"{MERGE_TARGET_ALIAS}"."purchase_order_id" = 1
),
merge_filter (
"{MERGE_SOURCE_ALIAS.lower()}"."ds" > (
"{MERGE_SOURCE_ALIAS}"."ds" > (
SELECT
MAX("ds")
FROM "db"."test"
)
AND "{MERGE_SOURCE_ALIAS.lower()}"."ds" > @start_ds
AND "{MERGE_SOURCE_ALIAS.lower()}"."_operation" <> 1
AND "{MERGE_TARGET_ALIAS.lower()}"."start_date" > CURRENT_DATE + INTERVAL '7' DAY
AND "{MERGE_SOURCE_ALIAS}"."ds" > @start_ds
AND "{MERGE_SOURCE_ALIAS}"."_operation" <> 1
AND "{MERGE_TARGET_ALIAS}"."start_date" > CURRENT_DATE + INTERVAL '7' DAY
),
batch_concurrency 1,
forward_only FALSE,
Expand All @@ -8224,7 +8224,7 @@ def test_merge_filter():
rendered_merge_filters = model.render_merge_filter(start="2023-01-01", end="2023-01-02")
assert (
rendered_merge_filters.sql(dialect="hive")
== "(`__merge_source__`.`ds` > (SELECT MAX(`ds`) FROM `db`.`test`) AND `__merge_source__`.`ds` > '2023-01-01' AND `__merge_source__`.`_operation` <> 1 AND `__merge_target__`.`start_date` > CURRENT_DATE + INTERVAL '7' DAY)"
== "(`__MERGE_SOURCE__`.`ds` > (SELECT MAX(`ds`) FROM `db`.`test`) AND `__MERGE_SOURCE__`.`ds` > '2023-01-01' AND `__MERGE_SOURCE__`.`_operation` <> 1 AND `__MERGE_TARGET__`.`start_date` > CURRENT_DATE + INTERVAL '7' DAY)"
)


Expand Down
Loading