Skip to content

Commit e67c31a

Browse files
Use dialect explicitly while parsing and transforming in Reconciliation (#1365)
Use dialect explicitly while parsing and transforming in Reconciliation. Currently, it doesn't. --------- Co-authored-by: SundarShankar89 <[email protected]>
1 parent a90290b commit e67c31a

File tree

2 files changed

+19
-9
lines changed

2 files changed

+19
-9
lines changed

src/databricks/labs/remorph/reconcile/query_builder/base.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@
44
import sqlglot.expressions as exp
55
from sqlglot import Dialect, parse_one
66

7-
from databricks.labs.remorph.transpiler.sqlglot.dialect_utils import SQLGLOT_DIALECTS
87
from databricks.labs.remorph.reconcile.exception import InvalidInputException
98
from databricks.labs.remorph.reconcile.query_builder.expression_generator import (
109
DataType_transform_mapping,
1110
transform_expression,
1211
)
1312
from databricks.labs.remorph.reconcile.recon_config import Schema, Table, Aggregate
13+
from databricks.labs.remorph.transpiler.sqlglot.dialect_utils import get_dialect, SQLGLOT_DIALECTS
1414

1515
logger = logging.getLogger(__name__)
1616

@@ -91,12 +91,12 @@ def _apply_user_transformation(self, aliases: list[exp.Expression]) -> list[exp.
9191
with_transform.append(alias.transform(self._user_transformer, self.user_transformations))
9292
return with_transform
9393

94-
@staticmethod
95-
def _user_transformer(node: exp.Expression, user_transformations: dict[str, str]) -> exp.Expression:
94+
def _user_transformer(self, node: exp.Expression, user_transformations: dict[str, str]) -> exp.Expression:
9695
if isinstance(node, exp.Column) and user_transformations:
96+
dialect = self.engine if self.layer == "source" else get_dialect("databricks")
9797
column_name = node.name
9898
if column_name in user_transformations.keys():
99-
return parse_one(user_transformations.get(column_name, column_name))
99+
return parse_one(user_transformations.get(column_name, column_name), read=dialect)
100100
return node
101101

102102
def _apply_default_transformation(

src/databricks/labs/remorph/reconcile/query_builder/expression_generator.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def array_sort(expr: exp.Expression, asc=True) -> exp.Expression:
8585
return _apply_func_expr(expr, exp.ArraySort, expression=exp.Boolean(this=asc))
8686

8787

88-
def anonymous(expr: exp.Column, func: str, is_expr: bool = False) -> exp.Expression:
88+
def anonymous(expr: exp.Column, func: str, is_expr: bool = False, dialect=None) -> exp.Expression:
8989
"""
9090
9191
This function used in cases where the sql functions are not available in sqlGlot expressions
@@ -104,6 +104,8 @@ def anonymous(expr: exp.Column, func: str, is_expr: bool = False) -> exp.Express
104104
105105
"""
106106
if is_expr:
107+
if dialect:
108+
return exp.Column(this=func.format(expr.sql(dialect=dialect)))
107109
return exp.Column(this=func.format(expr))
108110
is_terminal = isinstance(expr, exp.Column)
109111
new_expr = expr.copy()
@@ -235,11 +237,17 @@ def _get_is_string(column_types_dict: dict[str, DataType], column_name: str) ->
235237
"universal": {"default": [partial(coalesce, default='_null_recon_', is_string=True), partial(trim)]},
236238
"snowflake": {exp.DataType.Type.ARRAY.value: [partial(array_to_string), partial(array_sort)]},
237239
"oracle": {
238-
exp.DataType.Type.NCHAR.value: [partial(anonymous, func="NVL(TRIM(TO_CHAR({})),'_null_recon_')")],
239-
exp.DataType.Type.NVARCHAR.value: [partial(anonymous, func="NVL(TRIM(TO_CHAR({})),'_null_recon_')")],
240+
exp.DataType.Type.NCHAR.value: [
241+
partial(anonymous, func="NVL(TRIM(TO_CHAR({})),'_null_recon_')", dialect=get_dialect("oracle"))
242+
],
243+
exp.DataType.Type.NVARCHAR.value: [
244+
partial(anonymous, func="NVL(TRIM(TO_CHAR({})),'_null_recon_')", dialect=get_dialect("oracle"))
245+
],
240246
},
241247
"databricks": {
242-
exp.DataType.Type.ARRAY.value: [partial(anonymous, func="CONCAT_WS(',', SORT_ARRAY({}))")],
248+
exp.DataType.Type.ARRAY.value: [
249+
partial(anonymous, func="CONCAT_WS(',', SORT_ARRAY({}))", dialect=get_dialect("databricks"))
250+
],
243251
},
244252
}
245253

@@ -250,7 +258,9 @@ def _get_is_string(column_types_dict: dict[str, DataType], column_name: str) ->
250258
target=sha256_partial,
251259
),
252260
get_dialect("oracle"): HashAlgoMapping(
253-
source=partial(anonymous, func="RAWTOHEX(STANDARD_HASH({}, 'SHA256'))", is_expr=True),
261+
source=partial(
262+
anonymous, func="RAWTOHEX(STANDARD_HASH({}, 'SHA256'))", is_expr=True, dialect=get_dialect("oracle")
263+
),
254264
target=sha256_partial,
255265
),
256266
get_dialect("databricks"): HashAlgoMapping(

0 commit comments

Comments
 (0)