Skip to content

Commit 95a49b9

Browse files
committed
fix has_changed form bug, fix admin tests
1 parent abfa0f4 commit 95a49b9

File tree

2 files changed

+83
-58
lines changed

2 files changed

+83
-58
lines changed

src/django_enum/forms.py

+49-31
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
"""Enumeration support for django model forms"""
22

3+
import typing as t
34
from copy import copy
45
from decimal import DecimalException
56
from enum import Enum, Flag
67
from functools import reduce
78
from operator import or_
8-
from typing import Any, Iterable, List, Optional, Protocol, Sequence, Tuple, Type, Union
99

1010
from django.core.exceptions import ValidationError
1111
from django.forms.fields import (
@@ -42,22 +42,24 @@
4242
]
4343

4444

45-
_SelectChoices = Iterable[Union[Tuple[Any, Any], Tuple[str, Iterable[Tuple[Any, Any]]]]]
45+
_SelectChoices = t.Iterable[
46+
t.Union[t.Tuple[t.Any, t.Any], t.Tuple[str, t.Iterable[t.Tuple[t.Any, t.Any]]]]
47+
]
4648

47-
_Choice = Tuple[Any, Any]
48-
_ChoiceNamedGroup = Tuple[str, Iterable[_Choice]]
49-
_FieldChoices = Iterable[Union[_Choice, _ChoiceNamedGroup]]
49+
_Choice = t.Tuple[t.Any, t.Any]
50+
_ChoiceNamedGroup = t.Tuple[str, t.Iterable[_Choice]]
51+
_FieldChoices = t.Iterable[t.Union[_Choice, _ChoiceNamedGroup]]
5052

5153

52-
class _ChoicesCallable(Protocol):
54+
class _ChoicesCallable(t.Protocol):
5355
def __call__(self) -> _FieldChoices: ... # pragma: no cover
5456

5557

56-
_ChoicesParameter = Union[_FieldChoices, _ChoicesCallable]
58+
_ChoicesParameter = t.Union[_FieldChoices, _ChoicesCallable]
5759

5860

59-
class _CoerceCallable(Protocol):
60-
def __call__(self, value: Any, /) -> Any: ... # pragma: no cover
61+
class _CoerceCallable(t.Protocol):
62+
def __call__(self, value: t.Any, /) -> t.Any: ... # pragma: no cover
6163

6264

6365
class _Unspecified:
@@ -81,7 +83,7 @@ def render(self, *args, **kwargs):
8183
one of our choices, we add it as an option.
8284
"""
8385

84-
value: Any = getattr(kwargs.get("value"), "value", kwargs.get("value"))
86+
value: t.Any = getattr(kwargs.get("value"), "value", kwargs.get("value"))
8587
if value not in EnumChoiceField.empty_values and value not in (
8688
choice[0] for choice in self.choices
8789
):
@@ -135,9 +137,9 @@ class FlagMixin:
135137
This mixin adapts a widget to work with :class:`~enum.IntFlag` types.
136138
"""
137139

138-
enum: Optional[Type[Flag]]
140+
enum: t.Optional[t.Type[Flag]]
139141

140-
def __init__(self, enum: Optional[Type[Flag]] = None, **kwargs):
142+
def __init__(self, enum: t.Optional[t.Type[Flag]] = None, **kwargs):
141143
self.enum = enum
142144
super().__init__(**kwargs)
143145

@@ -223,29 +225,29 @@ class ChoiceFieldMixin(
223225
:param kwargs: Any additional parameters to pass to ChoiceField base class.
224226
"""
225227

226-
_enum_: Optional[Type[Enum]] = None
227-
_primitive_: Optional[Type] = None
228+
_enum_: t.Optional[t.Type[Enum]] = None
229+
_primitive_: t.Optional[t.Type] = None
228230
_strict_: bool = True
229-
empty_value: Any = ""
230-
empty_values: Sequence[Any] = list(TypedChoiceField.empty_values)
231+
empty_value: t.Any = ""
232+
empty_values: t.Sequence[t.Any] = list(TypedChoiceField.empty_values)
231233

232234
_empty_value_overridden_: bool = False
233235
_empty_values_overridden_: bool = False
234236

235237
choices: _ChoicesParameter
236238

237-
non_strict_widget: Optional[Type[ChoiceWidget]] = NonStrictSelect
239+
non_strict_widget: t.Optional[t.Type[ChoiceWidget]] = NonStrictSelect
238240

239241
def __init__(
240242
self,
241-
enum: Optional[Type[Enum]] = _enum_,
242-
primitive: Optional[Type] = _primitive_,
243+
enum: t.Optional[t.Type[Enum]] = _enum_,
244+
primitive: t.Optional[t.Type] = _primitive_,
243245
*,
244-
empty_value: Any = _Unspecified,
246+
empty_value: t.Any = _Unspecified,
245247
strict: bool = _strict_,
246-
empty_values: Union[List[Any], Type[_Unspecified]] = _Unspecified,
248+
empty_values: t.Union[t.List[t.Any], t.Type[_Unspecified]] = _Unspecified,
247249
choices: _ChoicesParameter = (),
248-
coerce: Optional[_CoerceCallable] = None,
250+
coerce: t.Optional[_CoerceCallable] = None,
249251
**kwargs,
250252
):
251253
self._strict_ = strict
@@ -328,30 +330,30 @@ def enum(self, enum):
328330
f"specify a non-conflicting empty_value."
329331
)
330332

331-
def _coerce_to_value_type(self, value: Any) -> Any:
333+
def _coerce_to_value_type(self, value: t.Any) -> t.Any:
332334
"""Coerce the value to the enumerations value type"""
333335
return self.primitive(value)
334336

335-
def prepare_value(self, value: Any) -> Any:
337+
def prepare_value(self, value: t.Any) -> t.Any:
336338
"""Must return the raw enumeration value type"""
337339
value = self._coerce(value)
338340
return super().prepare_value(
339341
value.value if isinstance(value, self.enum) else value
340342
)
341343

342-
def to_python(self, value: Any) -> Any:
344+
def to_python(self, value: t.Any) -> t.Any:
343345
"""Return the value as its full enumeration object"""
344346
return self._coerce(value)
345347

346-
def valid_value(self, value: Any) -> bool:
348+
def valid_value(self, value: t.Any) -> bool:
347349
"""Return false if this value is not valid"""
348350
try:
349351
self._coerce(value)
350352
return True
351353
except ValidationError:
352354
return False
353355

354-
def default_coerce(self, value: Any) -> Any:
356+
def default_coerce(self, value: t.Any) -> t.Any:
355357
"""
356358
Attempt conversion of value to an enumeration value and return it
357359
if successful.
@@ -421,6 +423,10 @@ class EnumMultipleChoiceField( # type: ignore
421423

422424
non_strict_widget = NonStrictSelectMultiple
423425

426+
def has_changed(self, initial, data):
427+
# TODO
428+
return super().has_changed(initial, data)
429+
424430

425431
class EnumFlagField(ChoiceFieldMixin, TypedMultipleChoiceField): # type: ignore
426432
"""
@@ -441,11 +447,11 @@ class EnumFlagField(ChoiceFieldMixin, TypedMultipleChoiceField): # type: ignore
441447

442448
def __init__(
443449
self,
444-
enum: Optional[Type[Flag]] = None,
450+
enum: t.Optional[t.Type[Flag]] = None,
445451
*,
446-
empty_value: Any = _Unspecified,
452+
empty_value: t.Any = _Unspecified,
447453
strict: bool = ChoiceFieldMixin._strict_,
448-
empty_values: Union[List[Any], Type[_Unspecified]] = _Unspecified,
454+
empty_values: t.Union[t.List[t.Any], t.Type[_Unspecified]] = _Unspecified,
449455
choices: _ChoicesParameter = (),
450456
**kwargs,
451457
):
@@ -466,11 +472,23 @@ def __init__(
466472
**kwargs,
467473
)
468474

469-
def _coerce(self, value: Any) -> Any:
475+
def _coerce(self, value: t.Any) -> t.Any:
470476
"""Combine the values into a single flag using |"""
471477
if self.enum and isinstance(value, self.enum):
472478
return value
473479
values = TypedMultipleChoiceField._coerce(self, value) # type: ignore[attr-defined]
474480
if values:
475481
return reduce(or_, values)
476482
return self.empty_value
483+
484+
def has_changed(self, initial, data):
485+
return super().has_changed(
486+
*(
487+
[str(en.value) for en in decompose(initial)]
488+
if isinstance(initial, Flag)
489+
else initial,
490+
[str(en.value) for en in decompose(data)]
491+
if isinstance(data, Flag)
492+
else data,
493+
)
494+
)

tests/test_admin.py

+34-27
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from django.test import LiveServerTestCase
99
from django.contrib.staticfiles.testing import StaticLiveServerTestCase
1010
from django_enum import EnumField
11+
from django_enum.utils import values
1112
from tests.djenum.models import (
1213
AdminDisplayBug35,
1314
EnumTester,
@@ -168,41 +169,47 @@ def setUp(self):
168169
def set_form_value(
169170
self, field_name: str, value: t.Optional[t.Union[Enum, str]], flag=False
170171
):
171-
try:
172-
if value is None and None in self.enum(field_name):
173-
value = self.enum(field_name)(value)
174-
# should override this if needed
175-
if getattr(value, "value", value) is None and not flag:
176-
if self.use_radio:
177-
self.page.click(f"input[name='{field_name}'][value='']")
178-
else:
179-
self.page.select_option(f"select[name='{field_name}']", "")
180-
elif flag:
181-
if self.use_checkbox:
182-
for checkbox in self.page.locator(
183-
f"input[type='checkbox'][name='{field_name}']"
184-
).all():
185-
if checkbox.is_checked():
186-
checkbox.uncheck()
172+
# if field_name == "constellation_null" and value is None:
173+
# import ipdb
174+
# ipdb.set_trace()
175+
if value is None and None in values(self.enum(field_name)):
176+
value = self.enum(field_name)(value)
177+
# should override this if needed
178+
if getattr(value, "value", value) is None and not flag:
179+
if self.use_radio:
180+
self.page.click(f"input[name='{field_name}'][value='']")
181+
else:
182+
self.page.select_option(f"select[name='{field_name}']", "")
183+
elif flag:
184+
if self.use_checkbox:
185+
for checkbox in self.page.locator(
186+
f"input[type='checkbox'][name='{field_name}']"
187+
).all():
188+
if checkbox.is_checked():
189+
checkbox.uncheck()
190+
if value is not None:
191+
assert isinstance(value, Flag)
187192
for flag in decompose(value):
188193
self.page.check(
189194
f"input[name='{field_name}'][value='{flag.value}']"
190195
)
191-
else:
196+
else:
197+
if value is not None:
198+
assert isinstance(value, Flag)
192199
self.page.select_option(
193200
f"select[name='{field_name}']",
194201
[str(flag.value) for flag in decompose(value)],
195202
)
196-
else:
197-
if self.use_radio:
198-
self.page.click(f"input[name='{field_name}'][value='{value}']")
199203
else:
200-
self.page.select_option(
201-
f"select[name='{field_name}']",
202-
str(getattr(value, "value", value)),
203-
)
204-
except Exception:
205-
self.page.pause()
204+
self.page.select_option(f"select[name='{field_name}']", [])
205+
else:
206+
if self.use_radio:
207+
self.page.click(f"input[name='{field_name}'][value='{value}']")
208+
else:
209+
self.page.select_option(
210+
f"select[name='{field_name}']",
211+
str(getattr(value, "value", value)),
212+
)
206213

207214
def verify_changes(self, obj: Model, expected: t.Dict[str, t.Any]):
208215
count = 0
@@ -559,7 +566,7 @@ def verify_labels(inputs, expected):
559566
text_null_radios = self.page.locator("input[type='radio'][name='text_null']")
560567
self.assertEqual(text_null_radios.count(), len(TextEnum) + 1)
561568
verify_labels(
562-
text_null_radios, [BLANK_CHOICE_DASH] + [en.label for en in TextEnum]
569+
text_null_radios, [BLANK_CHOICE_DASH[0][1]] + [en.label for en in TextEnum]
563570
)
564571

565572
# text_non_strict

0 commit comments

Comments
 (0)