2
2
import io
3
3
import mimetypes
4
4
import os
5
- from typing import Any , AsyncIterator , Awaitable , Callable , Collection , Dict , Optional
5
+ from typing import (
6
+ Any ,
7
+ AsyncIterator ,
8
+ Awaitable ,
9
+ Callable ,
10
+ Collection ,
11
+ Dict ,
12
+ Mapping ,
13
+ Optional ,
14
+ cast ,
15
+ )
6
16
from urllib .parse import urlparse
7
17
8
18
import httpx
@@ -61,7 +71,7 @@ def webhook_headers() -> "dict[str, str]":
61
71
62
72
async def on_request_trace_context_hook (request : httpx .Request ) -> None :
63
73
ctx = current_trace_context () or {}
64
- request .headers .update (ctx )
74
+ request .headers .update (cast ( Mapping [ str , str ], ctx ) )
65
75
66
76
67
77
def httpx_webhook_client () -> httpx .AsyncClient :
@@ -111,6 +121,22 @@ def httpx_file_client() -> httpx.AsyncClient:
111
121
)
112
122
113
123
124
+ class ChunkFileReader :
125
+ def __init__ (self , fh : io .IOBase ) -> None :
126
+ self .fh = fh
127
+
128
+ async def __aiter__ (self ) -> AsyncIterator [bytes ]:
129
+ self .fh .seek (0 )
130
+ while True :
131
+ chunk = self .fh .read (1024 * 1024 )
132
+ if isinstance (chunk , str ):
133
+ chunk = chunk .encode ("utf-8" )
134
+ if not chunk :
135
+ log .info ("finished reading file" )
136
+ break
137
+ yield chunk
138
+
139
+
114
140
# there's a case for splitting this apart or inlining parts of it
115
141
# I'm somewhat sympathetic to separating webhooks and files, but they both have
116
142
# the same semantics of holding a client for the lifetime of runner
@@ -163,10 +189,11 @@ async def sender(response: Any, event: WebhookEvent) -> None:
163
189
164
190
# files
165
191
166
- async def upload_file (self , fh : io .IOBase , url : Optional [str ]) -> str :
192
+ async def upload_file (
193
+ self , fh : io .IOBase , * , url : Optional [str ], prediction_id : Optional [str ]
194
+ ) -> str :
167
195
"""put file to signed endpoint"""
168
196
log .debug ("upload_file" )
169
- fh .seek (0 )
170
197
# try to guess the filename of the given object
171
198
name = getattr (fh , "name" , "file" )
172
199
filename = os .path .basename (name ) or "file"
@@ -184,17 +211,12 @@ async def upload_file(self, fh: io.IOBase, url: Optional[str]) -> str:
184
211
# ensure trailing slash
185
212
url_with_trailing_slash = url if url .endswith ("/" ) else url + "/"
186
213
187
- async def chunk_file_reader () -> AsyncIterator [bytes ]:
188
- while 1 :
189
- chunk = fh .read (1024 * 1024 )
190
- if isinstance (chunk , str ):
191
- chunk = chunk .encode ("utf-8" )
192
- if not chunk :
193
- log .info ("finished reading file" )
194
- break
195
- yield chunk
196
-
197
214
url = url_with_trailing_slash + filename
215
+
216
+ headers = {"Content-Type" : content_type }
217
+ if prediction_id is not None :
218
+ headers ["X-Prediction-ID" ] = prediction_id
219
+
198
220
# this is a somewhat unfortunate hack, but it works
199
221
# and is critical for upload training/quantization outputs
200
222
# if we get multipart uploads working or a separate API route
@@ -204,29 +226,36 @@ async def chunk_file_reader() -> AsyncIterator[bytes]:
204
226
resp1 = await self .file_client .put (
205
227
url ,
206
228
content = b"" ,
207
- headers = { "Content-Type" : content_type } ,
229
+ headers = headers ,
208
230
follow_redirects = False ,
209
231
)
210
232
if resp1 .status_code == 307 and resp1 .headers ["Location" ]:
211
233
log .info ("got file upload redirect from api" )
212
234
url = resp1 .headers ["Location" ]
235
+
213
236
log .info ("doing real upload to %s" , url )
214
237
resp = await self .file_client .put (
215
238
url ,
216
- content = chunk_file_reader ( ),
217
- headers = { "Content-Type" : content_type } ,
239
+ content = ChunkFileReader ( fh ),
240
+ headers = headers ,
218
241
)
219
242
# TODO: if file size is >1MB, show upload throughput
220
243
resp .raise_for_status ()
221
244
222
- # strip any signing gubbins from the URL
223
- final_url = urlparse (str (resp .url ))._replace (query = "" ).geturl ()
245
+ # Try to extract the final asset URL from the `Location` header
246
+ # otherwise fallback to the URL of the final request.
247
+ final_url = str (resp .url )
248
+ if "location" in resp .headers :
249
+ final_url = resp .headers .get ("location" )
224
250
225
- return final_url
251
+ # strip any signing gubbins from the URL
252
+ return urlparse (final_url )._replace (query = "" ).geturl ()
226
253
227
254
# this previously lived in json.upload_files, but it's clearer here
228
255
# this is a great pattern that should be adopted for input files
229
- async def upload_files (self , obj : Any , url : Optional [str ]) -> Any :
256
+ async def upload_files (
257
+ self , obj : Any , * , url : Optional [str ], prediction_id : Optional [str ]
258
+ ) -> Any :
230
259
"""
231
260
Iterates through an object from make_encodeable and uploads any files.
232
261
When a file is encountered, it will be passed to upload_file. Any paths will be opened and converted to files.
@@ -238,15 +267,21 @@ async def upload_files(self, obj: Any, url: Optional[str]) -> Any:
238
267
# TODO: upload concurrently
239
268
if isinstance (obj , dict ):
240
269
return {
241
- key : await self .upload_files (value , url ) for key , value in obj .items ()
270
+ key : await self .upload_files (
271
+ value , url = url , prediction_id = prediction_id
272
+ )
273
+ for key , value in obj .items ()
242
274
}
243
275
if isinstance (obj , list ):
244
- return [await self .upload_files (value , url ) for value in obj ]
276
+ return [
277
+ await self .upload_files (value , url = url , prediction_id = prediction_id )
278
+ for value in obj
279
+ ]
245
280
if isinstance (obj , Path ):
246
281
with obj .open ("rb" ) as f :
247
- return await self .upload_file (f , url )
282
+ return await self .upload_file (f , url = url , prediction_id = prediction_id )
248
283
if isinstance (obj , io .IOBase ):
249
- return await self .upload_file (obj , url )
284
+ return await self .upload_file (obj , url = url , prediction_id = prediction_id )
250
285
return obj
251
286
252
287
# inputs
0 commit comments