Skip to content

Commit b1d4d18

Browse files
aronmattt
authored andcommitted
[async] Include prediction id upload request (#1788)
* Cast TraceContext into Mapping[str, str] to fix linter * Include prediction id upload request Based on #1667 This PR introduces two small changes to the file upload interface. 1. We now allow downstream services to include the destination of the asset in a `Location` header, rather than assuming that it's the same as the final upload url (either the one passed via `--upload-url` or the result of a 307 redirect response. 2. We now include the `X-Prediction-Id` header in upload request, this allows the downstream client to potentially do configuration/routing based on the prediction ID. This ID should be considered unsafe and needs to be validated by the downstream service. * Extract ChunkFileReader into top-level class --------- Co-authored-by: Mattt Zmuda <[email protected]>
1 parent e729dd6 commit b1d4d18

File tree

3 files changed

+158
-28
lines changed

3 files changed

+158
-28
lines changed

Diff for: python/cog/server/clients.py

+60-25
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,17 @@
22
import io
33
import mimetypes
44
import os
5-
from typing import Any, AsyncIterator, Awaitable, Callable, Collection, Dict, Optional
5+
from typing import (
6+
Any,
7+
AsyncIterator,
8+
Awaitable,
9+
Callable,
10+
Collection,
11+
Dict,
12+
Mapping,
13+
Optional,
14+
cast,
15+
)
616
from urllib.parse import urlparse
717

818
import httpx
@@ -61,7 +71,7 @@ def webhook_headers() -> "dict[str, str]":
6171

6272
async def on_request_trace_context_hook(request: httpx.Request) -> None:
6373
ctx = current_trace_context() or {}
64-
request.headers.update(ctx)
74+
request.headers.update(cast(Mapping[str, str], ctx))
6575

6676

6777
def httpx_webhook_client() -> httpx.AsyncClient:
@@ -111,6 +121,22 @@ def httpx_file_client() -> httpx.AsyncClient:
111121
)
112122

113123

124+
class ChunkFileReader:
125+
def __init__(self, fh: io.IOBase) -> None:
126+
self.fh = fh
127+
128+
async def __aiter__(self) -> AsyncIterator[bytes]:
129+
self.fh.seek(0)
130+
while True:
131+
chunk = self.fh.read(1024 * 1024)
132+
if isinstance(chunk, str):
133+
chunk = chunk.encode("utf-8")
134+
if not chunk:
135+
log.info("finished reading file")
136+
break
137+
yield chunk
138+
139+
114140
# there's a case for splitting this apart or inlining parts of it
115141
# I'm somewhat sympathetic to separating webhooks and files, but they both have
116142
# the same semantics of holding a client for the lifetime of runner
@@ -163,10 +189,11 @@ async def sender(response: Any, event: WebhookEvent) -> None:
163189

164190
# files
165191

166-
async def upload_file(self, fh: io.IOBase, url: Optional[str]) -> str:
192+
async def upload_file(
193+
self, fh: io.IOBase, *, url: Optional[str], prediction_id: Optional[str]
194+
) -> str:
167195
"""put file to signed endpoint"""
168196
log.debug("upload_file")
169-
fh.seek(0)
170197
# try to guess the filename of the given object
171198
name = getattr(fh, "name", "file")
172199
filename = os.path.basename(name) or "file"
@@ -184,17 +211,12 @@ async def upload_file(self, fh: io.IOBase, url: Optional[str]) -> str:
184211
# ensure trailing slash
185212
url_with_trailing_slash = url if url.endswith("/") else url + "/"
186213

187-
async def chunk_file_reader() -> AsyncIterator[bytes]:
188-
while 1:
189-
chunk = fh.read(1024 * 1024)
190-
if isinstance(chunk, str):
191-
chunk = chunk.encode("utf-8")
192-
if not chunk:
193-
log.info("finished reading file")
194-
break
195-
yield chunk
196-
197214
url = url_with_trailing_slash + filename
215+
216+
headers = {"Content-Type": content_type}
217+
if prediction_id is not None:
218+
headers["X-Prediction-ID"] = prediction_id
219+
198220
# this is a somewhat unfortunate hack, but it works
199221
# and is critical for upload training/quantization outputs
200222
# if we get multipart uploads working or a separate API route
@@ -204,29 +226,36 @@ async def chunk_file_reader() -> AsyncIterator[bytes]:
204226
resp1 = await self.file_client.put(
205227
url,
206228
content=b"",
207-
headers={"Content-Type": content_type},
229+
headers=headers,
208230
follow_redirects=False,
209231
)
210232
if resp1.status_code == 307 and resp1.headers["Location"]:
211233
log.info("got file upload redirect from api")
212234
url = resp1.headers["Location"]
235+
213236
log.info("doing real upload to %s", url)
214237
resp = await self.file_client.put(
215238
url,
216-
content=chunk_file_reader(),
217-
headers={"Content-Type": content_type},
239+
content=ChunkFileReader(fh),
240+
headers=headers,
218241
)
219242
# TODO: if file size is >1MB, show upload throughput
220243
resp.raise_for_status()
221244

222-
# strip any signing gubbins from the URL
223-
final_url = urlparse(str(resp.url))._replace(query="").geturl()
245+
# Try to extract the final asset URL from the `Location` header
246+
# otherwise fallback to the URL of the final request.
247+
final_url = str(resp.url)
248+
if "location" in resp.headers:
249+
final_url = resp.headers.get("location")
224250

225-
return final_url
251+
# strip any signing gubbins from the URL
252+
return urlparse(final_url)._replace(query="").geturl()
226253

227254
# this previously lived in json.upload_files, but it's clearer here
228255
# this is a great pattern that should be adopted for input files
229-
async def upload_files(self, obj: Any, url: Optional[str]) -> Any:
256+
async def upload_files(
257+
self, obj: Any, *, url: Optional[str], prediction_id: Optional[str]
258+
) -> Any:
230259
"""
231260
Iterates through an object from make_encodeable and uploads any files.
232261
When a file is encountered, it will be passed to upload_file. Any paths will be opened and converted to files.
@@ -238,15 +267,21 @@ async def upload_files(self, obj: Any, url: Optional[str]) -> Any:
238267
# TODO: upload concurrently
239268
if isinstance(obj, dict):
240269
return {
241-
key: await self.upload_files(value, url) for key, value in obj.items()
270+
key: await self.upload_files(
271+
value, url=url, prediction_id=prediction_id
272+
)
273+
for key, value in obj.items()
242274
}
243275
if isinstance(obj, list):
244-
return [await self.upload_files(value, url) for value in obj]
276+
return [
277+
await self.upload_files(value, url=url, prediction_id=prediction_id)
278+
for value in obj
279+
]
245280
if isinstance(obj, Path):
246281
with obj.open("rb") as f:
247-
return await self.upload_file(f, url)
282+
return await self.upload_file(f, url=url, prediction_id=prediction_id)
248283
if isinstance(obj, io.IOBase):
249-
return await self.upload_file(obj, url)
284+
return await self.upload_file(obj, url=url, prediction_id=prediction_id)
250285
return obj
251286

252287
# inputs

Diff for: python/cog/server/runner.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,9 @@ async def _send_webhook(self, event: schema.WebhookEvent) -> None:
288288
async def _upload_files(self, output: Any) -> Any:
289289
try:
290290
# TODO: clean up output files
291-
return await self._client_manager.upload_files(output, self._upload_url)
291+
return await self._client_manager.upload_files(
292+
output, url=self._upload_url, prediction_id=self.p.id
293+
)
292294
except Exception as error:
293295
# If something goes wrong uploading a file, it's irrecoverable.
294296
# The re-raised exception will be caught and cause the prediction

Diff for: python/tests/server/test_clients.py

+95-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
import httpx
12
import os
3+
import responses
24
import tempfile
35

46
import cog
@@ -7,12 +9,103 @@
79

810

911
@pytest.mark.asyncio
10-
async def test_upload_files():
12+
async def test_upload_files_without_url():
1113
client_manager = ClientManager()
1214
temp_dir = tempfile.mkdtemp()
1315
temp_path = os.path.join(temp_dir, "my_file.txt")
1416
with open(temp_path, "w") as fh:
1517
fh.write("file content")
1618
obj = {"path": cog.Path(temp_path)}
17-
result = await client_manager.upload_files(obj, None)
19+
result = await client_manager.upload_files(obj, url=None, prediction_id=None)
1820
assert result == {"path": "data:text/plain;base64,ZmlsZSBjb250ZW50"}
21+
22+
23+
@pytest.mark.asyncio
24+
@pytest.mark.respx(base_url="https://example.com")
25+
async def test_upload_files_with_url(respx_mock):
26+
uploader = respx_mock.put("/bucket/my_file.txt").mock(
27+
return_value=httpx.Response(201)
28+
)
29+
30+
client_manager = ClientManager()
31+
temp_dir = tempfile.mkdtemp()
32+
temp_path = os.path.join(temp_dir, "my_file.txt")
33+
with open(temp_path, "w") as fh:
34+
fh.write("file content")
35+
36+
obj = {"path": cog.Path(temp_path)}
37+
result = await client_manager.upload_files(
38+
obj, url="https://example.com/bucket", prediction_id=None
39+
)
40+
assert result == {"path": "https://example.com/bucket/my_file.txt"}
41+
42+
assert uploader.call_count == 1
43+
44+
45+
@pytest.mark.asyncio
46+
@pytest.mark.respx(base_url="https://example.com")
47+
async def test_upload_files_with_prediction_id(respx_mock):
48+
uploader = respx_mock.put(
49+
"/bucket/my_file.txt", headers={"x-prediction-id": "p123"}
50+
).mock(return_value=httpx.Response(201))
51+
52+
client_manager = ClientManager()
53+
temp_dir = tempfile.mkdtemp()
54+
temp_path = os.path.join(temp_dir, "my_file.txt")
55+
with open(temp_path, "w") as fh:
56+
fh.write("file content")
57+
58+
obj = {"path": cog.Path(temp_path)}
59+
result = await client_manager.upload_files(
60+
obj, url="https://example.com/bucket", prediction_id="p123"
61+
)
62+
assert result == {"path": "https://example.com/bucket/my_file.txt"}
63+
64+
assert uploader.call_count == 1
65+
66+
67+
@pytest.mark.asyncio
68+
@pytest.mark.respx(base_url="https://example.com")
69+
async def test_upload_files_with_location_header(respx_mock):
70+
uploader = respx_mock.put("/bucket/my_file.txt").mock(
71+
return_value=httpx.Response(
72+
201, headers={"Location": "https://cdn.example.com/bucket/my_file.txt"}
73+
)
74+
)
75+
76+
client_manager = ClientManager()
77+
temp_dir = tempfile.mkdtemp()
78+
temp_path = os.path.join(temp_dir, "my_file.txt")
79+
with open(temp_path, "w") as fh:
80+
fh.write("file content")
81+
82+
obj = {"path": cog.Path(temp_path)}
83+
result = await client_manager.upload_files(
84+
obj, url="https://example.com/bucket", prediction_id=None
85+
)
86+
assert result == {"path": "https://cdn.example.com/bucket/my_file.txt"}
87+
88+
assert uploader.call_count == 1
89+
90+
91+
@pytest.mark.asyncio
92+
@pytest.mark.respx(base_url="https://example.com")
93+
async def test_upload_files_with_retry(respx_mock):
94+
uploader = respx_mock.put("/bucket/my_file.txt").mock(
95+
return_value=httpx.Response(502)
96+
)
97+
98+
client_manager = ClientManager()
99+
temp_dir = tempfile.mkdtemp()
100+
temp_path = os.path.join(temp_dir, "my_file.txt")
101+
with open(temp_path, "w") as fh:
102+
fh.write("file content")
103+
104+
obj = {"path": cog.Path(temp_path)}
105+
with pytest.raises(httpx.HTTPStatusError):
106+
result = await client_manager.upload_files(
107+
obj, url="https://example.com/bucket", prediction_id=None
108+
)
109+
110+
assert result == {"path": "https://cdn.example.com/bucket/my_file.txt"}
111+
assert uploader.call_count == 3

0 commit comments

Comments
 (0)