6
6
import logging
7
7
import json
8
8
9
- from aiohttp import ClientResponse , streams
9
+ from aiohttp import ClientConnectionError , ClientResponse , RequestInfo , streams
10
10
from multidict import CIMultiDict , CIMultiDictProxy
11
11
from yarl import URL
12
12
@@ -20,14 +20,14 @@ class MockStream(asyncio.StreamReader, streams.AsyncStreamReaderMixin):
20
20
21
21
22
22
class MockClientResponse (ClientResponse ):
23
- def __init__ (self , method , url ):
23
+ def __init__ (self , method , url , request_info = None ):
24
24
super ().__init__ (
25
25
method = method ,
26
26
url = url ,
27
27
writer = None ,
28
28
continue100 = None ,
29
29
timer = None ,
30
- request_info = None ,
30
+ request_info = request_info ,
31
31
traces = None ,
32
32
loop = asyncio .get_event_loop (),
33
33
session = None ,
@@ -58,7 +58,13 @@ def content(self):
58
58
59
59
60
60
def build_response (vcr_request , vcr_response , history ):
61
- response = MockClientResponse (vcr_request .method , URL (vcr_response .get ("url" )))
61
+ request_info = RequestInfo (
62
+ url = URL (vcr_request .url ),
63
+ method = vcr_request .method ,
64
+ headers = CIMultiDictProxy (CIMultiDict (vcr_request .headers )),
65
+ real_url = URL (vcr_request .url ),
66
+ )
67
+ response = MockClientResponse (vcr_request .method , URL (vcr_response .get ("url" )), request_info = request_info )
62
68
response .status = vcr_response ["status" ]["code" ]
63
69
response ._body = vcr_response ["body" ].get ("string" , b"" )
64
70
response .reason = vcr_response ["status" ]["message" ]
@@ -69,35 +75,92 @@ def build_response(vcr_request, vcr_response, history):
69
75
return response
70
76
71
77
78
+ def _serialize_headers (headers ):
79
+ """Serialize CIMultiDictProxy to a pickle-able dict because proxy
80
+ objects forbid pickling:
81
+
82
+ https://github.com/aio-libs/multidict/issues/340
83
+ """
84
+ # Mark strings as keys so 'istr' types don't show up in
85
+ # the cassettes as comments.
86
+ return {str (k ): v for k , v in headers .items ()}
87
+
88
+
72
89
def play_responses (cassette , vcr_request ):
73
90
history = []
74
91
vcr_response = cassette .play_response (vcr_request )
75
92
response = build_response (vcr_request , vcr_response , history )
76
93
77
- while cassette .can_play_response_for (vcr_request ):
94
+ # If we're following redirects, continue playing until we reach
95
+ # our final destination.
96
+ while 300 <= response .status <= 399 :
97
+ next_url = URL (response .url ).with_path (response .headers ["location" ])
98
+
99
+ # Make a stub VCR request that we can then use to look up the recorded
100
+ # VCR request saved to the cassette. This feels a little hacky and
101
+ # may have edge cases based on the headers we're providing (e.g. if
102
+ # there's a matcher that is used to filter by headers).
103
+ vcr_request = Request ("GET" , str (next_url ), None , _serialize_headers (response .request_info .headers ))
104
+ vcr_request = cassette .find_requests_with_most_matches (vcr_request )[0 ][0 ]
105
+
106
+ # Tack on the response we saw from the redirect into the history
107
+ # list that is added on to the final response.
78
108
history .append (response )
79
109
vcr_response = cassette .play_response (vcr_request )
80
110
response = build_response (vcr_request , vcr_response , history )
81
111
82
112
return response
83
113
84
114
85
- async def record_response (cassette , vcr_request , response , past = False ):
86
- body = {} if past else {"string" : (await response .read ())}
87
- headers = {str (key ): value for key , value in response .headers .items ()}
115
+ async def record_response (cassette , vcr_request , response ):
116
+ """Record a VCR request-response chain to the cassette."""
117
+
118
+ try :
119
+ body = {"string" : (await response .read ())}
120
+ # aiohttp raises a ClientConnectionError on reads when
121
+ # there is no body. We can use this to know to not write one.
122
+ except ClientConnectionError :
123
+ body = {}
88
124
89
125
vcr_response = {
90
126
"status" : {"code" : response .status , "message" : response .reason },
91
- "headers" : headers ,
127
+ "headers" : _serialize_headers ( response . headers ) ,
92
128
"body" : body , # NOQA: E999
93
129
"url" : str (response .url ),
94
130
}
131
+
95
132
cassette .append (vcr_request , vcr_response )
96
133
97
134
98
135
async def record_responses (cassette , vcr_request , response ):
136
+ """Because aiohttp follows redirects by default, we must support
137
+ them by default. This method is used to write individual
138
+ request-response chains that were implicitly followed to get
139
+ to the final destination.
140
+ """
141
+
99
142
for past_response in response .history :
100
- await record_response (cassette , vcr_request , past_response , past = True )
143
+ aiohttp_request = past_response .request_info
144
+
145
+ # No data because it's following a redirect.
146
+ past_request = Request (
147
+ aiohttp_request .method ,
148
+ str (aiohttp_request .url ),
149
+ None ,
150
+ _serialize_headers (aiohttp_request .headers ),
151
+ )
152
+ await record_response (cassette , past_request , past_response )
153
+
154
+ # If we're following redirects, then the last request-response
155
+ # we record is the one attached to the `response`.
156
+ if response .history :
157
+ aiohttp_request = response .request_info
158
+ vcr_request = Request (
159
+ aiohttp_request .method ,
160
+ str (aiohttp_request .url ),
161
+ None ,
162
+ _serialize_headers (aiohttp_request .headers ),
163
+ )
101
164
102
165
await record_response (cassette , vcr_request , response )
103
166
0 commit comments