Skip to content

Commit 1e44bed

Browse files
committed
fix: Fields refactor
1 parent 529c675 commit 1e44bed

2 files changed

Lines changed: 213 additions & 129 deletions

File tree

dataclasses_avroschema/fields.py

Lines changed: 212 additions & 128 deletions
Original file line numberDiff line numberDiff line change
@@ -43,83 +43,10 @@
4343

4444

4545
@dataclasses.dataclass
46-
class Field:
46+
class BaseField:
4747
name: str
4848
type: typing.Any # store the python type (Field)
4949
default: typing.Any = dataclasses.MISSING
50-
default_factory: typing.Any = None
51-
52-
# for avro array field
53-
items_type: typing.Any = None
54-
55-
# for avro enum field
56-
symbols: typing.Any = None
57-
58-
# for avro map field
59-
values_type: typing.Any = None
60-
61-
# avro type storing
62-
avro_type: typing.Any = None
63-
64-
def __post_init__(self):
65-
if isinstance(self.type, typing._GenericAlias) and not self.is_self_referenced(self.type):
66-
# Means that could be a list, tuple or dict
67-
origin = self.type.__origin__
68-
processor = self.get_processor(origin)
69-
processor()
70-
71-
self.type = origin
72-
73-
def get_processor(self, origin):
74-
"""
75-
Get processor for a specific type.
76-
77-
Supported: tuple, list, dict and typing.Type (custom types)
78-
"""
79-
if origin is list:
80-
return self._process_list_type
81-
elif origin is dict:
82-
return self._process_dict_type
83-
elif origin is tuple:
84-
return self._process_tuple_type
85-
else:
86-
# we do not accept any other typing._GenericAlias like a set
87-
# we should raise an exception
88-
raise ValueError(
89-
f"Invalid Type for field {self.name}. Accepted types are list, tuple or dict")
90-
91-
def _process_list_type(self):
92-
# because avro can have only one type, we take the first one
93-
items_type = self.type.__args__[0]
94-
95-
if items_type in PYTHON_PRIMITIVE_TYPES:
96-
self.items_type = PYTHON_TYPE_TO_AVRO[items_type]
97-
elif self.is_self_referenced(items_type):
98-
# Checking for a self reference. Maybe is a typing.ForwardRef
99-
self.items_type = self._get_self_reference_type(items_type)
100-
else:
101-
# Is Avro Record Type
102-
self.items_type = schema_generator.SchemaGenerator(
103-
items_type).avro_schema_to_python()
104-
105-
def _process_dict_type(self):
106-
"""
107-
Process typing.Dict. Avro assumes that the key of a map is always a string,
108-
so we take the second argument to determine the value type
109-
"""
110-
values_type = self.type.__args__[1]
111-
112-
if values_type in PYTHON_PRIMITIVE_TYPES:
113-
self.values_type = PYTHON_TYPE_TO_AVRO[values_type]
114-
elif self.is_self_referenced(values_type):
115-
# Checking for a self reference. Maybe is a typing.ForwardRef
116-
self.values_type = self._get_self_reference_type(values_type)
117-
else:
118-
self.values_type = schema_generator.SchemaGenerator(
119-
values_type).avro_schema_to_python()
120-
121-
def _process_tuple_type(self):
122-
self.symbols = list(self.default)
12350

12451
@staticmethod
12552
def _get_self_reference_type(a_type):
@@ -140,60 +67,6 @@ def get_singular_name(name):
14067
return singular
14168
return name
14269

143-
def get_avro_type(self) -> PythonPrimitiveTypes:
144-
if self.is_self_referenced(self.type):
145-
return self._get_self_reference_type(self.type)
146-
147-
avro_type = PYTHON_TYPE_TO_AVRO.get(self.type)
148-
149-
if self.type in PYTHON_INMUTABLE_TYPES:
150-
if self.default is not dataclasses.MISSING and self.type is not tuple:
151-
if self.default is not None:
152-
return [avro_type, NULL]
153-
# means that default value is None
154-
return [NULL, avro_type]
155-
156-
return avro_type
157-
elif self.type in PYTHON_PRIMITIVE_CONTAINERS:
158-
if self.items_type:
159-
avro_type["items"] = self.items_type
160-
elif self.values_type:
161-
avro_type["values"] = self.values_type
162-
elif self.symbols:
163-
avro_type["symbols"] = self.symbols
164-
165-
avro_type["name"] = self.get_singular_name(self.name)
166-
return avro_type
167-
else:
168-
# Assuming that is a Avro Record type.
169-
return schema_generator.SchemaGenerator(self.type).avro_schema_to_python()
170-
171-
def get_default_value(self):
172-
if self.default is not dataclasses.MISSING:
173-
if self.type in PYTHON_INMUTABLE_TYPES:
174-
if self.default is None:
175-
return NULL
176-
return self.default
177-
elif self.type is list:
178-
if self.default is None:
179-
return []
180-
elif self.type is dict:
181-
if self.default is None:
182-
return {}
183-
elif self.default_factory not in (dataclasses.MISSING, None):
184-
if self.type is list:
185-
# expeting a callable
186-
default = self.default_factory()
187-
assert isinstance(default, list), f"List is required as default for field {self.name}"
188-
189-
return default
190-
elif self.type is dict:
191-
# expeting a callable
192-
default = self.default_factory()
193-
assert isinstance(default, dict), f"Dict is required as default for field {self.name}"
194-
195-
return default
196-
19770
def render(self) -> OrderedDict:
19871
"""
19972
Render the fields base on the avro field
@@ -224,8 +97,219 @@ def render(self) -> OrderedDict:
22497

22598
return template
22699

100+
def get_default_value(self):
101+
return
102+
227103
def to_json(self) -> str:
228104
return json.dumps(self.render())
229105

230106
def to_dict(self) -> dict:
231107
return json.loads(self.to_json())
108+
109+
110+
class InmutableField(BaseField):
111+
112+
def get_avro_type(self) -> PythonPrimitiveTypes:
113+
if self.default is not dataclasses.MISSING:
114+
if self.default is not None:
115+
return [self.avro_type, NULL]
116+
# means that default value is None
117+
return [NULL, self.avro_type]
118+
119+
return self.avro_type
120+
121+
def get_default_value(self):
122+
if self.default is not dataclasses.MISSING:
123+
if self.default is None:
124+
return NULL
125+
return self.default
126+
127+
128+
@dataclasses.dataclass
129+
class StringField(InmutableField):
130+
avro_type: typing.ClassVar = STRING
131+
132+
133+
@dataclasses.dataclass
134+
class IntegerField(InmutableField):
135+
avro_type: typing.ClassVar = INT
136+
137+
138+
@dataclasses.dataclass
139+
class BooleanField(InmutableField):
140+
avro_type: typing.ClassVar = BOOLEAN
141+
142+
143+
@dataclasses.dataclass
144+
class FloatField(InmutableField):
145+
avro_type: typing.ClassVar = FLOAT
146+
147+
148+
@dataclasses.dataclass
149+
class BytesField(InmutableField):
150+
avro_type: typing.ClassVar = BYTES
151+
152+
153+
@dataclasses.dataclass
154+
class TupleField(BaseField):
155+
avro_type: typing.ClassVar = ENUM
156+
symbols: typing.Any = None
157+
default_factory: typing.Any = None
158+
159+
def __post_init__(self):
160+
self.generate_symbols()
161+
162+
def get_avro_type(self) -> PythonPrimitiveTypes:
163+
avro_type = {
164+
"type": ENUM,
165+
"symbols": self.symbols
166+
}
167+
168+
avro_type["name"] = self.get_singular_name(self.name)
169+
return avro_type
170+
171+
def generate_symbols(self):
172+
self.symbols = list(self.default)
173+
174+
175+
@dataclasses.dataclass
176+
class ListField(BaseField):
177+
avro_type: typing.ClassVar = ARRAY
178+
items_type: typing.Any = None
179+
default_factory: typing.Any = None
180+
181+
def __post_init__(self):
182+
self.generate_items_type()
183+
184+
def get_avro_type(self) -> PythonPrimitiveTypes:
185+
avro_type = {
186+
"type": ARRAY,
187+
"items": self.items_type
188+
}
189+
190+
avro_type["name"] = self.get_singular_name(self.name)
191+
return avro_type
192+
193+
def get_default_value(self):
194+
if self.default is not dataclasses.MISSING:
195+
if self.default is None:
196+
return []
197+
elif self.default_factory not in (dataclasses.MISSING, None):
198+
# expeting a callable
199+
default = self.default_factory()
200+
assert isinstance(default, list), f"List is required as default for field {self.name}"
201+
202+
return default
203+
204+
def generate_items_type(self):
205+
# because avro can have only one type, we take the first one
206+
items_type = self.type.__args__[0]
207+
208+
if items_type in PYTHON_PRIMITIVE_TYPES:
209+
self.items_type = PYTHON_TYPE_TO_AVRO[items_type]
210+
elif self.is_self_referenced(items_type):
211+
# Checking for a self reference. Maybe is a typing.ForwardRef
212+
self.items_type = self._get_self_reference_type(items_type)
213+
else:
214+
# Is Avro Record Type
215+
self.items_type = schema_generator.SchemaGenerator(
216+
items_type).avro_schema_to_python()
217+
218+
219+
@dataclasses.dataclass
220+
class DictField(BaseField):
221+
avro_type: typing.ClassVar = MAP
222+
default_factory: typing.Any = None
223+
values_type: typing.Any = None
224+
225+
def __post_init__(self):
226+
self.generate_values_type()
227+
228+
def get_avro_type(self) -> PythonPrimitiveTypes:
229+
avro_type = {
230+
"type": MAP,
231+
"values": self.values_type
232+
}
233+
234+
avro_type["name"] = self.get_singular_name(self.name)
235+
return avro_type
236+
237+
def get_default_value(self):
238+
if self.default is not dataclasses.MISSING:
239+
if self.default is None:
240+
return {}
241+
elif self.default_factory not in (dataclasses.MISSING, None):
242+
# expeting a callable
243+
default = self.default_factory()
244+
assert isinstance(default, dict), f"Dict is required as default for field {self.name}"
245+
246+
return default
247+
248+
def generate_values_type(self):
249+
"""
250+
Process typing.Dict. Avro assumes that the key of a map is always a string,
251+
so we take the second argument to determine the value type
252+
"""
253+
values_type = self.type.__args__[1]
254+
255+
if values_type in PYTHON_PRIMITIVE_TYPES:
256+
self.values_type = PYTHON_TYPE_TO_AVRO[values_type]
257+
elif self.is_self_referenced(values_type):
258+
# Checking for a self reference. Maybe is a typing.ForwardRef
259+
self.values_type = self._get_self_reference_type(values_type)
260+
else:
261+
self.values_type = schema_generator.SchemaGenerator(
262+
values_type).avro_schema_to_python()
263+
264+
265+
@dataclasses.dataclass
266+
class SelfReferenceField(BaseField):
267+
268+
def get_avro_type(self):
269+
return self._get_self_reference_type(self.type)
270+
271+
272+
@dataclasses.dataclass
273+
class RecordField(BaseField):
274+
275+
def get_avro_type(self):
276+
return schema_generator.SchemaGenerator(self.type).avro_schema_to_python()
277+
278+
279+
INMUTABLE_FIELDS_CLASSES = {
280+
bool: BooleanField,
281+
int: IntegerField,
282+
float: FloatField,
283+
bytes: BytesField,
284+
str: StringField,
285+
}
286+
287+
CONTAINER_FIELDS_CLASSES = {
288+
tuple: TupleField,
289+
list: ListField,
290+
dict: DictField
291+
}
292+
293+
294+
def field_factory(name: str, native_type: typing.Any, default: typing.Any = dataclasses.MISSING,
295+
default_factory: typing.Any = None):
296+
297+
if native_type in PYTHON_INMUTABLE_TYPES:
298+
klass = INMUTABLE_FIELDS_CLASSES[native_type]
299+
return klass(name=name, type=native_type, default=default)
300+
elif BaseField.is_self_referenced(native_type):
301+
return SelfReferenceField(name=name, type=native_type, default=default)
302+
elif isinstance(native_type, typing._GenericAlias):
303+
origin = native_type.__origin__
304+
305+
if origin not in (tuple, list, dict):
306+
raise ValueError(
307+
f"Invalid Type for field {name}. Accepted types are list, tuple or dict")
308+
309+
klass = CONTAINER_FIELDS_CLASSES[origin]
310+
return klass(name=name, type=native_type, default=default, default_factory=default_factory)
311+
else:
312+
return RecordField(name=name, type=native_type, default=default)
313+
314+
315+
Field = field_factory

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from setuptools import setup, find_packages
77

8-
__version__ = "0.4.0"
8+
__version__ = "0.4.1"
99

1010
with open("README.md") as readme_file:
1111
long_description = readme_file.read()

0 commit comments

Comments
 (0)