Skip to content

Commit 28416d6

Browse files
committed
Add support for repeated get params in aiohttp
1 parent b92be4e commit 28416d6

File tree

3 files changed

+64
-14
lines changed

3 files changed

+64
-14
lines changed

tests/integration/test_aiohttp.py

Lines changed: 51 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
import contextlib
22

33
import pytest
4+
import vcr # noqa: E402
5+
6+
7+
multidict = pytest.importorskip("multidict")
48
asyncio = pytest.importorskip("asyncio")
59
aiohttp = pytest.importorskip("aiohttp")
610

7-
import vcr # noqa: E402
811
from .aiohttp_utils import aiohttp_app, aiohttp_request # noqa: E402
912

1013

@@ -97,7 +100,9 @@ def test_stream(tmpdir, scheme):
97100
url = scheme + '://httpbin.org/get'
98101

99102
with vcr.use_cassette(str(tmpdir.join('stream.yaml'))):
100-
resp, body = get(url, output='raw') # Do not use stream here, as the stream is exhausted by vcr
103+
resp, body = get(
104+
url, output='raw'
105+
) # Do not use stream here, as the stream is exhausted by vcr
101106

102107
with vcr.use_cassette(str(tmpdir.join('stream.yaml'))) as cassette:
103108
cassette_resp, cassette_body = get(url, output='stream')
@@ -123,10 +128,16 @@ def test_params(tmpdir, scheme):
123128
params = {'a': 1, 'b': False, 'c': 'c'}
124129

125130
with vcr.use_cassette(str(tmpdir.join('get.yaml'))) as cassette:
126-
_, response_json = get(url, output='json', params=params, headers=headers)
131+
_, response_json = get(url,
132+
output='json',
133+
params=params,
134+
headers=headers)
127135

128136
with vcr.use_cassette(str(tmpdir.join('get.yaml'))) as cassette:
129-
_, cassette_response_json = get(url, output='json', params=params, headers=headers)
137+
_, cassette_response_json = get(url,
138+
output='json',
139+
params=params,
140+
headers=headers)
130141
assert cassette_response_json == response_json
131142
assert cassette.play_count == 1
132143

@@ -137,16 +148,24 @@ def test_params_same_url_distinct_params(tmpdir, scheme):
137148
params = {'a': 1, 'b': False, 'c': 'c'}
138149

139150
with vcr.use_cassette(str(tmpdir.join('get.yaml'))) as cassette:
140-
_, response_json = get(url, output='json', params=params, headers=headers)
151+
_, response_json = get(url,
152+
output='json',
153+
params=params,
154+
headers=headers)
141155

142156
with vcr.use_cassette(str(tmpdir.join('get.yaml'))) as cassette:
143-
_, cassette_response_json = get(url, output='json', params=params, headers=headers)
157+
_, cassette_response_json = get(url,
158+
output='json',
159+
params=params,
160+
headers=headers)
144161
assert cassette_response_json == response_json
145162
assert cassette.play_count == 1
146163

147164
other_params = {'other': 'params'}
148165
with vcr.use_cassette(str(tmpdir.join('get.yaml'))) as cassette:
149-
response, cassette_response_text = get(url, output='text', params=other_params)
166+
response, cassette_response_text = get(url,
167+
output='text',
168+
params=other_params)
150169
assert 'No match for the request' in cassette_response_text
151170
assert response.status == 599
152171

@@ -164,6 +183,17 @@ def test_params_on_url(tmpdir, scheme):
164183
_, cassette_response_json = get(url, output='json', headers=headers)
165184
request = cassette.requests[0]
166185
assert request.url == url
186+
187+
188+
def test_repeated_params(tmpdir, scheme):
189+
url = scheme + '://httpbin.org/get'
190+
params = [('a', '1'), ('c', 'c'), ('a', '2')]
191+
192+
with vcr.use_cassette(str(tmpdir.join('get.yaml'))) as cassette:
193+
_, response_json = get(url, as_text=False, params=params)
194+
195+
with vcr.use_cassette(str(tmpdir.join('get.yaml'))) as cassette:
196+
_, cassette_response_json = get(url, as_text=False, params=params)
167197
assert cassette_response_json == response_json
168198
assert cassette.play_count == 1
169199

@@ -214,3 +244,17 @@ def test_aiohttp_test_client_json(aiohttp_client, tmpdir):
214244
response_json = loop.run_until_complete(response.json())
215245
assert response_json is None
216246
assert cassette.play_count == 1
247+
248+
249+
def test_repeated_params_multidict(tmpdir, scheme):
250+
url = scheme + '://httpbin.org/get'
251+
params_list = [('a', 1), ('c', 'c'), ('a', '2')]
252+
params = multidict.MultiDict(params_list)
253+
254+
with vcr.use_cassette(str(tmpdir.join('get.yaml'))) as cassette:
255+
_, response_json = get(url, as_text=False, params=params)
256+
257+
with vcr.use_cassette(str(tmpdir.join('get.yaml'))) as cassette:
258+
_, cassette_response_json = get(url, as_text=False, params=params)
259+
assert cassette_response_json == response_json
260+
assert cassette.play_count == 1

tox.ini

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ deps =
2929
aiohttp: aiohttp
3030
aiohttp: pytest-asyncio
3131
aiohttp: pytest-aiohttp
32+
aiohttp: multidict
3233

3334
[flake8]
3435
max_line_length = 110

vcr/stubs/aiohttp_stubs/__init__.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@
55
import functools
66
import json
77

8+
import multidict
89
from aiohttp import ClientResponse, streams
9-
from yarl import URL
10-
1110
from vcr.request import Request
11+
from yarl import URL
1212

1313

1414
class MockStream(asyncio.StreamReader, streams.AsyncStreamReaderMixin):
@@ -29,7 +29,8 @@ def __init__(self, method, url):
2929
session=None,
3030
)
3131

32-
async def json(self, *, encoding='utf-8', loads=json.loads, **kwargs): # NOQA: E999
32+
async def json(self, *, encoding='utf-8', loads=json.loads,
33+
**kwargs): # NOQA: E999
3334
stripped = self._body.strip()
3435
if not stripped:
3536
return None
@@ -63,9 +64,10 @@ async def new_request(self, method, url, **kwargs):
6364

6465
request_url = URL(url)
6566
if params:
67+
new_params = multidict.MultiDict()
6668
for k, v in params.items():
67-
params[k] = str(v)
68-
request_url = URL(url).with_query(params)
69+
new_params.add(k, str(v))
70+
request_url = URL(url).with_query(new_params)
6971

7072
vcr_request = Request(method, str(request_url), data, headers)
7173

@@ -91,15 +93,18 @@ async def new_request(self, method, url, **kwargs):
9193
response.close()
9294
return response
9395

94-
response = await real_request(self, method, url, **kwargs) # NOQA: E999
96+
response = await real_request(self, method, url,
97+
**kwargs) # NOQA: E999
9598

9699
vcr_response = {
97100
'status': {
98101
'code': response.status,
99102
'message': response.reason,
100103
},
101104
'headers': dict(response.headers),
102-
'body': {'string': (await response.read())}, # NOQA: E999
105+
'body': {
106+
'string': (await response.read())
107+
}, # NOQA: E999
103108
'url': response.url,
104109
}
105110
cassette.append(vcr_request, vcr_response)

0 commit comments

Comments
 (0)