Skip to content

Commit 4bb0335

Browse files
committed
feat(v3): add improved matcher class to v3 preview
1 parent 8a0fa12 commit 4bb0335

File tree

12 files changed

+736
-1
lines changed

12 files changed

+736
-1
lines changed

codebook.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
words = [
2+
"matcher's",
23
"matchers",
34
"mundo",
45
"stubbings",

decoy/errors.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,3 +125,12 @@ class VerifyOrderError(VerifyError):
125125
126126
[spying with verify]: usage/verify.md
127127
"""
128+
129+
130+
class NoMatcherValueCapturedError(ValueError):
131+
"""An error raised if a [decoy.next.Matcher][] has not captured any matching values.
132+
133+
See the [matchers guide][] for more details.
134+
135+
[matchers guide]: ./v3/matchers.md
136+
"""

decoy/next/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,15 @@
55
"""
66

77
from ._internal.decoy import Decoy
8+
from ._internal.matcher import Matcher
89
from ._internal.mock import AsyncMock, Mock
910
from ._internal.verify import Verify
1011
from ._internal.when import Stub, When
1112

1213
__all__ = [
1314
"AsyncMock",
1415
"Decoy",
16+
"Matcher",
1517
"Mock",
1618
"Stub",
1719
"Verify",

decoy/next/_internal/errors.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,3 +95,9 @@ def createVerifyOrderError(
9595
)
9696

9797
return errors.VerifyOrderError(message)
98+
99+
100+
def createNoMatcherValueCapturedError(
101+
message: str,
102+
) -> errors.NoMatcherValueCapturedError:
103+
return errors.NoMatcherValueCapturedError(message)

decoy/next/_internal/matcher.py

Lines changed: 307 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,307 @@
1+
import collections.abc
2+
import re
3+
from typing import Any, Callable, Generic, TypeVar, cast, overload
4+
5+
try:
6+
from typing import TypeIs
7+
except ImportError:
8+
from typing_extensions import TypeIs
9+
10+
from .errors import createNoMatcherValueCapturedError
11+
12+
ValueT = TypeVar("ValueT")
13+
MatchT = TypeVar("MatchT")
14+
MappingT = TypeVar("MappingT", bound=collections.abc.Mapping[Any, Any])
15+
SequenceT = TypeVar("SequenceT", bound=collections.abc.Sequence[Any])
16+
ErrorT = TypeVar("ErrorT", bound=BaseException)
17+
18+
TypedMatch = Callable[[object], TypeIs[MatchT]]
19+
UntypedMatch = Callable[[object], bool]
20+
21+
22+
class Matcher(Generic[ValueT]):
23+
"""Create a matcher from a comparison function.
24+
25+
Arguments:
26+
match: A comparison function that returns a bool or `TypeIs` guard.
27+
name: Optional name for the matcher; defaults to `match.__name__`
28+
description: Optional extra description for the matcher's repr.
29+
30+
Example:
31+
Matchers can be constructed from built-in inspection functions, like `callable`.
32+
33+
```python
34+
callable_matcher = Matcher(callable)
35+
```
36+
"""
37+
38+
@overload
39+
def __init__(
40+
self: "Matcher[MatchT]",
41+
match: TypedMatch[MatchT],
42+
name: str | None = None,
43+
description: str | None = None,
44+
) -> None: ...
45+
46+
@overload
47+
def __init__(
48+
self: "Matcher[Any]",
49+
match: UntypedMatch,
50+
name: str | None = None,
51+
description: str | None = None,
52+
) -> None: ...
53+
54+
def __init__(
55+
self,
56+
match: TypedMatch[ValueT] | UntypedMatch,
57+
name: str | None = None,
58+
description: str | None = None,
59+
) -> None:
60+
self._match = match
61+
self._name = name or match.__name__
62+
self._description = description
63+
self._values: list[ValueT] = []
64+
65+
def __eq__(self, target: object) -> bool:
66+
if self._match(target):
67+
self._values.append(cast(ValueT, target)) # type: ignore[redundant-cast]
68+
return True
69+
70+
return False
71+
72+
def __repr__(self) -> str:
73+
matcher_name = f"Matcher.{self._name}"
74+
if self._description:
75+
return f"<{matcher_name} {self._description.strip()}>"
76+
77+
return f"<{matcher_name}>"
78+
79+
@property
80+
def arg(self) -> ValueT:
81+
"""Type-cast the matcher as the expected value.
82+
83+
Example:
84+
If the mock expects a `str` argument, using `arg` prevents the type-checker from raising an error.
85+
86+
```python
87+
decoy
88+
.when(mock)
89+
.called_with(Matcher.matches("^(hello|hi)$").arg)
90+
.then_return("world")
91+
```
92+
"""
93+
return cast(ValueT, self)
94+
95+
@property
96+
def value(self) -> ValueT:
97+
"""The latest matching compared value.
98+
99+
Raises:
100+
NoMatcherValueCapturedError: the matcher has not been compared with any matching value.
101+
102+
Example:
103+
You can use `value` to trigger a callback passed to your mock.
104+
105+
```python
106+
callback_matcher = Matcher(callable)
107+
decoy.verify(mock).called_with(callback_matcher)
108+
callback_matcher.value("value")
109+
```
110+
"""
111+
if len(self._values) == 0:
112+
raise createNoMatcherValueCapturedError(
113+
f"{self} has not matched any values"
114+
)
115+
116+
return self._values[-1]
117+
118+
@property
119+
def values(self) -> list[ValueT]:
120+
"""All matching compared values."""
121+
return self._values.copy()
122+
123+
@staticmethod
124+
@overload
125+
def any(
126+
type: type[MatchT],
127+
attrs: collections.abc.Mapping[str, object] | None = None,
128+
) -> "Matcher[MatchT]": ...
129+
130+
@staticmethod
131+
@overload
132+
def any(
133+
type: None = None,
134+
attrs: collections.abc.Mapping[str, object] | None = None,
135+
) -> "Matcher[Any]": ...
136+
137+
@staticmethod
138+
def any(
139+
type: type[MatchT] | None = None,
140+
attrs: collections.abc.Mapping[str, object] | None = None,
141+
) -> "Matcher[MatchT] | Matcher[Any]":
142+
"""Match an argument, optionally by type and/or attributes.
143+
144+
If type and attributes are omitted, will match everything,
145+
including `None`.
146+
147+
Arguments:
148+
type: Type to match, if any.
149+
attrs: Set of attributes to match, if any.
150+
"""
151+
description = ""
152+
153+
if type:
154+
description = type.__name__
155+
156+
if attrs:
157+
description = f"{description} {attrs!r}"
158+
159+
return Matcher(
160+
lambda t: _any(t, type, attrs),
161+
name="any",
162+
description=description,
163+
)
164+
165+
@staticmethod
166+
def is_not(value: object) -> "Matcher[Any]":
167+
"""Match any value that does not `==` the given value.
168+
169+
Arguments:
170+
value: The value that the matcher rejects.
171+
"""
172+
return Matcher(
173+
lambda t: t != value,
174+
name="is_not",
175+
description=repr(value),
176+
)
177+
178+
@staticmethod
179+
@overload
180+
def contains(values: MappingT) -> "Matcher[MappingT]": ...
181+
182+
@staticmethod
183+
@overload
184+
def contains(values: SequenceT, in_order: bool = False) -> "Matcher[SequenceT]": ...
185+
186+
@staticmethod
187+
def contains(
188+
values: MappingT | SequenceT,
189+
in_order: bool = False,
190+
) -> "Matcher[MappingT] | Matcher[SequenceT]":
191+
"""Match a dict, list, or string with a partial value.
192+
193+
Arguments:
194+
values: Partial value to match.
195+
in_order: Match list values in order.
196+
"""
197+
description = repr(values)
198+
199+
if in_order:
200+
description = f"{description} {in_order=}"
201+
202+
return Matcher(
203+
lambda t: _contains(t, values, in_order),
204+
name="contains",
205+
description=description,
206+
)
207+
208+
@staticmethod
209+
def matches(pattern: str) -> "Matcher[str]":
210+
"""Match a string by a pattern.
211+
212+
Arguments:
213+
pattern: Regular expression pattern.
214+
"""
215+
pattern_re = re.compile(pattern)
216+
217+
return Matcher(
218+
lambda t: isinstance(t, str) and pattern_re.search(t) is not None,
219+
name="matches",
220+
description=repr(pattern),
221+
)
222+
223+
@staticmethod
224+
def error(type: type[ErrorT], message: str | None = None) -> "Matcher[ErrorT]":
225+
"""Match an exception object.
226+
227+
Arguments:
228+
type: The type of exception to match.
229+
message: An optional regular expression pattern to match.
230+
"""
231+
message_re = re.compile(message or "")
232+
description = type.__name__
233+
234+
if message:
235+
description = f"{description} {message!r}"
236+
237+
return Matcher(
238+
lambda t: isinstance(t, type) and message_re.search(str(t)) is not None,
239+
name="error",
240+
description=description,
241+
)
242+
243+
244+
def _any(
245+
target: object,
246+
match_type: type[Any] | None,
247+
attrs: collections.abc.Mapping[str, object] | None,
248+
) -> bool:
249+
return (match_type is None or isinstance(target, match_type)) and (
250+
attrs is None or _has_attrs(target, attrs)
251+
)
252+
253+
254+
def _has_attrs(
255+
target: object,
256+
attributes: collections.abc.Mapping[str, object],
257+
) -> bool:
258+
return all(
259+
hasattr(target, attr_name) and getattr(target, attr_name) == attr_value
260+
for attr_name, attr_value in attributes.items()
261+
)
262+
263+
264+
def _contains(
265+
target: object,
266+
values: collections.abc.Mapping[object, object] | collections.abc.Sequence[object],
267+
in_order: bool,
268+
) -> bool:
269+
if isinstance(values, str):
270+
return isinstance(target, str) and values in target
271+
if isinstance(values, collections.abc.Mapping):
272+
return _dict_containing(target, values)
273+
if isinstance(values, collections.abc.Sequence):
274+
return _list_containing(target, values, in_order)
275+
276+
277+
def _dict_containing(
278+
target: object,
279+
values: collections.abc.Mapping[object, object],
280+
) -> bool:
281+
try:
282+
return all(
283+
attr_name in target and target[attr_name] == attr_value # type: ignore[index,operator]
284+
for attr_name, attr_value in values.items()
285+
)
286+
except TypeError:
287+
return False
288+
289+
290+
def _list_containing(
291+
target: object,
292+
values: collections.abc.Sequence[object],
293+
in_order: bool,
294+
) -> bool:
295+
target_index = 0
296+
297+
try:
298+
for value in values:
299+
if in_order:
300+
target = target[target_index:] # type: ignore[index]
301+
302+
target_index = target.index(value) # type: ignore[attr-defined]
303+
304+
except (AttributeError, TypeError, ValueError):
305+
return False
306+
307+
return True

0 commit comments

Comments
 (0)