Skip to content

Commit 9ec19dd

Browse files
nickdirienzoneozenith
authored andcommitted
aiohttp: fix multiple requests being replayed per request and add support for request_info on mocked responses (#495)
* Fix how redirects are handled so we can have requests with the same URL be used distinctly * Add support for request_info. Remove `past` kwarg. * Remove as e to make linter happy * Add unreleased 3.0.0 to changelog.
1 parent ffd2142 commit 9ec19dd

File tree

3 files changed

+117
-10
lines changed

3 files changed

+117
-10
lines changed

docs/changelog.rst

+4
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
Changelog
22
---------
3+
- 3.0.0 (UNRELEASED)
4+
- Fix multiple requests being replayed per single request in aiohttp stub (@nickdirienzo)
5+
- Add support for `request_info` on mocked responses in aiohttp stub (@nickdirienzo)
6+
- ...
37
- 2.1.x (UNRELEASED)
48
- ....
59
- 2.1.1

tests/integration/test_aiohttp.py

+40
Original file line numberDiff line numberDiff line change
@@ -262,3 +262,43 @@ def test_redirect(aiohttp_client, tmpdir):
262262
assert len(cassette_response.history) == len(response.history)
263263
assert len(cassette) == 3
264264
assert cassette.play_count == 3
265+
266+
# Assert that the real response and the cassette response have a similar
267+
# looking request_info.
268+
assert cassette_response.request_info.url == response.request_info.url
269+
assert cassette_response.request_info.method == response.request_info.method
270+
assert {k: v for k, v in cassette_response.request_info.headers.items()} == {
271+
k: v for k, v in response.request_info.headers.items()
272+
}
273+
assert cassette_response.request_info.real_url == response.request_info.real_url
274+
275+
276+
def test_double_requests(tmpdir):
277+
"""We should capture, record, and replay all requests and response chains,
278+
even if there are duplicate ones.
279+
280+
We should replay in the order we saw them.
281+
"""
282+
url = "https://httpbin.org/get"
283+
284+
with vcr.use_cassette(str(tmpdir.join("text.yaml"))):
285+
_, response_text1 = get(url, output="text")
286+
_, response_text2 = get(url, output="text")
287+
288+
with vcr.use_cassette(str(tmpdir.join("text.yaml"))) as cassette:
289+
resp, cassette_response_text = get(url, output="text")
290+
assert resp.status == 200
291+
assert cassette_response_text == response_text1
292+
293+
# We made only one request, so we should only play 1 recording.
294+
assert cassette.play_count == 1
295+
296+
# Now make the second test to url
297+
resp, cassette_response_text = get(url, output="text")
298+
299+
assert resp.status == 200
300+
301+
assert cassette_response_text == response_text2
302+
303+
# Now that we made both requests, we should have played both.
304+
assert cassette.play_count == 2

vcr/stubs/aiohttp_stubs/__init__.py

+73-10
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import logging
77
import json
88

9-
from aiohttp import ClientResponse, streams
9+
from aiohttp import ClientConnectionError, ClientResponse, RequestInfo, streams
1010
from multidict import CIMultiDict, CIMultiDictProxy
1111
from yarl import URL
1212

@@ -20,14 +20,14 @@ class MockStream(asyncio.StreamReader, streams.AsyncStreamReaderMixin):
2020

2121

2222
class MockClientResponse(ClientResponse):
23-
def __init__(self, method, url):
23+
def __init__(self, method, url, request_info=None):
2424
super().__init__(
2525
method=method,
2626
url=url,
2727
writer=None,
2828
continue100=None,
2929
timer=None,
30-
request_info=None,
30+
request_info=request_info,
3131
traces=None,
3232
loop=asyncio.get_event_loop(),
3333
session=None,
@@ -58,7 +58,13 @@ def content(self):
5858

5959

6060
def build_response(vcr_request, vcr_response, history):
61-
response = MockClientResponse(vcr_request.method, URL(vcr_response.get("url")))
61+
request_info = RequestInfo(
62+
url=URL(vcr_request.url),
63+
method=vcr_request.method,
64+
headers=CIMultiDictProxy(CIMultiDict(vcr_request.headers)),
65+
real_url=URL(vcr_request.url),
66+
)
67+
response = MockClientResponse(vcr_request.method, URL(vcr_response.get("url")), request_info=request_info)
6268
response.status = vcr_response["status"]["code"]
6369
response._body = vcr_response["body"].get("string", b"")
6470
response.reason = vcr_response["status"]["message"]
@@ -69,35 +75,92 @@ def build_response(vcr_request, vcr_response, history):
6975
return response
7076

7177

78+
def _serialize_headers(headers):
79+
"""Serialize CIMultiDictProxy to a pickle-able dict because proxy
80+
objects forbid pickling:
81+
82+
https://github.com/aio-libs/multidict/issues/340
83+
"""
84+
# Mark strings as keys so 'istr' types don't show up in
85+
# the cassettes as comments.
86+
return {str(k): v for k, v in headers.items()}
87+
88+
7289
def play_responses(cassette, vcr_request):
7390
history = []
7491
vcr_response = cassette.play_response(vcr_request)
7592
response = build_response(vcr_request, vcr_response, history)
7693

77-
while cassette.can_play_response_for(vcr_request):
94+
# If we're following redirects, continue playing until we reach
95+
# our final destination.
96+
while 300 <= response.status <= 399:
97+
next_url = URL(response.url).with_path(response.headers["location"])
98+
99+
# Make a stub VCR request that we can then use to look up the recorded
100+
# VCR request saved to the cassette. This feels a little hacky and
101+
# may have edge cases based on the headers we're providing (e.g. if
102+
# there's a matcher that is used to filter by headers).
103+
vcr_request = Request("GET", str(next_url), None, _serialize_headers(response.request_info.headers))
104+
vcr_request = cassette.find_requests_with_most_matches(vcr_request)[0][0]
105+
106+
# Tack on the response we saw from the redirect into the history
107+
# list that is added on to the final response.
78108
history.append(response)
79109
vcr_response = cassette.play_response(vcr_request)
80110
response = build_response(vcr_request, vcr_response, history)
81111

82112
return response
83113

84114

85-
async def record_response(cassette, vcr_request, response, past=False):
86-
body = {} if past else {"string": (await response.read())}
87-
headers = {str(key): value for key, value in response.headers.items()}
115+
async def record_response(cassette, vcr_request, response):
116+
"""Record a VCR request-response chain to the cassette."""
117+
118+
try:
119+
body = {"string": (await response.read())}
120+
# aiohttp raises a ClientConnectionError on reads when
121+
# there is no body. We can use this to know to not write one.
122+
except ClientConnectionError:
123+
body = {}
88124

89125
vcr_response = {
90126
"status": {"code": response.status, "message": response.reason},
91-
"headers": headers,
127+
"headers": _serialize_headers(response.headers),
92128
"body": body, # NOQA: E999
93129
"url": str(response.url),
94130
}
131+
95132
cassette.append(vcr_request, vcr_response)
96133

97134

98135
async def record_responses(cassette, vcr_request, response):
136+
"""Because aiohttp follows redirects by default, we must support
137+
them by default. This method is used to write individual
138+
request-response chains that were implicitly followed to get
139+
to the final destination.
140+
"""
141+
99142
for past_response in response.history:
100-
await record_response(cassette, vcr_request, past_response, past=True)
143+
aiohttp_request = past_response.request_info
144+
145+
# No data because it's following a redirect.
146+
past_request = Request(
147+
aiohttp_request.method,
148+
str(aiohttp_request.url),
149+
None,
150+
_serialize_headers(aiohttp_request.headers),
151+
)
152+
await record_response(cassette, past_request, past_response)
153+
154+
# If we're following redirects, then the last request-response
155+
# we record is the one attached to the `response`.
156+
if response.history:
157+
aiohttp_request = response.request_info
158+
vcr_request = Request(
159+
aiohttp_request.method,
160+
str(aiohttp_request.url),
161+
None,
162+
_serialize_headers(aiohttp_request.headers),
163+
)
101164

102165
await record_response(cassette, vcr_request, response)
103166

0 commit comments

Comments
 (0)