Skip to content

Commit ea3a4e9

Browse files
committed
stream output
1 parent 4aaba0d commit ea3a4e9

File tree

2 files changed

+118
-4
lines changed

2 files changed

+118
-4
lines changed

elsevier_coordinate_extraction/download/api.py

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22

33
from __future__ import annotations
44

5+
import inspect
56
import re
6-
from collections.abc import Mapping, Sequence
7-
from typing import Any
7+
from collections.abc import Awaitable, Mapping, Sequence
8+
from typing import Any, Protocol
89

910
import httpx
1011
from lxml import etree
@@ -43,20 +44,39 @@
4344
_CDN_BASE = "https://ars.els-cdn.com/content/image"
4445

4546

47+
class ProgressCallback(
48+
Protocol,
49+
):
50+
"""Callback invoked after each record is processed."""
51+
52+
def __call__(
53+
self,
54+
record: Mapping[str, str],
55+
article: ArticleContent | None,
56+
error: BaseException | None,
57+
) -> Awaitable[None] | None: ...
58+
59+
4660
async def download_articles(
4761
records: Sequence[Mapping[str, str]],
4862
*,
4963
client: ScienceDirectClient | None = None,
5064
cache: Any | None = None,
5165
cache_namespace: str = "articles",
5266
settings: Settings | None = None,
67+
progress_callback: ProgressCallback | None = None,
5368
) -> list[ArticleContent]:
5469
"""Download ScienceDirect articles identified by DOI and/or PubMed ID records.
5570
5671
Each record in ``records`` should contain at least one of the keys ``"doi"`` or ``"pmid"``.
5772
For every record, the downloader first attempts to retrieve the FULL text using the DOI
5873
(when present); if that fails, it retries with the PubMed ID. A successful download using
5974
either identifier stops further attempts for that record.
75+
76+
When ``progress_callback`` is provided it will be invoked after each record finishes processing.
77+
The callback receives the original record, the downloaded ``ArticleContent`` when successful
78+
(``None`` when no payload is returned), and the exception raised while processing
79+
(``None`` on success). Callbacks may be synchronous or async functions.
6080
"""
6181
if not records:
6282
return []
@@ -65,23 +85,39 @@ async def download_articles(
6585
owns_client = client is None
6686
sci_client = client or ScienceDirectClient(cfg)
6787

88+
async def _emit_progress(
89+
record: Mapping[str, str],
90+
article: ArticleContent | None,
91+
error: BaseException | None,
92+
) -> None:
93+
if progress_callback is None:
94+
return
95+
result = progress_callback(record, article, error)
96+
if inspect.isawaitable(result):
97+
await result
98+
6899
async def _runner() -> list[ArticleContent]:
69100
results: list[ArticleContent] = []
70101
for record in records:
102+
article: ArticleContent | None = None
71103
try:
72104
article = await _download_record(
73105
record=record,
74106
client=sci_client,
75107
cache=cache,
76108
cache_namespace=cache_namespace,
77109
)
78-
except httpx.HTTPError:
110+
except httpx.HTTPError as exc:
111+
await _emit_progress(record, None, exc)
79112
raise
80-
except Exception:
113+
except Exception as exc:
114+
await _emit_progress(record, None, exc)
81115
continue
82116
if article is None:
117+
await _emit_progress(record, None, None)
83118
continue
84119
results.append(article)
120+
await _emit_progress(record, article, None)
85121
return results
86122

87123
if owns_client:

tests/download/test_api.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from __future__ import annotations
44

5+
import asyncio
56
from collections.abc import Sequence
67
from pathlib import Path
78

@@ -221,3 +222,80 @@ async def request(
221222
assert article.metadata.get("rate_limit_limit") == 100
222223
assert article.metadata.get("rate_limit_remaining") == 99
223224
assert article.metadata.get("rate_limit_reset_epoch") == 1234567891.0
225+
226+
227+
@pytest.mark.asyncio()
228+
async def test_download_progress_callback_invoked_for_each_record(test_dois: Sequence[str]) -> None:
229+
"""Progress callback fires for every successfully downloaded record."""
230+
231+
cfg = _test_settings()
232+
records = [{"doi": test_dois[0]}, {"doi": test_dois[1]}]
233+
234+
async def handler(request: httpx.Request) -> httpx.Response:
235+
doi = request.url.path.rsplit("/", 1)[-1]
236+
payload = f"""
237+
<article xmlns="http://www.elsevier.com/xml/svapi/article/dtd" xmlns:ce="http://www.elsevier.com/xml/common/dtd">
238+
<item-info>
239+
<doi>{doi}</doi>
240+
<pii>S105381192400679X</pii>
241+
</item-info>
242+
<ce:body><ce:para>{doi}</ce:para></ce:body>
243+
</article>
244+
""".encode("utf-8")
245+
return httpx.Response(
246+
200,
247+
content=payload,
248+
headers={"content-type": "application/xml"},
249+
request=request,
250+
)
251+
252+
transport = httpx.MockTransport(handler)
253+
progress_calls: list[tuple[dict[str, str], ArticleContent | None, BaseException | None]] = []
254+
255+
def progress_cb(
256+
record: dict[str, str],
257+
article: ArticleContent | None,
258+
error: BaseException | None,
259+
) -> None:
260+
progress_calls.append((record, article, error))
261+
262+
async with ScienceDirectClient(cfg, transport=transport) as client:
263+
articles = await download_articles(records, client=client, progress_callback=progress_cb)
264+
265+
assert len(articles) == len(records)
266+
assert len(progress_calls) == len(records)
267+
assert [call[0]["doi"] for call in progress_calls] == [record["doi"] for record in records]
268+
assert all(call[1] is not None for call in progress_calls)
269+
assert all(call[2] is None for call in progress_calls)
270+
271+
272+
@pytest.mark.asyncio()
273+
async def test_download_progress_callback_receives_errors(test_dois: Sequence[str]) -> None:
274+
"""Progress callback should receive exceptions before they propagate."""
275+
276+
cfg = _test_settings()
277+
doi = test_dois[0]
278+
279+
async def handler(request: httpx.Request) -> httpx.Response:
280+
raise httpx.TimeoutException("simulated timeout")
281+
282+
transport = httpx.MockTransport(handler)
283+
progress_calls: list[tuple[dict[str, str], ArticleContent | None, BaseException | None]] = []
284+
285+
async def progress_cb(
286+
record: dict[str, str],
287+
article: ArticleContent | None,
288+
error: BaseException | None,
289+
) -> None:
290+
await asyncio.sleep(0)
291+
progress_calls.append((record, article, error))
292+
293+
with pytest.raises(httpx.TimeoutException):
294+
async with ScienceDirectClient(cfg, transport=transport) as client:
295+
await download_articles([{"doi": doi}], client=client, progress_callback=progress_cb)
296+
297+
assert len(progress_calls) == 1
298+
record, article, error = progress_calls[0]
299+
assert record["doi"] == doi
300+
assert article is None
301+
assert isinstance(error, httpx.TimeoutException)

0 commit comments

Comments
 (0)