Skip to content

Commit e07485d

Browse files
committed
feat(v3): add improved matcher class to v3 preview
1 parent 5435762 commit e07485d

File tree

15 files changed

+861
-9
lines changed

15 files changed

+861
-9
lines changed

.github/actions/setup/action.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ description: "Install development dependencies"
44
inputs:
55
python-version:
66
description: "Python version to install"
7-
default: "3.12"
7+
default: "3.14"
88

99
runs:
1010
using: "composite"

decoy/errors.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,3 +98,7 @@ def create(
9898

9999
class VerifyOrderError(VerifyError):
100100
"""A [`Decoy.verify_order`][decoy.next.Decoy.verify_order] assertion failed."""
101+
102+
103+
class NoMatcherValueCapturedError(ValueError):
104+
"""An error raised if a [decoy.next.Matcher][] has not captured any matching values."""

decoy/next/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,15 @@
55
"""
66

77
from ._internal.decoy import Decoy
8+
from ._internal.matcher import Matcher
89
from ._internal.mock import AsyncMock, Mock
910
from ._internal.verify import Verify
1011
from ._internal.when import Stub, When
1112

1213
__all__ = [
1314
"AsyncMock",
1415
"Decoy",
16+
"Matcher",
1517
"Mock",
1618
"Stub",
1719
"Verify",

decoy/next/_internal/errors.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,3 +95,9 @@ def createVerifyOrderError(
9595
)
9696

9797
return errors.VerifyOrderError(message)
98+
99+
100+
def createNoMatcherValueCapturedError(
101+
message: str,
102+
) -> errors.NoMatcherValueCapturedError:
103+
return errors.NoMatcherValueCapturedError(message)

decoy/next/_internal/inspect.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,14 @@ def bind_args(
163163
return BoundArguments(bound_args.args, bound_args.kwargs)
164164

165165

166+
def get_func_name(func: Callable[..., object]) -> str:
167+
"""Get the name of a function."""
168+
if isinstance(func, functools.partial):
169+
return func.func.__name__
170+
171+
return func.__name__
172+
173+
166174
def _unwrap_callable(value: object) -> Callable[..., object] | None:
167175
"""Return an object's callable, checking if a class has a `__call__` method."""
168176
if not callable(value):

decoy/next/_internal/matcher.py

Lines changed: 328 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,328 @@
1+
import collections.abc
2+
import functools
3+
import re
4+
import sys
5+
from typing import Any, Callable, Generic, TypeVar, cast, overload
6+
7+
if sys.version_info >= (3, 13):
8+
from typing import TypeIs
9+
else:
10+
from typing_extensions import TypeIs
11+
12+
from .errors import createNoMatcherValueCapturedError
13+
from .inspect import get_func_name
14+
15+
ValueT = TypeVar("ValueT")
16+
MatchT = TypeVar("MatchT")
17+
MappingT = TypeVar("MappingT", bound=collections.abc.Mapping[Any, Any])
18+
SequenceT = TypeVar("SequenceT", bound=collections.abc.Sequence[Any])
19+
ErrorT = TypeVar("ErrorT", bound=BaseException)
20+
21+
TypedMatch = Callable[[object], TypeIs[MatchT]]
22+
UntypedMatch = Callable[[object], bool]
23+
24+
25+
class Matcher(Generic[ValueT]):
26+
"""Create an [argument matcher](./matchers.md).
27+
28+
Arguments:
29+
match: A comparison function that returns a bool or `TypeIs` guard.
30+
name: Optional name for the matcher; defaults to `match.__name__`
31+
description: Optional extra description for the matcher's repr.
32+
33+
Example:
34+
Use a function to create a custom matcher.
35+
36+
```python
37+
def is_even(target: object) -> TypeIs[int]:
38+
return isinstance(target, int) and target % 2 == 0
39+
40+
is_even_matcher = Matcher(is_even)
41+
```
42+
43+
Matchers can also be constructed from built-in inspection functions, like `callable`.
44+
45+
```python
46+
callable_matcher = Matcher(callable)
47+
```
48+
"""
49+
50+
@overload
51+
def __init__(
52+
self: "Matcher[MatchT]",
53+
match: TypedMatch[MatchT],
54+
name: str | None = None,
55+
description: str | None = None,
56+
) -> None: ...
57+
58+
@overload
59+
def __init__(
60+
self: "Matcher[Any]",
61+
match: UntypedMatch,
62+
name: str | None = None,
63+
description: str | None = None,
64+
) -> None: ...
65+
66+
def __init__(
67+
self,
68+
match: TypedMatch[ValueT] | UntypedMatch,
69+
name: str | None = None,
70+
description: str | None = None,
71+
) -> None:
72+
self._match = match
73+
self._name = name or get_func_name(match)
74+
self._description = description
75+
self._values: list[ValueT] = []
76+
77+
def __eq__(self, target: object) -> bool:
78+
if self._match(target):
79+
self._values.append(cast(ValueT, target)) # type: ignore[redundant-cast]
80+
return True
81+
82+
return False
83+
84+
def __repr__(self) -> str:
85+
matcher_name = f"Matcher.{self._name}"
86+
if self._description:
87+
return f"<{matcher_name} {self._description.strip()}>"
88+
89+
return f"<{matcher_name}>"
90+
91+
@property
92+
def arg(self) -> ValueT:
93+
"""Type-cast the matcher as the expected value.
94+
95+
Example:
96+
If the mock expects a `str` argument, using `arg` prevents the type-checker from raising an error.
97+
98+
```python
99+
decoy
100+
.when(mock)
101+
.called_with(Matcher.matches("^(hello|hi)$").arg)
102+
.then_return("world")
103+
```
104+
"""
105+
return cast(ValueT, self)
106+
107+
@property
108+
def value(self) -> ValueT:
109+
"""The latest matching compared value.
110+
111+
Raises:
112+
NoMatcherValueCapturedError: the matcher has not been compared with any matching value.
113+
114+
Example:
115+
You can use `value` to trigger a callback passed to your mock.
116+
117+
```python
118+
callback_matcher = Matcher(callable)
119+
decoy.verify(mock).called_with(callback_matcher)
120+
callback_matcher.value("value")
121+
```
122+
"""
123+
if len(self._values) == 0:
124+
raise createNoMatcherValueCapturedError(
125+
f"{self} has not matched any values"
126+
)
127+
128+
return self._values[-1]
129+
130+
@property
131+
def values(self) -> list[ValueT]:
132+
"""All matching compared values."""
133+
return self._values.copy()
134+
135+
@overload
136+
@staticmethod
137+
def any(
138+
type: type[MatchT],
139+
attrs: collections.abc.Mapping[str, object] | None = None,
140+
) -> "Matcher[MatchT]": ...
141+
142+
@overload
143+
@staticmethod
144+
def any(
145+
type: None = None,
146+
attrs: collections.abc.Mapping[str, object] | None = None,
147+
) -> "Matcher[Any]": ...
148+
149+
@staticmethod
150+
def any(
151+
type: type[MatchT] | None = None,
152+
attrs: collections.abc.Mapping[str, object] | None = None,
153+
) -> "Matcher[MatchT] | Matcher[Any]":
154+
"""Match an argument, optionally by type and/or attributes.
155+
156+
If type and attributes are omitted, will match everything,
157+
including `None`.
158+
159+
Arguments:
160+
type: Type to match, if any.
161+
attrs: Set of attributes to match, if any.
162+
"""
163+
description = ""
164+
165+
if type:
166+
description = type.__name__
167+
168+
if attrs:
169+
description = f"{description} attrs={attrs!r}"
170+
171+
return Matcher(
172+
match=functools.partial(any, type, attrs),
173+
description=description,
174+
)
175+
176+
@staticmethod
177+
def is_not(value: object) -> "Matcher[Any]":
178+
"""Match any value that does not `==` the given value.
179+
180+
Arguments:
181+
value: The value that the matcher rejects.
182+
"""
183+
return Matcher(
184+
lambda t: t != value,
185+
name="is_not",
186+
description=repr(value),
187+
)
188+
189+
@overload
190+
@staticmethod
191+
def contains(values: MappingT) -> "Matcher[MappingT]": ...
192+
193+
@overload
194+
@staticmethod
195+
def contains(values: SequenceT, in_order: bool = False) -> "Matcher[SequenceT]": ...
196+
197+
@staticmethod
198+
def contains(
199+
values: MappingT | SequenceT,
200+
in_order: bool = False,
201+
) -> "Matcher[MappingT] | Matcher[SequenceT]":
202+
"""Match a dict, list, or string with a partial value.
203+
204+
Arguments:
205+
values: Partial value to match.
206+
in_order: Match list values in order.
207+
"""
208+
description = repr(values)
209+
210+
if in_order:
211+
description = f"{description} in_order={in_order}"
212+
213+
return Matcher(
214+
match=functools.partial(contains, values, in_order),
215+
description=description,
216+
)
217+
218+
@staticmethod
219+
def matches(pattern: str) -> "Matcher[str]":
220+
"""Match a string by a pattern.
221+
222+
Arguments:
223+
pattern: Regular expression pattern.
224+
"""
225+
pattern_re = re.compile(pattern)
226+
227+
return Matcher(
228+
match=functools.partial(matches, pattern_re),
229+
description=repr(pattern),
230+
)
231+
232+
@staticmethod
233+
def error(type: type[ErrorT], message: str | None = None) -> "Matcher[ErrorT]":
234+
"""Match an exception object.
235+
236+
Arguments:
237+
type: The type of exception to match.
238+
message: An optional regular expression pattern to match.
239+
"""
240+
message_re = re.compile(message or "")
241+
description = type.__name__
242+
243+
if message:
244+
description = f"{description} message={message!r}"
245+
246+
return Matcher(
247+
match=functools.partial(error, type, message_re),
248+
name="error",
249+
description=description,
250+
)
251+
252+
253+
def any(
254+
match_type: type[Any] | None,
255+
attrs: collections.abc.Mapping[str, object] | None,
256+
target: object,
257+
) -> bool:
258+
return (match_type is None or isinstance(target, match_type)) and (
259+
attrs is None or _has_attrs(attrs, target)
260+
)
261+
262+
263+
def _has_attrs(
264+
attributes: collections.abc.Mapping[str, object],
265+
target: object,
266+
) -> bool:
267+
return all(
268+
hasattr(target, attr_name) and getattr(target, attr_name) == attr_value
269+
for attr_name, attr_value in attributes.items()
270+
)
271+
272+
273+
def contains(
274+
values: collections.abc.Mapping[object, object] | collections.abc.Sequence[object],
275+
in_order: bool,
276+
target: object,
277+
) -> bool:
278+
if isinstance(values, collections.abc.Mapping):
279+
return _dict_containing(values, target)
280+
if isinstance(values, str):
281+
return isinstance(target, str) and values in target
282+
283+
return _list_containing(values, in_order, target)
284+
285+
286+
def _dict_containing(
287+
values: collections.abc.Mapping[object, object],
288+
target: object,
289+
) -> bool:
290+
try:
291+
return all(
292+
attr_name in target and target[attr_name] == attr_value # type: ignore[index,operator]
293+
for attr_name, attr_value in values.items()
294+
)
295+
except TypeError:
296+
return False
297+
298+
299+
def _list_containing(
300+
values: collections.abc.Sequence[object],
301+
in_order: bool,
302+
target: object,
303+
) -> bool:
304+
target_index = 0
305+
306+
try:
307+
for value in values:
308+
if in_order:
309+
target = target[target_index:] # type: ignore[index]
310+
311+
target_index = target.index(value) # type: ignore[attr-defined]
312+
313+
except (AttributeError, TypeError, ValueError):
314+
return False
315+
316+
return True
317+
318+
319+
def error(
320+
type: type[ErrorT],
321+
message_pattern: re.Pattern[str],
322+
target: object,
323+
) -> bool:
324+
return isinstance(target, type) and message_pattern.search(str(target)) is not None
325+
326+
327+
def matches(pattern: re.Pattern[str], target: object) -> bool:
328+
return isinstance(target, str) and pattern.search(target) is not None

0 commit comments

Comments
 (0)