Skip to content

Commit 2413232

Browse files
committed
introduce enum
1 parent 394b35d commit 2413232

File tree

4 files changed

+31
-32
lines changed

4 files changed

+31
-32
lines changed

src/databricks/labs/lakebridge/reconcile/design/expressions.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import typing as t
33

44
import sqlglot.expressions as e
5-
import sqlglot.time as et
65
from sqlglot.dialects import Dialect as SqlglotDialect
76

87
DialectType = t.Union[str, SqlglotDialect, t.Type[SqlglotDialect], None]

src/databricks/labs/lakebridge/reconcile/design/normalizers.py

Lines changed: 18 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -3,24 +3,9 @@
33

44
import expressions as e
55
from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import DialectUtils
6+
from utypes import ExternalType, UType, DatabaseTypeName
67

78

8-
class Type:
9-
name: str
10-
11-
def __eq__(self, other):
12-
return isinstance(other, Type) and self.name.lower() == other.name.lower()
13-
def __hash__(self):
14-
return hash(self.name.lower())
15-
16-
@dataclasses.dataclass(frozen=True)
17-
class UType(Type):
18-
name: str
19-
20-
@dataclasses.dataclass(frozen=True)
21-
class ExternalType(Type): # can be sybtyped if needed to override equal e.g. combining all oracle char columns in one
22-
name: str
23-
249
@dataclasses.dataclass(frozen=True)
2510
class ExternalColumnDefinition:
2611
name: str
@@ -129,11 +114,11 @@ class UDatetimeTypeNormalizer(AbstractTypeNormalizer):
129114

130115
@classmethod
131116
def registry_key(cls) -> str:
132-
return cls.utype().name
117+
return cls.utype().name.name
133118

134119
@classmethod
135120
def utype(cls) -> UType:
136-
return UType(cls.__name__.removesuffix("TypeNormalizer").upper())
121+
return UType(DatabaseTypeName.DATETIME)
137122

138123
def normalize_oracle(self, column: e.ExpressionBuilder, source_col: ExternalColumnDefinition) -> e.ExpressionBuilder:
139124
return column
@@ -150,7 +135,7 @@ class UStringTypeNormalizer(AbstractTypeNormalizer):
150135

151136
@classmethod
152137
def registry_key(cls) -> str:
153-
return cls.utype().name
138+
return cls.utype().name.name
154139

155140
@classmethod
156141
def utype(cls) -> UType:
@@ -180,27 +165,33 @@ def register_normalizer(self, normalizer: AbstractNormalizer): # also subclasses
180165
self.registry[normalizer.registry_key_family()] = {}
181166
self.registry[normalizer.registry_key_family()][normalizer.registry_key()] = normalizer
182167

183-
def get_type_normalizer(self, utype: UType) -> AbstractTypeNormalizer | None:
184-
return self.registry.get(AbstractTypeNormalizer.registry_key_family()).get(utype.name)
168+
def get_type_normalizer(self, name: DatabaseTypeName) -> AbstractTypeNormalizer | None:
169+
return self.registry.get(AbstractTypeNormalizer.registry_key_family()).get(name.name)
185170

186171
class DialectNormalizer(ABC):
187-
type_normalizers: dict[ExternalType, UType] # do we need more info
172+
DbTypeNormalizerType = dict[DatabaseTypeName, DatabaseTypeName]
173+
# or ExternalType to UType. what about extra type information e.g scale, precision?
174+
188175
dialect: e.DialectType
189176

177+
@classmethod
178+
def type_normalizers(cls, overrides: DbTypeNormalizerType) -> DbTypeNormalizerType:
179+
return {
180+
DatabaseTypeName("DATE"): UDatetimeTypeNormalizer.utype().name,
181+
**overrides
182+
}
183+
190184
@abstractmethod
191185
def normalize(self, column_def: ExternalColumnDefinition, registry: NormalizersRegistry) -> e.ExpressionBuilder:
192186
pass
193187

194188

195189
class OracleNormalizer(DialectNormalizer):
196-
type_normalizers = {
197-
ExternalType("DATE"): UDatetimeTypeNormalizer.utype()
198-
}
199190
dialect = "oracle"
200191

201192
def normalize(self, column_def: ExternalColumnDefinition, registry: NormalizersRegistry) -> e.ExpressionBuilder:
202193
start = e.ExpressionBuilder(column_def.name, self.dialect, None)
203-
utype = self.type_normalizers.get(column_def.data_type)
194+
utype = self.type_normalizers({}).get(column_def.data_type.name)
204195
if utype:
205196
normalizer = registry.get_type_normalizer(utype)
206197
if normalizer:
@@ -212,13 +203,9 @@ def normalize(self, column_def: ExternalColumnDefinition, registry: NormalizersR
212203
registry.register_normalizer(UDatetimeTypeNormalizer())
213204
oracle = OracleNormalizer()
214205

215-
column = ExternalColumnDefinition("student_id", ExternalType("NCHAR"))
206+
column = ExternalColumnDefinition("student_id", ExternalType(DatabaseTypeName["NCHAR"]))
216207
column_builder = oracle.normalize(column, registry)
217208

218209
sql = column_builder.build()
219210
print(sql)
220211
assert sql == "SELECT student_id"
221-
222-
223-
224-
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from enum import Enum
2+
3+
4+
class AutoName(Enum):
5+
"""
6+
This is used for creating Enum classes where `auto()` is the string form
7+
of the corresponding enum's identifier (e.g. FOO.value results in "FOO").
8+
9+
Reference: https://docs.python.org/3/howto/enum.html#using-automatic-values
10+
"""
11+
12+
def _generate_next_value_(name, _start, _count, _last_values):
13+
return name

0 commit comments

Comments
 (0)