Skip to content

Commit 2b7a83a

Browse files
committed
add base design for normalizers
1 parent 9f20d55 commit 2b7a83a

File tree

4 files changed

+459
-0
lines changed

4 files changed

+459
-0
lines changed
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import dataclasses
2+
import typing as t
3+
4+
import sqlglot.expressions as e
5+
from sqlglot.dialects import Dialect as SqlglotDialect
6+
7+
DialectType = t.Union[str, SqlglotDialect, t.Type[SqlglotDialect], None]
8+
9+
@dataclasses.dataclass(frozen=True)
10+
class ExpressionTransformation:
11+
func: t.Callable # isnt this Func
12+
args: dict
13+
14+
15+
class ExpressionBuilder:
16+
_expression: e.Expression
17+
18+
def __init__(self, column_name: str, dialect: str, table_name: str | None):
19+
self._column_name = column_name
20+
self._table_name = table_name
21+
self._dialect = dialect
22+
self._transformations: list[ExpressionTransformation] = []
23+
24+
def build(self) -> str:
25+
if self._table_name:
26+
id_exp = e.Identifier(this=self.column_name, table=self._table_name)
27+
else:
28+
id_exp = e.Identifier(this=self.column_name)
29+
column = e.Column(this=id_exp)
30+
exp = self._apply_transformations(column)
31+
return exp.sql(dialect=self._dialect)
32+
33+
def _apply_transformations(self, column: e.Column) -> e.Expression:
34+
exp = column
35+
for transformation in self._transformations:
36+
exp = transformation.func(exp.copy(), **transformation.args) # add error handling
37+
return exp
38+
39+
def column_name(self, name: str):
40+
self._column_name = name
41+
return self
42+
43+
def table_name(self, name: str):
44+
self._column_name = name
45+
return self
46+
47+
def transform(self, func: t.Callable, **kwargs):
48+
transform = ExpressionTransformation(func, kwargs)
49+
self._transformations.append(transform)
50+
return self
51+
52+
def coalesce(column: ExpressionBuilder, default=0, is_string=False) -> ExpressionBuilder:
53+
expressions = [e.Literal(this=default, is_string=is_string)]
54+
return column.transform(e.Coalesce, expressions=expressions)
55+
56+
def trim(column: ExpressionBuilder) -> ExpressionBuilder:
57+
return column.transform(e.Trim)
58+
59+
def unix_time(column: ExpressionBuilder):
60+
return column.transform(e.TimeStrToUnix) #placeholder
Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
import dataclasses
2+
from abc import ABC, abstractmethod
3+
4+
import expressions as e
5+
from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import DialectUtils
6+
from utypes import ExternalType, UType, DatabaseTypeName
7+
8+
9+
@dataclasses.dataclass(frozen=True)
10+
class ExternalColumnDefinition:
11+
name: str
12+
data_type: ExternalType
13+
encoding: str = "utf-8"
14+
15+
@dataclasses.dataclass(frozen=True)
16+
class DatetimeColumnDefinition(ExternalColumnDefinition):
17+
timezone: str = "UTC"
18+
19+
20+
class AbstractNormalizer(ABC):
21+
@classmethod
22+
@abstractmethod
23+
def registry_key_family(cls) -> str:
24+
pass
25+
26+
@classmethod
27+
@abstractmethod
28+
def registry_key(cls) -> str:
29+
pass
30+
31+
@abstractmethod
32+
def normalize(self, column: e.ExpressionBuilder, dialect: e.DialectType, column_def: ExternalColumnDefinition) -> e.ExpressionBuilder:
33+
pass
34+
35+
class UniversalNormalizer(AbstractNormalizer, ABC):
36+
@classmethod
37+
def registry_key_family(cls) -> str:
38+
return cls.__name__
39+
40+
class HandleNullsAndTrimNormalizer(UniversalNormalizer):
41+
@classmethod
42+
def registry_key(cls) -> str:
43+
return cls.__name__
44+
45+
def normalize(self, column: e.ExpressionBuilder, dialect: e.DialectType, column_def: ExternalColumnDefinition) -> e.ExpressionBuilder:
46+
return e.coalesce(e.trim(column), "__null_recon__", is_string=True)
47+
48+
class QuoteIdentifierNormalizer(UniversalNormalizer):
49+
@classmethod
50+
def registry_key(cls) -> str:
51+
return cls.__name__
52+
53+
def normalize(self, column: e.ExpressionBuilder, dialect: e.DialectType, column_def: ExternalColumnDefinition) -> e.ExpressionBuilder:
54+
match dialect:
55+
case "oracle": return self.normalize_oracle(column, column_def)
56+
case "databricks": return self.normalize_databricks(column, column_def)
57+
case "snowflake": return self.normalize_snowflake(column, column_def)
58+
case _: return column # instead of error, return as is
59+
60+
def normalize_oracle(self, column: e.ExpressionBuilder, column_def: ExternalColumnDefinition) -> e.ExpressionBuilder:
61+
normalized = DialectUtils.normalize_identifier(
62+
column_def.name,
63+
source_start_delimiter='"',
64+
source_end_delimiter='"',
65+
).source_normalized
66+
return column.column_name(normalized)
67+
68+
def normalize_databricks(self, column: e.ExpressionBuilder, column_def: ExternalColumnDefinition) -> e.ExpressionBuilder:
69+
normalized = DialectUtils.ansi_normalize_identifier(column_def.name)
70+
return column.column_name(normalized)
71+
72+
def normalize_snowflake(self, column: e.ExpressionBuilder, column_def: ExternalColumnDefinition) -> e.ExpressionBuilder:
73+
normalized = DialectUtils.normalize_identifier(
74+
column_def.name,
75+
source_start_delimiter='"',
76+
source_end_delimiter='"',
77+
).source_normalized
78+
return column.column_name(normalized)
79+
80+
81+
class AbstractTypeNormalizer(AbstractNormalizer):
82+
@classmethod
83+
def registry_key_family(cls) -> str:
84+
return cls.__name__
85+
86+
@classmethod
87+
@abstractmethod
88+
def utype(cls) -> UType:
89+
pass
90+
91+
def normalize(self, column: e.ExpressionBuilder, dialect: str, column_def: ExternalColumnDefinition) -> e.ExpressionBuilder:
92+
match dialect:
93+
case "oracle": return self.normalize_oracle(column, column_def)
94+
case "databricks": return self.normalize_databricks(column, column_def)
95+
case "snowflake": return self.normalize_snowflake(column, column_def)
96+
case _: return column # instead of error, return as is
97+
98+
@abstractmethod
99+
def normalize_oracle(self, column: e.ExpressionBuilder, column_def: ExternalColumnDefinition) -> e.ExpressionBuilder:
100+
pass
101+
102+
@abstractmethod
103+
def normalize_databricks(self, column: e.ExpressionBuilder, column_def: ExternalColumnDefinition) -> e.ExpressionBuilder:
104+
pass
105+
106+
@abstractmethod
107+
def normalize_snowflake(self, column: e.ExpressionBuilder, column_def: ExternalColumnDefinition) -> e.ExpressionBuilder:
108+
pass
109+
110+
class UDatetimeTypeNormalizer(AbstractTypeNormalizer):
111+
"""
112+
transform all dialects to unix time
113+
"""
114+
115+
@classmethod
116+
def registry_key(cls) -> str:
117+
return cls.utype().name.name
118+
119+
@classmethod
120+
def utype(cls) -> UType:
121+
return UType(DatabaseTypeName.DATETIME)
122+
123+
def normalize_oracle(self, column: e.ExpressionBuilder, source_col: ExternalColumnDefinition) -> e.ExpressionBuilder:
124+
return column
125+
126+
def normalize_databricks(self, column: e.ExpressionBuilder, source_col: ExternalColumnDefinition) -> e.ExpressionBuilder:
127+
return e.unix_time(column)
128+
129+
def normalize_snowflake(self, column: e.ExpressionBuilder, source_col: ExternalColumnDefinition) -> e.ExpressionBuilder:
130+
return column
131+
132+
class UStringTypeNormalizer(AbstractTypeNormalizer):
133+
134+
_delegate = HandleNullsAndTrimNormalizer()
135+
136+
@classmethod
137+
def registry_key(cls) -> str:
138+
return cls.utype().name.name
139+
140+
@classmethod
141+
def utype(cls) -> UType:
142+
return UType(cls.__name__.removesuffix("TypeNormalizer").upper())
143+
144+
def normalize_oracle(self, column: e.ExpressionBuilder,
145+
column_def: ExternalColumnDefinition) -> e.ExpressionBuilder:
146+
return self._delegate.normalize(column, "", column_def)
147+
148+
def normalize_databricks(self, column: e.ExpressionBuilder,
149+
column_def: ExternalColumnDefinition) -> e.ExpressionBuilder:
150+
return self._delegate.normalize(column, "", column_def)
151+
152+
def normalize_snowflake(self, column: e.ExpressionBuilder,
153+
column_def: ExternalColumnDefinition) -> e.ExpressionBuilder:
154+
return self._delegate.normalize(column, "", column_def)
155+
156+
157+
class NormalizersRegistry:
158+
registry: dict[str,dict[str, AbstractNormalizer]] = {} # can we type this to subclass of AbstractTypeNormalizer
159+
160+
def register_normalizer(self, normalizer: AbstractNormalizer): # also subclasses
161+
family = self.registry.get(normalizer.registry_key_family(), {})
162+
if family.get(normalizer.registry_key()):
163+
raise ValueError(f"Normalizer already registered for utype: {normalizer.registry_key_family()},{normalizer.registry_key()}")
164+
if not family:
165+
self.registry[normalizer.registry_key_family()] = {}
166+
self.registry[normalizer.registry_key_family()][normalizer.registry_key()] = normalizer
167+
168+
def get_type_normalizer(self, name: DatabaseTypeName) -> AbstractTypeNormalizer | None:
169+
return self.registry.get(AbstractTypeNormalizer.registry_key_family()).get(name.name)
170+
171+
class DialectNormalizer(ABC):
172+
DbTypeNormalizerType = dict[DatabaseTypeName, DatabaseTypeName]
173+
# or ExternalType to UType. what about extra type information e.g scale, precision?
174+
175+
dialect: e.DialectType
176+
177+
@classmethod
178+
def type_normalizers(cls, overrides: DbTypeNormalizerType) -> DbTypeNormalizerType:
179+
return {
180+
DatabaseTypeName("DATE"): UDatetimeTypeNormalizer.utype().name,
181+
**overrides
182+
}
183+
184+
@abstractmethod
185+
def normalize(self, column_def: ExternalColumnDefinition, registry: NormalizersRegistry) -> e.ExpressionBuilder:
186+
pass
187+
188+
189+
class OracleNormalizer(DialectNormalizer):
190+
dialect = "oracle"
191+
192+
def normalize(self, column_def: ExternalColumnDefinition, registry: NormalizersRegistry) -> e.ExpressionBuilder:
193+
start = e.ExpressionBuilder(column_def.name, self.dialect, None)
194+
utype = self.type_normalizers({}).get(column_def.data_type.name)
195+
if utype:
196+
normalizer = registry.get_type_normalizer(utype)
197+
if normalizer:
198+
return normalizer.normalize_oracle(start, column_def)
199+
return start
200+
201+
if __name__ == "__main__":
202+
registry = NormalizersRegistry()
203+
registry.register_normalizer(UDatetimeTypeNormalizer())
204+
oracle = OracleNormalizer()
205+
206+
column = ExternalColumnDefinition("student_id", ExternalType(DatabaseTypeName["NCHAR"]))
207+
column_builder = oracle.normalize(column, registry)
208+
209+
sql = column_builder.build()
210+
print(sql)
211+
assert sql == "SELECT student_id"

0 commit comments

Comments
 (0)