Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions config/basic.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,14 @@ paperoni:
collection:
$class: paperoni.collection.filecoll:FileCollection
file: ${paperoni.data_path}/collection.json
embedding:
$class: paperoni.embed.cfg:Embedding
model: gemini-embedding-001
api_key: ${paperoni.api_keys.gemini}
server:
host: localhost
port: 8000
protocol: http
max_results: 10000
process_pool_executor:
max_workers: 4
Expand All @@ -83,3 +90,7 @@ paperoni:
search: []
validate: [search]
dev: []
mcp:
api_client:
$class: paperoni.mcp.client:PaperoniAPIClient
endpoint: ${paperoni.server.protocol}://${paperoni.server.host}:${paperoni.server.port}
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ dependencies = [
"jinja2>=3.1.6",
"easy-oauth",
"markdown>=3.10",
"fastmcp>=2.14.2",
]

[project.urls]
Expand Down
45 changes: 42 additions & 3 deletions src/paperoni/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from .config import config
from .dash import History
from .display import display, print_field, terminal_width
from .embed.embeddings import PaperEmbedding
from .fulltext.locate import URL, locate_all
from .fulltext.pdf import PDF, CachePolicies, get_pdf
from .heuristics import simplify_paper
Expand Down Expand Up @@ -617,6 +618,12 @@ class Search:
# [alias: -f]
flags: set[str] = None

# Semantic search query
query: str = None

# Similarity threshold
similarity_threshold: float = 0.75

# Whether to expand links
expand_links: bool = False

Expand Down Expand Up @@ -648,8 +655,19 @@ def run(self, coll: "Coll") -> list[Paper]:
exclude_flags={f for f in flags if f.startswith("~")},
)
]

if self.query:
papers, similarities = zip(
*PaperEmbedding.semantic_search(
papers, self.query, self.similarity_threshold
)
)
else:
similarities = None

self.format(papers)
return papers

return papers, similarities

@dataclass
class Import:
Expand Down Expand Up @@ -851,7 +869,7 @@ class Login:
"""Retrieve an access token from the paperoni server."""

# Endpoint to login to
endpoint: str = "http://localhost:8000"
endpoint: str = None

# Whether to use headless mode
headless: bool = False
Expand All @@ -860,8 +878,29 @@ def run(self):
print_field("Access token", login(self.endpoint, self.headless))


@dataclass
class MCP:
"""MCP server for paperoni."""

# Paperoni API
endpoint: str = None

transport: Literal["stdio", "http"] = "stdio"
host: str = "localhost"
port: int = 9000

def run(self):
from .mcp.server import create_mcp

mcp = create_mcp(self.endpoint)
if self.transport == "stdio":
mcp.run(transport="stdio")
elif self.transport == "http":
mcp.run(transport="http", host=self.host, port=self.port)


PaperoniCommand = TaggedUnion[
Discover, Refine, Fulltext, Work, Coll, Batch, Focus, Serve, Login
Discover, Refine, Fulltext, Work, Coll, Batch, Focus, Serve, Login, MCP
]


Expand Down
11 changes: 11 additions & 0 deletions src/paperoni/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
from serieux.features.encrypt import Secret

from .collection.abc import PaperCollection
from .embed.cfg import Embedding
from .get import Fetcher, RequestsFetcher
from .mcp.client import PaperoniAPIClient
from .model.focus import AutoFocus, Focuses
from .prompt import GenAIPrompt, Prompt

Expand Down Expand Up @@ -52,6 +54,13 @@ def __post_init__(self):
self.process_pool = ProcessPoolExecutor(**self.process_pool_executor)


@dataclass(kw_only=True)
class MCP:
api_client: TaggedSubclass[PaperoniAPIClient] = field(
default_factory=PaperoniAPIClient
)


@dataclass
class PaperoniConfig:
cache_path: Path = None
Expand All @@ -66,7 +75,9 @@ class PaperoniConfig:
work_file: Path = None
collection: TaggedSubclass[PaperCollection] = None
reporters: list[TaggedSubclass[Reporter]] = field(default_factory=list)
embedding: TaggedSubclass[Embedding] = field(default_factory=Embedding)
server: Server = field(default_factory=Server)
mcp: MCP = field(default_factory=MCP)

def __post_init__(self):
self.metadata: Meta[Path | list[Path] | Meta | Any] = Meta()
Expand Down
75 changes: 75 additions & 0 deletions src/paperoni/embed/cfg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
"""Embedding service for semantic search using Google GenAI."""

import json
from dataclasses import dataclass
from pathlib import Path
from typing import BinaryIO

from google import genai
from google.genai.types import EmbedContentResponse
from paperazzi.utils import _make_key as paperazzi_make_key, disk_cache, disk_store
from serieux.features.encrypt import Secret


@dataclass
class EmbedContentResponseSerializer:
@staticmethod
def dump(response: EmbedContentResponse, file_obj: BinaryIO):
model_dump = response.model_dump()
return file_obj.write(
json.dumps(model_dump, indent=2, ensure_ascii=False).encode("utf-8")
)

@staticmethod
def load(file_obj: BinaryIO) -> EmbedContentResponse:
data = json.load(file_obj)
return EmbedContentResponse.model_validate(data)


@dataclass
class Embedding:
"""Service for generating and caching embeddings using Google GenAI."""

client: genai.Client = None
api_key: Secret[str] = None
model: str = None

def __post_init__(self):
if self.client is None:
self.client = genai.Client(api_key=self.api_key or None)

def embed(
self, contents: list[str], cache_dir: Path = None, **kwargs
) -> list[EmbedContentResponse]:
"""Embed content."""
if cache_dir is not None:
embed = self._embed.update(cache_dir=cache_dir)
else:
embed = self._embed

return list(
map(
lambda c: embed(self.client, model=self.model, content=c, **kwargs),
contents,
)
)

@staticmethod
def _make_key(_: tuple, kwargs: dict) -> str:
kwargs = kwargs.copy()
kwargs.pop("client", None)

return paperazzi_make_key(None, kwargs)

@disk_store
@disk_cache(
cache_dir=Path("/tmp/embeddings"),
serializer=EmbedContentResponseSerializer,
make_key=_make_key,
)
@staticmethod
def _embed(
client: genai.Client, *, model: str, content: str, **kwargs
) -> EmbedContentResponse:
"""Get embedding for text."""
return client.models.embed_content(model=model, contents=[content], **kwargs)
62 changes: 62 additions & 0 deletions src/paperoni/embed/embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from typing import Iterable

import numpy as np
from google import genai
from google.genai.types import EmbedContentResponse

from ..config import config
from ..model.classes import Paper


class PaperEmbedding:
@staticmethod
def semantic_search(
papers: Iterable[Paper], query: str, similarity_threshold: float = 0.75
) -> list[tuple[Paper, float]]:
"""Semantic search for papers."""
responses: Iterable[EmbedContentResponse] = map(
PaperEmbedding.get_paper_embedding, papers
)
store = np.array([r.embeddings[0].values for r in responses])
query_embedding = np.array(
config.embedding.embed(
[query],
cache_dir=None,
config=genai.types.EmbedContentConfig(task_type="SEMANTIC_SIMILARITY"),
)[0]
.embeddings[0]
.values
)

# Calculate Cosine Similarity
# Formula: (A dot B) / (||A|| * ||B||)
dot_products = np.dot(store, query_embedding)
norms = np.linalg.norm(store, axis=1) * np.linalg.norm(query_embedding)
similarities = dot_products / norms

# Return sorted results
sorted_indices = np.argsort(similarities)[::-1]
return [
(papers[i], similarities[i])
for i in sorted_indices
if similarities[i] >= similarity_threshold
]

@staticmethod
def get_paper_embedding(paper: Paper) -> list[float] | None:
"""Get embedding for a paper (title + abstract + topics)."""
parts = [paper.title]

if paper.abstract:
parts.append(paper.abstract)

if paper.topics:
topic_names = ", ".join(sorted(t.name.lower() for t in paper.topics))
parts.append(f"Topics: {topic_names}")
content = "\n\n".join(parts)

return config.embedding.embed(
[content],
cache_dir=config.data_path / "embeddings",
config=genai.types.EmbedContentConfig(task_type="SEMANTIC_SIMILARITY"),
)[0]
1 change: 1 addition & 0 deletions src/paperoni/mcp/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""MCP server for paperoni."""
Loading
Loading