Skip to content

Commit 5be7569

Browse files
authored
Merge pull request #353 from lamenezes/support-aiohttp-over-3.1.0
Fix aiohttp stub to support version >= 3.1.0
2 parents f890709 + b10b92b commit 5be7569

File tree

8 files changed

+53
-42
lines changed

8 files changed

+53
-42
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
*.pyc
22
.tox
33
.cache
4+
.pytest_cache/
45
build/
56
dist/
67
*.egg/

.travis.yml

+2
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ env:
1616
matrix:
1717
allow_failures:
1818
- env: TOX_SUFFIX="boto3"
19+
- env: TOX_SUFFIX="aiohttp"
20+
python: "pypy3.5-5.9.0"
1921
exclude:
2022
# Only run flakes on a single Python 2.x and a single 3.x
2123
- env: TOX_SUFFIX="flakes"

tests/integration/aiohttp_utils.py

+18-10
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,23 @@
1+
# flake8: noqa
12
import asyncio
3+
24
import aiohttp
35

46

57
@asyncio.coroutine
6-
def aiohttp_request(loop, method, url, output='text', **kwargs):
7-
with aiohttp.ClientSession(loop=loop) as session:
8-
response = yield from session.request(method, url, **kwargs) # NOQA: E999
9-
if output == 'text':
10-
content = yield from response.text() # NOQA: E999
11-
elif output == 'json':
12-
content = yield from response.json() # NOQA: E999
13-
elif output == 'raw':
14-
content = yield from response.read() # NOQA: E999
15-
return response, content
8+
def aiohttp_request(loop, method, url, output='text', encoding='utf-8', **kwargs):
9+
session = aiohttp.ClientSession(loop=loop)
10+
response_ctx = session.request(method, url, **kwargs)
11+
12+
response = yield from response_ctx.__aenter__()
13+
if output == 'text':
14+
content = yield from response.text()
15+
elif output == 'json':
16+
content = yield from response.json(encoding=encoding)
17+
elif output == 'raw':
18+
content = yield from response.read()
19+
20+
response_ctx._resp.close()
21+
yield from session.close()
22+
23+
return response, content

tests/integration/async_def.py

-13
This file was deleted.

tests/integration/test_aiohttp.py

+15-16
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,12 @@
1+
import contextlib
2+
13
import pytest
4+
asyncio = pytest.importorskip("asyncio")
25
aiohttp = pytest.importorskip("aiohttp")
36

4-
import asyncio # noqa: E402
5-
import contextlib # noqa: E402
6-
7-
import pytest # noqa: E402
87
import vcr # noqa: E402
9-
108
from .aiohttp_utils import aiohttp_request # noqa: E402
119

12-
try:
13-
from .async_def import test_http # noqa: F401
14-
except SyntaxError:
15-
pass
16-
1710

1811
def run_in_loop(fn):
1912
with contextlib.closing(asyncio.new_event_loop()) as loop:
@@ -78,11 +71,13 @@ def test_text(tmpdir, scheme):
7871

7972
def test_json(tmpdir, scheme):
8073
url = scheme + '://httpbin.org/get'
74+
headers = {'Content-Type': 'application/json'}
75+
8176
with vcr.use_cassette(str(tmpdir.join('json.yaml'))):
82-
_, response_json = get(url, output='json')
77+
_, response_json = get(url, output='json', headers=headers)
8378

8479
with vcr.use_cassette(str(tmpdir.join('json.yaml'))) as cassette:
85-
_, cassette_response_json = get(url, output='json')
80+
_, cassette_response_json = get(url, output='json', headers=headers)
8681
assert cassette_response_json == response_json
8782
assert cassette.play_count == 1
8883

@@ -112,24 +107,28 @@ def test_post(tmpdir, scheme):
112107

113108
def test_params(tmpdir, scheme):
114109
url = scheme + '://httpbin.org/get'
110+
headers = {'Content-Type': 'application/json'}
115111
params = {'a': 1, 'b': False, 'c': 'c'}
112+
116113
with vcr.use_cassette(str(tmpdir.join('get.yaml'))) as cassette:
117-
_, response_json = get(url, output='json', params=params)
114+
_, response_json = get(url, output='json', params=params, headers=headers)
118115

119116
with vcr.use_cassette(str(tmpdir.join('get.yaml'))) as cassette:
120-
_, cassette_response_json = get(url, output='json', params=params)
117+
_, cassette_response_json = get(url, output='json', params=params, headers=headers)
121118
assert cassette_response_json == response_json
122119
assert cassette.play_count == 1
123120

124121

125122
def test_params_same_url_distinct_params(tmpdir, scheme):
126123
url = scheme + '://httpbin.org/get'
124+
headers = {'Content-Type': 'application/json'}
127125
params = {'a': 1, 'b': False, 'c': 'c'}
126+
128127
with vcr.use_cassette(str(tmpdir.join('get.yaml'))) as cassette:
129-
_, response_json = get(url, output='json', params=params)
128+
_, response_json = get(url, output='json', params=params, headers=headers)
130129

131130
with vcr.use_cassette(str(tmpdir.join('get.yaml'))) as cassette:
132-
_, cassette_response_json = get(url, output='json', params=params)
131+
_, cassette_response_json = get(url, output='json', params=params, headers=headers)
133132
assert cassette_response_json == response_json
134133
assert cassette.play_count == 1
135134

tests/integration/test_requests.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,8 @@ def test_post_chunked_binary(tmpdir, httpbin):
116116
assert req1 == req2
117117

118118

119-
@pytest.mark.xfail('sys.version_info >= (3, 6)', strict=True, raises=ConnectionError)
120-
@pytest.mark.xfail((3, 5) < sys.version_info < (3, 6) and
119+
@pytest.mark.xskip('sys.version_info >= (3, 6)', strict=True, raises=ConnectionError)
120+
@pytest.mark.xskip((3, 5) < sys.version_info < (3, 6) and
121121
platform.python_implementation() == 'CPython',
122122
reason='Fails on CPython 3.5')
123123
def test_post_chunked_binary_secure(tmpdir, httpbin_secure):

tox.ini

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ deps =
2525
{py27,py35,py36,pypy}-tornado4: pytest-tornado
2626
{py27,py35,py36}-tornado4: pycurl
2727
boto3: boto3
28-
aiohttp: aiohttp<3
28+
aiohttp: aiohttp
2929
aiohttp: pytest-asyncio
3030

3131
[flake8]

vcr/stubs/aiohttp_stubs/__init__.py

+14
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,20 @@
1212

1313

1414
class MockClientResponse(ClientResponse):
15+
def __init__(self, method, url):
16+
super().__init__(
17+
method=method,
18+
url=url,
19+
writer=None,
20+
continue100=None,
21+
timer=None,
22+
request_info=None,
23+
auto_decompress=None,
24+
traces=None,
25+
loop=asyncio.get_event_loop(),
26+
session=None,
27+
)
28+
1529
# TODO: get encoding from header
1630
@asyncio.coroutine
1731
def json(self, *, encoding='utf-8', loads=json.loads): # NOQA: E999

0 commit comments

Comments
 (0)