Skip to content

Commit 04d1dce

Browse files
committed
feat(v3): add improved matcher class to v3 preview
1 parent 5435762 commit 04d1dce

File tree

15 files changed

+859
-9
lines changed

15 files changed

+859
-9
lines changed

.github/actions/setup/action.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ description: "Install development dependencies"
44
inputs:
55
python-version:
66
description: "Python version to install"
7-
default: "3.12"
7+
default: "3.14"
88

99
runs:
1010
using: "composite"

decoy/errors.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,3 +98,7 @@ def create(
9898

9999
class VerifyOrderError(VerifyError):
100100
"""A [`Decoy.verify_order`][decoy.next.Decoy.verify_order] assertion failed."""
101+
102+
103+
class NoMatcherValueCapturedError(ValueError):
104+
"""An error raised if a [decoy.next.Matcher][] has not captured any matching values."""

decoy/next/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,15 @@
55
"""
66

77
from ._internal.decoy import Decoy
8+
from ._internal.matcher import Matcher
89
from ._internal.mock import AsyncMock, Mock
910
from ._internal.verify import Verify
1011
from ._internal.when import Stub, When
1112

1213
__all__ = [
1314
"AsyncMock",
1415
"Decoy",
16+
"Matcher",
1517
"Mock",
1618
"Stub",
1719
"Verify",

decoy/next/_internal/errors.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,3 +95,9 @@ def createVerifyOrderError(
9595
)
9696

9797
return errors.VerifyOrderError(message)
98+
99+
100+
def createNoMatcherValueCapturedError(
101+
message: str,
102+
) -> errors.NoMatcherValueCapturedError:
103+
return errors.NoMatcherValueCapturedError(message)

decoy/next/_internal/inspect.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,14 @@ def bind_args(
163163
return BoundArguments(bound_args.args, bound_args.kwargs)
164164

165165

166+
def get_func_name(func: Callable[..., object]) -> str:
167+
"""Get the name of a function."""
168+
if isinstance(func, functools.partial):
169+
return func.func.__name__
170+
171+
return func.__name__
172+
173+
166174
def _unwrap_callable(value: object) -> Callable[..., object] | None:
167175
"""Return an object's callable, checking if a class has a `__call__` method."""
168176
if not callable(value):

decoy/next/_internal/matcher.py

Lines changed: 326 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,326 @@
1+
import collections.abc
2+
import functools
3+
import re
4+
import sys
5+
from typing import Any, Callable, Generic, TypeVar, cast, overload
6+
7+
if sys.version_info >= (3, 13):
8+
from typing import TypeIs
9+
else:
10+
from typing_extensions import TypeIs
11+
12+
from ...errors import NoMatcherValueCapturedError
13+
from .inspect import get_func_name
14+
15+
ValueT = TypeVar("ValueT")
16+
MatchT = TypeVar("MatchT")
17+
MappingT = TypeVar("MappingT", bound=collections.abc.Mapping[Any, Any])
18+
SequenceT = TypeVar("SequenceT", bound=collections.abc.Sequence[Any])
19+
ErrorT = TypeVar("ErrorT", bound=BaseException)
20+
21+
TypedMatch = Callable[[object], TypeIs[MatchT]]
22+
UntypedMatch = Callable[[object], bool]
23+
24+
25+
class Matcher(Generic[ValueT]):
26+
"""Create an [argument matcher](./matchers.md).
27+
28+
Arguments:
29+
match: A comparison function that returns a bool or `TypeIs` guard.
30+
name: Optional name for the matcher; defaults to `match.__name__`
31+
description: Optional extra description for the matcher's repr.
32+
33+
Example:
34+
Use a function to create a custom matcher.
35+
36+
```python
37+
def is_even(target: object) -> TypeIs[int]:
38+
return isinstance(target, int) and target % 2 == 0
39+
40+
is_even_matcher = Matcher(is_even)
41+
```
42+
43+
Matchers can also be constructed from built-in inspection functions, like `callable`.
44+
45+
```python
46+
callable_matcher = Matcher(callable)
47+
```
48+
"""
49+
50+
@overload
51+
def __init__(
52+
self: "Matcher[MatchT]",
53+
match: TypedMatch[MatchT],
54+
name: str | None = None,
55+
description: str | None = None,
56+
) -> None: ...
57+
58+
@overload
59+
def __init__(
60+
self: "Matcher[Any]",
61+
match: UntypedMatch,
62+
name: str | None = None,
63+
description: str | None = None,
64+
) -> None: ...
65+
66+
def __init__(
67+
self,
68+
match: TypedMatch[ValueT] | UntypedMatch,
69+
name: str | None = None,
70+
description: str | None = None,
71+
) -> None:
72+
self._match = match
73+
self._name = name or get_func_name(match)
74+
self._description = description
75+
self._values: list[ValueT] = []
76+
77+
def __eq__(self, target: object) -> bool:
78+
if self._match(target):
79+
self._values.append(cast(ValueT, target)) # type: ignore[redundant-cast]
80+
return True
81+
82+
return False
83+
84+
def __repr__(self) -> str:
85+
matcher_name = f"Matcher.{self._name}"
86+
if self._description:
87+
return f"<{matcher_name} {self._description.strip()}>"
88+
89+
return f"<{matcher_name}>"
90+
91+
@property
92+
def arg(self) -> ValueT:
93+
"""Type-cast the matcher as the expected value.
94+
95+
Example:
96+
If the mock expects a `str` argument, using `arg` prevents the type-checker from raising an error.
97+
98+
```python
99+
decoy
100+
.when(mock)
101+
.called_with(Matcher.matches("^(hello|hi)$").arg)
102+
.then_return("world")
103+
```
104+
"""
105+
return cast(ValueT, self)
106+
107+
@property
108+
def value(self) -> ValueT:
109+
"""The latest matching compared value.
110+
111+
Raises:
112+
NoMatcherValueCapturedError: the matcher has not been compared with any matching value.
113+
114+
Example:
115+
You can use `value` to trigger a callback passed to your mock.
116+
117+
```python
118+
callback_matcher = Matcher(callable)
119+
decoy.verify(mock).called_with(callback_matcher)
120+
callback_matcher.value("value")
121+
```
122+
"""
123+
if len(self._values) == 0:
124+
raise NoMatcherValueCapturedError(f"{self} has not matched any values")
125+
126+
return self._values[-1]
127+
128+
@property
129+
def values(self) -> list[ValueT]:
130+
"""All matching compared values."""
131+
return self._values.copy()
132+
133+
@staticmethod
134+
@overload
135+
def any(
136+
type: type[MatchT],
137+
attrs: collections.abc.Mapping[str, object] | None = None,
138+
) -> "Matcher[MatchT]": ...
139+
140+
@staticmethod
141+
@overload
142+
def any(
143+
type: None = None,
144+
attrs: collections.abc.Mapping[str, object] | None = None,
145+
) -> "Matcher[Any]": ...
146+
147+
@staticmethod
148+
def any(
149+
type: type[MatchT] | None = None,
150+
attrs: collections.abc.Mapping[str, object] | None = None,
151+
) -> "Matcher[MatchT] | Matcher[Any]":
152+
"""Match an argument, optionally by type and/or attributes.
153+
154+
If type and attributes are omitted, will match everything,
155+
including `None`.
156+
157+
Arguments:
158+
type: Type to match, if any.
159+
attrs: Set of attributes to match, if any.
160+
"""
161+
description = ""
162+
163+
if type:
164+
description = type.__name__
165+
166+
if attrs:
167+
description = f"{description} attrs={attrs!r}"
168+
169+
return Matcher(
170+
match=functools.partial(any, type, attrs),
171+
description=description,
172+
)
173+
174+
@staticmethod
175+
def is_not(value: object) -> "Matcher[Any]":
176+
"""Match any value that does not `==` the given value.
177+
178+
Arguments:
179+
value: The value that the matcher rejects.
180+
"""
181+
return Matcher(
182+
lambda t: t != value,
183+
name="is_not",
184+
description=repr(value),
185+
)
186+
187+
@staticmethod
188+
@overload
189+
def contains(values: MappingT) -> "Matcher[MappingT]": ...
190+
191+
@staticmethod
192+
@overload
193+
def contains(values: SequenceT, in_order: bool = False) -> "Matcher[SequenceT]": ...
194+
195+
@staticmethod
196+
def contains(
197+
values: MappingT | SequenceT,
198+
in_order: bool = False,
199+
) -> "Matcher[MappingT] | Matcher[SequenceT]":
200+
"""Match a dict, list, or string with a partial value.
201+
202+
Arguments:
203+
values: Partial value to match.
204+
in_order: Match list values in order.
205+
"""
206+
description = repr(values)
207+
208+
if in_order:
209+
description = f"{description} in_order={in_order}"
210+
211+
return Matcher(
212+
match=functools.partial(contains, values, in_order),
213+
description=description,
214+
)
215+
216+
@staticmethod
217+
def matches(pattern: str) -> "Matcher[str]":
218+
"""Match a string by a pattern.
219+
220+
Arguments:
221+
pattern: Regular expression pattern.
222+
"""
223+
pattern_re = re.compile(pattern)
224+
225+
return Matcher(
226+
match=functools.partial(matches, pattern_re),
227+
description=repr(pattern),
228+
)
229+
230+
@staticmethod
231+
def error(type: type[ErrorT], message: str | None = None) -> "Matcher[ErrorT]":
232+
"""Match an exception object.
233+
234+
Arguments:
235+
type: The type of exception to match.
236+
message: An optional regular expression pattern to match.
237+
"""
238+
message_re = re.compile(message or "")
239+
description = type.__name__
240+
241+
if message:
242+
description = f"{description} message={message!r}"
243+
244+
return Matcher(
245+
match=functools.partial(error, type, message_re),
246+
name="error",
247+
description=description,
248+
)
249+
250+
251+
def any(
252+
match_type: type[Any] | None,
253+
attrs: collections.abc.Mapping[str, object] | None,
254+
target: object,
255+
) -> bool:
256+
return (match_type is None or isinstance(target, match_type)) and (
257+
attrs is None or _has_attrs(attrs, target)
258+
)
259+
260+
261+
def _has_attrs(
262+
attributes: collections.abc.Mapping[str, object],
263+
target: object,
264+
) -> bool:
265+
return all(
266+
hasattr(target, attr_name) and getattr(target, attr_name) == attr_value
267+
for attr_name, attr_value in attributes.items()
268+
)
269+
270+
271+
def contains(
272+
values: collections.abc.Mapping[object, object] | collections.abc.Sequence[object],
273+
in_order: bool,
274+
target: object,
275+
) -> bool:
276+
if isinstance(values, str):
277+
return isinstance(target, str) and values in target
278+
if isinstance(values, collections.abc.Mapping):
279+
return _dict_containing(values, target)
280+
if isinstance(values, collections.abc.Sequence):
281+
return _list_containing(values, in_order, target)
282+
283+
284+
def _dict_containing(
285+
values: collections.abc.Mapping[object, object],
286+
target: object,
287+
) -> bool:
288+
try:
289+
return all(
290+
attr_name in target and target[attr_name] == attr_value # type: ignore[index,operator]
291+
for attr_name, attr_value in values.items()
292+
)
293+
except TypeError:
294+
return False
295+
296+
297+
def _list_containing(
298+
values: collections.abc.Sequence[object],
299+
in_order: bool,
300+
target: object,
301+
) -> bool:
302+
target_index = 0
303+
304+
try:
305+
for value in values:
306+
if in_order:
307+
target = target[target_index:] # type: ignore[index]
308+
309+
target_index = target.index(value) # type: ignore[attr-defined]
310+
311+
except (AttributeError, TypeError, ValueError):
312+
return False
313+
314+
return True
315+
316+
317+
def error(
318+
type: type[ErrorT],
319+
message_pattern: re.Pattern[str],
320+
target: object,
321+
) -> bool:
322+
return isinstance(target, type) and message_pattern.search(str(target)) is not None
323+
324+
325+
def matches(pattern: re.Pattern[str], target: object) -> bool:
326+
return isinstance(target, str) and pattern.search(target) is not None

0 commit comments

Comments
 (0)