Skip to content

Commit 3300730

Browse files
authored
Discovery of Paperoni v2, auto-validation, collection diff (#131)
* Add coll validate to insert validated v2 papers back to the db * Use paperoni.mila.quebec v2 endpoint and [in]valid flags * Add coll diff cli
1 parent e227b2f commit 3300730

File tree

12 files changed

+1402
-9
lines changed

12 files changed

+1402
-9
lines changed

config/basic.yaml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ paperoni:
55
api_keys:
66
gemini: _
77
scraperapi: _
8+
paperoni_v2: _
89
fetch:
910
$class: RulesFetcher
1011
rules:
@@ -41,10 +42,17 @@ paperoni:
4142
author:
4243
score: 1
4344
institution_score_threshold: 1
45+
autovalidate:
46+
score_threshold: 10.0
4447
discovery:
4548
scrape:
4649
urls:
4750
- https://dadelani.github.io/publications
51+
v2:
52+
$class: paperoni.discovery.paperoni_v2:PaperoniV2
53+
endpoint: https://paperoni.mila.quebec
54+
token: ${paperoni.api_keys.paperoni_v2}
55+
cache: ${paperoni.data_path}/paperoniv2.json
4856
refine:
4957
prompt:
5058
$class: paperoni.prompt:GenAIPrompt

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ pmlr = "paperoni.discovery.pmlr:PMLR"
7171
jmlr = "paperoni.discovery.jmlr:JMLR"
7272
synth = "paperoni.discovery.synth:Synth"
7373
scrape = "paperoni.discovery.scrape:Scrape"
74+
v2 = "paperoni.discovery.paperoni_v2:PaperoniV2"
7475

7576
[project.entry-points."outsight.fixtures"]
7677
dash = "paperoni.dash:Dash"

src/paperoni/__main__.py

Lines changed: 119 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
from .collection.remotecoll import RemoteCollection
4444
from .config import config
4545
from .dash import History
46+
from .discovery.paperoni_v2 import PaperoniV2
4647
from .display import display, print_field, terminal_width
4748
from .fulltext.locate import URL, locate_all
4849
from .fulltext.pdf import PDF, CachePolicies, get_pdf
@@ -247,6 +248,7 @@ class Configure:
247248
"""Configure the workset."""
248249

249250
n: int
251+
drop_zero: bool = True
250252
clear: bool = False
251253

252254
async def run(self, work: "Work"):
@@ -255,14 +257,15 @@ async def run(self, work: "Work"):
255257
top = deserialize(
256258
Top[Scored[CommentRec[PaperWorkingSet, float]]], work_file
257259
)
260+
top.drop_zero = self.drop_zero
258261
if self.clear:
259262
top.entries = []
260263
elif top.n > self.n:
261264
top.entries = list(top)[: self.n]
262265
top.resort()
263266
top.n = self.n
264267
else:
265-
top = Top(self.n)
268+
top = Top(self.n, drop_zero=self.drop_zero)
266269
work.save(top)
267270
print(f"Configured {work_file.resolve()} for n={self.n}")
268271

@@ -701,8 +704,122 @@ async def run(self, coll: "Coll"):
701704
elif not len(coll.collection) and not len(await coll.collection.exclusions()):
702705
logging.warning("Collection is not empty. Use --force to drop it.")
703706

707+
@dataclass
708+
class Validate:
709+
"""Validate the papers in the collection using the paperoni v2 database."""
710+
711+
# The paperoni v2 database
712+
# [optional]
713+
# [metavar v2]
714+
paperoni_v2: Auto[PaperoniV2.query] = None
715+
716+
# Validate papers having a score greater than the threshold
717+
# [metavar FLOAT]
718+
threshold: float = None
719+
720+
async def iterate(
721+
self, coll: "Coll" = None, **kwargs
722+
) -> AsyncGenerator[Paper, None]:
723+
if self.paperoni_v2 is not None:
724+
validated = 0
725+
total = 0
726+
async for paper_v2 in self.paperoni_v2(**kwargs):
727+
paper_v2: Paper
728+
total += 1
729+
730+
if "valid" not in paper_v2.flags:
731+
continue
732+
733+
validated += 1
734+
735+
yield paper_v2
736+
737+
send(progress=("Validated v2 papers", validated, total))
738+
739+
send(progress=("Validated v2 papers", None, total))
740+
741+
else:
742+
score_threshold = self.threshold or config.autovalidate.score_threshold
743+
744+
async for paper in coll.collection.search():
745+
paper: Paper
746+
747+
if (
748+
"valid" in paper.flags
749+
or (paper.score or config.focuses.score(paper)) < score_threshold
750+
):
751+
continue
752+
753+
yield paper
754+
755+
async def run(self, coll: "Coll"):
756+
ignored = 0
757+
validated = 0
758+
count = 0
759+
760+
async for paper in self.iterate(coll=coll):
761+
count += 1
762+
763+
if coll_paper := await coll.collection.find_paper(paper):
764+
if "invalid" in coll_paper.flags:
765+
ignored += 1
766+
continue
767+
768+
validated += 1
769+
coll_paper.flags.add("valid")
770+
await coll.collection.edit_paper(coll_paper)
771+
772+
if ignored and ignored != count:
773+
send(progress=("Ignored papers", ignored, count))
774+
775+
send(progress=("Validated papers", validated, count))
776+
777+
@dataclass
778+
class Diff:
779+
"""Diff the paper collection and another collection.
780+
781+
The output directory will contain two files:
782+
- missing.json: Papers in the other collection that are not in the current collection
783+
- extra.json: Papers in the current collection that are not in the other collection
784+
"""
785+
786+
# The other collection
787+
# [positional]
788+
other_collection_path: str
789+
790+
# Output directory
791+
out: Path
792+
793+
# Format of the output files
794+
# [alias: --fmt]
795+
format: Literal["json", "yaml"] = "json"
796+
797+
async def run(self, coll: "Coll"):
798+
other_collection = FileCollection(file=Path(self.other_collection_path))
799+
missings = []
800+
extras = []
801+
802+
async for paper in other_collection.search():
803+
if not await coll.collection.find_paper(paper):
804+
missings.append(paper)
805+
806+
self.out.mkdir(exist_ok=True, parents=True)
807+
(self.out / f"missing.{self.format}").unlink(missing_ok=True)
808+
await FileCollection(file=self.out / f"missing.{self.format}").add_papers(
809+
missings
810+
)
811+
812+
async for paper in coll.collection.search():
813+
if not await other_collection.find_paper(paper):
814+
extras.append(paper)
815+
816+
(self.out / f"extra.{self.format}").unlink(missing_ok=True)
817+
await FileCollection(file=self.out / f"extra.{self.format}").add_papers(
818+
extras
819+
)
820+
704821
# Command to execute
705-
command: TaggedUnion[Search, Import, Export, Drop]
822+
command: TaggedUnion[Search, Import, Export, Drop, Validate, Diff]
706823

707824
# Collection dir
708825
# [alias: -c]

src/paperoni/collection/abc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from datetime import date
2-
from typing import AsyncIterable, Iterable
2+
from typing import AsyncGenerator, Iterable
33

44
from ..model.classes import Paper
55

@@ -62,7 +62,7 @@ async def search(
6262
include_flags: list[str] = None,
6363
# Flags that must be False
6464
exclude_flags: list[str] = None,
65-
) -> AsyncIterable[Paper]:
65+
) -> AsyncGenerator[Paper, None]:
6666
raise NotImplementedError()
6767

6868
def __len__(self) -> int:

src/paperoni/collection/remotecoll.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import os
22
from dataclasses import dataclass, field
33
from datetime import date
4-
from typing import Iterable
4+
from typing import AsyncGenerator, Iterable
55

66
from fastapi import HTTPException
77
from serieux import deserialize
@@ -81,7 +81,7 @@ async def search(
8181
include_flags: list[str] = None,
8282
# Flags that must be False
8383
exclude_flags: list[str] = None,
84-
):
84+
) -> AsyncGenerator[Paper, None]:
8585
params = {}
8686
if paper_id:
8787
params["paper_id"] = paper_id
@@ -106,7 +106,7 @@ async def search(
106106
while True:
107107
query_params = params.copy()
108108
query_params["offset"] = offset
109-
resp = await self.fetch.read(
109+
resp: dict = await self.fetch.read(
110110
url,
111111
format="json",
112112
cache_into=None,

src/paperoni/config.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import gifnoc
88
from easy_oauth import OAuthManager
99
from rapporteur.report import Reporter
10-
from serieux import JSON, TaggedSubclass
10+
from serieux import TaggedSubclass
1111
from serieux.features.encrypt import Secret
1212

1313
from .collection.abc import PaperCollection
@@ -52,16 +52,21 @@ def __post_init__(self):
5252
self.process_pool = ProcessPoolExecutor(**self.process_pool_executor)
5353

5454

55+
@dataclass
56+
class AutoValidate:
57+
score_threshold: float = 10.0
58+
59+
5560
@dataclass
5661
class PaperoniConfig:
5762
cache_path: Path = None
5863
data_path: Path = None
5964
mailto: str = ""
6065
api_keys: Keys[str, Secret[str]] = field(default_factory=Keys)
61-
discovery: JSON = None
6266
fetch: TaggedSubclass[Fetcher] = field(default_factory=RequestsFetcher)
6367
focuses: Focuses = field(default_factory=Focuses)
6468
autofocus: AutoFocus[str, AutoFocus.Author] = field(default_factory=AutoFocus)
69+
autovalidate: AutoValidate = field(default_factory=AutoValidate)
6570
refine: Refine = None
6671
work_file: Path = None
6772
collection: TaggedSubclass[PaperCollection] = None

0 commit comments

Comments
 (0)