-
Notifications
You must be signed in to change notification settings - Fork 402
/
Copy pathhttpx_stubs.py
192 lines (144 loc) · 6.32 KB
/
httpx_stubs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
import asyncio
import functools
import inspect
import logging
import warnings
from unittest.mock import MagicMock, patch
import httpx
from vcr.errors import CannotOverwriteExistingCassetteException
from vcr.filters import decode_response
from vcr.request import Request as VcrRequest
from vcr.serializers.compat import convert_body_to_bytes
_httpx_signature = inspect.signature(httpx.Client.request)
try:
HTTPX_REDIRECT_PARAM = _httpx_signature.parameters["follow_redirects"]
except KeyError:
HTTPX_REDIRECT_PARAM = _httpx_signature.parameters["allow_redirects"]
_logger = logging.getLogger(__name__)
def _transform_headers(httpx_response):
"""
Some headers can appear multiple times, like "Set-Cookie".
Therefore transform to every header key to list of values.
"""
out = {}
for key, var in httpx_response.headers.raw:
decoded_key = key.decode("utf-8")
out.setdefault(decoded_key, [])
out[decoded_key].append(var.decode("utf-8"))
return out
async def _to_serialized_response(resp, aread):
# The content shouldn't already have been read in by HTTPX.
assert not hasattr(resp, "_decoder")
# Retrieve the content, but without decoding it.
with patch.dict(resp.headers, {"Content-Encoding": ""}):
if aread:
await resp.aread()
else:
resp.read()
result = {
"status": {"code": resp.status_code, "message": resp.reason_phrase},
"headers": _transform_headers(resp),
"body": {"string": resp.content},
}
# As the content wasn't decoded, we restore the response to a state which
# will be capable of decoding the content for the consumer.
del resp._decoder
resp._content = resp._get_content_decoder().decode(resp.content)
return result
def _from_serialized_headers(headers):
"""
httpx accepts headers as list of tuples of header key and value.
"""
header_list = []
for key, values in headers.items():
for v in values:
header_list.append((key, v))
return header_list
@patch("httpx.Response.close", MagicMock())
@patch("httpx.Response.read", MagicMock())
def _from_serialized_response(request, serialized_response, history=None):
# Cassette format generated for HTTPX requests by older versions of
# vcrpy. We restructure the content to resemble what a regular
# cassette looks like.
if "status_code" in serialized_response:
serialized_response = decode_response(
convert_body_to_bytes(
{
"headers": serialized_response["headers"],
"body": {"string": serialized_response["content"]},
"status": {"code": serialized_response["status_code"]},
},
),
)
extensions = None
else:
extensions = {"reason_phrase": serialized_response["status"]["message"].encode()}
response = httpx.Response(
status_code=serialized_response["status"]["code"],
request=request,
headers=_from_serialized_headers(serialized_response["headers"]),
content=serialized_response["body"]["string"],
history=history or [],
extensions=extensions,
)
return response
def _make_vcr_request(httpx_request, **kwargs):
try:
body = httpx_request.read().decode("utf-8")
except UnicodeDecodeError as e:
body = httpx_request.read().decode("utf-8", errors="ignore")
warnings.warn(f"Could not decode full request payload as UTF8, recording may have lost bytes. {e}")
uri = str(httpx_request.url)
headers = dict(httpx_request.headers)
return VcrRequest(httpx_request.method, uri, body, headers)
def _shared_vcr_send(cassette, real_send, *args, **kwargs):
real_request = args[1]
vcr_request = _make_vcr_request(real_request, **kwargs)
if cassette.can_play_response_for(vcr_request):
return vcr_request, _play_responses(cassette, real_request, vcr_request, args[0], kwargs)
if cassette.write_protected and cassette.filter_request(vcr_request):
raise CannotOverwriteExistingCassetteException(cassette=cassette, failed_request=vcr_request)
_logger.info("%s not in cassette, sending to real server", vcr_request)
return vcr_request, None
async def _record_responses(cassette, vcr_request, real_response, aread):
for past_real_response in real_response.history:
past_vcr_request = _make_vcr_request(past_real_response.request)
cassette.append(past_vcr_request, await _to_serialized_response(past_real_response, aread))
if real_response.history:
# If there was a redirection keep we want the request which will hold the
# final redirect value
vcr_request = _make_vcr_request(real_response.request)
cassette.append(vcr_request, await _to_serialized_response(real_response, aread))
return real_response
def _play_responses(cassette, request, vcr_request, client, kwargs):
vcr_response = cassette.play_response(vcr_request)
response = _from_serialized_response(request, vcr_response)
return response
async def _async_vcr_send(cassette, real_send, *args, **kwargs):
vcr_request, response = _shared_vcr_send(cassette, real_send, *args, **kwargs)
if response:
# add cookies from response to session cookie store
args[0].cookies.extract_cookies(response)
return response
real_response = await real_send(*args, **kwargs)
await _record_responses(cassette, vcr_request, real_response, aread=True)
return real_response
def async_vcr_send(cassette, real_send):
@functools.wraps(real_send)
def _inner_send(*args, **kwargs):
return _async_vcr_send(cassette, real_send, *args, **kwargs)
return _inner_send
def _sync_vcr_send(cassette, real_send, *args, **kwargs):
vcr_request, response = _shared_vcr_send(cassette, real_send, *args, **kwargs)
if response:
# add cookies from response to session cookie store
args[0].cookies.extract_cookies(response)
return response
real_response = real_send(*args, **kwargs)
asyncio.run(_record_responses(cassette, vcr_request, real_response, aread=False))
return real_response
def sync_vcr_send(cassette, real_send):
@functools.wraps(real_send)
def _inner_send(*args, **kwargs):
return _sync_vcr_send(cassette, real_send, *args, **kwargs)
return _inner_send