Skip to content

Commit c48ac1f

Browse files
authored
feat: alias nested schemas added (#110)
1 parent fdccbbd commit c48ac1f

10 files changed

Lines changed: 245 additions & 28 deletions

dataclasses_avroschema/fields.py

Lines changed: 43 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,10 @@ class BaseField:
103103
type: typing.Any # store the python primitive type
104104
default: typing.Any
105105
metadata: typing.Mapping = dataclasses.field(default_factory=dict)
106+
model_metadata: typing.Optional[utils.SchemaMetadata] = None
107+
108+
def __post_init__(self) -> None:
109+
self.model_metadata = self.model_metadata or utils.SchemaMetadata()
106110

107111
@staticmethod
108112
def _get_self_reference_type(a_type: typing.Any) -> str:
@@ -296,10 +300,14 @@ def generate_items_type(self) -> typing.Any:
296300

297301
if utils.is_union(items_type):
298302
self.items_type = UnionField(
299-
name, items_type, default=self.default, default_factory=self.default_factory
303+
name,
304+
items_type,
305+
default=self.default,
306+
default_factory=self.default_factory,
307+
model_metadata=self.model_metadata,
300308
).get_avro_type()
301309
else:
302-
self.internal_field = AvroField(name, items_type)
310+
self.internal_field = AvroField(name, items_type, model_metadata=self.model_metadata)
303311
self.items_type = self.internal_field.get_avro_type()
304312

305313
def fake(self) -> typing.List:
@@ -347,7 +355,7 @@ def generate_values_type(self) -> typing.Any:
347355
values_type = self.type.__args__[1]
348356

349357
name = self.get_singular_name(self.name)
350-
self.internal_field = AvroField(name, values_type)
358+
self.internal_field = AvroField(name, values_type, model_metadata=self.model_metadata)
351359
self.values_type = self.internal_field.get_avro_type()
352360

353361
def fake(self) -> typing.Dict[str, typing.Any]:
@@ -388,13 +396,15 @@ def generate_unions_type(self) -> typing.List:
388396
unions.insert(0, NULL)
389397
elif type(self.default) is not dataclasses._MISSING_TYPE:
390398
default_type = type(self.default)
391-
default_field = AvroField(name, default_type)
399+
default_field = AvroField(name, default_type, model_metadata=self.model_metadata)
392400
unions.append(default_field.get_avro_type())
393401
self.internal_fields.append(default_field)
394402

395403
for element in elements:
404+
405+
print(self.model_metadata, "jjsjs")
396406
# create the field and get the avro type
397-
field = AvroField(name, element)
407+
field = AvroField(name, element, model_metadata=self.model_metadata)
398408

399409
if field.get_avro_type() not in unions:
400410
unions.append(field.get_avro_type())
@@ -646,20 +656,28 @@ def fake(self) -> uuid.UUID:
646656
class RecordField(BaseField):
647657
def get_avro_type(self) -> typing.Union[typing.List, typing.Dict]:
648658
record_type = self.type.avro_schema_to_python()
659+
record_type["name"] = self.record_name
660+
661+
if self.default is None:
662+
return [NULL, record_type]
663+
return record_type
664+
665+
@property
666+
def record_name(self) -> str:
667+
alias = self.model_metadata.get_alias(self.name) # type: ignore
668+
669+
if alias:
670+
return alias
649671

650672
# when there is a nested record replace its name
651-
# to avoid name colisions
673+
# to avoid name collisions
652674
record_name = self.type.__name__.lower()
653675
if record_name not in self.name:
654676
name = f"{self.name}_{record_name}_record"
655677
else:
656678
name = f"{self.name}_record"
657679

658-
record_type["name"] = name
659-
660-
if self.default is None:
661-
return [NULL, record_type]
662-
return record_type
680+
return name
663681

664682
def fake(self) -> typing.Any:
665683
return self.type.fake()
@@ -797,16 +815,21 @@ def field_factory(
797815
default: typing.Any = dataclasses.MISSING,
798816
default_factory: typing.Any = dataclasses.MISSING,
799817
metadata: typing.Mapping = dataclasses.field(default_factory=dict),
818+
model_metadata: typing.Optional[utils.SchemaMetadata] = None,
800819
) -> FieldType:
801820
if native_type in PYTHON_INMUTABLE_TYPES:
802821
klass = INMUTABLE_FIELDS_CLASSES[native_type]
803-
return klass(name=name, type=native_type, default=default, metadata=metadata)
822+
return klass(name=name, type=native_type, default=default, metadata=metadata, model_metadata=model_metadata)
804823
elif utils.is_self_referenced(native_type):
805-
return SelfReferenceField(name=name, type=native_type, default=default, metadata=metadata)
824+
return SelfReferenceField(
825+
name=name, type=native_type, default=default, metadata=metadata, model_metadata=model_metadata
826+
)
806827
elif native_type is types.Fixed:
807-
return FixedField(name=name, type=native_type, default=default, metadata=metadata)
828+
return FixedField(
829+
name=name, type=native_type, default=default, metadata=metadata, model_metadata=model_metadata
830+
)
808831
elif native_type is types.Enum:
809-
return EnumField(name=name, type=native_type, default=default, metadata=metadata)
832+
return EnumField(name=name, type=native_type, default=default, metadata=metadata, model_metadata=model_metadata)
810833
elif isinstance(native_type, GenericAlias): # type: ignore
811834
origin = native_type.__origin__
812835

@@ -833,12 +856,15 @@ def field_factory(
833856
default=default,
834857
metadata=metadata,
835858
default_factory=default_factory,
859+
model_metadata=model_metadata,
836860
)
837861
elif native_type in PYTHON_LOGICAL_TYPES:
838862
klass = LOGICAL_TYPES_FIELDS_CLASSES[native_type] # type: ignore
839-
return klass(name=name, type=native_type, default=default, metadata=metadata)
863+
return klass(name=name, type=native_type, default=default, metadata=metadata, model_metadata=model_metadata)
840864
elif inspect.isclass(native_type) and issubclass(native_type, schema_generator.AvroModel):
841-
return RecordField(name=name, type=native_type, default=default, metadata=metadata)
865+
return RecordField(
866+
name=name, type=native_type, default=default, metadata=metadata, model_metadata=model_metadata
867+
)
842868
else:
843869
msg = (
844870
f"Type {native_type} is unknown. Please check the valid types at "

dataclasses_avroschema/schema_definition.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,8 @@ def is_faust_record(self) -> bool:
4646

4747
@dataclasses.dataclass
4848
class AvroSchemaDefinition(BaseSchemaDefinition):
49-
aliases: typing.List[str] = dataclasses.field(default_factory=list)
50-
namespace: typing.Optional[str] = None
49+
metadata: utils.SchemaMetadata
5150
fields: typing.List[FieldType] = dataclasses.field(default_factory=list)
52-
metadata: utils.SchemaMetadata = dataclasses.field(default_factory=utils.SchemaMetadata)
5351

5452
def __post_init__(self) -> None:
5553
self.fields = self.parse_dataclasses_fields()
@@ -67,6 +65,7 @@ def parse_fields(self) -> typing.List[FieldType]:
6765
dataclass_field.default,
6866
dataclass_field.default_factory, # type: ignore # TODO: resolve mypy
6967
dataclass_field.metadata,
68+
self.metadata,
7069
)
7170
for dataclass_field in dataclasses.fields(self.klass)
7271
]
@@ -87,7 +86,11 @@ def parse_faust_record_fields(self) -> typing.List[FieldType]:
8786
default_factory = default.default_factory # type: ignore # TODO: resolve mypy
8887
default = dataclasses.MISSING
8988

90-
schema_fields.append(AvroField(dataclass_field.name, dataclass_field.type, default, default_factory))
89+
schema_fields.append(
90+
AvroField(
91+
dataclass_field.name, dataclass_field.type, default, default_factory, model_metadata=self.metadata
92+
)
93+
)
9194

9295
return schema_fields
9396

dataclasses_avroschema/schema_generator.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,6 @@ def generate_dataclass(cls: typing.Any) -> typing.Any:
3333
def generate_metadata(cls: typing.Any) -> SchemaMetadata:
3434
meta = getattr(cls.klass, "Meta", None)
3535

36-
if meta is None:
37-
return SchemaMetadata()
3836
return SchemaMetadata.create(meta)
3937

4038
@classmethod

dataclasses_avroschema/utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import typing
2-
from dataclasses import dataclass
2+
from dataclasses import dataclass, field
33
from datetime import datetime
44

55
from pytz import utc
@@ -56,6 +56,7 @@ class SchemaMetadata:
5656
schema_doc: bool = True
5757
namespace: typing.Optional[typing.List[str]] = None
5858
aliases: typing.Optional[typing.List[str]] = None
59+
alias_nested_items: typing.Dict[str, str] = field(default_factory=dict)
5960

6061
@classmethod
6162
def create(cls, klass: typing.Any) -> typing.Any:
@@ -64,8 +65,12 @@ def create(cls, klass: typing.Any) -> typing.Any:
6465
schema_doc=getattr(klass, "schema_doc", True),
6566
namespace=getattr(klass, "namespace", None),
6667
aliases=getattr(klass, "aliases", None),
68+
alias_nested_items=getattr(klass, "alias_nested_items", {}),
6769
)
6870

71+
def get_alias(self, name: str) -> typing.Optional[str]:
72+
return self.alias_nested_items.get(name)
73+
6974

7075
epoch: datetime = datetime(1970, 1, 1, tzinfo=utc)
7176
epoch_naive: datetime = datetime(1970, 1, 1)

docs/complex_types.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -401,16 +401,18 @@ User.avro_schema()
401401

402402
#### Class Meta
403403

404-
The `class Meta` is used to specify schema attributes that are not represented by the class fields like `namespace`, `aliases` and whether to include the `schema documentation`. One can also provide a custom schema name (the default is the class' name) via `schema_name` attribute.
404+
The `class Meta` is used to specify schema attributes that are not represented by the class fields like `namespace`, `aliases` and whether to include the `schema documentation`. One can also provide a custom schema name (the default is the class' name) via `schema_name` attribute and `alias_nested_items` when you have nested items and you want to use custom naming for them.
405405

406406
```python
407407
class Meta:
408408
schema_name = "Name other than the class name"
409409
schema_doc = False
410410
namespace = "test.com.ar/user/v1"
411411
aliases = ["User", "My favorite User"]
412+
alias_nested_items = {"address": "Address"}
412413
```
413414

414415
`schema_doc (boolean)`: Whether include the `schema documentation` generated from `docstrings`. Default `True`
415416
`namespace (optional[str])`: Schema namespace. Default `None`
416417
`aliases (optional[List[str]])`: Schema aliases. Default `None`
418+
`alias_nested_items (optional[Dict[str, str]])`: Nested items names

docs/schema_relationships.md

Lines changed: 62 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -261,10 +261,10 @@ User.avro_schema()
261261
}'
262262
```
263263

264-
## Avoid name colision in multiple relationships
264+
## Avoid name collision in multiple relationships
265265

266266
Sometimes we have relationships where a class is related more than once with a particular class,
267-
and the name for the nested schemas must be diferent, otherwise we will generate an invalid `avro schema`.
267+
and the name for the nested schemas must be different, otherwise we will generate an invalid `avro schema`.
268268

269269
For example:
270270

@@ -290,7 +290,7 @@ class Trip(AvroModel):
290290
finish_location: Location # second relationship
291291
```
292292

293-
In order to avoid name colisions, the nested name is generated in the following way:
293+
In order to avoid name collision, the nested name is generated in the following way:
294294

295295
1. Get the lower name of the related class
296296
2. Get the field name
@@ -361,3 +361,62 @@ Example for start_location:
361361
"doc": "Trip(start_time: datetime.datetime, start_location: __main__.Location, finish_time: datetime.datetime, finish_location: __main__.Location)"
362362
}'
363363
```
364+
365+
If you want, also you can use custom name for nested items (`nested records`, `arrays` or `maps`) using the property `alias_nested_items` in `class Meta`:
366+
367+
```python
368+
from dataclasses_avroschema import AvroModel
369+
370+
371+
class Address(AvroModel):
372+
"An Address"
373+
street: str
374+
street_number: int
375+
376+
class User(AvroModel):
377+
"An User with Address"
378+
name: str
379+
age: int
380+
address: Address # default name address_record
381+
382+
class Meta:
383+
alias_nested_items = {"address": "MySuperAddress"}
384+
```
385+
386+
`User.avro_schema()` will generate:
387+
388+
```json
389+
{
390+
"type": "record",
391+
"name": "User",
392+
"fields": [
393+
{
394+
"name": "name",
395+
"type": "string"
396+
},
397+
{
398+
"name": "age",
399+
"type": "long"
400+
},
401+
{
402+
"name": "address",
403+
"type": {
404+
"type": "record",
405+
"name": "MySuperAddress", // renamed it using alias_nested_items
406+
"fields": [
407+
{
408+
"name": "street",
409+
"type": "string"
410+
},
411+
{
412+
"name": "street_number",
413+
"type": "long"
414+
}
415+
],
416+
"doc": "An Address"
417+
}
418+
}
419+
],
420+
"doc": "An User with Address"
421+
}
422+
```
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
{
2+
"type": "record",
3+
"name": "User",
4+
"fields": [
5+
{
6+
"name": "name",
7+
"type": "string"
8+
},
9+
{
10+
"name": "age",
11+
"type": "long"
12+
},
13+
{
14+
"name": "addresses",
15+
"type": {
16+
"type": "map",
17+
"values": {
18+
"type": "record",
19+
"name": "Address",
20+
"fields": [
21+
{
22+
"name": "street",
23+
"type": "string"
24+
},
25+
{
26+
"name": "street_number",
27+
"type": "long"
28+
}
29+
],
30+
"doc": "An Address"
31+
},
32+
"name": "address"
33+
}
34+
}
35+
],
36+
"doc": "User with multiple Address"
37+
}
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
{
2+
"type": "record",
3+
"name": "User",
4+
"fields": [
5+
{
6+
"name": "name",
7+
"type": "string"
8+
},
9+
{
10+
"name": "age",
11+
"type": "long"
12+
},
13+
{
14+
"name": "address",
15+
"type": {
16+
"type": "record",
17+
"name": "Address",
18+
"fields": [
19+
{
20+
"name": "street",
21+
"type": "string"
22+
},
23+
{
24+
"name": "street_number",
25+
"type": "long"
26+
}
27+
],
28+
"doc": "An Address"
29+
}
30+
}
31+
],
32+
"doc": "An User with Address"
33+
}

0 commit comments

Comments
 (0)