Skip to content

Commit b13be16

Browse files
authored
Merge pull request #97 from crgwbr/fix-dja-flags
fix: editing of FlagField via EnumFlagField form field
2 parents 4a1688d + 549a1f1 commit b13be16

File tree

5 files changed

+156
-8
lines changed

5 files changed

+156
-8
lines changed

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ rest = ["djangorestframework>=3.9,<4.0"]
7777

7878
[dependency-groups]
7979
dev = [
80+
"doc8>=1.1.2",
8081
"beautifulsoup4>=4.13.3",
8182
"coverage>=7.6.12",
8283
"darglint>=1.8.1",
@@ -104,7 +105,6 @@ dev = [
104105
"typing-extensions>=4.12.2",
105106
]
106107
docs = [
107-
"doc8>=1.1.2",
108108
"docutils>=0.21.2",
109109
"furo>=2024.8.6",
110110
"readme-renderer[md]>=44.0",

src/django_enum/fields.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -673,17 +673,20 @@ def formfield(self, form_class=None, choices_form_class=None, **kwargs):
673673
)
674674

675675
is_multi = self.enum and issubclass(self.enum, Flag)
676-
if is_multi and self.enum:
676+
if is_multi:
677677
kwargs["empty_value"] = self.enum(0)
678678
# why fail? - does this fail for single select too?
679679
# kwargs['show_hidden_initial'] = True
680680

681681
if not self.strict:
682682
kwargs.setdefault(
683-
"widget", NonStrictSelectMultiple if is_multi else NonStrictSelect
683+
"widget",
684+
NonStrictSelectMultiple(enum=self.enum)
685+
if is_multi
686+
else NonStrictSelect,
684687
)
685688
elif is_multi:
686-
kwargs.setdefault("widget", FlagSelectMultiple)
689+
kwargs.setdefault("widget", FlagSelectMultiple(enum=self.enum))
687690

688691
form_field = super().formfield(
689692
form_class=form_class,

src/django_enum/forms.py

+50-1
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
"""Enumeration support for django model forms"""
22

3+
import sys
34
from copy import copy
45
from decimal import DecimalException
56
from enum import Enum, Flag
7+
from functools import reduce
8+
from operator import or_
69
from typing import Any, Iterable, List, Optional, Protocol, Sequence, Tuple, Type, Union
710

811
from django.core.exceptions import ValidationError
@@ -85,8 +88,39 @@ class FlagSelectMultiple(SelectMultiple):
8588
A SelectMultiple widget for EnumFlagFields.
8689
"""
8790

91+
enum: Optional[Type[Flag]]
8892

89-
class NonStrictSelectMultiple(NonStrictMixin, SelectMultiple):
93+
def __init__(self, enum: Optional[Type[Flag]] = None, **kwargs):
94+
self.enum = enum
95+
super().__init__(**kwargs)
96+
97+
def format_value(self, value):
98+
"""
99+
Return a list of the flag's values.
100+
"""
101+
if not isinstance(value, list):
102+
# see impl of ChoiceWidget.optgroups
103+
# it compares the string conversion of the value of each
104+
# choice tuple to the string conversion of the value
105+
# to determine selected options
106+
if self.enum:
107+
if sys.version_info < (3, 11):
108+
return [
109+
str(flg.value)
110+
for flg in self.enum
111+
if flg in self.enum(value) and flg is not self.enum(0)
112+
]
113+
else:
114+
return [str(en.value) for en in self.enum(value)]
115+
if isinstance(value, int):
116+
# automagically work for IntFlags even if we weren't given the enum
117+
return [
118+
str(1 << i) for i in range(value.bit_length()) if (value >> i) & 1
119+
]
120+
return value
121+
122+
123+
class NonStrictSelectMultiple(NonStrictMixin, FlagSelectMultiple):
90124
"""
91125
A SelectMultiple widget for non-strict EnumFlagFields that includes any
92126
existing non-conforming value as a choice option.
@@ -314,6 +348,8 @@ class EnumFlagField(ChoiceFieldMixin, TypedMultipleChoiceField): # type: ignore
314348
if strict=False, values can be outside of the enumerations
315349
"""
316350

351+
widget = FlagSelectMultiple
352+
317353
def __init__(
318354
self,
319355
enum: Optional[Type[Flag]] = None,
@@ -324,6 +360,10 @@ def __init__(
324360
choices: _ChoicesParameter = (),
325361
**kwargs,
326362
):
363+
kwargs.setdefault(
364+
"widget",
365+
self.widget(enum=enum) if strict else NonStrictSelectMultiple(enum=enum),
366+
)
327367
super().__init__(
328368
enum=enum,
329369
empty_value=(
@@ -334,3 +374,12 @@ def __init__(
334374
choices=choices,
335375
**kwargs,
336376
)
377+
378+
def _coerce(self, value: Any) -> Any:
379+
"""Combine the values into a single flag using |"""
380+
if self.enum and isinstance(value, self.enum):
381+
return value
382+
values = TypedMultipleChoiceField._coerce(self, value) # type: ignore[attr-defined]
383+
if values:
384+
return reduce(or_, values)
385+
return self.empty_value

tests/test_forms_ep.py

+97-1
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,11 @@
22

33
pytest.importorskip("enum_properties")
44
from tests.test_forms import FormTests, TestFormField
5-
from tests.enum_prop.models import EnumTester
5+
from tests.enum_prop.models import EnumTester, BitFieldModel
66
from tests.enum_prop.forms import EnumTesterForm
7+
from tests.examples.models import FlagExample
8+
from django_enum.forms import EnumFlagField, FlagSelectMultiple
9+
from django.forms import ModelForm
710

811

912
class EnumPropertiesFormTests(FormTests):
@@ -34,6 +37,99 @@ def model_params(self):
3437
"no_coerce": "Value 1",
3538
}
3639

40+
def test_flag_choices_admin_form(self):
41+
from django.contrib import admin
42+
43+
admin_class = admin.site._registry.get(BitFieldModel)
44+
self.assertIsInstance(
45+
admin_class.get_form(None).base_fields.get("bit_field_small"), EnumFlagField
46+
)
47+
48+
def test_flag_choices_model_form(self):
49+
from tests.examples.models.flag import Permissions
50+
from tests.enum_prop.enums import GNSSConstellation
51+
52+
class FlagChoicesModelForm(ModelForm):
53+
class Meta(EnumTesterForm.Meta):
54+
model = BitFieldModel
55+
56+
form = FlagChoicesModelForm(
57+
data={"bit_field_small": [GNSSConstellation.GPS, GNSSConstellation.GLONASS]}
58+
)
59+
60+
form.full_clean()
61+
self.assertTrue(form.is_valid())
62+
self.assertEqual(
63+
form.cleaned_data["bit_field_small"],
64+
GNSSConstellation.GPS | GNSSConstellation.GLONASS,
65+
)
66+
self.assertIsInstance(form.base_fields["bit_field_small"], EnumFlagField)
67+
68+
def test_extern_flag_admin_form(self):
69+
from django.contrib import admin
70+
71+
admin_class = admin.site._registry.get(FlagExample)
72+
self.assertIsInstance(
73+
admin_class.get_form(None).base_fields.get("permissions"), EnumFlagField
74+
)
75+
76+
def test_extern_flag_model_form(self):
77+
from tests.examples.models.flag import Permissions
78+
79+
class FlagModelForm(ModelForm):
80+
class Meta(EnumTesterForm.Meta):
81+
model = FlagExample
82+
83+
form = FlagModelForm(
84+
data={"permissions": [Permissions.READ, Permissions.WRITE]}
85+
)
86+
87+
form.full_clean()
88+
self.assertTrue(form.is_valid())
89+
self.assertEqual(
90+
form.cleaned_data["permissions"], Permissions.READ | Permissions.WRITE
91+
)
92+
self.assertIsInstance(form.base_fields["permissions"], EnumFlagField)
93+
94+
def test_flag_select_multiple_format(self):
95+
from tests.examples.models.flag import Permissions
96+
97+
widget = FlagSelectMultiple() # no enum
98+
self.assertEqual(
99+
widget.format_value(Permissions.READ | Permissions.WRITE),
100+
[str(Permissions.READ.value), str(Permissions.WRITE.value)],
101+
)
102+
self.assertEqual(
103+
widget.format_value(Permissions.READ | Permissions.EXECUTE),
104+
[str(Permissions.READ.value), str(Permissions.EXECUTE.value)],
105+
)
106+
self.assertEqual(
107+
widget.format_value(Permissions.EXECUTE | Permissions.WRITE),
108+
[str(Permissions.WRITE.value), str(Permissions.EXECUTE.value)],
109+
)
110+
111+
widget = FlagSelectMultiple(enum=Permissions) # no enum
112+
self.assertEqual(
113+
widget.format_value(Permissions.READ | Permissions.WRITE),
114+
[str(Permissions.READ.value), str(Permissions.WRITE.value)],
115+
)
116+
self.assertEqual(
117+
widget.format_value(Permissions.READ | Permissions.EXECUTE),
118+
[str(Permissions.READ.value), str(Permissions.EXECUTE.value)],
119+
)
120+
self.assertEqual(
121+
widget.format_value(Permissions.EXECUTE | Permissions.WRITE),
122+
[str(Permissions.WRITE.value), str(Permissions.EXECUTE.value)],
123+
)
124+
125+
# check pass through
126+
self.assertEqual(
127+
widget.format_value(
128+
[str(Permissions.WRITE.value), str(Permissions.EXECUTE.value)]
129+
),
130+
[str(Permissions.WRITE.value), str(Permissions.EXECUTE.value)],
131+
)
132+
37133

38134
FormTests = None
39135
TestFormField = None

uv.lock

+2-2
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)