Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 15 additions & 7 deletions src/paperoni/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,29 +390,37 @@ async def run(self, work: "Work"):
it = list(itertools.islice(work.top, self.n)) if self.n else work.top

for i in range(self.loops):
for sws in prog(it, name=f"refine{i + 1 if i else ''}"):

async def fetch_and_add(sws, i):
statuses.update(
{
(name, key): "done"
for pinfo in sws.value.collected
for name, key in pinfo.info.get("refined_by", {}).items()
for paper in sws.value.collected
for name, key in paper.info.get("refined_by", {}).items()
}
)
# Loop a bit because refiners can add new links to refine further
links = [(lnk.type, lnk.link) for lnk in sws.value.current.links]
links.append(("title", sws.value.current.title))
if i == 0:
send(to_refine=links)
async for pinfo in fetch_all(
async for paper in fetch_all(
links,
group=";".join([f"{type}:{link}" for type, link in links]),
tags=self.tags,
force=self.force,
statuses=statuses,
):
send(refinement=pinfo)
sws.value.add(pinfo)
sws.score = work.focuses.score(sws.value)
send(refinement=paper)
sws.value.add(paper)
sws.score = work.focuses.score(sws.value)
return sws

coros = [
fetch_and_add(sws, i)
for sws in prog(it, name=f"refine{i + 1 if i else ''}")
]
await asyncio.gather(*coros)

work.top.resort()
work.save()
Expand Down
30 changes: 26 additions & 4 deletions src/paperoni/get.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
import os
import re
from contextlib import asynccontextmanager
from dataclasses import dataclass
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from functools import cached_property
from pathlib import Path
from typing import Literal
from urllib.parse import urlparse

import chardet
import hishel
Expand All @@ -20,7 +21,7 @@
from requests import Session
from serieux import TaggedSubclass
from serieux.features.encrypt import Secret
from tenacity import retry, stop_after_delay, wait_exponential
from tenacity import retry, stop_after_delay, wait_exponential, wait_random

ERRORS = (httpx.HTTPStatusError, requests.RequestException)
ua = UserAgent()
Expand Down Expand Up @@ -146,7 +147,7 @@ def is_cache_valid(path: Path, expiry: timedelta):
return parse(content, format)

@retry(
wait=wait_exponential(multiplier=1, exp_base=2),
wait=wait_exponential(multiplier=1, exp_base=2) + wait_random(0, 0.5),
stop=stop_after_delay(30),
retry=lambda retry_state: not _giveup(retry_state.outcome.exception()),
reraise=True,
Expand Down Expand Up @@ -326,10 +327,31 @@ async def generic(self, method, url, **kwargs):
class RulesFetcher(Fetcher):
rules: dict[re.Pattern, str]
fetchers: dict[str, TaggedSubclass[Fetcher]]
simultaneous: dict[str, int] = field(default_factory=dict)

# [serieux: ignore]
_semaphores: dict[str, asyncio.Semaphore] = field(default_factory=dict, repr=False)

def _get_semaphore(self, hostname: str) -> asyncio.Semaphore | None:
"""Get or create a semaphore for the given hostname."""
if hostname not in self._semaphores:
limit = self.simultaneous.get(hostname, None) or self.simultaneous.get(
"*", -1
)
if limit == -1:
return None
self._semaphores[hostname] = asyncio.Semaphore(limit)

return self._semaphores[hostname]

async def generic(self, method, url, **kwargs):
for pattern, fetcher_key in self.rules.items():
if pattern.search(url):
f = self.fetchers[fetcher_key]
return await f.generic(method, url, **kwargs)
hostname = urlparse(url).hostname or ""
if (semaphore := self._get_semaphore(hostname)) is not None:
async with semaphore:
return await f.generic(method, url, **kwargs)
else:
return await f.generic(method, url, **kwargs)
raise ValueError(f"No fetcher rule matches URL: {url}")
16 changes: 11 additions & 5 deletions src/paperoni/refinement/fetch.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import inspect
from dataclasses import replace
from typing import Callable
Expand Down Expand Up @@ -63,7 +64,7 @@ async def fetch_all(links, group="composite", statuses=None, tags=None, force=Fa
for type, link in links:
funcs.append((f"{type}:{link}", (type, link), fetch.resolve_all(type, link)))

async def go(key, args, f):
async def go(key, name, nk, args, f):
__trace__ = f"refine:{key}" # noqa: F841
with soft_fail(f"Refinement of {key}"):
try:
Expand All @@ -81,16 +82,21 @@ async def go(key, args, f):
statuses[nk] = "error"
raise

tasks = []

for key, args, fs in funcs:
for f in fs:
if not _test_tags(getattr(f.func, "tags", {"normal"}), tags):
continue

name = getattr(f.func, "description", "???")
nk = f"{name}/{key}"
if nk in statuses:
continue
statuses[nk] = "pending"
result = await go(key, args, f)
if result is not None:
yield result
coro = go(key, name, nk, args, f)
tasks.append(asyncio.create_task(coro))

for task in tasks:
result = await task
if result is not None:
yield result
4 changes: 2 additions & 2 deletions src/paperoni/refinement/title.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ async def crossref_title(typ: Literal["title"], link: str):
encoded_title = quote(title.strip())

try:
data = await config.fetch.read(
data = await config.fetch.read_retry(
f"https://api.crossref.org/works?query.title={encoded_title}&rows=1",
format="json",
)
Expand Down Expand Up @@ -67,7 +67,7 @@ async def arxiv_title(type: Literal["title"], link: str):
title = link

try:
soup = await config.fetch.read(
soup = await config.fetch.read_retry(
"https://export.arxiv.org/api/query",
params={
"search_query": f'title:"{title}"',
Expand Down
2 changes: 2 additions & 0 deletions tests/config/test-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ paperoni:
cloudflare:
$class: CloudFlareFetcher
user_agent: chrome
simultaneous:
"*": 1
focuses:
- "!institution :: Mila :: 10"
- "!author :: Yoshua Bengio :: 3"
Expand Down
Loading