Skip to content

Commit 0cac04f

Browse files
committed
parallelize the coordinate processing
1 parent 92f5464 commit 0cac04f

File tree

5 files changed

+82
-45
lines changed

5 files changed

+82
-45
lines changed

elsevier_coordinate_extraction/extract/coordinates.py

Lines changed: 73 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -3,69 +3,98 @@
33
from __future__ import annotations
44

55
from collections.abc import Iterable
6+
from concurrent.futures import ProcessPoolExecutor, as_completed
67
from typing import Any, Tuple
78

9+
import os
810
import pandas as pd
911
from lxml import etree
1012
from pubget._coordinate_space import _neurosynth_guess_space
1113
from pubget._coordinates import _extract_coordinates_from_table
1214

1315
from elsevier_coordinate_extraction.table_extraction import extract_tables_from_article
1416
from elsevier_coordinate_extraction.types import ArticleContent, TableMetadata
17+
from elsevier_coordinate_extraction import settings
1518

1619

1720
def extract_coordinates(articles: Iterable[ArticleContent]) -> dict:
1821
"""Extract coordinate tables from the supplied articles."""
1922

20-
studies: list[dict[str, Any]] = []
21-
for article in articles:
22-
analyses: list[dict[str, Any]] = []
23-
tables = extract_tables_from_article(article.payload)
24-
if not tables:
25-
tables = _manual_extract_tables(article.payload)
26-
article_text: str | None = None
27-
for metadata, df in tables:
28-
meta_text = _metadata_text(metadata)
29-
coords = _extract_coordinates_from_dataframe(df, meta_text.lower())
30-
if not coords:
31-
continue
32-
header_text = " ".join(str(col).lower() for col in df.columns)
33-
space = _heuristic_space(header_text, meta_text)
34-
if space is None:
35-
if article_text is None:
36-
article_text = _article_text(article.payload)
37-
guessed = _neurosynth_guess_space(article_text)
38-
if guessed != "UNKNOWN":
39-
space = guessed
40-
analysis_metadata = {
41-
"table_label": metadata.label,
42-
"table_id": metadata.identifier,
43-
"raw_table_xml": metadata.raw_xml,
23+
article_list = list(articles)
24+
if not article_list:
25+
return {"studyset": {"studies": []}}
26+
27+
cfg = settings.get_settings()
28+
user_workers = cfg.extraction_workers
29+
if user_workers <= 0:
30+
worker_count = min(len(article_list), max(os.cpu_count() or 1, 1))
31+
else:
32+
worker_count = min(len(article_list), user_workers)
33+
if worker_count == 1:
34+
studies = [_build_study(article) for article in article_list]
35+
else:
36+
indexed_results: list[tuple[int, dict[str, Any]]] = []
37+
with ProcessPoolExecutor(max_workers=worker_count) as pool:
38+
future_map = {
39+
pool.submit(_build_study, article): idx
40+
for idx, article in enumerate(article_list)
4441
}
45-
points = [
46-
{
47-
"coordinates": triplet,
48-
"space": space,
49-
}
50-
for triplet in coords
51-
]
52-
if not points:
53-
continue
54-
analysis_name = _analysis_name(metadata)
55-
analyses.append(
56-
{"name": analysis_name, "points": points, "metadata": analysis_metadata}
57-
)
58-
study_metadata = dict(article.metadata)
59-
study: dict[str, Any] = {
60-
"doi": article.doi,
61-
"analyses": analyses,
62-
"metadata": study_metadata,
63-
}
64-
studies.append(study)
42+
for future in as_completed(future_map):
43+
idx = future_map[future]
44+
indexed_results.append((idx, future.result()))
45+
indexed_results.sort(key=lambda pair: pair[0])
46+
studies = [study for _, study in indexed_results]
6547

6648
return {"studyset": {"studies": studies}}
6749

6850

51+
def _build_study(article: ArticleContent) -> dict[str, Any]:
52+
"""Process a single article into a study representation."""
53+
54+
analyses: list[dict[str, Any]] = []
55+
tables = extract_tables_from_article(article.payload)
56+
if not tables:
57+
tables = _manual_extract_tables(article.payload)
58+
article_text: str | None = None
59+
for metadata, df in tables:
60+
meta_text = _metadata_text(metadata)
61+
coords = _extract_coordinates_from_dataframe(df, meta_text.lower())
62+
if not coords:
63+
continue
64+
header_text = " ".join(str(col).lower() for col in df.columns)
65+
space = _heuristic_space(header_text, meta_text)
66+
if space is None:
67+
if article_text is None:
68+
article_text = _article_text(article.payload)
69+
guessed = _neurosynth_guess_space(article_text)
70+
if guessed != "UNKNOWN":
71+
space = guessed
72+
analysis_metadata = {
73+
"table_label": metadata.label,
74+
"table_id": metadata.identifier,
75+
"raw_table_xml": metadata.raw_xml,
76+
}
77+
points = [
78+
{
79+
"coordinates": triplet,
80+
"space": space,
81+
}
82+
for triplet in coords
83+
]
84+
if not points:
85+
continue
86+
analysis_name = _analysis_name(metadata)
87+
analyses.append(
88+
{"name": analysis_name, "points": points, "metadata": analysis_metadata}
89+
)
90+
study_metadata = dict(article.metadata)
91+
return {
92+
"doi": article.doi,
93+
"analyses": analyses,
94+
"metadata": study_metadata,
95+
}
96+
97+
6998
def _heuristic_space(header_text: str, meta_text: str) -> str | None:
7099
combined = f"{header_text} {meta_text}".strip()
71100
if not combined:
@@ -395,4 +424,3 @@ def _extract_numbers(text: str) -> list[float]:
395424

396425
matches = re.findall(r"[-+]?\d+(?:\.\d+)?", text.replace("−", "-"))
397426
return [float(match) for match in matches]
398-

elsevier_coordinate_extraction/settings.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
_DEFAULT_CACHE_DIR: Final[str] = ".elsevier_cache"
1616
_DEFAULT_USER_AGENT: Final[str] = "elsevierCoordinateExtraction/0.1.0"
1717
_DEFAULT_MAX_RATE_LIMIT_WAIT: Final[float] = 3600.0 # 1 hour
18+
_DEFAULT_EXTRACTION_WORKERS: Final[int] = 0
1819

1920
_CACHED_SETTINGS: Settings | None = None
2021

@@ -34,6 +35,7 @@ class Settings:
3435
https_proxy: str | None
3536
use_proxy: bool
3637
max_rate_limit_wait: float | None
38+
extraction_workers: int
3739

3840

3941
_TRUE_VALUES: Final[set[str]] = {"1", "true", "yes", "on"}
@@ -110,5 +112,6 @@ def get_settings(*, force_reload: bool = False) -> Settings:
110112
https_proxy=https_proxy,
111113
use_proxy=use_proxy,
112114
max_rate_limit_wait=max_rate_limit_wait,
115+
extraction_workers=int(os.getenv("ELSEVIER_EXTRACTION_WORKERS", _DEFAULT_EXTRACTION_WORKERS)),
113116
)
114117
return _CACHED_SETTINGS

tests/download/test_api.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def _test_settings() -> Settings:
3131
https_proxy=None,
3232
use_proxy=False,
3333
max_rate_limit_wait=cfg.max_rate_limit_wait,
34+
extraction_workers=cfg.extraction_workers,
3435
)
3536

3637

tests/test_client.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def _make_test_settings(
4141
https_proxy=None,
4242
use_proxy=False,
4343
max_rate_limit_wait=resolved_wait,
44+
extraction_workers=cfg.extraction_workers,
4445
)
4546

4647

tests/test_settings.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def test_get_settings_reads_environment(
2727
assert cfg_a.api_key == "unit-test-key"
2828
assert cfg_a.insttoken is None
2929
assert cfg_a.use_proxy is False
30+
assert cfg_a.extraction_workers == 0
3031
assert cfg_a is cfg_b
3132

3233

@@ -66,6 +67,7 @@ def test_use_proxy_flag_disables_proxies(monkeypatch: pytest.MonkeyPatch) -> Non
6667
cfg = settings.get_settings(force_reload=True)
6768
assert cfg.http_proxy == "socks5://localhost:1080"
6869
assert cfg.use_proxy is False
70+
assert cfg.extraction_workers == 0
6971

7072

7173
def test_max_rate_limit_wait_env(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
@@ -76,8 +78,10 @@ def test_max_rate_limit_wait_env(monkeypatch: pytest.MonkeyPatch, tmp_path: Path
7678
blank_env.write_text("")
7779
monkeypatch.setenv("ELSEVIER_DOTENV_PATH", str(blank_env))
7880
monkeypatch.setenv("ELSEVIER_MAX_RATE_LIMIT_WAIT_SECONDS", "120")
81+
monkeypatch.setenv("ELSEVIER_EXTRACTION_WORKERS", "8")
7982
cfg = settings.get_settings(force_reload=True)
8083
assert cfg.max_rate_limit_wait == 120.0
84+
assert cfg.extraction_workers == 8
8185

8286

8387
def test_max_rate_limit_wait_unlimited(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:

0 commit comments

Comments
 (0)