Skip to content

Commit 0c58e6a

Browse files
committed
add some precision tolerance to float enum conversions
1 parent 2038ddd commit 0c58e6a

File tree

2 files changed

+53
-2
lines changed

2 files changed

+53
-2
lines changed

README.rst

+3-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,9 @@
3131
.. |Code Style| image:: https://img.shields.io/badge/code%20style-black-000000.svg
3232
:target: https://github.com/psf/black
3333

34-
.. |Postgres| image:: https://img.shields.io/badge/Postgres-9.5%2B-blue
34+
---------------------------------------------------------------------------------------------------
35+
36+
.. |Postgres| image:: https://img.shields.io/badge/Postgres-9.6%2B-blue
3537
:target: https://www.postgresql.org/
3638

3739
.. |MySQL| image:: https://img.shields.io/badge/MySQL-5.7%2B-blue

django_enum/fields.py

+50-1
Original file line numberDiff line numberDiff line change
@@ -453,6 +453,10 @@ class to ensure the same object layout and then use the same "weird"
453453
obj.__dict__ = self.__dict__.copy()
454454
return obj
455455

456+
def _fallback(self, value: Any) -> Any:
457+
"""Allow deriving classes to implement a final fallback coercion attempt."""
458+
return value
459+
456460
def _try_coerce(self, value: Any, force: bool = False) -> Union[Enum, Any]:
457461
"""
458462
Attempt coercion of value to enumeration type instance, if unsuccessful
@@ -482,7 +486,10 @@ def _try_coerce(self, value: Any, force: bool = False) -> Union[Enum, Any]:
482486
)
483487
except Exception: # pylint: disable=W0703
484488
pass
485-
if self.strict or not isinstance(value, self.primitive):
489+
value = self._fallback(value)
490+
if not isinstance(value, self.enum) and (
491+
self.strict or not isinstance(value, self.primitive)
492+
):
486493
raise ValueError(
487494
f"'{value}' is not a valid "
488495
f"{self.enum.__name__} required by field "
@@ -774,10 +781,52 @@ def __init__(
774781
class EnumFloatField(EnumField[Type[float]], FloatField):
775782
"""A database field supporting enumerations with floating point values"""
776783

784+
_tolerance_: float
785+
_value_primitives_: List[Tuple[float, Enum]]
786+
777787
@property
778788
def primitive(self):
779789
return EnumField.primitive.fget(self) or float # type: ignore
780790

791+
def _fallback(self, value: Any) -> Any:
792+
if value and isinstance(value, float):
793+
for en_value, en in self._value_primitives_:
794+
if abs(en_value - value) < self._tolerance_:
795+
return en
796+
return value
797+
798+
def __init__(
799+
self,
800+
enum: Optional[Type[Enum]] = None,
801+
primitive: Optional[Type[float]] = None,
802+
strict: bool = EnumField._strict_,
803+
coerce: bool = EnumField._coerce_,
804+
constrained: Optional[bool] = None,
805+
**kwargs,
806+
):
807+
super().__init__(
808+
enum=enum,
809+
primitive=primitive,
810+
strict=strict,
811+
coerce=coerce,
812+
constrained=constrained,
813+
**kwargs,
814+
)
815+
# some database backends (earlier supported versions of Postgres)
816+
# can't rely on straight equality because of floating point imprecision
817+
if self.enum:
818+
self._value_primitives_ = []
819+
for en in self.enum:
820+
if en.value is not None:
821+
self._value_primitives_.append(
822+
(self._coerce_to_value_type(en.value), en)
823+
)
824+
self._tolerance_ = (
825+
min((prim[0] * 1e-6 for prim in self._value_primitives_))
826+
if self._value_primitives_
827+
else 0.0
828+
)
829+
781830

782831
class IntEnumField(EnumField[Type[int]]):
783832
"""

0 commit comments

Comments
 (0)