-
Notifications
You must be signed in to change notification settings - Fork 18
Expand file tree
/
Copy pathsearch.py
More file actions
259 lines (219 loc) · 8.39 KB
/
search.py
File metadata and controls
259 lines (219 loc) · 8.39 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
"""Lit search tools that use the Semantic Scholar API.
Based on https://github.com/allenai/nora-corpusqa-demo/blob/f0be05b6fc24a2c600e3f3c66852c782e22f9a1d/api/tool/utils.py
"""
import logging
import os
from datetime import datetime, timedelta
from typing import Any, Literal
import httpx
from httpx import HTTPStatusError, Timeout
from inspect_ai.tool import ToolError
logger = logging.getLogger(__name__)
ASTA_TOOL_KEY = os.getenv("ASTA_TOOL_KEY", "")
S2_API_BASE_URL = "https://api.semanticscholar.org/graph/v1/"
METADATA_FIELDS = "title,abstract,corpusId,authors,year,venue,citationCount,referenceCount,influentialCitationCount"
DEFAULT_RESULTS_LIMIT = 20
def make_paper_search(inserted_before: str | None = None):
"""
Return a paper_search function bound to an optional inserted_before.
Args:
inserted_before: Optional cutoff date to include only results inserted before this date.
Accepts formats YYYY, YYYY-MM, or YYYY-MM-DD.
"""
async def paper_search(
kquery: str, limit: int = DEFAULT_RESULTS_LIMIT
) -> list[dict[str, Any]]:
"""
Search for papers using the Semantic Scholar API.
Args:
kquery: The search query.
limit: The maximum number of results to return.
Returns:
A list of dictionaries containing paper metadata.
"""
return await _paper_search_impl(kquery, limit, inserted_before)
return paper_search
def make_snippet_search(inserted_before: str | None = None):
"""
Return a snippet_search function bound to an optional inserted_before.
Args:
inserted_before: Optional cutoff date to include only results inserted before this date.
Accepts formats YYYY, YYYY-MM, or YYYY-MM-DD.
"""
async def snippet_search(
query: str,
limit: int = DEFAULT_RESULTS_LIMIT,
corpus_ids: list[str | int] | None = None,
) -> list[dict[str, Any]]:
"""
Search for snippets using the Semantic Scholar API.
Args:
query: The search query.
limit: The maximum number of results to return.
corpus_ids: List of Semantic Scholar corpus IDs to restrict results.
Returns:
A list of dictionaries containing snippet metadata.
"""
return await _snippet_search_impl(query, limit, corpus_ids, inserted_before)
return snippet_search
async def _query_s2_api(
end_pt: str = "paper/batch",
params: dict[str, Any] | None = None,
payload: dict[str, Any] | None = None,
method: Literal["get", "post"] = "get",
) -> dict[str, Any]:
"""
Query the Semantic Scholar API.
Args:
end_pt: The endpoint to query.
params: The query parameters.
payload: The payload for POST requests.
method: The HTTP method to use.
Returns:
The JSON response from the API.
"""
url = S2_API_BASE_URL + end_pt
request = httpx.Request(
method=method.upper(),
url=url,
headers={"x-api-key": ASTA_TOOL_KEY},
params=params,
json=payload,
)
async with httpx.AsyncClient(
# 10s->20s fixed observed test case timeout issues
timeout=Timeout(20.0, connect=5.0),
) as http_client:
response = await http_client.send(request)
try:
response.raise_for_status()
except HTTPStatusError:
logger.error(
f"S2 API request to end point {end_pt} failed with status code {response.status_code}"
)
if response.status_code == 400:
raise ToolError(
f"API request failed with status code {response.status_code}: {response.text}"
)
raise
return response.json()
def datetime_before(date_str: str) -> datetime:
"""Given a partial date string, return a datetime object representing the
last second before that date.
For example:
2025 -> 2024-12-31 23:59:59
2025-01 -> 2024-12-31 23:59:59
2025-01-01 -> 2024-12-31 23:59:59
"""
fmt_map = {4: "%Y", 7: "%Y-%m", 10: "%Y-%m-%d"}
fmt = fmt_map.get(len(date_str.strip()))
if not fmt:
raise ValueError(
f"date_str format must be in {fmt_map.values()}, got: {date_str}"
)
return datetime.strptime(date_str, fmt) - timedelta(seconds=1)
def _format_publication_before(inserted_before: str) -> str:
return datetime_before(inserted_before).strftime(":%Y-%m-%d")
async def _paper_search_impl(
kquery: str,
limit: int,
inserted_before: str | None = None,
) -> list[dict[str, Any]]:
paper_data = []
query_params = {}
# query_params = {fkey: fval for fkey, fval in filter_kwargs.items() if fval}
query_params.update({"query": kquery, "limit": limit, "fields": METADATA_FIELDS})
if inserted_before:
# The 'publicationDateOrYear' filter looks at when a paper was published.
# 'inserted_before' cares about when a paper was added to the index that we're searching.
# The /paper/search endpoint does not support filtering by the latter, so we use the former as an approximation.
query_params["publicationDateOrYear"] = _format_publication_before(
inserted_before
)
res = await _query_s2_api(
end_pt="paper/search",
params=query_params,
method="get",
)
if "data" in res:
paper_data = res["data"]
paper_data = [
pd
for pd in paper_data
if pd.get("corpusId") and pd.get("title") and pd.get("abstract")
]
for pd in paper_data:
pd["corpus_id"] = str(pd["corpusId"])
pd["text"] = pd["abstract"]
pd["section_title"] = "abstract"
pd["sentences"] = []
pd["ref_mentions"] = []
pd["score"] = 0.0
pd["stype"] = "public_api"
return paper_data
async def _snippet_search_impl(
query: str,
limit: int,
corpus_ids: list[str | int] | None = None,
inserted_before: str | None = None,
) -> list[dict[str, Any]]:
params = {
"query": query,
"limit": limit,
}
if corpus_ids:
if any(cid is None or cid == "None" for cid in corpus_ids):
raise ValueError(
"corpus_ids contains None or 'None', which is not allowed."
)
formatted_corpus_ids = [f"CorpusId:{cid}" for cid in corpus_ids]
params["paperIds"] = ",".join(formatted_corpus_ids)
if inserted_before:
params["insertedBefore"] = inserted_before
snippets = await _query_s2_api(
end_pt="snippet/search",
params=params,
method="get",
)
snippets_list = []
res_data = snippets["data"]
if res_data:
for fields in res_data:
res_map = dict()
snippet, paper = fields["snippet"], fields["paper"]
res_map["corpus_id"] = paper["corpusId"]
res_map["title"] = paper["title"]
res_map["text"] = snippet["text"]
res_map["score"] = fields["score"]
res_map["section_title"] = (
snippet["snippetKind"]
if snippet["snippetKind"] != "body"
else fields.get("section", "body")
)
if "snippetOffset" in snippet and snippet["snippetOffset"].get("start"):
res_map["char_start_offset"] = snippet["snippetOffset"]["start"]
else:
res_map["char_start_offset"] = 0
if (
"annotations" in snippet
and "sentences" in snippet["annotations"]
and snippet["annotations"]["sentences"]
):
res_map["sentence_offsets"] = snippet["annotations"]["sentences"]
else:
res_map["sentence_offsets"] = []
if snippet.get("annotations") and snippet["annotations"].get("refMentions"):
res_map["ref_mentions"] = [
rmen
for rmen in snippet["annotations"]["refMentions"]
if rmen.get("matchedPaperCorpusId")
and rmen.get("start")
and rmen.get("end")
]
else:
res_map["ref_mentions"] = []
res_map["stype"] = "vespa"
if res_map:
snippets_list.append(res_map)
return snippets_list
__all__ = ["make_paper_search", "make_snippet_search"]