-
Notifications
You must be signed in to change notification settings - Fork 15
Expand file tree
/
Copy pathmodels.py
More file actions
175 lines (143 loc) · 4.77 KB
/
models.py
File metadata and controls
175 lines (143 loc) · 4.77 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
"""
arXiv RAG data models
====================
Data structures for papers, text chunks, and retrieval results.
Components:
- ArxivPaper: paper metadata + raw content
- TextChunk: chunk with section info
- RAGResult: retrieval result (paper + relevant chunks)
"""
from dataclasses import dataclass, field
from typing import List, Optional, TYPE_CHECKING
# Avoid circular imports: only import numpy for type checking
if TYPE_CHECKING:
import numpy as np
@dataclass
class ArxivPaper:
"""
arXiv paper metadata.
Includes:
- Basic metadata (id, title, abstract, authors, year)
- Raw LaTeX content
- Chunked text
- Relevance score
"""
arxiv_id: str
title: str
abstract: str
authors: List[str]
year: int
url: str
categories: List[str] = field(default_factory=list)
# Processed data
latex_content: str = "" # Raw LaTeX content
latex_contents: List[str] = field(
default_factory=list
) # Multiple-file content blobs
chunks: List["TextChunk"] = field(default_factory=list) # Chunk results
relevance_score: float = 0.0 # Relevance to the query
# Download state
source_downloaded: bool = False
def to_dict(self) -> dict:
"""Convert to dict."""
return {
"arxiv_id": self.arxiv_id,
"title": self.title,
"abstract": self.abstract,
"authors": self.authors,
"year": self.year,
"url": self.url,
"categories": self.categories,
"relevance_score": self.relevance_score,
"chunks_count": len(self.chunks),
"source_downloaded": self.source_downloaded,
}
def __repr__(self) -> str:
return f"ArxivPaper({self.arxiv_id}: {self.title[:50]}...)"
@dataclass
class TextChunk:
"""
Text chunk.
Chunking output containing:
- Source paper ID
- Chunk index
- Text content
- Section title (e.g. "Introduction", "Methods")
- Embedding vector (for retrieval)
- Relevance score
"""
paper_id: str
chunk_id: int
text: str
section: str = ""
embedding: Optional["np.ndarray"] = None
relevance_score: float = 0.0
# Metadata
char_count: int = 0
def __post_init__(self):
self.char_count = len(self.text)
def to_dict(self) -> dict:
"""Convert to dict (excluding embedding)."""
return {
"paper_id": self.paper_id,
"chunk_id": self.chunk_id,
"text": self.text,
"section": self.section,
"relevance_score": self.relevance_score,
"char_count": self.char_count,
}
def __repr__(self) -> str:
preview = self.text[:50] + "..." if len(self.text) > 50 else self.text
return f"TextChunk({self.paper_id}#{self.chunk_id} [{self.section}]: {preview})"
@dataclass
class RAGResult:
"""
RAG retrieval result.
Contains:
- Paper metadata
- Relevant chunks (sorted by relevance)
"""
paper: ArxivPaper
relevant_chunks: List[TextChunk]
def to_context(self, max_chunks: int = 3, max_chunk_length: int = 1500) -> str:
"""
Format as LLM context.
Args:
max_chunks: Max chunks
max_chunk_length: Max length per chunk
"""
context_parts = [
f"{'=' * 60}",
f"Paper: {self.paper.title} ({self.paper.year})",
f"arXiv ID: {self.paper.arxiv_id}",
f"Authors: {', '.join(self.paper.authors[:5])}",
f"URL: {self.paper.url}",
f"Relevance: {self.paper.relevance_score:.3f}",
f"{'=' * 60}",
"",
"Abstract:",
self.paper.abstract,
"",
]
# Add relevant excerpts
chunks = self.relevant_chunks[:max_chunks]
if chunks and not (len(chunks) == 1 and chunks[0].section == "Abstract"):
context_parts.append("Relevant Excerpts:")
for i, chunk in enumerate(chunks):
context_parts.append(
f"\n--- [{chunk.section}] (relevance: {chunk.relevance_score:.3f}) ---"
)
chunk_text = chunk.text[:max_chunk_length]
if len(chunk.text) > max_chunk_length:
chunk_text += "..."
context_parts.append(chunk_text)
return "\n".join(context_parts)
def to_dict(self) -> dict:
"""Convert to dict."""
return {
"paper": self.paper.to_dict(),
"relevant_chunks": [c.to_dict() for c in self.relevant_chunks],
}
def __repr__(self) -> str:
return f"RAGResult({self.paper.arxiv_id}, {len(self.relevant_chunks)} chunks)"
__all__ = ["ArxivPaper", "TextChunk", "RAGResult"]