Skip to content

Commit 71c9760

Browse files
committed
feat(v3): add improved matcher class to v3 preview
1 parent 5435762 commit 71c9760

File tree

14 files changed

+751
-9
lines changed

14 files changed

+751
-9
lines changed

.github/actions/setup/action.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ description: "Install development dependencies"
44
inputs:
55
python-version:
66
description: "Python version to install"
7-
default: "3.12"
7+
default: "3.14"
88

99
runs:
1010
using: "composite"

decoy/errors.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,3 +98,7 @@ def create(
9898

9999
class VerifyOrderError(VerifyError):
100100
"""A [`Decoy.verify_order`][decoy.next.Decoy.verify_order] assertion failed."""
101+
102+
103+
class NoMatcherValueCapturedError(ValueError):
104+
"""An error raised if a [decoy.next.Matcher][] has not captured any matching values."""

decoy/next/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,15 @@
55
"""
66

77
from ._internal.decoy import Decoy
8+
from ._internal.matcher import Matcher
89
from ._internal.mock import AsyncMock, Mock
910
from ._internal.verify import Verify
1011
from ._internal.when import Stub, When
1112

1213
__all__ = [
1314
"AsyncMock",
1415
"Decoy",
16+
"Matcher",
1517
"Mock",
1618
"Stub",
1719
"Verify",

decoy/next/_internal/errors.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,3 +95,9 @@ def createVerifyOrderError(
9595
)
9696

9797
return errors.VerifyOrderError(message)
98+
99+
100+
def createNoMatcherValueCapturedError(
101+
message: str,
102+
) -> errors.NoMatcherValueCapturedError:
103+
return errors.NoMatcherValueCapturedError(message)

decoy/next/_internal/matcher.py

Lines changed: 315 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,315 @@
1+
import collections.abc
2+
import re
3+
import sys
4+
from typing import Any, Callable, Generic, TypeVar, cast, overload
5+
6+
if sys.version_info >= (3, 13):
7+
from typing import TypeIs
8+
else:
9+
from typing_extensions import TypeIs
10+
11+
from ...errors import NoMatcherValueCapturedError
12+
13+
ValueT = TypeVar("ValueT")
14+
MatchT = TypeVar("MatchT")
15+
MappingT = TypeVar("MappingT", bound=collections.abc.Mapping[Any, Any])
16+
SequenceT = TypeVar("SequenceT", bound=collections.abc.Sequence[Any])
17+
ErrorT = TypeVar("ErrorT", bound=BaseException)
18+
19+
TypedMatch = Callable[[object], TypeIs[MatchT]]
20+
UntypedMatch = Callable[[object], bool]
21+
22+
23+
class Matcher(Generic[ValueT]):
24+
"""Create an [argument matcher](./matchers.md).
25+
26+
Arguments:
27+
match: A comparison function that returns a bool or `TypeIs` guard.
28+
name: Optional name for the matcher; defaults to `match.__name__`
29+
description: Optional extra description for the matcher's repr.
30+
31+
Example:
32+
Use a function to create a custom matcher.
33+
34+
```python
35+
def is_even(target: object) -> TypeIs[int]:
36+
return isinstance(target, int) and target % 2 == 0
37+
38+
is_even_matcher = Matcher(is_even)
39+
```
40+
41+
Matchers can also be constructed from built-in inspection functions, like `callable`.
42+
43+
```python
44+
callable_matcher = Matcher(callable)
45+
```
46+
"""
47+
48+
@overload
49+
def __init__(
50+
self: "Matcher[MatchT]",
51+
match: TypedMatch[MatchT],
52+
name: str | None = None,
53+
description: str | None = None,
54+
) -> None: ...
55+
56+
@overload
57+
def __init__(
58+
self: "Matcher[Any]",
59+
match: UntypedMatch,
60+
name: str | None = None,
61+
description: str | None = None,
62+
) -> None: ...
63+
64+
def __init__(
65+
self,
66+
match: TypedMatch[ValueT] | UntypedMatch,
67+
name: str | None = None,
68+
description: str | None = None,
69+
) -> None:
70+
self._match = match
71+
self._name = name or match.__name__
72+
self._description = description
73+
self._values: list[ValueT] = []
74+
75+
def __eq__(self, target: object) -> bool:
76+
if self._match(target):
77+
self._values.append(cast(ValueT, target)) # type: ignore[redundant-cast]
78+
return True
79+
80+
return False
81+
82+
def __repr__(self) -> str:
83+
matcher_name = f"Matcher.{self._name}"
84+
if self._description:
85+
return f"<{matcher_name} {self._description.strip()}>"
86+
87+
return f"<{matcher_name}>"
88+
89+
@property
90+
def arg(self) -> ValueT:
91+
"""Type-cast the matcher as the expected value.
92+
93+
Example:
94+
If the mock expects a `str` argument, using `arg` prevents the type-checker from raising an error.
95+
96+
```python
97+
decoy
98+
.when(mock)
99+
.called_with(Matcher.matches("^(hello|hi)$").arg)
100+
.then_return("world")
101+
```
102+
"""
103+
return cast(ValueT, self)
104+
105+
@property
106+
def value(self) -> ValueT:
107+
"""The latest matching compared value.
108+
109+
Raises:
110+
NoMatcherValueCapturedError: the matcher has not been compared with any matching value.
111+
112+
Example:
113+
You can use `value` to trigger a callback passed to your mock.
114+
115+
```python
116+
callback_matcher = Matcher(callable)
117+
decoy.verify(mock).called_with(callback_matcher)
118+
callback_matcher.value("value")
119+
```
120+
"""
121+
if len(self._values) == 0:
122+
raise NoMatcherValueCapturedError(f"{self} has not matched any values")
123+
124+
return self._values[-1]
125+
126+
@property
127+
def values(self) -> list[ValueT]:
128+
"""All matching compared values."""
129+
return self._values.copy()
130+
131+
@staticmethod
132+
@overload
133+
def any(
134+
type: type[MatchT],
135+
attrs: collections.abc.Mapping[str, object] | None = None,
136+
) -> "Matcher[MatchT]": ...
137+
138+
@staticmethod
139+
@overload
140+
def any(
141+
type: None = None,
142+
attrs: collections.abc.Mapping[str, object] | None = None,
143+
) -> "Matcher[Any]": ...
144+
145+
@staticmethod
146+
def any(
147+
type: type[MatchT] | None = None,
148+
attrs: collections.abc.Mapping[str, object] | None = None,
149+
) -> "Matcher[MatchT] | Matcher[Any]":
150+
"""Match an argument, optionally by type and/or attributes.
151+
152+
If type and attributes are omitted, will match everything,
153+
including `None`.
154+
155+
Arguments:
156+
type: Type to match, if any.
157+
attrs: Set of attributes to match, if any.
158+
"""
159+
description = ""
160+
161+
if type:
162+
description = type.__name__
163+
164+
if attrs:
165+
description = f"{description} attrs={attrs!r}"
166+
167+
return Matcher(
168+
lambda t: _any(t, type, attrs),
169+
name="any",
170+
description=description,
171+
)
172+
173+
@staticmethod
174+
def is_not(value: object) -> "Matcher[Any]":
175+
"""Match any value that does not `==` the given value.
176+
177+
Arguments:
178+
value: The value that the matcher rejects.
179+
"""
180+
return Matcher(
181+
lambda t: t != value,
182+
name="is_not",
183+
description=repr(value),
184+
)
185+
186+
@staticmethod
187+
@overload
188+
def contains(values: MappingT) -> "Matcher[MappingT]": ...
189+
190+
@staticmethod
191+
@overload
192+
def contains(values: SequenceT, in_order: bool = False) -> "Matcher[SequenceT]": ...
193+
194+
@staticmethod
195+
def contains(
196+
values: MappingT | SequenceT,
197+
in_order: bool = False,
198+
) -> "Matcher[MappingT] | Matcher[SequenceT]":
199+
"""Match a dict, list, or string with a partial value.
200+
201+
Arguments:
202+
values: Partial value to match.
203+
in_order: Match list values in order.
204+
"""
205+
description = repr(values)
206+
207+
if in_order:
208+
description = f"{description} in_order={in_order}"
209+
210+
return Matcher(
211+
lambda t: _contains(t, values, in_order),
212+
name="contains",
213+
description=description,
214+
)
215+
216+
@staticmethod
217+
def matches(pattern: str) -> "Matcher[str]":
218+
"""Match a string by a pattern.
219+
220+
Arguments:
221+
pattern: Regular expression pattern.
222+
"""
223+
pattern_re = re.compile(pattern)
224+
225+
return Matcher(
226+
lambda t: isinstance(t, str) and pattern_re.search(t) is not None,
227+
name="matches",
228+
description=repr(pattern),
229+
)
230+
231+
@staticmethod
232+
def error(type: type[ErrorT], message: str | None = None) -> "Matcher[ErrorT]":
233+
"""Match an exception object.
234+
235+
Arguments:
236+
type: The type of exception to match.
237+
message: An optional regular expression pattern to match.
238+
"""
239+
message_re = re.compile(message or "")
240+
description = type.__name__
241+
242+
if message:
243+
description = f"{description} message={message!r}"
244+
245+
return Matcher(
246+
lambda t: isinstance(t, type) and message_re.search(str(t)) is not None,
247+
name="error",
248+
description=description,
249+
)
250+
251+
252+
def _any(
253+
target: object,
254+
match_type: type[Any] | None,
255+
attrs: collections.abc.Mapping[str, object] | None,
256+
) -> bool:
257+
return (match_type is None or isinstance(target, match_type)) and (
258+
attrs is None or _has_attrs(target, attrs)
259+
)
260+
261+
262+
def _has_attrs(
263+
target: object,
264+
attributes: collections.abc.Mapping[str, object],
265+
) -> bool:
266+
return all(
267+
hasattr(target, attr_name) and getattr(target, attr_name) == attr_value
268+
for attr_name, attr_value in attributes.items()
269+
)
270+
271+
272+
def _contains(
273+
target: object,
274+
values: collections.abc.Mapping[object, object] | collections.abc.Sequence[object],
275+
in_order: bool,
276+
) -> bool:
277+
if isinstance(values, str):
278+
return isinstance(target, str) and values in target
279+
if isinstance(values, collections.abc.Mapping):
280+
return _dict_containing(target, values)
281+
if isinstance(values, collections.abc.Sequence):
282+
return _list_containing(target, values, in_order)
283+
284+
285+
def _dict_containing(
286+
target: object,
287+
values: collections.abc.Mapping[object, object],
288+
) -> bool:
289+
try:
290+
return all(
291+
attr_name in target and target[attr_name] == attr_value # type: ignore[index,operator]
292+
for attr_name, attr_value in values.items()
293+
)
294+
except TypeError:
295+
return False
296+
297+
298+
def _list_containing(
299+
target: object,
300+
values: collections.abc.Sequence[object],
301+
in_order: bool,
302+
) -> bool:
303+
target_index = 0
304+
305+
try:
306+
for value in values:
307+
if in_order:
308+
target = target[target_index:] # type: ignore[index]
309+
310+
target_index = target.index(value) # type: ignore[attr-defined]
311+
312+
except (AttributeError, TypeError, ValueError):
313+
return False
314+
315+
return True

0 commit comments

Comments
 (0)