Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions sqlmesh/core/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -1388,17 +1388,27 @@ def is_meta_expression(v: t.Any) -> bool:
return isinstance(v, (Audit, Metric, Model))


def replace_merge_table_aliases(expression: exp.Expression) -> exp.Expression:
def replace_merge_table_aliases(
expression: exp.Expression, dialect: t.Optional[str] = None
) -> exp.Expression:
"""
Resolves references from the "source" and "target" tables (or their DBT equivalents)
with the corresponding SQLMesh merge aliases (MERGE_SOURCE_ALIAS and MERGE_TARGET_ALIAS)
"""
from sqlmesh.core.engine_adapter.base import MERGE_SOURCE_ALIAS, MERGE_TARGET_ALIAS

normalized_merge_source_alias = quote_identifiers(
normalize_identifiers(exp.to_identifier(MERGE_SOURCE_ALIAS), dialect), dialect=dialect
)

normalized_merge_target_alias = quote_identifiers(
normalize_identifiers(exp.to_identifier(MERGE_TARGET_ALIAS), dialect), dialect=dialect
)

if isinstance(expression, exp.Column) and (first_part := expression.parts[0]):
if first_part.this.lower() in ("target", "dbt_internal_dest", "__merge_target__"):
first_part.replace(exp.to_identifier(MERGE_TARGET_ALIAS))
first_part.replace(normalized_merge_target_alias)
elif first_part.this.lower() in ("source", "dbt_internal_source", "__merge_source__"):
first_part.replace(exp.to_identifier(MERGE_SOURCE_ALIAS))
first_part.replace(normalized_merge_source_alias)

return expression
2 changes: 1 addition & 1 deletion sqlmesh/core/model/definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,7 +653,7 @@ def render_merge_filter(
)
if len(rendered_exprs) != 1:
raise SQLMeshError(f"Expected one expression but got {len(rendered_exprs)}")
return rendered_exprs[0].transform(d.replace_merge_table_aliases)
return rendered_exprs[0].transform(d.replace_merge_table_aliases, dialect=self.dialect)

def _render_properties(
self, properties: t.Dict[str, exp.Expression] | SessionProperties, **render_kwargs: t.Any
Expand Down
19 changes: 15 additions & 4 deletions sqlmesh/core/model/kind.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
get_dialect,
validate_string,
positive_int_validator,
validate_expression,
)


Expand Down Expand Up @@ -467,15 +468,20 @@ def _when_matched_validator(
return v
if isinstance(v, list):
v = " ".join(v)

dialect = get_dialect(info.data)

if isinstance(v, str):
# Whens wrap the WHEN clauses, but the parentheses aren't parsed by sqlglot
v = v.strip()
if v.startswith("("):
v = v[1:-1]

return t.cast(exp.Whens, d.parse_one(v, into=exp.Whens, dialect=get_dialect(info.data)))
v = t.cast(exp.Whens, d.parse_one(v, into=exp.Whens, dialect=dialect))
else:
v = t.cast(exp.Whens, v.transform(d.replace_merge_table_aliases, dialect=dialect))

return t.cast(exp.Whens, v.transform(d.replace_merge_table_aliases))
return validate_expression(v, dialect=dialect)

@field_validator("merge_filter", mode="before")
def _merge_filter_validator(
Expand All @@ -485,11 +491,16 @@ def _merge_filter_validator(
) -> t.Optional[exp.Expression]:
if v is None:
return v

dialect = get_dialect(info.data)

if isinstance(v, str):
v = v.strip()
return d.parse_one(v, dialect=get_dialect(info.data))
v = d.parse_one(v, dialect=dialect)
else:
v = v.transform(d.replace_merge_table_aliases, dialect=dialect)

return v.transform(d.replace_merge_table_aliases)
return validate_expression(v, dialect=dialect)

@property
def data_hash_values(self) -> t.List[t.Optional[str]]:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
"""
Normalize and quote the when_matched and merge_filter properties of IncrementalByUniqueKeyKind
to match how other properties (such as time_column and partitioned_by) are handled and to
prevent un-normalized identifiers being quoted at the EngineAdapter level
"""


def migrate(state_sync, **kwargs): # type: ignore
pass
8 changes: 8 additions & 0 deletions sqlmesh/utils/pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
from sqlmesh.utils import str_to_bool

if t.TYPE_CHECKING:
from sqlglot._typing import E

Model = t.TypeVar("Model", bound="PydanticModel")


Expand Down Expand Up @@ -193,6 +195,12 @@ def validate_string(v: t.Any) -> str:
return str(v)


def validate_expression(expression: E, dialect: str) -> E:
# this normalizes and quotes identifiers in the given expression according the specified dialect
# it also sets expression.meta["dialect"] so that when we serialize for state, the expression is serialized in the correct dialect
return _get_field(expression, {"dialect": dialect}) # type: ignore


def bool_validator(v: t.Any) -> bool:
if isinstance(v, exp.Boolean):
return v.this
Expand Down
Loading