|
2 | 2 | import collections |
3 | 3 | import dataclasses |
4 | 4 | import datetime |
| 5 | +import decimal |
5 | 6 | import json |
6 | 7 | import random |
7 | 8 | import typing |
|
12 | 13 | from faker import Faker |
13 | 14 | from pytz import utc |
14 | 15 |
|
15 | | -from dataclasses_avroschema import types, utils |
| 16 | +from dataclasses_avroschema import serialization, types, utils |
16 | 17 |
|
17 | 18 | fake = Faker() |
18 | 19 | p = inflect.engine() |
|
33 | 34 | TIME_MILLIS = "time-millis" |
34 | 35 | TIMESTAMP_MILLIS = "timestamp-millis" |
35 | 36 | UUID = "uuid" |
| 37 | +DECIMAL = "decimal" |
36 | 38 | LOGICAL_DATE = {"type": INT, "logicalType": DATE} |
37 | 39 | LOGICAL_TIME = {"type": INT, "logicalType": TIME_MILLIS} |
38 | 40 | LOGICAL_DATETIME = {"type": LONG, "logicalType": TIMESTAMP_MILLIS} |
|
61 | 63 |
|
62 | 64 | PYTHON_PRIMITIVE_CONTAINERS = (list, tuple, dict) |
63 | 65 |
|
64 | | -PYTHON_LOGICAL_TYPES = ( |
65 | | - datetime.date, |
66 | | - datetime.time, |
67 | | - datetime.datetime, |
68 | | - uuid.uuid4, |
69 | | - uuid.UUID, |
70 | | -) |
| 66 | +PYTHON_LOGICAL_TYPES = (datetime.date, datetime.time, datetime.datetime, uuid.uuid4, uuid.UUID, decimal.Decimal) |
71 | 67 |
|
72 | 68 | PYTHON_PRIMITIVE_TYPES = PYTHON_INMUTABLE_TYPES + PYTHON_PRIMITIVE_CONTAINERS |
73 | 69 |
|
74 | 70 | PRIMITIVE_AND_LOGICAL_TYPES = PYTHON_INMUTABLE_TYPES + PYTHON_LOGICAL_TYPES |
75 | 71 |
|
76 | 72 | PythonImnutableTypes = typing.Union[ |
77 | | - str, int, bool, float, list, tuple, dict, datetime.date, datetime.time, datetime.datetime, uuid.UUID |
| 73 | + str, |
| 74 | + int, |
| 75 | + bool, |
| 76 | + float, |
| 77 | + list, |
| 78 | + tuple, |
| 79 | + dict, |
| 80 | + datetime.date, |
| 81 | + datetime.time, |
| 82 | + datetime.datetime, |
| 83 | + uuid.UUID, |
| 84 | + decimal.Decimal, |
78 | 85 | ] |
79 | 86 |
|
80 | 87 |
|
@@ -638,6 +645,71 @@ def fake(self) -> typing.Any: |
638 | 645 | return self.type.fake() |
639 | 646 |
|
640 | 647 |
|
| 648 | +@dataclasses.dataclass |
| 649 | +class DecimalField(BaseField): |
| 650 | + |
| 651 | + precision: int = -1 |
| 652 | + scale: int = 0 |
| 653 | + |
| 654 | + def __post_init__(self) -> None: |
| 655 | + self.set_precision_scale() |
| 656 | + |
| 657 | + def set_precision_scale(self) -> None: |
| 658 | + if self.default != types.MissingSentinel: |
| 659 | + if isinstance(self.default, decimal.Decimal): |
| 660 | + sign, digits, scale = self.default.as_tuple() |
| 661 | + self.scale = scale * -1 # Make scale positive, as that's what Avro expects |
| 662 | + # decimal.Context has a precision property |
| 663 | + # BUT the precision property is independent of the number of digits stored in the Decimal instance |
| 664 | + # # # FROM THE DOCS HERE https://docs.python.org/3/library/decimal.html |
| 665 | + # The context precision does not affect how many digits are stored. |
| 666 | + # That is determined exclusively by the number of digits in value. |
| 667 | + # For example, Decimal('3.00000') records all five zeros even if the context precision is only three. |
| 668 | + # # # |
| 669 | + # Avro is concerned with *what form the number takes* and not with *handling errors in the Python env* |
| 670 | + # so we take the number of digits stored in the decimal as Avro precision |
| 671 | + self.precision = len(digits) |
| 672 | + elif isinstance(self.default, types.Decimal): |
| 673 | + self.scale = self.default.scale |
| 674 | + self.precision = self.default.precision |
| 675 | + else: |
| 676 | + raise ValueError("decimal.Decimal default types must be either decimal.Decimal or types.Decimal") |
| 677 | + else: |
| 678 | + raise ValueError( |
| 679 | + "decimal.Decimal default types must be specified to provide precision and scale," |
| 680 | + " and must be either decimal.Decimal or types.Decimal" |
| 681 | + ) |
| 682 | + |
| 683 | + # Validation on precision and scale per Avro schema |
| 684 | + if self.precision <= 0: |
| 685 | + raise ValueError("Precision must be a positive integer greater than zero") |
| 686 | + |
| 687 | + if self.scale < 0 or self.precision < self.scale: |
| 688 | + raise ValueError("Scale must be zero or a positive integer less than or equal to the precision.") |
| 689 | + |
| 690 | + # Just pull the precision from default context and default out scale |
| 691 | + # Not ideal |
| 692 | + # |
| 693 | + # self.precision = decimal.Context().prec |
| 694 | + |
| 695 | + def get_avro_type(self) -> typing.Dict[str, typing.Any]: |
| 696 | + avro_type = {"type": BYTES, "logicalType": DECIMAL, "precision": self.precision, "scale": self.scale} |
| 697 | + |
| 698 | + return avro_type |
| 699 | + |
| 700 | + def get_default_value(self) -> typing.Union[str, dataclasses._MISSING_TYPE, None]: |
| 701 | + default = self.default |
| 702 | + if isinstance(default, types.Decimal): |
| 703 | + default = default.default |
| 704 | + |
| 705 | + if default == types.MissingSentinel: |
| 706 | + return dataclasses.MISSING |
| 707 | + return serialization.decimal_to_str(default, self.precision, self.scale) |
| 708 | + |
| 709 | + def fake(self) -> decimal.Decimal: |
| 710 | + return fake.pydecimal(right_digits=self.scale, left_digits=self.precision - self.scale) |
| 711 | + |
| 712 | + |
641 | 713 | INMUTABLE_FIELDS_CLASSES = { |
642 | 714 | bool: BooleanField, |
643 | 715 | int: LongField, |
@@ -665,6 +737,7 @@ def fake(self) -> typing.Any: |
665 | 737 | uuid.uuid4: UUIDField, |
666 | 738 | uuid.UUID: UUIDField, |
667 | 739 | bytes: BytesField, |
| 740 | + decimal.Decimal: DecimalField, |
668 | 741 | } |
669 | 742 |
|
670 | 743 | PRIMITIVE_LOGICAL_TYPES_FIELDS_CLASSES = { |
@@ -742,7 +815,7 @@ def field_factory( |
742 | 815 | default_factory=default_factory, |
743 | 816 | ) |
744 | 817 | elif native_type in PYTHON_LOGICAL_TYPES: |
745 | | - klass = LOGICAL_TYPES_FIELDS_CLASSES[native_type] |
| 818 | + klass = LOGICAL_TYPES_FIELDS_CLASSES[native_type] # type: ignore |
746 | 819 | return klass(name=name, type=native_type, default=default, metadata=metadata) |
747 | 820 | else: |
748 | 821 | return RecordField(name=name, type=native_type, default=default, metadata=metadata) |
|
0 commit comments