Skip to content

Commit 59f79cd

Browse files
committed
fix: reject non-200 download responses
1 parent 6067a70 commit 59f79cd

2 files changed

Lines changed: 90 additions & 0 deletions

File tree

astrbot/core/utils/io.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,10 @@ async def download_file(
215215
_safe_url_for_log(url),
216216
resp.status,
217217
)
218+
raise RuntimeError(
219+
"Failed to download file from "
220+
f"{_safe_url_for_log(url)}. HTTP status code: {resp.status}"
221+
)
218222
total_size = int(resp.headers.get("content-length", 0))
219223
downloaded_size = 0
220224
start_time = time.time()
@@ -291,6 +295,16 @@ async def download_file(
291295
ssl_context.verify_mode = ssl.CERT_NONE
292296
async with aiohttp.ClientSession() as session:
293297
async with session.get(url, ssl=ssl_context, timeout=120) as resp:
298+
if resp.status != 200:
299+
logger.error(
300+
"Failed to download file from %s. HTTP status code: %s",
301+
_safe_url_for_log(url),
302+
resp.status,
303+
)
304+
raise RuntimeError(
305+
"Failed to download file from "
306+
f"{_safe_url_for_log(url)}. HTTP status code: {resp.status}"
307+
)
294308
total_size = int(resp.headers.get("content-length", 0))
295309
downloaded_size = 0
296310
start_time = time.time()
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
import pytest
2+
3+
from astrbot.core.utils import io
4+
5+
6+
class _FakeContent:
7+
def __init__(self, chunks: list[bytes]):
8+
self._chunks = chunks
9+
10+
async def read(self, _size: int) -> bytes:
11+
if self._chunks:
12+
return self._chunks.pop(0)
13+
return b""
14+
15+
16+
class _FakeResponse:
17+
def __init__(self, *, status: int, chunks: list[bytes]):
18+
self.status = status
19+
self.headers = {"content-length": str(sum(len(chunk) for chunk in chunks))}
20+
self.content = _FakeContent(chunks)
21+
22+
async def __aenter__(self):
23+
return self
24+
25+
async def __aexit__(self, exc_type, exc, tb):
26+
return False
27+
28+
29+
class _FakeSession:
30+
def __init__(self, response: _FakeResponse):
31+
self._response = response
32+
33+
async def __aenter__(self):
34+
return self
35+
36+
async def __aexit__(self, exc_type, exc, tb):
37+
return False
38+
39+
def get(self, *_args, **_kwargs):
40+
return self._response
41+
42+
43+
def _patch_download_session(monkeypatch, response: _FakeResponse):
44+
monkeypatch.setattr(io.aiohttp, "TCPConnector", lambda **_kwargs: object())
45+
monkeypatch.setattr(
46+
io.aiohttp,
47+
"ClientSession",
48+
lambda **_kwargs: _FakeSession(response),
49+
)
50+
51+
52+
@pytest.mark.asyncio
53+
async def test_download_file_rejects_non_200_response(monkeypatch, tmp_path):
54+
target_path = tmp_path / "missing.bin"
55+
_patch_download_session(
56+
monkeypatch,
57+
_FakeResponse(status=404, chunks=[b"not found"]),
58+
)
59+
60+
with pytest.raises(RuntimeError, match="HTTP status code: 404"):
61+
await io.download_file("https://example.test/missing", str(target_path))
62+
63+
assert not target_path.exists()
64+
65+
66+
@pytest.mark.asyncio
67+
async def test_download_file_writes_successful_response(monkeypatch, tmp_path):
68+
target_path = tmp_path / "ok.bin"
69+
_patch_download_session(
70+
monkeypatch,
71+
_FakeResponse(status=200, chunks=[b"hello", b" world"]),
72+
)
73+
74+
await io.download_file("https://example.test/ok.bin", str(target_path))
75+
76+
assert target_path.read_bytes() == b"hello world"

0 commit comments

Comments
 (0)