Skip to content

Commit 403bbb6

Browse files
committed
fix: share download response handling
1 parent 59f79cd commit 403bbb6

2 files changed

Lines changed: 146 additions & 135 deletions

File tree

astrbot/core/utils/io.py

Lines changed: 112 additions & 132 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,101 @@ async def _emit_download_progress(progress_callback, payload: dict) -> None:
178178
await result
179179

180180

181+
class DownloadFileHTTPError(RuntimeError):
182+
"""Raised when a file download returns an unsuccessful HTTP status."""
183+
184+
185+
def _raise_for_download_status(resp, url: str) -> None:
186+
if resp.status == 200:
187+
return
188+
logger.error(
189+
"Failed to download file from %s. HTTP status code: %s",
190+
_safe_url_for_log(url),
191+
resp.status,
192+
)
193+
raise DownloadFileHTTPError(
194+
"Failed to download file from "
195+
f"{_safe_url_for_log(url)}. HTTP status code: {resp.status}"
196+
)
197+
198+
199+
async def _download_response_to_file(
200+
resp,
201+
file_obj,
202+
url: str,
203+
show_progress: bool,
204+
progress_callback,
205+
show_downloading_label: bool = True,
206+
) -> None:
207+
"""Write a successful download response to a local file.
208+
209+
Args:
210+
resp: aiohttp response object to read from.
211+
file_obj: Open writable binary file object.
212+
url: Source URL used for progress events and sanitized errors.
213+
show_progress: Whether to print progress to stdout.
214+
progress_callback: Optional callback for progress payloads.
215+
show_downloading_label: Whether to use the standard download heading.
216+
217+
"""
218+
219+
total_size = int(resp.headers.get("content-length", 0))
220+
downloaded_size = 0
221+
start_time = time.time()
222+
if show_progress:
223+
if show_downloading_label:
224+
print(
225+
f"Downloading: {_safe_url_for_log(url)} | "
226+
f"Size: {total_size / 1024:.2f} KB"
227+
)
228+
else:
229+
print(f"Size: {total_size / 1024:.2f} KB | URL: {_safe_url_for_log(url)}")
230+
await _emit_download_progress(
231+
progress_callback,
232+
{
233+
"url": url,
234+
"downloaded": 0,
235+
"total": total_size,
236+
"percent": 0,
237+
"speed": 0,
238+
},
239+
)
240+
while True:
241+
chunk = await resp.content.read(8192)
242+
if not chunk:
243+
break
244+
file_obj.write(chunk)
245+
downloaded_size += len(chunk)
246+
elapsed_time = time.time() - start_time if time.time() - start_time > 0 else 1
247+
speed = downloaded_size / 1024 / elapsed_time # KB/s
248+
percent = downloaded_size / total_size if total_size > 0 else 0
249+
await _emit_download_progress(
250+
progress_callback,
251+
{
252+
"url": url,
253+
"downloaded": downloaded_size,
254+
"total": total_size,
255+
"percent": percent,
256+
"speed": speed,
257+
},
258+
)
259+
if show_progress:
260+
print(
261+
f"\rProgress: {percent:.2%} Speed: {speed:.2f} KB/s",
262+
end="",
263+
)
264+
await _emit_download_progress(
265+
progress_callback,
266+
{
267+
"url": url,
268+
"downloaded": downloaded_size,
269+
"total": total_size,
270+
"percent": 1,
271+
"speed": 0,
272+
},
273+
)
274+
275+
181276
async def download_file(
182277
url: str,
183278
path: str,
@@ -209,73 +304,15 @@ async def download_file(
209304
connector=connector,
210305
) as session:
211306
async with session.get(url, timeout=1800) as resp:
212-
if resp.status != 200:
213-
logger.error(
214-
"Failed to download file from %s. HTTP status code: %s",
215-
_safe_url_for_log(url),
216-
resp.status,
217-
)
218-
raise RuntimeError(
219-
"Failed to download file from "
220-
f"{_safe_url_for_log(url)}. HTTP status code: {resp.status}"
221-
)
222-
total_size = int(resp.headers.get("content-length", 0))
223-
downloaded_size = 0
224-
start_time = time.time()
225-
if show_progress:
226-
print(
227-
f"Downloading: {_safe_url_for_log(url)} | "
228-
f"Size: {total_size / 1024:.2f} KB"
229-
)
230-
await _emit_download_progress(
231-
progress_callback,
232-
{
233-
"url": url,
234-
"downloaded": 0,
235-
"total": total_size,
236-
"percent": 0,
237-
"speed": 0,
238-
},
239-
)
307+
_raise_for_download_status(resp, url)
240308
with open(path, "wb") as f:
241-
while True:
242-
chunk = await resp.content.read(8192)
243-
if not chunk:
244-
break
245-
f.write(chunk)
246-
downloaded_size += len(chunk)
247-
elapsed_time = (
248-
time.time() - start_time
249-
if time.time() - start_time > 0
250-
else 1
251-
)
252-
speed = downloaded_size / 1024 / elapsed_time # KB/s
253-
percent = downloaded_size / total_size if total_size > 0 else 0
254-
await _emit_download_progress(
255-
progress_callback,
256-
{
257-
"url": url,
258-
"downloaded": downloaded_size,
259-
"total": total_size,
260-
"percent": percent,
261-
"speed": speed,
262-
},
263-
)
264-
if show_progress:
265-
print(
266-
f"\rProgress: {percent:.2%} Speed: {speed:.2f} KB/s",
267-
end="",
268-
)
269-
await _emit_download_progress(
270-
progress_callback,
271-
{
272-
"url": url,
273-
"downloaded": downloaded_size,
274-
"total": total_size,
275-
"percent": 1,
276-
"speed": 0,
277-
},
278-
)
309+
await _download_response_to_file(
310+
resp,
311+
f,
312+
url,
313+
show_progress,
314+
progress_callback,
315+
)
279316
except (aiohttp.ClientConnectorSSLError, aiohttp.ClientConnectorCertificateError):
280317
if not allow_insecure_ssl_fallback:
281318
raise
@@ -295,73 +332,16 @@ async def download_file(
295332
ssl_context.verify_mode = ssl.CERT_NONE
296333
async with aiohttp.ClientSession() as session:
297334
async with session.get(url, ssl=ssl_context, timeout=120) as resp:
298-
if resp.status != 200:
299-
logger.error(
300-
"Failed to download file from %s. HTTP status code: %s",
301-
_safe_url_for_log(url),
302-
resp.status,
303-
)
304-
raise RuntimeError(
305-
"Failed to download file from "
306-
f"{_safe_url_for_log(url)}. HTTP status code: {resp.status}"
307-
)
308-
total_size = int(resp.headers.get("content-length", 0))
309-
downloaded_size = 0
310-
start_time = time.time()
311-
if show_progress:
312-
print(
313-
f"Size: {total_size / 1024:.2f} KB | "
314-
f"URL: {_safe_url_for_log(url)}"
315-
)
316-
await _emit_download_progress(
317-
progress_callback,
318-
{
319-
"url": url,
320-
"downloaded": 0,
321-
"total": total_size,
322-
"percent": 0,
323-
"speed": 0,
324-
},
325-
)
335+
_raise_for_download_status(resp, url)
326336
with open(path, "wb") as f:
327-
while True:
328-
chunk = await resp.content.read(8192)
329-
if not chunk:
330-
break
331-
f.write(chunk)
332-
downloaded_size += len(chunk)
333-
elapsed_time = (
334-
time.time() - start_time
335-
if time.time() - start_time > 0
336-
else 1
337-
)
338-
speed = downloaded_size / 1024 / elapsed_time # KB/s
339-
percent = downloaded_size / total_size if total_size > 0 else 0
340-
await _emit_download_progress(
341-
progress_callback,
342-
{
343-
"url": url,
344-
"downloaded": downloaded_size,
345-
"total": total_size,
346-
"percent": percent,
347-
"speed": speed,
348-
},
349-
)
350-
if show_progress:
351-
print(
352-
f"\rProgress: {percent:.2%} Speed: {speed:.2f} KB/s",
353-
end="",
354-
)
355-
await _emit_download_progress(
356-
progress_callback,
357-
{
358-
"url": url,
359-
"downloaded": downloaded_size,
360-
"total": total_size,
361-
"percent": 1,
362-
"speed": 0,
363-
},
364-
)
337+
await _download_response_to_file(
338+
resp,
339+
f,
340+
url,
341+
show_progress,
342+
progress_callback,
343+
show_downloading_label=False,
344+
)
365345
if show_progress:
366346
print()
367347

tests/unit/test_io_download_file.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ async def __aexit__(self, exc_type, exc, tb):
2727

2828

2929
class _FakeSession:
30-
def __init__(self, response: _FakeResponse):
30+
def __init__(self, response: _FakeResponse | Exception):
3131
self._response = response
3232

3333
async def __aenter__(self):
@@ -37,15 +37,21 @@ async def __aexit__(self, exc_type, exc, tb):
3737
return False
3838

3939
def get(self, *_args, **_kwargs):
40+
if isinstance(self._response, Exception):
41+
raise self._response
4042
return self._response
4143

4244

4345
def _patch_download_session(monkeypatch, response: _FakeResponse):
46+
_patch_download_sessions(monkeypatch, [response])
47+
48+
49+
def _patch_download_sessions(monkeypatch, responses: list[_FakeResponse | Exception]):
4450
monkeypatch.setattr(io.aiohttp, "TCPConnector", lambda **_kwargs: object())
4551
monkeypatch.setattr(
4652
io.aiohttp,
4753
"ClientSession",
48-
lambda **_kwargs: _FakeSession(response),
54+
lambda **_kwargs: _FakeSession(responses.pop(0)),
4955
)
5056

5157

@@ -57,7 +63,32 @@ async def test_download_file_rejects_non_200_response(monkeypatch, tmp_path):
5763
_FakeResponse(status=404, chunks=[b"not found"]),
5864
)
5965

60-
with pytest.raises(RuntimeError, match="HTTP status code: 404"):
66+
with pytest.raises(io.DownloadFileHTTPError, match="HTTP status code: 404"):
67+
await io.download_file("https://example.test/missing", str(target_path))
68+
69+
assert not target_path.exists()
70+
71+
72+
@pytest.mark.asyncio
73+
async def test_download_file_rejects_non_200_response_after_ssl_fallback(
74+
monkeypatch,
75+
tmp_path,
76+
):
77+
class FakeSSLError(Exception):
78+
pass
79+
80+
target_path = tmp_path / "missing.bin"
81+
_patch_download_sessions(
82+
monkeypatch,
83+
[
84+
FakeSSLError(),
85+
_FakeResponse(status=404, chunks=[b"not found"]),
86+
],
87+
)
88+
monkeypatch.setattr(io.aiohttp, "ClientConnectorSSLError", FakeSSLError)
89+
monkeypatch.setattr(io.aiohttp, "ClientConnectorCertificateError", FakeSSLError)
90+
91+
with pytest.raises(io.DownloadFileHTTPError, match="HTTP status code: 404"):
6192
await io.download_file("https://example.test/missing", str(target_path))
6293

6394
assert not target_path.exists()

0 commit comments

Comments
 (0)