Skip to content

Commit 95b536f

Browse files
authored
Fix flag search (#122)
* Edit search api * Fix flags search
1 parent 397c351 commit 95b536f

18 files changed

+531
-452
lines changed

src/paperoni/__main__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -645,7 +645,7 @@ def run(self, coll: "Coll") -> list[Paper]:
645645
start_date=self.start_date,
646646
end_date=self.end_date,
647647
include_flags={f for f in flags if not f.startswith("~")},
648-
exclude_flags={f for f in flags if f.startswith("~")},
648+
exclude_flags={f[1:] for f in flags if f.startswith("~")},
649649
)
650650
]
651651
self.format(papers)

src/paperoni/collection/remotecoll.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,9 +99,9 @@ def search(
9999
if end_date:
100100
params["end_date"] = end_date.isoformat()
101101
if include_flags:
102-
params["include_flags"] = ",".join(include_flags)
102+
params.setdefault("flags", []).extend(include_flags)
103103
if exclude_flags:
104-
params["exclude_flags"] = ",".join(exclude_flags)
104+
params.setdefault("flags", []).extend([f"~{f}" for f in exclude_flags])
105105
url = f"{self.endpoint}/search"
106106
offset = 0
107107
while True:

src/paperoni/web/assets/search.js

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,23 @@ async function fetchSearchResults(params, offset = 0) {
3131
if (params.venue) queryParams.append('venue', params.venue);
3232
if (params.start_date) queryParams.append('start_date', params.start_date);
3333
if (params.end_date) queryParams.append('end_date', params.end_date);
34-
if (params.validated) queryParams.append('validated', params.validated);
34+
35+
// Convert validated parameter to flags
36+
switch (params.validated) {
37+
case 'true':
38+
// Include papers with 'valid' flag
39+
queryParams.append('flags', 'valid');
40+
break;
41+
case 'false':
42+
// Include papers with 'invalid' flag
43+
queryParams.append('flags', 'invalid');
44+
break;
45+
case 'unset':
46+
// Exclude papers with both 'valid' and 'invalid' flags (unprocessed)
47+
queryParams.append('flags', '~valid');
48+
queryParams.append('flags', '~invalid');
49+
break;
50+
}
3551

3652
const url = `/api/v1/search?${queryParams.toString()}`;
3753
const response = await fetch(url);

src/paperoni/web/restapi.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from types import NoneType
1010
from typing import Generator, Iterable, Literal
1111

12-
from fastapi import Depends, FastAPI, HTTPException
12+
from fastapi import Depends, FastAPI, HTTPException, Query
1313
from fastapi.responses import StreamingResponse
1414
from serieux import auto_singleton, deserialize, serialize
1515

@@ -33,10 +33,10 @@ def __call__(self, things):
3333
class PagingMixin:
3434
"""Mixin for paging."""
3535

36-
# Results offset
36+
# Pagination offset
3737
offset: int = field(default=0)
38-
# Max number of results to return
39-
size: int = field(default=100)
38+
# Maximum number of results to return
39+
limit: int = field(default=100)
4040

4141
_count: int | None = field(repr=False, compare=False, default=None)
4242
_next_offset: int | None = field(repr=False, compare=False, default=None)
@@ -54,23 +54,23 @@ def slice(
5454
iterable: Iterable,
5555
*,
5656
offset: int = None,
57-
size: int = None,
57+
limit: int = None,
5858
) -> Iterable:
5959
if offset is None:
6060
offset = self.offset or 0
61-
if size is None:
62-
size = self.size or 100
61+
if limit is None:
62+
limit = self.limit or 100
6363

64-
size = min(size, config.server.max_results)
64+
limit = min(limit, config.server.max_results)
6565

6666
self._count = 0
6767
self._next_offset = offset
68-
for entry in itertools.islice(iterable, offset, offset + size):
68+
for entry in itertools.islice(iterable, offset, offset + limit):
6969
self._count += 1
7070
self._next_offset += 1
7171
yield entry
7272

73-
if self._count < size:
73+
if self._count < limit:
7474
# No more results
7575
self._next_offset = None
7676

@@ -331,12 +331,23 @@ async def root():
331331
"version": app.version,
332332
}
333333

334+
def parse_search_request(
335+
request: SearchRequest = Depends(),
336+
flags: set[str] = Query(default=None),
337+
) -> SearchRequest:
338+
"""Parse search request with proper handling of flags list parameter."""
339+
# Add flags if provided (FastAPI's Query() handles set parsing)
340+
if flags:
341+
request.flags = flags
342+
343+
return request
344+
334345
@app.get(
335346
f"{prefix}/search",
336347
response_model=SearchResponse,
337348
dependencies=[Depends(hascap("search"))],
338349
)
339-
async def search_papers(request: SearchRequest = Depends()):
350+
async def search_papers(request: SearchRequest = Depends(parse_search_request)):
340351
"""Search for papers in the collection."""
341352
results, count, next_offset, total = await run_in_process_pool(
342353
_search, serialize(SearchRequest, request)

tests/collection/test_collections.py

Lines changed: 70 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,12 @@ def sample_papers() -> Generator[list[Paper], None, None]:
8686
]
8787
)
8888

89+
papers[0].flags = {"valid"}
90+
papers[1].flags = {"invalid"}
91+
papers[2].flags = {"valid", "invalid"}
92+
papers[3].flags = {"reviewed"}
93+
papers[4].flags = {"valid", "reviewed"}
94+
8995
yield papers
9096

9197

@@ -190,6 +196,27 @@ def __getattr__(self, name):
190196
)
191197

192198

199+
def check_papers(
200+
data_regression: DataRegressionFixture, papers: list[Paper], basename: str = None
201+
):
202+
# Using file_regression and json.dumps to avoid
203+
# yaml.representer.RepresenterError on DatePrecision
204+
# papers = sort_keys(papers[:5])
205+
# [p.pop("acquired") for p in papers]
206+
papers = serialize(list[Paper], papers)
207+
208+
for paper in papers:
209+
# MongoPaper uses ObjectId which will not be the same each time the test is
210+
# run
211+
paper["id"] = None if not isinstance(paper.get("id", None), int) else paper["id"]
212+
paper.pop("_id", None)
213+
paper.pop("version", None)
214+
# Sort flags to ensure consistent ordering
215+
paper["flags"] = sorted(paper["flags"])
216+
217+
data_regression.check(papers, basename=basename)
218+
219+
193220
def test_add_papers(collection: PaperCollection, sample_papers: list[Paper]):
194221
"""Test adding multiple papers."""
195222
collection.add_papers(sample_papers)
@@ -406,6 +433,48 @@ def test_search_partial_matches(
406433
assert len(results) == 4
407434

408435

436+
@pytest.mark.coll_w_remote
437+
def test_search_by_flags(collection: PaperCollection, sample_papers: list[Paper]):
438+
"""Test searching by flag inclusion and exclusion."""
439+
collection.add_papers(sample_papers)
440+
441+
# Test include_flags: papers must have ALL specified flags
442+
results = sorted(collection.search(include_flags=["valid"]), key=lambda x: x.title)
443+
assert len(results) == 3
444+
assert all({"valid"} <= p.flags for p in results)
445+
446+
# Test include_flags with multiple flags: papers must have ALL specified flags
447+
results = sorted(
448+
collection.search(include_flags=["valid", "reviewed"]), key=lambda x: x.title
449+
)
450+
assert len(results) == 1
451+
assert all({"valid", "reviewed"} <= p.flags for p in results)
452+
453+
# Test exclude_flags: papers must NOT have ANY of the specified flags
454+
results = sorted(collection.search(exclude_flags=["invalid"]), key=lambda x: x.title)
455+
assert len(results) == len(sample_papers) - 2
456+
assert all(not p.flags & {"invalid"} for p in results)
457+
458+
# Test exclude_flags with multiple flags
459+
results = sorted(
460+
collection.search(exclude_flags=["valid", "invalid"]), key=lambda x: x.title
461+
)
462+
assert len(results) == len(sample_papers) - 4
463+
assert all(not p.flags & {"valid", "invalid"} for p in results)
464+
465+
# Test combining include_flags and exclude_flags
466+
results = sorted(
467+
collection.search(include_flags=["valid"], exclude_flags=["invalid"]),
468+
key=lambda x: x.title,
469+
)
470+
assert len(results) == 2
471+
assert all({"valid"} <= p.flags and not p.flags & {"invalid"} for p in results)
472+
473+
# Test with no flags (should return all papers)
474+
results = sorted(collection.search(), key=lambda x: x.title)
475+
assert len(results) == 10
476+
477+
409478
def test_file_collection_is_persistent(tmp_path: Path, sample_papers: list[Paper]):
410479
collection = FileCollection(file=tmp_path / "collection.json")
411480

@@ -442,10 +511,4 @@ def test_make_collection_item(
442511
assert paper.id == papers[0].id
443512
assert eq(paper, papers[0])
444513

445-
paper = serialize(paper_cls, paper)
446-
# MongoPaper uses ObjectId which will not be the same each time the test is
447-
# run
448-
paper["id"] = None if not isinstance(paper["id"], int) else paper["id"]
449-
paper.pop("_id", None)
450-
paper.pop("version")
451-
data_regression.check(paper)
514+
check_papers(data_regression, [paper])
Lines changed: 68 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1,67 +1,68 @@
1-
abstract: null
2-
authors:
3-
- affiliations: []
4-
author:
5-
aliases: []
6-
links: []
7-
name: Vijay Prakash Dwivedi
8-
display_name: Vijay Prakash Dwivedi
9-
- affiliations: []
10-
author:
11-
aliases: []
12-
links: []
13-
name: Chaitanya K. Joshi
14-
display_name: Chaitanya K. Joshi
15-
- affiliations: []
16-
author:
17-
aliases: []
18-
links: []
19-
name: Anh Tuan Luu
20-
display_name: Anh Tuan Luu
21-
- affiliations: []
22-
author:
23-
aliases: []
24-
links: []
25-
name: Thomas Laurent
26-
display_name: Thomas Laurent
27-
- affiliations: []
28-
author:
29-
aliases: []
30-
links: []
31-
name: Yoshua Bengio
32-
display_name: Yoshua Bengio
33-
- affiliations: []
34-
author:
35-
aliases: []
36-
links: []
37-
name: Xavier Bresson
38-
display_name: Xavier Bresson
39-
flags: []
40-
id: 0
41-
links:
42-
- link: https://jmlr.org/papers/volume24/22-0567/22-0567.pdf
43-
type: pdf.official
44-
- link: https://jmlr.org/papers/v24/22-0567.bib
45-
type: bibtex
46-
- link: https://jmlr.org/papers/v24/22-0567.html
47-
type: abstract.official
48-
- link: https://jmlr.org/papers/v24/22-0567.html
49-
type: uid
50-
releases:
51-
- pages: 1-48
52-
status: published
53-
venue:
54-
aliases: []
55-
date: '2023-01-01'
56-
date_precision: 1
57-
links: []
58-
name: Journal of Machine Learning Research
59-
open: false
60-
peer_reviewed: true
61-
publisher: JMLR
62-
series: JMLR
63-
short_name: null
64-
type: journal
65-
volume: null
66-
title: Benchmarking Graph Neural Networks
67-
topics: []
1+
- abstract: null
2+
authors:
3+
- affiliations: []
4+
author:
5+
aliases: []
6+
links: []
7+
name: Vijay Prakash Dwivedi
8+
display_name: Vijay Prakash Dwivedi
9+
- affiliations: []
10+
author:
11+
aliases: []
12+
links: []
13+
name: Chaitanya K. Joshi
14+
display_name: Chaitanya K. Joshi
15+
- affiliations: []
16+
author:
17+
aliases: []
18+
links: []
19+
name: Anh Tuan Luu
20+
display_name: Anh Tuan Luu
21+
- affiliations: []
22+
author:
23+
aliases: []
24+
links: []
25+
name: Thomas Laurent
26+
display_name: Thomas Laurent
27+
- affiliations: []
28+
author:
29+
aliases: []
30+
links: []
31+
name: Yoshua Bengio
32+
display_name: Yoshua Bengio
33+
- affiliations: []
34+
author:
35+
aliases: []
36+
links: []
37+
name: Xavier Bresson
38+
display_name: Xavier Bresson
39+
flags:
40+
- valid
41+
id: null
42+
links:
43+
- link: https://jmlr.org/papers/volume24/22-0567/22-0567.pdf
44+
type: pdf.official
45+
- link: https://jmlr.org/papers/v24/22-0567.bib
46+
type: bibtex
47+
- link: https://jmlr.org/papers/v24/22-0567.html
48+
type: abstract.official
49+
- link: https://jmlr.org/papers/v24/22-0567.html
50+
type: uid
51+
releases:
52+
- pages: 1-48
53+
status: published
54+
venue:
55+
aliases: []
56+
date: '2023-01-01'
57+
date_precision: 1
58+
links: []
59+
name: Journal of Machine Learning Research
60+
open: false
61+
peer_reviewed: true
62+
publisher: JMLR
63+
series: JMLR
64+
short_name: null
65+
type: journal
66+
volume: null
67+
title: Benchmarking Graph Neural Networks
68+
topics: []

0 commit comments

Comments
 (0)