Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/paperoni/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,7 +645,7 @@ def run(self, coll: "Coll") -> list[Paper]:
start_date=self.start_date,
end_date=self.end_date,
include_flags={f for f in flags if not f.startswith("~")},
exclude_flags={f for f in flags if f.startswith("~")},
exclude_flags={f[1:] for f in flags if f.startswith("~")},
)
]
self.format(papers)
Expand Down
4 changes: 2 additions & 2 deletions src/paperoni/collection/remotecoll.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,9 @@ def search(
if end_date:
params["end_date"] = end_date.isoformat()
if include_flags:
params["include_flags"] = ",".join(include_flags)
params.setdefault("flags", []).extend(include_flags)
if exclude_flags:
params["exclude_flags"] = ",".join(exclude_flags)
params.setdefault("flags", []).extend([f"~{f}" for f in exclude_flags])
url = f"{self.endpoint}/search"
offset = 0
while True:
Expand Down
18 changes: 17 additions & 1 deletion src/paperoni/web/assets/search.js
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,23 @@ async function fetchSearchResults(params, offset = 0) {
if (params.venue) queryParams.append('venue', params.venue);
if (params.start_date) queryParams.append('start_date', params.start_date);
if (params.end_date) queryParams.append('end_date', params.end_date);
if (params.validated) queryParams.append('validated', params.validated);

// Convert validated parameter to flags
switch (params.validated) {
case 'true':
// Include papers with 'valid' flag
queryParams.append('flags', 'valid');
break;
case 'false':
// Include papers with 'invalid' flag
queryParams.append('flags', 'invalid');
break;
case 'unset':
// Exclude papers with both 'valid' and 'invalid' flags (unprocessed)
queryParams.append('flags', '~valid');
queryParams.append('flags', '~invalid');
break;
}

const url = `/api/v1/search?${queryParams.toString()}`;
const response = await fetch(url);
Expand Down
33 changes: 22 additions & 11 deletions src/paperoni/web/restapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from types import NoneType
from typing import Generator, Iterable, Literal

from fastapi import Depends, FastAPI, HTTPException
from fastapi import Depends, FastAPI, HTTPException, Query
from fastapi.responses import StreamingResponse
from serieux import auto_singleton, deserialize, serialize

Expand All @@ -33,10 +33,10 @@ def __call__(self, things):
class PagingMixin:
"""Mixin for paging."""

# Results offset
# Pagination offset
offset: int = field(default=0)
# Max number of results to return
size: int = field(default=100)
# Maximum number of results to return
limit: int = field(default=100)

_count: int | None = field(repr=False, compare=False, default=None)
_next_offset: int | None = field(repr=False, compare=False, default=None)
Expand All @@ -54,23 +54,23 @@ def slice(
iterable: Iterable,
*,
offset: int = None,
size: int = None,
limit: int = None,
) -> Iterable:
if offset is None:
offset = self.offset or 0
if size is None:
size = self.size or 100
if limit is None:
limit = self.limit or 100

size = min(size, config.server.max_results)
limit = min(limit, config.server.max_results)

self._count = 0
self._next_offset = offset
for entry in itertools.islice(iterable, offset, offset + size):
for entry in itertools.islice(iterable, offset, offset + limit):
self._count += 1
self._next_offset += 1
yield entry

if self._count < size:
if self._count < limit:
# No more results
self._next_offset = None

Expand Down Expand Up @@ -331,12 +331,23 @@ async def root():
"version": app.version,
}

def parse_search_request(
request: SearchRequest = Depends(),
flags: set[str] = Query(default=None),
) -> SearchRequest:
"""Parse search request with proper handling of flags list parameter."""
# Add flags if provided (FastAPI's Query() handles set parsing)
if flags:
request.flags = flags

return request

@app.get(
f"{prefix}/search",
response_model=SearchResponse,
dependencies=[Depends(hascap("search"))],
)
async def search_papers(request: SearchRequest = Depends()):
async def search_papers(request: SearchRequest = Depends(parse_search_request)):
"""Search for papers in the collection."""
results, count, next_offset, total = await run_in_process_pool(
_search, serialize(SearchRequest, request)
Expand Down
77 changes: 70 additions & 7 deletions tests/collection/test_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,12 @@ def sample_papers() -> Generator[list[Paper], None, None]:
]
)

papers[0].flags = {"valid"}
papers[1].flags = {"invalid"}
papers[2].flags = {"valid", "invalid"}
papers[3].flags = {"reviewed"}
papers[4].flags = {"valid", "reviewed"}

yield papers


Expand Down Expand Up @@ -190,6 +196,27 @@ def __getattr__(self, name):
)


def check_papers(
data_regression: DataRegressionFixture, papers: list[Paper], basename: str = None
):
# Using file_regression and json.dumps to avoid
# yaml.representer.RepresenterError on DatePrecision
# papers = sort_keys(papers[:5])
# [p.pop("acquired") for p in papers]
papers = serialize(list[Paper], papers)

for paper in papers:
# MongoPaper uses ObjectId which will not be the same each time the test is
# run
paper["id"] = None if not isinstance(paper.get("id", None), int) else paper["id"]
paper.pop("_id", None)
paper.pop("version", None)
# Sort flags to ensure consistent ordering
paper["flags"] = sorted(paper["flags"])

data_regression.check(papers, basename=basename)


def test_add_papers(collection: PaperCollection, sample_papers: list[Paper]):
"""Test adding multiple papers."""
collection.add_papers(sample_papers)
Expand Down Expand Up @@ -406,6 +433,48 @@ def test_search_partial_matches(
assert len(results) == 4


@pytest.mark.coll_w_remote
def test_search_by_flags(collection: PaperCollection, sample_papers: list[Paper]):
"""Test searching by flag inclusion and exclusion."""
collection.add_papers(sample_papers)

# Test include_flags: papers must have ALL specified flags
results = sorted(collection.search(include_flags=["valid"]), key=lambda x: x.title)
assert len(results) == 3
assert all({"valid"} <= p.flags for p in results)

# Test include_flags with multiple flags: papers must have ALL specified flags
results = sorted(
collection.search(include_flags=["valid", "reviewed"]), key=lambda x: x.title
)
assert len(results) == 1
assert all({"valid", "reviewed"} <= p.flags for p in results)

# Test exclude_flags: papers must NOT have ANY of the specified flags
results = sorted(collection.search(exclude_flags=["invalid"]), key=lambda x: x.title)
assert len(results) == len(sample_papers) - 2
assert all(not p.flags & {"invalid"} for p in results)

# Test exclude_flags with multiple flags
results = sorted(
collection.search(exclude_flags=["valid", "invalid"]), key=lambda x: x.title
)
assert len(results) == len(sample_papers) - 4
assert all(not p.flags & {"valid", "invalid"} for p in results)

# Test combining include_flags and exclude_flags
results = sorted(
collection.search(include_flags=["valid"], exclude_flags=["invalid"]),
key=lambda x: x.title,
)
assert len(results) == 2
assert all({"valid"} <= p.flags and not p.flags & {"invalid"} for p in results)

# Test with no flags (should return all papers)
results = sorted(collection.search(), key=lambda x: x.title)
assert len(results) == 10


def test_file_collection_is_persistent(tmp_path: Path, sample_papers: list[Paper]):
collection = FileCollection(file=tmp_path / "collection.json")

Expand Down Expand Up @@ -442,10 +511,4 @@ def test_make_collection_item(
assert paper.id == papers[0].id
assert eq(paper, papers[0])

paper = serialize(paper_cls, paper)
# MongoPaper uses ObjectId which will not be the same each time the test is
# run
paper["id"] = None if not isinstance(paper["id"], int) else paper["id"]
paper.pop("_id", None)
paper.pop("version")
data_regression.check(paper)
check_papers(data_regression, [paper])
Original file line number Diff line number Diff line change
@@ -1,67 +1,68 @@
abstract: null
authors:
- affiliations: []
author:
aliases: []
links: []
name: Vijay Prakash Dwivedi
display_name: Vijay Prakash Dwivedi
- affiliations: []
author:
aliases: []
links: []
name: Chaitanya K. Joshi
display_name: Chaitanya K. Joshi
- affiliations: []
author:
aliases: []
links: []
name: Anh Tuan Luu
display_name: Anh Tuan Luu
- affiliations: []
author:
aliases: []
links: []
name: Thomas Laurent
display_name: Thomas Laurent
- affiliations: []
author:
aliases: []
links: []
name: Yoshua Bengio
display_name: Yoshua Bengio
- affiliations: []
author:
aliases: []
links: []
name: Xavier Bresson
display_name: Xavier Bresson
flags: []
id: 0
links:
- link: https://jmlr.org/papers/volume24/22-0567/22-0567.pdf
type: pdf.official
- link: https://jmlr.org/papers/v24/22-0567.bib
type: bibtex
- link: https://jmlr.org/papers/v24/22-0567.html
type: abstract.official
- link: https://jmlr.org/papers/v24/22-0567.html
type: uid
releases:
- pages: 1-48
status: published
venue:
aliases: []
date: '2023-01-01'
date_precision: 1
links: []
name: Journal of Machine Learning Research
open: false
peer_reviewed: true
publisher: JMLR
series: JMLR
short_name: null
type: journal
volume: null
title: Benchmarking Graph Neural Networks
topics: []
- abstract: null
authors:
- affiliations: []
author:
aliases: []
links: []
name: Vijay Prakash Dwivedi
display_name: Vijay Prakash Dwivedi
- affiliations: []
author:
aliases: []
links: []
name: Chaitanya K. Joshi
display_name: Chaitanya K. Joshi
- affiliations: []
author:
aliases: []
links: []
name: Anh Tuan Luu
display_name: Anh Tuan Luu
- affiliations: []
author:
aliases: []
links: []
name: Thomas Laurent
display_name: Thomas Laurent
- affiliations: []
author:
aliases: []
links: []
name: Yoshua Bengio
display_name: Yoshua Bengio
- affiliations: []
author:
aliases: []
links: []
name: Xavier Bresson
display_name: Xavier Bresson
flags:
- valid
id: null
links:
- link: https://jmlr.org/papers/volume24/22-0567/22-0567.pdf
type: pdf.official
- link: https://jmlr.org/papers/v24/22-0567.bib
type: bibtex
- link: https://jmlr.org/papers/v24/22-0567.html
type: abstract.official
- link: https://jmlr.org/papers/v24/22-0567.html
type: uid
releases:
- pages: 1-48
status: published
venue:
aliases: []
date: '2023-01-01'
date_precision: 1
links: []
name: Journal of Machine Learning Research
open: false
peer_reviewed: true
publisher: JMLR
series: JMLR
short_name: null
type: journal
volume: null
title: Benchmarking Graph Neural Networks
topics: []
Loading
Loading