Skip to content

Commit 1289985

Browse files
committed
Add MCP server
1 parent 9eabda1 commit 1289985

File tree

8 files changed

+1097
-68
lines changed

8 files changed

+1097
-68
lines changed

config/basic.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,3 +91,7 @@ paperoni:
9191
search: []
9292
validate: [search]
9393
dev: []
94+
mcp:
95+
api_client:
96+
$class: paperoni.mcp.client:PaperoniAPIClient
97+
endpoint: ${paperoni.server.protocol}://${paperoni.server.host}:${paperoni.server.port}

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ dependencies = [
3939
"filelock>=3.20.0",
4040
"jinja2>=3.1.6",
4141
"easy-oauth",
42+
"fastmcp>=2.14.2",
4243
]
4344

4445
[project.urls]

src/paperoni/__main__.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -871,7 +871,7 @@ class Login:
871871
"""Retrieve an access token from the paperoni server."""
872872

873873
# Endpoint to login to
874-
endpoint: str = "http://localhost:8000"
874+
endpoint: str = None
875875

876876
# Whether to use headless mode
877877
headless: bool = False
@@ -880,8 +880,29 @@ def run(self):
880880
print_field("Access token", login(self.endpoint, self.headless))
881881

882882

883+
@dataclass
884+
class MCP:
885+
"""MCP server for paperoni."""
886+
887+
# Paperoni API
888+
endpoint: str = None
889+
890+
transport: Literal["stdio", "http"] = "stdio"
891+
host: str = "localhost"
892+
port: int = 9000
893+
894+
def run(self):
895+
from .mcp.server import create_mcp
896+
897+
mcp = create_mcp(self.endpoint)
898+
if self.transport == "stdio":
899+
mcp.run(transport="stdio")
900+
elif self.transport == "http":
901+
mcp.run(transport="http", host=self.host, port=self.port)
902+
903+
883904
PaperoniCommand = TaggedUnion[
884-
Discover, Refine, Fulltext, Work, Coll, Batch, Focus, Serve, Login
905+
Discover, Refine, Fulltext, Work, Coll, Batch, Focus, Serve, Login, MCP
885906
]
886907

887908

src/paperoni/config.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from .collection.abc import PaperCollection
1414
from .embed.cfg import Embedding
1515
from .get import Fetcher, RequestsFetcher
16+
from .mcp.client import PaperoniAPIClient
1617
from .model.focus import AutoFocus, Focuses
1718
from .prompt import GenAIPrompt, Prompt
1819

@@ -52,6 +53,13 @@ def __post_init__(self):
5253
self.process_pool = ProcessPoolExecutor(**self.process_pool_executor)
5354

5455

56+
@dataclass(kw_only=True)
57+
class MCP:
58+
api_client: TaggedSubclass[PaperoniAPIClient] = field(
59+
default_factory=PaperoniAPIClient
60+
)
61+
62+
5563
@dataclass
5664
class PaperoniConfig:
5765
cache_path: Path = None
@@ -67,6 +75,7 @@ class PaperoniConfig:
6775
reporters: list[TaggedSubclass[Reporter]] = field(default_factory=list)
6876
embedding: TaggedSubclass[Embedding] = field(default_factory=Embedding)
6977
server: Server = field(default_factory=Server)
78+
mcp: MCP = field(default_factory=MCP)
7079

7180
def __post_init__(self):
7281
self.metadata: Meta[Path | list[Path] | Meta | Any] = Meta()

src/paperoni/mcp/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""MCP server for paperoni."""

src/paperoni/mcp/client.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
"""HTTP client for paperoni REST API."""
2+
3+
import os
4+
from dataclasses import dataclass, field
5+
from datetime import date
6+
from typing import Any
7+
8+
from fastapi import HTTPException
9+
from serieux.features.encrypt import Secret
10+
11+
from ..get import Fetcher, RequestsFetcher
12+
13+
14+
@dataclass
15+
class PaperoniAPIClient:
16+
"""Client for interacting with paperoni REST API."""
17+
18+
endpoint: str
19+
token: Secret[str] = os.getenv("PAPERONI_TOKEN")
20+
fetch: Fetcher = field(default_factory=RequestsFetcher)
21+
22+
def __post_init__(self):
23+
self.headers = {
24+
"Accept": "application/json",
25+
}
26+
if self.token is not None:
27+
self.headers["Authorization"] = f"Bearer {self.token}"
28+
29+
def search_papers(
30+
self,
31+
paper_id: int = None,
32+
title: str = None,
33+
institution: str = None,
34+
author: str = None,
35+
venue: str = None,
36+
start_date: date = None,
37+
end_date: date = None,
38+
include_flags: list[str] = None,
39+
exclude_flags: list[str] = None,
40+
query: str = None,
41+
similarity_threshold: float = 0.75,
42+
offset: int = 0,
43+
limit: int = 100,
44+
) -> dict[str, Any]:
45+
params: dict[str, Any] = {
46+
"offset": offset,
47+
"limit": limit,
48+
}
49+
50+
if paper_id:
51+
params["paper_id"] = paper_id
52+
if title:
53+
params["title"] = title
54+
if institution:
55+
params["institution"] = institution
56+
if author:
57+
params["author"] = author
58+
if venue:
59+
params["venue"] = venue
60+
if start_date:
61+
params["start_date"] = start_date.isoformat()
62+
if end_date:
63+
params["end_date"] = end_date.isoformat()
64+
if include_flags:
65+
params["include_flags"] = ",".join(include_flags)
66+
if exclude_flags:
67+
params["exclude_flags"] = ",".join(exclude_flags)
68+
if query:
69+
params["query"] = query
70+
if similarity_threshold:
71+
params["similarity_threshold"] = similarity_threshold
72+
73+
url = f"{self.endpoint}/api/v1/search"
74+
75+
try:
76+
resp = self.fetch.read(
77+
url,
78+
format="json",
79+
cache_into=None,
80+
headers=self.headers,
81+
params=params,
82+
)
83+
resp.pop("total", None)
84+
return resp
85+
86+
except Exception as e:
87+
if isinstance(e, HTTPException) and e.status_code == 404:
88+
return {
89+
"results": [],
90+
"similarities": None,
91+
"count": 0,
92+
"next_offset": None,
93+
}
94+
raise
95+
96+
def count_papers(
97+
self,
98+
paper_id: int = None,
99+
title: str = None,
100+
institution: str = None,
101+
author: str = None,
102+
venue: str = None,
103+
start_date: date = None,
104+
end_date: date = None,
105+
include_flags: list[str] = None,
106+
exclude_flags: list[str] = None,
107+
) -> int:
108+
# Fetch first page to get total count
109+
resp = self.search_papers(
110+
paper_id=paper_id,
111+
title=title,
112+
institution=institution,
113+
author=author,
114+
venue=venue,
115+
start_date=start_date,
116+
end_date=end_date,
117+
include_flags=include_flags,
118+
exclude_flags=exclude_flags,
119+
offset=0,
120+
limit=1, # Only need the total count
121+
)
122+
return resp.get("total", 0)

src/paperoni/mcp/server.py

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
"""MCP server implementation using fastmcp."""
2+
3+
from datetime import date
4+
5+
from fastmcp import FastMCP
6+
7+
from ..config import config
8+
from .client import PaperoniAPIClient
9+
10+
11+
def create_mcp(endpoint: str = None):
12+
api_client = PaperoniAPIClient(
13+
endpoint or config.mcp.api_client.endpoint,
14+
token=config.mcp.api_client.token,
15+
fetch=config.mcp.api_client.fetch,
16+
)
17+
18+
mcp = FastMCP(name="Paperoni")
19+
20+
@mcp.tool
21+
def search_papers(
22+
paper_id: int = None,
23+
title: str = None,
24+
institution: str = None,
25+
author: str = None,
26+
venue: str = None,
27+
start_date: date = None,
28+
end_date: date = None,
29+
include_flags: list[str] = None,
30+
exclude_flags: list[str] = None,
31+
query: str = None,
32+
similarity_threshold: float = 0.75,
33+
offset: int = 0,
34+
limit: int = 100,
35+
):
36+
"""Search for papers in the paperoni collection.
37+
38+
This tool allows searching for papers using various filters including
39+
semantic search, institution, venue, author, and date ranges. Results
40+
include paper metadata (title, abstract, authors, venues, topics) but
41+
exclude PDF content.
42+
43+
If a query is provided, the results will be sorted by similarity score.
44+
Note that unrelated papers may still be returned with a low similarity
45+
score.
46+
47+
Args:
48+
paper_id: Paper ID
49+
title: Title of the paper
50+
institution: Institution of an author
51+
author: Author of the paper
52+
venue: Venue name (long or short)
53+
start_date: Start date to consider
54+
end_date: End date to consider
55+
include_flags: Flags that must be present
56+
exclude_flags: Flags that must not be present
57+
query: Semantic search query
58+
similarity_threshold: Similarity threshold (default: 0.75)
59+
offset: Pagination offset (default: 0)
60+
limit: Maximum number of results to return (default: 100)
61+
62+
Returns:
63+
Dictionary containing:
64+
- results: List of paper objects with metadata
65+
- similarities: List of similarity scores (None if no query was provided)
66+
- count: Number of results in this page
67+
- next_offset: Offset for next page (None if no more pages)
68+
"""
69+
70+
return api_client.search_papers(
71+
paper_id=paper_id,
72+
title=title,
73+
institution=institution,
74+
author=author,
75+
venue=venue,
76+
start_date=start_date,
77+
end_date=end_date,
78+
include_flags=include_flags,
79+
exclude_flags=exclude_flags,
80+
query=query,
81+
similarity_threshold=similarity_threshold,
82+
offset=offset,
83+
limit=limit,
84+
)
85+
86+
@mcp.tool
87+
def count_papers(
88+
paper_id: int = None,
89+
title: str = None,
90+
institution: str = None,
91+
author: str = None,
92+
venue: str = None,
93+
start_date: date = None,
94+
end_date: date = None,
95+
include_flags: list[str] = None,
96+
exclude_flags: list[str] = None,
97+
) -> int:
98+
"""Count papers matching criteria without fetching all results.
99+
100+
Args:
101+
paper_id: Paper ID
102+
title: Title of the paper
103+
institution: Institution of an author
104+
author: Author of the paper
105+
venue: Venue name (long or short)
106+
start_date: Start date to consider
107+
end_date: End date to consider
108+
include_flags: Flags that must be present
109+
exclude_flags: Flags that must not be present
110+
111+
Returns:
112+
Total count of matching papers
113+
"""
114+
115+
return api_client.count_papers(
116+
paper_id=paper_id,
117+
title=title,
118+
institution=institution,
119+
author=author,
120+
venue=venue,
121+
start_date=start_date,
122+
end_date=end_date,
123+
include_flags=include_flags,
124+
exclude_flags=exclude_flags,
125+
)
126+
127+
return mcp

0 commit comments

Comments
 (0)