33
44import expressions as e
55from databricks .labs .lakebridge .reconcile .connectors .dialect_utils import DialectUtils
6+ from utypes import ExternalType , UType , DatabaseTypeName
67
78
8- class Type :
9- name : str
10-
11- def __eq__ (self , other ):
12- return isinstance (other , Type ) and self .name .lower () == other .name .lower ()
13- def __hash__ (self ):
14- return hash (self .name .lower ())
15-
16- @dataclasses .dataclass (frozen = True )
17- class UType (Type ):
18- name : str
19-
20- @dataclasses .dataclass (frozen = True )
21- class ExternalType (Type ): # can be sybtyped if needed to override equal e.g. combining all oracle char columns in one
22- name : str
23-
249@dataclasses .dataclass (frozen = True )
2510class ExternalColumnDefinition :
2611 name : str
@@ -129,11 +114,11 @@ class UDatetimeTypeNormalizer(AbstractTypeNormalizer):
129114
130115 @classmethod
131116 def registry_key (cls ) -> str :
132- return cls .utype ().name
117+ return cls .utype ().name . name
133118
134119 @classmethod
135120 def utype (cls ) -> UType :
136- return UType (cls . __name__ . removesuffix ( "TypeNormalizer" ). upper () )
121+ return UType (DatabaseTypeName . DATETIME )
137122
138123 def normalize_oracle (self , column : e .ExpressionBuilder , source_col : ExternalColumnDefinition ) -> e .ExpressionBuilder :
139124 return column
@@ -150,7 +135,7 @@ class UStringTypeNormalizer(AbstractTypeNormalizer):
150135
151136 @classmethod
152137 def registry_key (cls ) -> str :
153- return cls .utype ().name
138+ return cls .utype ().name . name
154139
155140 @classmethod
156141 def utype (cls ) -> UType :
@@ -180,27 +165,33 @@ def register_normalizer(self, normalizer: AbstractNormalizer): # also subclasses
180165 self .registry [normalizer .registry_key_family ()] = {}
181166 self .registry [normalizer .registry_key_family ()][normalizer .registry_key ()] = normalizer
182167
183- def get_type_normalizer (self , utype : UType ) -> AbstractTypeNormalizer | None :
184- return self .registry .get (AbstractTypeNormalizer .registry_key_family ()).get (utype .name )
168+ def get_type_normalizer (self , name : DatabaseTypeName ) -> AbstractTypeNormalizer | None :
169+ return self .registry .get (AbstractTypeNormalizer .registry_key_family ()).get (name .name )
185170
186171class DialectNormalizer (ABC ):
187- type_normalizers : dict [ExternalType , UType ] # do we need more info
172+ DbTypeNormalizerType = dict [DatabaseTypeName , DatabaseTypeName ]
173+ # or ExternalType to UType. what about extra type information e.g scale, precision?
174+
188175 dialect : e .DialectType
189176
177+ @classmethod
178+ def type_normalizers (cls , overrides : DbTypeNormalizerType ) -> DbTypeNormalizerType :
179+ return {
180+ DatabaseTypeName ("DATE" ): UDatetimeTypeNormalizer .utype ().name ,
181+ ** overrides
182+ }
183+
190184 @abstractmethod
191185 def normalize (self , column_def : ExternalColumnDefinition , registry : NormalizersRegistry ) -> e .ExpressionBuilder :
192186 pass
193187
194188
195189class OracleNormalizer (DialectNormalizer ):
196- type_normalizers = {
197- ExternalType ("DATE" ): UDatetimeTypeNormalizer .utype ()
198- }
199190 dialect = "oracle"
200191
201192 def normalize (self , column_def : ExternalColumnDefinition , registry : NormalizersRegistry ) -> e .ExpressionBuilder :
202193 start = e .ExpressionBuilder (column_def .name , self .dialect , None )
203- utype = self .type_normalizers .get (column_def .data_type )
194+ utype = self .type_normalizers ({}) .get (column_def .data_type . name )
204195 if utype :
205196 normalizer = registry .get_type_normalizer (utype )
206197 if normalizer :
@@ -212,13 +203,9 @@ def normalize(self, column_def: ExternalColumnDefinition, registry: NormalizersR
212203 registry .register_normalizer (UDatetimeTypeNormalizer ())
213204 oracle = OracleNormalizer ()
214205
215- column = ExternalColumnDefinition ("student_id" , ExternalType ("NCHAR" ))
206+ column = ExternalColumnDefinition ("student_id" , ExternalType (DatabaseTypeName [ "NCHAR" ] ))
216207 column_builder = oracle .normalize (column , registry )
217208
218209 sql = column_builder .build ()
219210 print (sql )
220211 assert sql == "SELECT student_id"
221-
222-
223-
224-
0 commit comments