Skip to content

Commit 1eb7ada

Browse files
authored
implement or and ior operators (pallets#2979)
2 parents 862cb19 + b65b587 commit 1eb7ada

File tree

5 files changed

+126
-1
lines changed

5 files changed

+126
-1
lines changed

CHANGES.rst

+2
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ Unreleased
2525
:issue:`2970`
2626
- ``MultiDict.getlist`` catches ``TypeError`` in addition to ``ValueError``
2727
when doing type conversion. :issue:`2976`
28+
- Implement ``|`` and ``|=`` operators for ``MultiDict``, ``Headers``, and
29+
``CallbackDict``, and disallow ``|=`` on immutable types. :issue:`2977`
2830

2931

3032
Version 3.0.6

src/werkzeug/datastructures/headers.py

+31
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ class Headers(cabc.MutableMapping[str, str]):
4141
4242
:param defaults: The list of default values for the :class:`Headers`.
4343
44+
.. versionchanged:: 3.1
45+
Implement ``|`` and ``|=`` operators.
46+
4447
.. versionchanged:: 2.1.0
4548
Default values are validated the same as values added later.
4649
@@ -524,6 +527,31 @@ def update( # type: ignore[override]
524527
else:
525528
self.set(key, value)
526529

530+
def __or__(
531+
self, other: cabc.Mapping[str, t.Any | cabc.Collection[t.Any]]
532+
) -> te.Self:
533+
if not isinstance(other, cabc.Mapping):
534+
return NotImplemented
535+
536+
rv = self.copy()
537+
rv.update(other)
538+
return rv
539+
540+
def __ior__(
541+
self,
542+
other: (
543+
cabc.Mapping[str, t.Any | cabc.Collection[t.Any]]
544+
| cabc.Iterable[tuple[str, t.Any]]
545+
),
546+
) -> te.Self:
547+
if not isinstance(other, (cabc.Mapping, cabc.Iterable)) or isinstance(
548+
other, str
549+
):
550+
return NotImplemented
551+
552+
self.update(other)
553+
return self
554+
527555
def to_wsgi_list(self) -> list[tuple[str, str]]:
528556
"""Convert the headers into a list suitable for WSGI.
529557
@@ -620,6 +648,9 @@ def __iter__(self) -> cabc.Iterator[tuple[str, str]]: # type: ignore[override]
620648
def copy(self) -> t.NoReturn:
621649
raise TypeError(f"cannot create {type(self).__name__!r} copies")
622650

651+
def __or__(self, other: t.Any) -> t.NoReturn:
652+
raise TypeError(f"cannot create {type(self).__name__!r} copies")
653+
623654

624655
# circular dependencies
625656
from .. import http

src/werkzeug/datastructures/mixins.py

+21
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,9 @@ def sort(self, key: t.Any = None, reverse: t.Any = False) -> t.NoReturn:
7676
class ImmutableDictMixin(t.Generic[K, V]):
7777
"""Makes a :class:`dict` immutable.
7878
79+
.. versionchanged:: 3.1
80+
Disallow ``|=`` operator.
81+
7982
.. versionadded:: 0.5
8083
8184
:private:
@@ -117,6 +120,9 @@ def setdefault(self, key: t.Any, default: t.Any = None) -> t.NoReturn:
117120
def update(self, arg: t.Any, /, **kwargs: t.Any) -> t.NoReturn:
118121
_immutable_error(self)
119122

123+
def __ior__(self, other: t.Any) -> t.NoReturn:
124+
_immutable_error(self)
125+
120126
def pop(self, key: t.Any, default: t.Any = None) -> t.NoReturn:
121127
_immutable_error(self)
122128

@@ -168,6 +174,9 @@ class ImmutableHeadersMixin:
168174
hashable though since the only usecase for this datastructure
169175
in Werkzeug is a view on a mutable structure.
170176
177+
.. versionchanged:: 3.1
178+
Disallow ``|=`` operator.
179+
171180
.. versionadded:: 0.5
172181
173182
:private:
@@ -200,6 +209,9 @@ def extend(self, arg: t.Any, /, **kwargs: t.Any) -> t.NoReturn:
200209
def update(self, arg: t.Any, /, **kwargs: t.Any) -> t.NoReturn:
201210
_immutable_error(self)
202211

212+
def __ior__(self, other: t.Any) -> t.NoReturn:
213+
_immutable_error(self)
214+
203215
def insert(self, pos: t.Any, value: t.Any) -> t.NoReturn:
204216
_immutable_error(self)
205217

@@ -233,6 +245,9 @@ def wrapper(
233245
class UpdateDictMixin(dict[K, V]):
234246
"""Makes dicts call `self.on_update` on modifications.
235247
248+
.. versionchanged:: 3.1
249+
Implement ``|=`` operator.
250+
236251
.. versionadded:: 0.5
237252
238253
:private:
@@ -294,3 +309,9 @@ def update( # type: ignore[override]
294309
super().update(**kwargs)
295310
else:
296311
super().update(arg, **kwargs)
312+
313+
@_always_update
314+
def __ior__( # type: ignore[override]
315+
self, other: cabc.Mapping[K, V] | cabc.Iterable[tuple[K, V]]
316+
) -> te.Self:
317+
return super().__ior__(other)

src/werkzeug/datastructures/structures.py

+25
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,9 @@ class MultiDict(TypeConversionDict[K, V]):
170170
:param mapping: the initial value for the :class:`MultiDict`. Either a
171171
regular dict, an iterable of ``(key, value)`` tuples
172172
or `None`.
173+
174+
.. versionchanged:: 3.1
175+
Implement ``|`` and ``|=`` operators.
173176
"""
174177

175178
def __init__(
@@ -435,6 +438,28 @@ def update( # type: ignore[override]
435438
for key, value in iter_multi_items(mapping):
436439
self.add(key, value)
437440

441+
def __or__( # type: ignore[override]
442+
self, other: cabc.Mapping[K, V | cabc.Collection[V]]
443+
) -> MultiDict[K, V]:
444+
if not isinstance(other, cabc.Mapping):
445+
return NotImplemented
446+
447+
rv = self.copy()
448+
rv.update(other)
449+
return rv
450+
451+
def __ior__( # type: ignore[override]
452+
self,
453+
other: cabc.Mapping[K, V | cabc.Collection[V]] | cabc.Iterable[tuple[K, V]],
454+
) -> te.Self:
455+
if not isinstance(other, (cabc.Mapping, cabc.Iterable)) or isinstance(
456+
other, str
457+
):
458+
return NotImplemented
459+
460+
self.update(other)
461+
return self
462+
438463
@t.overload
439464
def pop(self, key: K) -> V: ...
440465
@t.overload

tests/test_datastructures.py

+47-1
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,17 @@ def test_basic_interface(self):
258258
md.setlist("foo", [1, 2])
259259
assert md.getlist("foo") == [1, 2]
260260

261+
def test_or(self) -> None:
262+
a = self.storage_class({"x": 1})
263+
b = a | {"y": 2}
264+
assert isinstance(b, self.storage_class)
265+
assert "x" in b and "y" in b
266+
267+
def test_ior(self) -> None:
268+
a = self.storage_class({"x": 1})
269+
a |= {"y": 2}
270+
assert "x" in a and "y" in a
271+
261272

262273
class _ImmutableDictTests:
263274
storage_class: type[dict]
@@ -305,6 +316,17 @@ def test_dict_is_hashable(self):
305316
assert immutable in x
306317
assert immutable2 in x
307318

319+
def test_or(self) -> None:
320+
a = self.storage_class({"x": 1})
321+
b = a | {"y": 2}
322+
assert "x" in b and "y" in b
323+
324+
def test_ior(self) -> None:
325+
a = self.storage_class({"x": 1})
326+
327+
with pytest.raises(TypeError):
328+
a |= {"y": 2}
329+
308330

309331
class TestImmutableTypeConversionDict(_ImmutableDictTests):
310332
storage_class = ds.ImmutableTypeConversionDict
@@ -799,6 +821,17 @@ def test_equality(self):
799821

800822
assert h1 == h2
801823

824+
def test_or(self) -> None:
825+
a = ds.Headers({"x": 1})
826+
b = a | {"y": 2}
827+
assert isinstance(b, ds.Headers)
828+
assert "x" in b and "y" in b
829+
830+
def test_ior(self) -> None:
831+
a = ds.Headers({"x": 1})
832+
a |= {"y": 2}
833+
assert "x" in a and "y" in a
834+
802835

803836
class TestEnvironHeaders:
804837
storage_class = ds.EnvironHeaders
@@ -840,6 +873,18 @@ def test_return_type_is_str(self):
840873
assert headers["Foo"] == "\xe2\x9c\x93"
841874
assert next(iter(headers)) == ("Foo", "\xe2\x9c\x93")
842875

876+
def test_or(self) -> None:
877+
headers = ds.EnvironHeaders({"x": "1"})
878+
879+
with pytest.raises(TypeError):
880+
headers | {"y": "2"}
881+
882+
def test_ior(self) -> None:
883+
headers = ds.EnvironHeaders({})
884+
885+
with pytest.raises(TypeError):
886+
headers |= {"y": "2"}
887+
843888

844889
class TestHeaderSet:
845890
storage_class = ds.HeaderSet
@@ -927,7 +972,7 @@ def test_callback_dict_writes(self):
927972
assert_calls, func = make_call_asserter()
928973
initial = {"a": "foo", "b": "bar"}
929974
dct = self.storage_class(initial=initial, on_update=func)
930-
with assert_calls(8, "callback not triggered by write method"):
975+
with assert_calls(9, "callback not triggered by write method"):
931976
# always-write methods
932977
dct["z"] = 123
933978
dct["z"] = 123 # must trigger again
@@ -937,6 +982,7 @@ def test_callback_dict_writes(self):
937982
dct.popitem()
938983
dct.update([])
939984
dct.clear()
985+
dct |= {}
940986
with assert_calls(0, "callback triggered by failed del"):
941987
pytest.raises(KeyError, lambda: dct.__delitem__("x"))
942988
with assert_calls(0, "callback triggered by failed pop"):

0 commit comments

Comments
 (0)