Skip to content

Commit 70a52ea

Browse files
authored
fix: reject non-200 download responses (#9085)
* fix: reject non-200 download responses * fix: share download response handling
1 parent 029e9c8 commit 70a52ea

2 files changed

Lines changed: 219 additions & 118 deletions

File tree

astrbot/core/utils/io.py

Lines changed: 112 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,101 @@ async def _emit_download_progress(progress_callback, payload: dict) -> None:
178178
await result
179179

180180

181+
class DownloadFileHTTPError(RuntimeError):
182+
"""Raised when a file download returns an unsuccessful HTTP status."""
183+
184+
185+
def _raise_for_download_status(resp, url: str) -> None:
186+
if resp.status == 200:
187+
return
188+
logger.error(
189+
"Failed to download file from %s. HTTP status code: %s",
190+
_safe_url_for_log(url),
191+
resp.status,
192+
)
193+
raise DownloadFileHTTPError(
194+
"Failed to download file from "
195+
f"{_safe_url_for_log(url)}. HTTP status code: {resp.status}"
196+
)
197+
198+
199+
async def _download_response_to_file(
200+
resp,
201+
file_obj,
202+
url: str,
203+
show_progress: bool,
204+
progress_callback,
205+
show_downloading_label: bool = True,
206+
) -> None:
207+
"""Write a successful download response to a local file.
208+
209+
Args:
210+
resp: aiohttp response object to read from.
211+
file_obj: Open writable binary file object.
212+
url: Source URL used for progress events and sanitized errors.
213+
show_progress: Whether to print progress to stdout.
214+
progress_callback: Optional callback for progress payloads.
215+
show_downloading_label: Whether to use the standard download heading.
216+
217+
"""
218+
219+
total_size = int(resp.headers.get("content-length", 0))
220+
downloaded_size = 0
221+
start_time = time.time()
222+
if show_progress:
223+
if show_downloading_label:
224+
print(
225+
f"Downloading: {_safe_url_for_log(url)} | "
226+
f"Size: {total_size / 1024:.2f} KB"
227+
)
228+
else:
229+
print(f"Size: {total_size / 1024:.2f} KB | URL: {_safe_url_for_log(url)}")
230+
await _emit_download_progress(
231+
progress_callback,
232+
{
233+
"url": url,
234+
"downloaded": 0,
235+
"total": total_size,
236+
"percent": 0,
237+
"speed": 0,
238+
},
239+
)
240+
while True:
241+
chunk = await resp.content.read(8192)
242+
if not chunk:
243+
break
244+
file_obj.write(chunk)
245+
downloaded_size += len(chunk)
246+
elapsed_time = time.time() - start_time if time.time() - start_time > 0 else 1
247+
speed = downloaded_size / 1024 / elapsed_time # KB/s
248+
percent = downloaded_size / total_size if total_size > 0 else 0
249+
await _emit_download_progress(
250+
progress_callback,
251+
{
252+
"url": url,
253+
"downloaded": downloaded_size,
254+
"total": total_size,
255+
"percent": percent,
256+
"speed": speed,
257+
},
258+
)
259+
if show_progress:
260+
print(
261+
f"\rProgress: {percent:.2%} Speed: {speed:.2f} KB/s",
262+
end="",
263+
)
264+
await _emit_download_progress(
265+
progress_callback,
266+
{
267+
"url": url,
268+
"downloaded": downloaded_size,
269+
"total": total_size,
270+
"percent": 1,
271+
"speed": 0,
272+
},
273+
)
274+
275+
181276
async def download_file(
182277
url: str,
183278
path: str,
@@ -209,69 +304,15 @@ async def download_file(
209304
connector=connector,
210305
) as session:
211306
async with session.get(url, timeout=1800) as resp:
212-
if resp.status != 200:
213-
logger.error(
214-
"Failed to download file from %s. HTTP status code: %s",
215-
_safe_url_for_log(url),
216-
resp.status,
217-
)
218-
total_size = int(resp.headers.get("content-length", 0))
219-
downloaded_size = 0
220-
start_time = time.time()
221-
if show_progress:
222-
print(
223-
f"Downloading: {_safe_url_for_log(url)} | "
224-
f"Size: {total_size / 1024:.2f} KB"
225-
)
226-
await _emit_download_progress(
227-
progress_callback,
228-
{
229-
"url": url,
230-
"downloaded": 0,
231-
"total": total_size,
232-
"percent": 0,
233-
"speed": 0,
234-
},
235-
)
307+
_raise_for_download_status(resp, url)
236308
with open(path, "wb") as f:
237-
while True:
238-
chunk = await resp.content.read(8192)
239-
if not chunk:
240-
break
241-
f.write(chunk)
242-
downloaded_size += len(chunk)
243-
elapsed_time = (
244-
time.time() - start_time
245-
if time.time() - start_time > 0
246-
else 1
247-
)
248-
speed = downloaded_size / 1024 / elapsed_time # KB/s
249-
percent = downloaded_size / total_size if total_size > 0 else 0
250-
await _emit_download_progress(
251-
progress_callback,
252-
{
253-
"url": url,
254-
"downloaded": downloaded_size,
255-
"total": total_size,
256-
"percent": percent,
257-
"speed": speed,
258-
},
259-
)
260-
if show_progress:
261-
print(
262-
f"\rProgress: {percent:.2%} Speed: {speed:.2f} KB/s",
263-
end="",
264-
)
265-
await _emit_download_progress(
266-
progress_callback,
267-
{
268-
"url": url,
269-
"downloaded": downloaded_size,
270-
"total": total_size,
271-
"percent": 1,
272-
"speed": 0,
273-
},
274-
)
309+
await _download_response_to_file(
310+
resp,
311+
f,
312+
url,
313+
show_progress,
314+
progress_callback,
315+
)
275316
except (aiohttp.ClientConnectorSSLError, aiohttp.ClientConnectorCertificateError):
276317
if not allow_insecure_ssl_fallback:
277318
raise
@@ -291,63 +332,16 @@ async def download_file(
291332
ssl_context.verify_mode = ssl.CERT_NONE
292333
async with aiohttp.ClientSession() as session:
293334
async with session.get(url, ssl=ssl_context, timeout=120) as resp:
294-
total_size = int(resp.headers.get("content-length", 0))
295-
downloaded_size = 0
296-
start_time = time.time()
297-
if show_progress:
298-
print(
299-
f"Size: {total_size / 1024:.2f} KB | "
300-
f"URL: {_safe_url_for_log(url)}"
301-
)
302-
await _emit_download_progress(
303-
progress_callback,
304-
{
305-
"url": url,
306-
"downloaded": 0,
307-
"total": total_size,
308-
"percent": 0,
309-
"speed": 0,
310-
},
311-
)
335+
_raise_for_download_status(resp, url)
312336
with open(path, "wb") as f:
313-
while True:
314-
chunk = await resp.content.read(8192)
315-
if not chunk:
316-
break
317-
f.write(chunk)
318-
downloaded_size += len(chunk)
319-
elapsed_time = (
320-
time.time() - start_time
321-
if time.time() - start_time > 0
322-
else 1
323-
)
324-
speed = downloaded_size / 1024 / elapsed_time # KB/s
325-
percent = downloaded_size / total_size if total_size > 0 else 0
326-
await _emit_download_progress(
327-
progress_callback,
328-
{
329-
"url": url,
330-
"downloaded": downloaded_size,
331-
"total": total_size,
332-
"percent": percent,
333-
"speed": speed,
334-
},
335-
)
336-
if show_progress:
337-
print(
338-
f"\rProgress: {percent:.2%} Speed: {speed:.2f} KB/s",
339-
end="",
340-
)
341-
await _emit_download_progress(
342-
progress_callback,
343-
{
344-
"url": url,
345-
"downloaded": downloaded_size,
346-
"total": total_size,
347-
"percent": 1,
348-
"speed": 0,
349-
},
350-
)
337+
await _download_response_to_file(
338+
resp,
339+
f,
340+
url,
341+
show_progress,
342+
progress_callback,
343+
show_downloading_label=False,
344+
)
351345
if show_progress:
352346
print()
353347

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
import pytest
2+
3+
from astrbot.core.utils import io
4+
5+
6+
class _FakeContent:
7+
def __init__(self, chunks: list[bytes]):
8+
self._chunks = chunks
9+
10+
async def read(self, _size: int) -> bytes:
11+
if self._chunks:
12+
return self._chunks.pop(0)
13+
return b""
14+
15+
16+
class _FakeResponse:
17+
def __init__(self, *, status: int, chunks: list[bytes]):
18+
self.status = status
19+
self.headers = {"content-length": str(sum(len(chunk) for chunk in chunks))}
20+
self.content = _FakeContent(chunks)
21+
22+
async def __aenter__(self):
23+
return self
24+
25+
async def __aexit__(self, exc_type, exc, tb):
26+
return False
27+
28+
29+
class _FakeSession:
30+
def __init__(self, response: _FakeResponse | Exception):
31+
self._response = response
32+
33+
async def __aenter__(self):
34+
return self
35+
36+
async def __aexit__(self, exc_type, exc, tb):
37+
return False
38+
39+
def get(self, *_args, **_kwargs):
40+
if isinstance(self._response, Exception):
41+
raise self._response
42+
return self._response
43+
44+
45+
def _patch_download_session(monkeypatch, response: _FakeResponse):
46+
_patch_download_sessions(monkeypatch, [response])
47+
48+
49+
def _patch_download_sessions(monkeypatch, responses: list[_FakeResponse | Exception]):
50+
monkeypatch.setattr(io.aiohttp, "TCPConnector", lambda **_kwargs: object())
51+
monkeypatch.setattr(
52+
io.aiohttp,
53+
"ClientSession",
54+
lambda **_kwargs: _FakeSession(responses.pop(0)),
55+
)
56+
57+
58+
@pytest.mark.asyncio
59+
async def test_download_file_rejects_non_200_response(monkeypatch, tmp_path):
60+
target_path = tmp_path / "missing.bin"
61+
_patch_download_session(
62+
monkeypatch,
63+
_FakeResponse(status=404, chunks=[b"not found"]),
64+
)
65+
66+
with pytest.raises(io.DownloadFileHTTPError, match="HTTP status code: 404"):
67+
await io.download_file("https://example.test/missing", str(target_path))
68+
69+
assert not target_path.exists()
70+
71+
72+
@pytest.mark.asyncio
73+
async def test_download_file_rejects_non_200_response_after_ssl_fallback(
74+
monkeypatch,
75+
tmp_path,
76+
):
77+
class FakeSSLError(Exception):
78+
pass
79+
80+
target_path = tmp_path / "missing.bin"
81+
_patch_download_sessions(
82+
monkeypatch,
83+
[
84+
FakeSSLError(),
85+
_FakeResponse(status=404, chunks=[b"not found"]),
86+
],
87+
)
88+
monkeypatch.setattr(io.aiohttp, "ClientConnectorSSLError", FakeSSLError)
89+
monkeypatch.setattr(io.aiohttp, "ClientConnectorCertificateError", FakeSSLError)
90+
91+
with pytest.raises(io.DownloadFileHTTPError, match="HTTP status code: 404"):
92+
await io.download_file("https://example.test/missing", str(target_path))
93+
94+
assert not target_path.exists()
95+
96+
97+
@pytest.mark.asyncio
98+
async def test_download_file_writes_successful_response(monkeypatch, tmp_path):
99+
target_path = tmp_path / "ok.bin"
100+
_patch_download_session(
101+
monkeypatch,
102+
_FakeResponse(status=200, chunks=[b"hello", b" world"]),
103+
)
104+
105+
await io.download_file("https://example.test/ok.bin", str(target_path))
106+
107+
assert target_path.read_bytes() == b"hello world"

0 commit comments

Comments
 (0)