1- import asyncio
21import functools
32import logging
43from collections import defaultdict
5- from collections .abc import AsyncIterable , Iterable
64
75from httpcore import Response
86from httpcore ._models import ByteStream
1513_logger = logging .getLogger (__name__ )
1614
1715
18- async def _convert_byte_stream (stream ):
19- if isinstance (stream , Iterable ):
20- return list (stream )
21-
22- if isinstance (stream , AsyncIterable ):
23- return [part async for part in stream ]
24-
25- raise TypeError (
26- f"_convert_byte_stream: stream must be Iterable or AsyncIterable, got { type (stream ).__name__ } " ,
27- )
28-
29-
3016def _serialize_headers (real_response ):
3117 """
3218 Some headers can appear multiple times, like "Set-Cookie".
@@ -41,21 +27,17 @@ def _serialize_headers(real_response):
4127 return dict (headers )
4228
4329
44- async def _serialize_response (real_response ):
30+ def _serialize_response (real_response , real_response_content ):
4531 # The reason_phrase may not exist
4632 try :
4733 reason_phrase = real_response .extensions ["reason_phrase" ].decode ("ascii" )
4834 except KeyError :
4935 reason_phrase = None
5036
51- # Reading the response stream consumes the iterator, so we need to restore it afterwards
52- content = b"" .join (await _convert_byte_stream (real_response .stream ))
53- real_response .stream = ByteStream (content )
54-
5537 return {
5638 "status" : {"code" : real_response .status , "message" : reason_phrase },
5739 "headers" : _serialize_headers (real_response ),
58- "body" : {"string" : content },
40+ "body" : {"string" : real_response_content },
5941 }
6042
6143
@@ -99,11 +81,7 @@ def _deserialize_response(vcr_response):
9981 )
10082
10183
102- async def _make_vcr_request (real_request ):
103- # Reading the request stream consumes the iterator, so we need to restore it afterwards
104- body = b"" .join (await _convert_byte_stream (real_request .stream ))
105- real_request .stream = ByteStream (body )
106-
84+ def _make_vcr_request (real_request , real_request_body ):
10785 uri = bytes (real_request .url ).decode ("ascii" )
10886
10987 # As per HTTPX: If there are multiple headers with the same key, then we concatenate them with commas
@@ -114,28 +92,25 @@ async def _make_vcr_request(real_request):
11492
11593 headers = {name : ", " .join (values ) for name , values in headers .items ()}
11694
117- return VcrRequest (real_request .method .decode ("ascii" ), uri , body , headers )
95+ return VcrRequest (real_request .method .decode ("ascii" ), uri , real_request_body , headers )
11896
11997
120- async def _vcr_request (cassette , real_request ):
121- vcr_request = await _make_vcr_request (real_request )
98+ def _vcr_request (cassette , real_request , real_request_body ):
99+ vcr_request = _make_vcr_request (real_request , real_request_body )
122100
123101 if cassette .can_play_response_for (vcr_request ):
124102 return vcr_request , _play_responses (cassette , vcr_request )
125103
126104 if cassette .write_protected and cassette .filter_request (vcr_request ):
127- raise CannotOverwriteExistingCassetteException (
128- cassette = cassette ,
129- failed_request = vcr_request ,
130- )
105+ raise CannotOverwriteExistingCassetteException (cassette = cassette , failed_request = vcr_request )
131106
132107 _logger .info ("%s not in cassette, sending to real server" , vcr_request )
133108
134109 return vcr_request , None
135110
136111
137- async def _record_responses (cassette , vcr_request , real_response ):
138- cassette .append (vcr_request , await _serialize_response (real_response ))
112+ def _record_responses (cassette , vcr_request , real_response , real_response_content ):
113+ cassette .append (vcr_request , _serialize_response (real_response , real_response_content ))
139114
140115
141116def _play_responses (cassette , vcr_request ):
@@ -145,64 +120,52 @@ def _play_responses(cassette, vcr_request):
145120 return real_response
146121
147122
148- async def _vcr_handle_async_request (
149- cassette ,
150- real_handle_async_request ,
151- self ,
152- real_request ,
153- ):
154- vcr_request , vcr_response = await _vcr_request (cassette , real_request )
123+ async def _vcr_handle_async_request (cassette , real_handle_async_request , self , real_request ):
124+ # Reading the request stream consumes the iterator, so we need to restore it afterwards
125+ real_request_body = b"" .join ([part async for part in real_request .stream ])
126+ real_request .stream = ByteStream (real_request_body )
127+
128+ vcr_request , vcr_response = _vcr_request (cassette , real_request , real_request_body )
155129
156130 if vcr_response :
157131 return vcr_response
158132
159133 real_response = await real_handle_async_request (self , real_request )
160- await _record_responses (cassette , vcr_request , real_response )
134+
135+ # Reading the response stream consumes the iterator, so we need to restore it afterwards
136+ real_response_content = b"" .join ([part async for part in real_response .stream ])
137+ real_response .stream = ByteStream (real_response_content )
138+
139+ _record_responses (cassette , vcr_request , real_response , real_response_content )
161140
162141 return real_response
163142
164143
165144def vcr_handle_async_request (cassette , real_handle_async_request ):
166145 @functools .wraps (real_handle_async_request )
167146 def _inner_handle_async_request (self , real_request ):
168- return _vcr_handle_async_request (
169- cassette ,
170- real_handle_async_request ,
171- self ,
172- real_request ,
173- )
147+ return _vcr_handle_async_request (cassette , real_handle_async_request , self , real_request )
174148
175149 return _inner_handle_async_request
176150
177151
178- def _run_async_function (sync_func , * args , ** kwargs ):
179- """
180- Safely run an asynchronous function from a synchronous context.
181- Handles both cases:
182- - An event loop is already running.
183- - No event loop exists yet.
184- """
185- try :
186- asyncio .get_running_loop ()
187- except RuntimeError :
188- return asyncio .run (sync_func (* args , ** kwargs ))
189- else :
190- # If inside a running loop, create a task and wait for it
191- return asyncio .ensure_future (sync_func (* args , ** kwargs ))
192-
193-
194152def _vcr_handle_request (cassette , real_handle_request , self , real_request ):
195- vcr_request , vcr_response = _run_async_function (
196- _vcr_request ,
197- cassette ,
198- real_request ,
199- )
153+ # Reading the request stream consumes the iterator, so we need to restore it afterwards
154+ real_request_body = b"" . join ( list ( real_request . stream ))
155+ real_request . stream = ByteStream ( real_request_body )
156+
157+ vcr_request , vcr_response = _vcr_request ( cassette , real_request , real_request_body )
200158
201159 if vcr_response :
202160 return vcr_response
203161
204162 real_response = real_handle_request (self , real_request )
205- _run_async_function (_record_responses , cassette , vcr_request , real_response )
163+
164+ # Reading the response stream consumes the iterator, so we need to restore it afterwards
165+ real_response_content = b"" .join (list (real_response .stream ))
166+ real_response .stream = ByteStream (real_response_content )
167+
168+ _record_responses (cassette , vcr_request , real_response , real_response_content )
206169
207170 return real_response
208171
0 commit comments