Skip to content

Commit ed7b3cf

Browse files
committed
Fix: Normalize when_matched and merge_filter expressions to the source dialect
1 parent 09eb408 commit ed7b3cf

File tree

8 files changed

+247
-86
lines changed

8 files changed

+247
-86
lines changed

sqlmesh/core/dialect.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1388,17 +1388,27 @@ def is_meta_expression(v: t.Any) -> bool:
13881388
return isinstance(v, (Audit, Metric, Model))
13891389

13901390

1391-
def replace_merge_table_aliases(expression: exp.Expression) -> exp.Expression:
1391+
def replace_merge_table_aliases(
1392+
expression: exp.Expression, dialect: t.Optional[str] = None
1393+
) -> exp.Expression:
13921394
"""
13931395
Resolves references from the "source" and "target" tables (or their DBT equivalents)
13941396
with the corresponding SQLMesh merge aliases (MERGE_SOURCE_ALIAS and MERGE_TARGET_ALIAS)
13951397
"""
13961398
from sqlmesh.core.engine_adapter.base import MERGE_SOURCE_ALIAS, MERGE_TARGET_ALIAS
13971399

1400+
normalized_merge_source_alias = quote_identifiers(
1401+
normalize_identifiers(exp.to_identifier(MERGE_SOURCE_ALIAS), dialect), dialect=dialect
1402+
)
1403+
1404+
normalized_merge_target_alias = quote_identifiers(
1405+
normalize_identifiers(exp.to_identifier(MERGE_TARGET_ALIAS), dialect), dialect=dialect
1406+
)
1407+
13981408
if isinstance(expression, exp.Column) and (first_part := expression.parts[0]):
13991409
if first_part.this.lower() in ("target", "dbt_internal_dest", "__merge_target__"):
1400-
first_part.replace(exp.to_identifier(MERGE_TARGET_ALIAS))
1410+
first_part.replace(normalized_merge_target_alias)
14011411
elif first_part.this.lower() in ("source", "dbt_internal_source", "__merge_source__"):
1402-
first_part.replace(exp.to_identifier(MERGE_SOURCE_ALIAS))
1412+
first_part.replace(normalized_merge_source_alias)
14031413

14041414
return expression

sqlmesh/core/model/definition.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -653,7 +653,7 @@ def render_merge_filter(
653653
)
654654
if len(rendered_exprs) != 1:
655655
raise SQLMeshError(f"Expected one expression but got {len(rendered_exprs)}")
656-
return rendered_exprs[0].transform(d.replace_merge_table_aliases)
656+
return rendered_exprs[0].transform(d.replace_merge_table_aliases, dialect=self.dialect)
657657

658658
def _render_properties(
659659
self, properties: t.Dict[str, exp.Expression] | SessionProperties, **render_kwargs: t.Any

sqlmesh/core/model/kind.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
get_dialect,
3535
validate_string,
3636
positive_int_validator,
37+
validate_expression,
3738
)
3839

3940

@@ -467,15 +468,20 @@ def _when_matched_validator(
467468
return v
468469
if isinstance(v, list):
469470
v = " ".join(v)
471+
472+
dialect = get_dialect(info.data)
473+
470474
if isinstance(v, str):
471475
# Whens wrap the WHEN clauses, but the parentheses aren't parsed by sqlglot
472476
v = v.strip()
473477
if v.startswith("("):
474478
v = v[1:-1]
475479

476-
return t.cast(exp.Whens, d.parse_one(v, into=exp.Whens, dialect=get_dialect(info.data)))
480+
v = t.cast(exp.Whens, d.parse_one(v, into=exp.Whens, dialect=dialect))
481+
else:
482+
v = t.cast(exp.Whens, v.transform(d.replace_merge_table_aliases, dialect=dialect))
477483

478-
return t.cast(exp.Whens, v.transform(d.replace_merge_table_aliases))
484+
return validate_expression(v, dialect=dialect)
479485

480486
@field_validator("merge_filter", mode="before")
481487
def _merge_filter_validator(
@@ -485,11 +491,16 @@ def _merge_filter_validator(
485491
) -> t.Optional[exp.Expression]:
486492
if v is None:
487493
return v
494+
495+
dialect = get_dialect(info.data)
496+
488497
if isinstance(v, str):
489498
v = v.strip()
490-
return d.parse_one(v, dialect=get_dialect(info.data))
499+
v = d.parse_one(v, dialect=dialect)
500+
else:
501+
v = v.transform(d.replace_merge_table_aliases, dialect=dialect)
491502

492-
return v.transform(d.replace_merge_table_aliases)
503+
return validate_expression(v, dialect=dialect)
493504

494505
@property
495506
def data_hash_values(self) -> t.List[t.Optional[str]]:
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
"""
2+
Normalize and quote the when_matched and merge_filter properties of IncrementalByUniqueKeyKind
3+
to match how other properties (such as time_column and partitioned_by) are handled and to
4+
prevent un-normalized identifiers being quoted at the EngineAdapter level
5+
"""
6+
7+
8+
def migrate(state_sync, **kwargs): # type: ignore
9+
pass

sqlmesh/utils/pydantic.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
from sqlmesh.utils import str_to_bool
1717

1818
if t.TYPE_CHECKING:
19+
from sqlglot._typing import E
20+
1921
Model = t.TypeVar("Model", bound="PydanticModel")
2022

2123

@@ -193,6 +195,12 @@ def validate_string(v: t.Any) -> str:
193195
return str(v)
194196

195197

198+
def validate_expression(expression: E, dialect: str) -> E:
199+
# this normalizes and quotes identifiers in the given expression according the specified dialect
200+
# it also sets expression.meta["dialect"] so that when we serialize for state, the expression is serialized in the correct dialect
201+
return _get_field(expression, {"dialect": dialect}) # type: ignore
202+
203+
196204
def bool_validator(v: t.Any) -> bool:
197205
if isinstance(v, exp.Boolean):
198206
return v.this

0 commit comments

Comments
 (0)