2
2
3
3
import contextlib
4
4
import functools
5
+ from collections import defaultdict
5
6
from enum import Enum
6
7
from types import TracebackType
7
- from typing import Callable , List , Optional , Union
8
+ from typing import Callable , Dict , Iterable , List , Optional , Union
8
9
9
10
import requests_mock
10
11
@@ -40,7 +41,7 @@ class HttpMocker(contextlib.ContextDecorator):
40
41
41
42
def __init__ (self ) -> None :
42
43
self ._mocker = requests_mock .Mocker ()
43
- self ._matchers : List [HttpRequestMatcher ] = []
44
+ self ._matchers : Dict [ SupportedHttpMethods , List [HttpRequestMatcher ]] = defaultdict ( list )
44
45
45
46
def __enter__ (self ) -> "HttpMocker" :
46
47
self ._mocker .__enter__ ()
@@ -55,7 +56,7 @@ def __exit__(
55
56
self ._mocker .__exit__ (exc_type , exc_val , exc_tb )
56
57
57
58
def _validate_all_matchers_called (self ) -> None :
58
- for matcher in self ._matchers :
59
+ for matcher in self ._get_matchers () :
59
60
if not matcher .has_expected_match_count ():
60
61
raise ValueError (f"Invalid number of matches for `{ matcher } `" )
61
62
@@ -69,9 +70,9 @@ def _mock_request_method(
69
70
responses = [responses ]
70
71
71
72
matcher = HttpRequestMatcher (request , len (responses ))
72
- if matcher in self ._matchers :
73
+ if matcher in self ._matchers [ method ] :
73
74
raise ValueError (f"Request { matcher .request } already mocked" )
74
- self ._matchers .append (matcher )
75
+ self ._matchers [ method ] .append (matcher )
75
76
76
77
getattr (self ._mocker , method )(
77
78
requests_mock .ANY ,
@@ -129,7 +130,7 @@ def matches(requests_mock_request: requests_mock.request._RequestObjectProxy) ->
129
130
130
131
def assert_number_of_calls (self , request : HttpRequest , number_of_calls : int ) -> None :
131
132
corresponding_matchers = list (
132
- filter (lambda matcher : matcher .request == request , self ._matchers )
133
+ filter (lambda matcher : matcher .request is request , self ._get_matchers () )
133
134
)
134
135
if len (corresponding_matchers ) != 1 :
135
136
raise ValueError (
@@ -150,7 +151,7 @@ def wrapper(*args, **kwargs): # type: ignore # this is a very generic wrapper
150
151
result = f (* args , ** kwargs )
151
152
except requests_mock .NoMockAddress as no_mock_exception :
152
153
matchers_as_string = "\n \t " .join (
153
- map (lambda matcher : str (matcher .request ), self ._matchers )
154
+ map (lambda matcher : str (matcher .request ), self ._get_matchers () )
154
155
)
155
156
raise ValueError (
156
157
f"No matcher matches { no_mock_exception .args [0 ]} with headers `{ no_mock_exception .request .headers } ` "
@@ -175,6 +176,10 @@ def wrapper(*args, **kwargs): # type: ignore # this is a very generic wrapper
175
176
176
177
return wrapper
177
178
179
+ def _get_matchers (self ) -> Iterable [HttpRequestMatcher ]:
180
+ for matchers in self ._matchers .values ():
181
+ yield from matchers
182
+
178
183
def clear_all_matchers (self ) -> None :
179
184
"""Clears all stored matchers by resetting the _matchers list to an empty state."""
180
- self ._matchers = []
185
+ self ._matchers = defaultdict ( list )
0 commit comments