Skip to content

Commit 728bc26

Browse files
committed
Split sync and async handlers
1 parent e4858b0 commit 728bc26

File tree

2 files changed

+49
-70
lines changed

2 files changed

+49
-70
lines changed

tests/integration/test_httpx.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -352,3 +352,19 @@ def test_gzip__decode_compressed_response_true(do_request, tmpdir, httpbin):
352352
# As the content is uncompressed, it should have a bigger
353353
# length than the compressed version.
354354
assert r.headers["content-length"] > content_length
355+
356+
357+
def test_sync_in_async_context(tmpdir, httpbin, do_request):
358+
async def run():
359+
url = httpbin.url
360+
361+
with vcr.use_cassette(str(tmpdir.join("sync_in_async_context.yaml"))):
362+
do_request()("GET", url)
363+
364+
with vcr.use_cassette(str(tmpdir.join("sync_in_async_context.yaml"))) as cassette:
365+
do_request()("GET", url)
366+
assert cassette.play_count == 1
367+
368+
# Only test sync requests in async contexts
369+
if do_request is DoSyncRequest:
370+
asyncio.run(run())

vcr/stubs/httpcore_stubs.py

Lines changed: 33 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
1-
import asyncio
21
import functools
32
import logging
43
from collections import defaultdict
5-
from collections.abc import AsyncIterable, Iterable
64

75
from httpcore import Response
86
from httpcore._models import ByteStream
@@ -15,18 +13,6 @@
1513
_logger = logging.getLogger(__name__)
1614

1715

18-
async def _convert_byte_stream(stream):
19-
if isinstance(stream, Iterable):
20-
return list(stream)
21-
22-
if isinstance(stream, AsyncIterable):
23-
return [part async for part in stream]
24-
25-
raise TypeError(
26-
f"_convert_byte_stream: stream must be Iterable or AsyncIterable, got {type(stream).__name__}",
27-
)
28-
29-
3016
def _serialize_headers(real_response):
3117
"""
3218
Some headers can appear multiple times, like "Set-Cookie".
@@ -41,21 +27,17 @@ def _serialize_headers(real_response):
4127
return dict(headers)
4228

4329

44-
async def _serialize_response(real_response):
30+
def _serialize_response(real_response, real_response_content):
4531
# The reason_phrase may not exist
4632
try:
4733
reason_phrase = real_response.extensions["reason_phrase"].decode("ascii")
4834
except KeyError:
4935
reason_phrase = None
5036

51-
# Reading the response stream consumes the iterator, so we need to restore it afterwards
52-
content = b"".join(await _convert_byte_stream(real_response.stream))
53-
real_response.stream = ByteStream(content)
54-
5537
return {
5638
"status": {"code": real_response.status, "message": reason_phrase},
5739
"headers": _serialize_headers(real_response),
58-
"body": {"string": content},
40+
"body": {"string": real_response_content},
5941
}
6042

6143

@@ -99,11 +81,7 @@ def _deserialize_response(vcr_response):
9981
)
10082

10183

102-
async def _make_vcr_request(real_request):
103-
# Reading the request stream consumes the iterator, so we need to restore it afterwards
104-
body = b"".join(await _convert_byte_stream(real_request.stream))
105-
real_request.stream = ByteStream(body)
106-
84+
def _make_vcr_request(real_request, real_request_body):
10785
uri = bytes(real_request.url).decode("ascii")
10886

10987
# As per HTTPX: If there are multiple headers with the same key, then we concatenate them with commas
@@ -114,28 +92,25 @@ async def _make_vcr_request(real_request):
11492

11593
headers = {name: ", ".join(values) for name, values in headers.items()}
11694

117-
return VcrRequest(real_request.method.decode("ascii"), uri, body, headers)
95+
return VcrRequest(real_request.method.decode("ascii"), uri, real_request_body, headers)
11896

11997

120-
async def _vcr_request(cassette, real_request):
121-
vcr_request = await _make_vcr_request(real_request)
98+
def _vcr_request(cassette, real_request, real_request_body):
99+
vcr_request = _make_vcr_request(real_request, real_request_body)
122100

123101
if cassette.can_play_response_for(vcr_request):
124102
return vcr_request, _play_responses(cassette, vcr_request)
125103

126104
if cassette.write_protected and cassette.filter_request(vcr_request):
127-
raise CannotOverwriteExistingCassetteException(
128-
cassette=cassette,
129-
failed_request=vcr_request,
130-
)
105+
raise CannotOverwriteExistingCassetteException(cassette=cassette, failed_request=vcr_request)
131106

132107
_logger.info("%s not in cassette, sending to real server", vcr_request)
133108

134109
return vcr_request, None
135110

136111

137-
async def _record_responses(cassette, vcr_request, real_response):
138-
cassette.append(vcr_request, await _serialize_response(real_response))
112+
def _record_responses(cassette, vcr_request, real_response, real_response_content):
113+
cassette.append(vcr_request, _serialize_response(real_response, real_response_content))
139114

140115

141116
def _play_responses(cassette, vcr_request):
@@ -145,64 +120,52 @@ def _play_responses(cassette, vcr_request):
145120
return real_response
146121

147122

148-
async def _vcr_handle_async_request(
149-
cassette,
150-
real_handle_async_request,
151-
self,
152-
real_request,
153-
):
154-
vcr_request, vcr_response = await _vcr_request(cassette, real_request)
123+
async def _vcr_handle_async_request(cassette, real_handle_async_request, self, real_request):
124+
# Reading the request stream consumes the iterator, so we need to restore it afterwards
125+
real_request_body = b"".join([part async for part in real_request.stream])
126+
real_request.stream = ByteStream(real_request_body)
127+
128+
vcr_request, vcr_response = _vcr_request(cassette, real_request, real_request_body)
155129

156130
if vcr_response:
157131
return vcr_response
158132

159133
real_response = await real_handle_async_request(self, real_request)
160-
await _record_responses(cassette, vcr_request, real_response)
134+
135+
# Reading the response stream consumes the iterator, so we need to restore it afterwards
136+
real_response_content = b"".join([part async for part in real_response.stream])
137+
real_response.stream = ByteStream(real_response_content)
138+
139+
_record_responses(cassette, vcr_request, real_response, real_response_content)
161140

162141
return real_response
163142

164143

165144
def vcr_handle_async_request(cassette, real_handle_async_request):
166145
@functools.wraps(real_handle_async_request)
167146
def _inner_handle_async_request(self, real_request):
168-
return _vcr_handle_async_request(
169-
cassette,
170-
real_handle_async_request,
171-
self,
172-
real_request,
173-
)
147+
return _vcr_handle_async_request(cassette, real_handle_async_request, self, real_request)
174148

175149
return _inner_handle_async_request
176150

177151

178-
def _run_async_function(sync_func, *args, **kwargs):
179-
"""
180-
Safely run an asynchronous function from a synchronous context.
181-
Handles both cases:
182-
- An event loop is already running.
183-
- No event loop exists yet.
184-
"""
185-
try:
186-
asyncio.get_running_loop()
187-
except RuntimeError:
188-
return asyncio.run(sync_func(*args, **kwargs))
189-
else:
190-
# If inside a running loop, create a task and wait for it
191-
return asyncio.ensure_future(sync_func(*args, **kwargs))
192-
193-
194152
def _vcr_handle_request(cassette, real_handle_request, self, real_request):
195-
vcr_request, vcr_response = _run_async_function(
196-
_vcr_request,
197-
cassette,
198-
real_request,
199-
)
153+
# Reading the request stream consumes the iterator, so we need to restore it afterwards
154+
real_request_body = b"".join(list(real_request.stream))
155+
real_request.stream = ByteStream(real_request_body)
156+
157+
vcr_request, vcr_response = _vcr_request(cassette, real_request, real_request_body)
200158

201159
if vcr_response:
202160
return vcr_response
203161

204162
real_response = real_handle_request(self, real_request)
205-
_run_async_function(_record_responses, cassette, vcr_request, real_response)
163+
164+
# Reading the response stream consumes the iterator, so we need to restore it afterwards
165+
real_response_content = b"".join(list(real_response.stream))
166+
real_response.stream = ByteStream(real_response_content)
167+
168+
_record_responses(cassette, vcr_request, real_response, real_response_content)
206169

207170
return real_response
208171

0 commit comments

Comments
 (0)