Skip to content

Commit 317d29e

Browse files
committed
fix: share download response handling
1 parent 59f79cd commit 317d29e

2 files changed

Lines changed: 141 additions & 133 deletions

File tree

astrbot/core/utils/io.py

Lines changed: 107 additions & 130 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,102 @@ async def _emit_download_progress(progress_callback, payload: dict) -> None:
178178
await result
179179

180180

181+
class DownloadFileHTTPError(RuntimeError):
182+
"""Raised when a file download returns an unsuccessful HTTP status."""
183+
184+
185+
async def _download_response_to_file(
186+
resp,
187+
url: str,
188+
path: str,
189+
show_progress: bool,
190+
progress_callback,
191+
show_downloading_label: bool = True,
192+
) -> None:
193+
"""Write a successful download response to a local file.
194+
195+
Args:
196+
resp: aiohttp response object to read from.
197+
url: Source URL used for progress events and sanitized errors.
198+
path: Local destination path.
199+
show_progress: Whether to print progress to stdout.
200+
progress_callback: Optional callback for progress payloads.
201+
show_downloading_label: Whether to use the standard download heading.
202+
203+
Raises:
204+
DownloadFileHTTPError: If the HTTP response status is not 200.
205+
"""
206+
207+
if resp.status != 200:
208+
logger.error(
209+
"Failed to download file from %s. HTTP status code: %s",
210+
_safe_url_for_log(url),
211+
resp.status,
212+
)
213+
raise DownloadFileHTTPError(
214+
"Failed to download file from "
215+
f"{_safe_url_for_log(url)}. HTTP status code: {resp.status}"
216+
)
217+
total_size = int(resp.headers.get("content-length", 0))
218+
downloaded_size = 0
219+
start_time = time.time()
220+
if show_progress:
221+
if show_downloading_label:
222+
print(
223+
f"Downloading: {_safe_url_for_log(url)} | "
224+
f"Size: {total_size / 1024:.2f} KB"
225+
)
226+
else:
227+
print(f"Size: {total_size / 1024:.2f} KB | URL: {_safe_url_for_log(url)}")
228+
await _emit_download_progress(
229+
progress_callback,
230+
{
231+
"url": url,
232+
"downloaded": 0,
233+
"total": total_size,
234+
"percent": 0,
235+
"speed": 0,
236+
},
237+
)
238+
with open(path, "wb") as f:
239+
while True:
240+
chunk = await resp.content.read(8192)
241+
if not chunk:
242+
break
243+
f.write(chunk)
244+
downloaded_size += len(chunk)
245+
elapsed_time = (
246+
time.time() - start_time if time.time() - start_time > 0 else 1
247+
)
248+
speed = downloaded_size / 1024 / elapsed_time # KB/s
249+
percent = downloaded_size / total_size if total_size > 0 else 0
250+
await _emit_download_progress(
251+
progress_callback,
252+
{
253+
"url": url,
254+
"downloaded": downloaded_size,
255+
"total": total_size,
256+
"percent": percent,
257+
"speed": speed,
258+
},
259+
)
260+
if show_progress:
261+
print(
262+
f"\rProgress: {percent:.2%} Speed: {speed:.2f} KB/s",
263+
end="",
264+
)
265+
await _emit_download_progress(
266+
progress_callback,
267+
{
268+
"url": url,
269+
"downloaded": downloaded_size,
270+
"total": total_size,
271+
"percent": 1,
272+
"speed": 0,
273+
},
274+
)
275+
276+
181277
async def download_file(
182278
url: str,
183279
path: str,
@@ -209,72 +305,12 @@ async def download_file(
209305
connector=connector,
210306
) as session:
211307
async with session.get(url, timeout=1800) as resp:
212-
if resp.status != 200:
213-
logger.error(
214-
"Failed to download file from %s. HTTP status code: %s",
215-
_safe_url_for_log(url),
216-
resp.status,
217-
)
218-
raise RuntimeError(
219-
"Failed to download file from "
220-
f"{_safe_url_for_log(url)}. HTTP status code: {resp.status}"
221-
)
222-
total_size = int(resp.headers.get("content-length", 0))
223-
downloaded_size = 0
224-
start_time = time.time()
225-
if show_progress:
226-
print(
227-
f"Downloading: {_safe_url_for_log(url)} | "
228-
f"Size: {total_size / 1024:.2f} KB"
229-
)
230-
await _emit_download_progress(
308+
await _download_response_to_file(
309+
resp,
310+
url,
311+
path,
312+
show_progress,
231313
progress_callback,
232-
{
233-
"url": url,
234-
"downloaded": 0,
235-
"total": total_size,
236-
"percent": 0,
237-
"speed": 0,
238-
},
239-
)
240-
with open(path, "wb") as f:
241-
while True:
242-
chunk = await resp.content.read(8192)
243-
if not chunk:
244-
break
245-
f.write(chunk)
246-
downloaded_size += len(chunk)
247-
elapsed_time = (
248-
time.time() - start_time
249-
if time.time() - start_time > 0
250-
else 1
251-
)
252-
speed = downloaded_size / 1024 / elapsed_time # KB/s
253-
percent = downloaded_size / total_size if total_size > 0 else 0
254-
await _emit_download_progress(
255-
progress_callback,
256-
{
257-
"url": url,
258-
"downloaded": downloaded_size,
259-
"total": total_size,
260-
"percent": percent,
261-
"speed": speed,
262-
},
263-
)
264-
if show_progress:
265-
print(
266-
f"\rProgress: {percent:.2%} Speed: {speed:.2f} KB/s",
267-
end="",
268-
)
269-
await _emit_download_progress(
270-
progress_callback,
271-
{
272-
"url": url,
273-
"downloaded": downloaded_size,
274-
"total": total_size,
275-
"percent": 1,
276-
"speed": 0,
277-
},
278314
)
279315
except (aiohttp.ClientConnectorSSLError, aiohttp.ClientConnectorCertificateError):
280316
if not allow_insecure_ssl_fallback:
@@ -295,72 +331,13 @@ async def download_file(
295331
ssl_context.verify_mode = ssl.CERT_NONE
296332
async with aiohttp.ClientSession() as session:
297333
async with session.get(url, ssl=ssl_context, timeout=120) as resp:
298-
if resp.status != 200:
299-
logger.error(
300-
"Failed to download file from %s. HTTP status code: %s",
301-
_safe_url_for_log(url),
302-
resp.status,
303-
)
304-
raise RuntimeError(
305-
"Failed to download file from "
306-
f"{_safe_url_for_log(url)}. HTTP status code: {resp.status}"
307-
)
308-
total_size = int(resp.headers.get("content-length", 0))
309-
downloaded_size = 0
310-
start_time = time.time()
311-
if show_progress:
312-
print(
313-
f"Size: {total_size / 1024:.2f} KB | "
314-
f"URL: {_safe_url_for_log(url)}"
315-
)
316-
await _emit_download_progress(
317-
progress_callback,
318-
{
319-
"url": url,
320-
"downloaded": 0,
321-
"total": total_size,
322-
"percent": 0,
323-
"speed": 0,
324-
},
325-
)
326-
with open(path, "wb") as f:
327-
while True:
328-
chunk = await resp.content.read(8192)
329-
if not chunk:
330-
break
331-
f.write(chunk)
332-
downloaded_size += len(chunk)
333-
elapsed_time = (
334-
time.time() - start_time
335-
if time.time() - start_time > 0
336-
else 1
337-
)
338-
speed = downloaded_size / 1024 / elapsed_time # KB/s
339-
percent = downloaded_size / total_size if total_size > 0 else 0
340-
await _emit_download_progress(
341-
progress_callback,
342-
{
343-
"url": url,
344-
"downloaded": downloaded_size,
345-
"total": total_size,
346-
"percent": percent,
347-
"speed": speed,
348-
},
349-
)
350-
if show_progress:
351-
print(
352-
f"\rProgress: {percent:.2%} Speed: {speed:.2f} KB/s",
353-
end="",
354-
)
355-
await _emit_download_progress(
334+
await _download_response_to_file(
335+
resp,
336+
url,
337+
path,
338+
show_progress,
356339
progress_callback,
357-
{
358-
"url": url,
359-
"downloaded": downloaded_size,
360-
"total": total_size,
361-
"percent": 1,
362-
"speed": 0,
363-
},
340+
show_downloading_label=False,
364341
)
365342
if show_progress:
366343
print()

tests/unit/test_io_download_file.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ async def __aexit__(self, exc_type, exc, tb):
2727

2828

2929
class _FakeSession:
30-
def __init__(self, response: _FakeResponse):
30+
def __init__(self, response: _FakeResponse | Exception):
3131
self._response = response
3232

3333
async def __aenter__(self):
@@ -37,15 +37,21 @@ async def __aexit__(self, exc_type, exc, tb):
3737
return False
3838

3939
def get(self, *_args, **_kwargs):
40+
if isinstance(self._response, Exception):
41+
raise self._response
4042
return self._response
4143

4244

4345
def _patch_download_session(monkeypatch, response: _FakeResponse):
46+
_patch_download_sessions(monkeypatch, [response])
47+
48+
49+
def _patch_download_sessions(monkeypatch, responses: list[_FakeResponse | Exception]):
4450
monkeypatch.setattr(io.aiohttp, "TCPConnector", lambda **_kwargs: object())
4551
monkeypatch.setattr(
4652
io.aiohttp,
4753
"ClientSession",
48-
lambda **_kwargs: _FakeSession(response),
54+
lambda **_kwargs: _FakeSession(responses.pop(0)),
4955
)
5056

5157

@@ -57,7 +63,32 @@ async def test_download_file_rejects_non_200_response(monkeypatch, tmp_path):
5763
_FakeResponse(status=404, chunks=[b"not found"]),
5864
)
5965

60-
with pytest.raises(RuntimeError, match="HTTP status code: 404"):
66+
with pytest.raises(io.DownloadFileHTTPError, match="HTTP status code: 404"):
67+
await io.download_file("https://example.test/missing", str(target_path))
68+
69+
assert not target_path.exists()
70+
71+
72+
@pytest.mark.asyncio
73+
async def test_download_file_rejects_non_200_response_after_ssl_fallback(
74+
monkeypatch,
75+
tmp_path,
76+
):
77+
class FakeSSLError(Exception):
78+
pass
79+
80+
target_path = tmp_path / "missing.bin"
81+
_patch_download_sessions(
82+
monkeypatch,
83+
[
84+
FakeSSLError(),
85+
_FakeResponse(status=404, chunks=[b"not found"]),
86+
],
87+
)
88+
monkeypatch.setattr(io.aiohttp, "ClientConnectorSSLError", FakeSSLError)
89+
monkeypatch.setattr(io.aiohttp, "ClientConnectorCertificateError", FakeSSLError)
90+
91+
with pytest.raises(io.DownloadFileHTTPError, match="HTTP status code: 404"):
6192
await io.download_file("https://example.test/missing", str(target_path))
6293

6394
assert not target_path.exists()

0 commit comments

Comments
 (0)