-
Notifications
You must be signed in to change notification settings - Fork 18
Expand file tree
/
Copy pathasta_tools.py
More file actions
797 lines (653 loc) · 28.3 KB
/
asta_tools.py
File metadata and controls
797 lines (653 loc) · 28.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
import asyncio
import concurrent.futures
import functools
import inspect
import itertools
import json
import logging
import os
from contextlib import asynccontextmanager
from datetime import datetime
from typing import Any, Awaitable, Callable, Collection, Literal
import anyio
import httpx
from inspect_ai.tool import ContentText, Tool, ToolDef, ToolError, ToolSource
from mcp.client.streamable_http import streamablehttp_client
try:
# inspect_ai<=0.3.114
from inspect_ai.tool._mcp._mcp import MCPServerImpl
except ImportError: # pragma: no cover - exercised in scorer env on newer Inspect
MCPServerImpl = None
try:
# inspect_ai>=0.3.137
from inspect_ai.tool import mcp_server_http
except ImportError: # pragma: no cover - exercised in root env on older Inspect
mcp_server_http = None
from .search import (
DEFAULT_RESULTS_LIMIT,
METADATA_FIELDS,
_format_publication_before,
datetime_before,
)
logger = logging.getLogger(__name__)
@asynccontextmanager
async def fixed_streamablehttp_client(
url, headers=None, timeout=5, sse_read_timeout=300
):
"""Wrapper around streamablehttp_client that fixes the unpacking issue"""
async with streamablehttp_client(url, headers, timeout, sse_read_timeout) as (
read,
write,
_session_id,
):
# Only yield read and write streams, ignore the session_id
yield read, write
def create_server_streamable_http(
url: str,
headers: dict[str, Any] | None = None,
timeout: float = 5,
sse_read_timeout: float = 60 * 5,
) -> ToolSource:
if MCPServerImpl is not None:
return MCPServerImpl(
lambda: fixed_streamablehttp_client(
url, headers, timeout, sse_read_timeout
),
name=url,
events=True,
)
if mcp_server_http is not None:
# Modern Inspect MCP helpers already handle both 2- and 3-tuple stream
# client results internally (`read, write, *_`), so we can call the
# official helper directly without the legacy wrapper above.
return mcp_server_http(
name=url,
url=url,
headers=headers,
timeout=timeout,
sse_read_timeout=sse_read_timeout,
)
raise RuntimeError("No compatible inspect_ai MCP HTTP tool source API found.")
def make_asta_toolsource(api_key: str | None = None) -> ToolSource:
api_key = api_key if api_key is not None else os.getenv("ASTA_TOOL_KEY")
if not api_key:
raise ValueError("api_key not given and ASTA_TOOL_KEY is not set")
return create_server_streamable_http(
"https://asta-tools.allen.ai/mcp/v1",
headers={"x-api-key": api_key},
)
def _filter_one_paper(
paper: dict[str, Any],
prune_after_date: datetime,
remove_fields: Collection[str] | None = None,
) -> dict[str, Any] | None:
"""Filter a single paper based on the date (remove papers strictly newer
than prune_after_date), and remove specified fields if requested. Returns
the adjusted paper object if it meets the date criterion, or None if it
should be pruned."""
# publicationDate is str but may be null, year is integer and is seemingly never null
if paper.get("publicationDate"):
pub_date = datetime.strptime(paper["publicationDate"], "%Y-%m-%d")
elif paper.get("year"):
# When only year is available, use the last day of that year
# This ensures papers from year Y are included when filtering for dates in year Y+1
pub_date = datetime(paper["year"], 12, 31)
else:
return None
if pub_date > prune_after_date:
return None
# Copy since we're modifying
paper = paper.copy()
if remove_fields:
for field in remove_fields:
paper.pop(field, None)
return paper
def filter_papers(
papers: list[dict[str, Any]],
insertion_date: str,
remove_fields: Collection[str] | None = None,
) -> list[dict[str, Any]]:
"""Filter papers based on the insertion date (only keep older papers).
Note: the actual date returned is the *publication* date, which may be
different than the date the paper was inserted into the corpus, so this
could be slightly inaccurate (but we consider it a reasonable
approximation)."""
filter_date = datetime_before(insertion_date)
filtered_papers: list[dict[str, Any]] = []
for paper in papers:
filtered_paper = _filter_one_paper(paper, filter_date, remove_fields)
if filtered_paper is not None:
filtered_papers.append(filtered_paper)
return filtered_papers
def _filter_response_subfield(
raw_resp: list[ContentText],
insertion_date: str,
added_date_fields: Collection[str],
subfield_name: str,
) -> None:
"""Filter the specified subfield in the response based on the insertion
date. Modifies the response in-place."""
# Note: we make a lot of assumptions about format here, because we expect
# that the api gives a ToolError if anything goes wrong that would cause
# the response to not be a valid paper.
# asta mcp response is a list of ContentText, with one ContentText json for each paper
for resp in raw_resp:
# raises JSONDecodeError
parsed_resp = json.loads(resp.text)
# Some APIs wrap everything in a "data" field (TODO: Should we unpack
# in that case, similar to how the old search tools work?)
dict_resp = parsed_resp
if isinstance(parsed_resp, dict) and "data" in parsed_resp:
dict_resp = parsed_resp["data"]
if isinstance(dict_resp, list):
# list of papers
for paper in dict_resp:
paper[subfield_name] = filter_papers(
paper[subfield_name],
insertion_date,
remove_fields=added_date_fields,
)
elif isinstance(dict_resp, dict):
# single paper
dict_resp[subfield_name] = filter_papers(
dict_resp[subfield_name],
insertion_date,
remove_fields=added_date_fields,
)
else:
raise TypeError(
f"Response does not appear to be papers, got {type(dict_resp)}"
)
resp.text = json.dumps(parsed_resp, indent=2)
def _error_if_recent_papers_toplevel(
raw_resp: list[ContentText],
insertion_date: str,
added_date_fields: Collection[str],
) -> None:
"""Raises a ToolError if the top-level response contains papers that are
newer than the insertion date.
This function is designed specifically for get_paper and get_paper_batch,
which return a list of ContentText where each element contains a single
paper object as JSON. The function will crash if this assumption is not met.
Args:
raw_resp: List of ContentText, each containing a single paper as JSON
insertion_date: The cutoff date string
added_date_fields: Fields that were injected and should be removed
Returns:
No return, but modifies the response in-place
"""
filter_date = datetime_before(insertion_date)
invalid_paper_ids = []
for resp in raw_resp:
paper = json.loads(resp.text)
if not isinstance(paper, dict):
raise TypeError(
f"Expected paper dict in ContentText, got {type(paper)}. "
f"This function only works with get_paper and get_paper_batch."
)
# Use _filter_one_paper to check if paper should be filtered
filtered_paper = _filter_one_paper(
paper, filter_date, remove_fields=added_date_fields
)
if filtered_paper is None:
invalid_paper_ids.append(paper.get("paperId", "unknown"))
resp.text = json.dumps(filtered_paper, indent=2)
# Raise error if any invalid papers were found - in theory for
# get_paper_batch we could return the abridged list instead, but then the
# model wouldn't be getting a paper back for each one it requested, which
# could be confusing; better to raise the error and let it retry without
# the offending ids
if invalid_paper_ids:
if len(invalid_paper_ids) == 1:
raise ToolError(
f"Paper {invalid_paper_ids[0]} is newer than the date cutoff of "
f"{insertion_date} and is not allowed to be requested"
)
else:
raise ToolError(
f"Papers {', '.join(invalid_paper_ids)} are newer than the date cutoff of "
f"{insertion_date} and are not allowed to be requested"
)
def _adjust_fields_arg_for_subfield_filter(
original_fields: str,
subfield_name: str,
) -> tuple[str, set[str]]:
"""Adjusts the fields argument to ensure that the necessary date fields for
filtering the specified subfield are included.
If subfield_name is an empty string, this adds the date fields for the
top-level response instead of a subfield.
Returns:
A tuple containing the adjusted fields string and a set of added date
fields (so that they can be removed from the response later).
"""
subfield_prefix = f"{subfield_name}." if subfield_name != "" else ""
new_fields = original_fields
added_date_fields = set()
if f"{subfield_prefix}year" not in original_fields.split(","):
new_fields += f",{subfield_prefix}year"
added_date_fields.add("year")
if f"{subfield_prefix}publicationDate" not in original_fields.split(","):
new_fields += f",{subfield_prefix}publicationDate"
added_date_fields.add("publicationDate")
return new_fields, added_date_fields
def _wrap_subfield_date_filter(
tool: ToolDef,
insertion_date: str,
subfield_name: Literal["citations", "papers"],
fields_arg_name: str = "fields",
) -> None:
"""Wraps a tool to ensure that the specified subfield in the response is
filtered based on the desired insertion date. Modifies the given ToolDef
in place.
Args:
tool: The ToolDef to modify.
insertion_date: The date cutoff (papers after this will be pruned)
subfield_name: The name of the subfield to filter (e.g., "citations").
Currently only tested for "citations" and "papers", and there
are no other subfields that would contain paper data.
fields_arg_name: The name of the tool argument that specifies fields to
return. Defaults to "fields".
"""
origtool = tool.tool
@functools.wraps(origtool)
async def _req_wrapper(**kwargs):
# if subfield isn't requested we shouldn't do anything (else our
# `subfield.publicationDate` injection would cause the subfield to be
# added to the response).
is_requesting_subfield = False
if fields_arg_name in kwargs:
# Parse comma-separated fields to check if subfield is requested
fields_list = [f.strip() for f in kwargs[fields_arg_name].split(",")]
is_requesting_subfield = any(
field == subfield_name or field.startswith(f"{subfield_name}.")
for field in fields_list
)
else:
# This would be a corner case if we don't error: if fields_arg_name
# is omitted entirely and the default return on the server-side
# includes the requested subfield, we'd fail to filter it. But in
# this codebase we know:
# - we override the default fields; fields_arg_name is always given
# - the server-side defaults never have a paper-containing subfield
raise KeyError(
f"Tools with post-hoc subfield filtering should always set a default '{fields_arg_name}' value. But it's missing for '{tool.name}' which only provided {kwargs.keys()}"
)
if is_requesting_subfield:
kwargs[fields_arg_name], added_date_fields = (
_adjust_fields_arg_for_subfield_filter(
kwargs[fields_arg_name], subfield_name
)
)
raw_resp: list[ContentText] = await origtool(**kwargs)
if is_requesting_subfield:
_filter_response_subfield(
raw_resp,
insertion_date,
added_date_fields=added_date_fields,
subfield_name=subfield_name,
)
return raw_resp
tool.tool = _req_wrapper
def _wrap_toplevel_date_filter(
tool: ToolDef,
insertion_date: str,
fields_arg_name: str = "fields",
) -> None:
"""Wraps get_paper and get_paper_batch tools to check if returned papers
were published after the cutoff date, and raises a ToolError if so.
Args:
tool: The ToolDef to modify.
insertion_date: The date cutoff (papers after this will cause an error)
fields_arg_name: The name of the tool argument that specifies fields to
return. Defaults to "fields".
"""
origtool = tool.tool
@functools.wraps(origtool)
async def _req_wrapper(**kwargs):
# Ensure fields argument is present (should be set by override wrapper)
if fields_arg_name not in kwargs:
raise KeyError(
f"Tools with post-hoc toplevel filtering should always set a default '{fields_arg_name}' value. But it's missing for '{tool.name}' which only provided {kwargs.keys()}"
)
# Inject date fields if not already present
kwargs[fields_arg_name], added_date_fields = (
_adjust_fields_arg_for_subfield_filter(kwargs[fields_arg_name], "")
)
raw_resp: list[ContentText] = await origtool(**kwargs)
# Check dates and raise error if any papers are too recent
_error_if_recent_papers_toplevel(
raw_resp,
insertion_date,
added_date_fields=added_date_fields,
)
return raw_resp
tool.tool = _req_wrapper
def _remove_arg_description_block(
arg_to_remove: str, description: str, arg_list: Collection[str]
) -> str:
"""Remove the description block for a specific argument from the tool's
description. This is useful when we want to remove an argument that has
been overridden and we don't want the LLM to see the original description."""
lines = description.split("\n")
args_block_start_idx = next(
(
i
for i, line in enumerate(lines)
if line.lower().lstrip().startswith("args:")
),
None,
)
if args_block_start_idx is None or len(lines) <= args_block_start_idx + 1:
logger.warning(f"Could not find args block in tool description: {description}")
return description
args_block_start_indent = len(lines[args_block_start_idx]) - len(
lines[args_block_start_idx].lstrip()
)
# Find a line starting with our arg_to_remove
arg_to_remove_start_idx: int | None = None
for i in range(args_block_start_idx + 1, len(lines)):
line = lines[i]
if line.lstrip().startswith(arg_to_remove + ":"):
arg_to_remove_start_idx = i
break
if arg_to_remove_start_idx is None:
return description
# Now find the end of the args block, which is one of:
# - the next non-blank line that is not indented
# - the next line starting with another argument in arg_list
# - the end of the description
arg_line_indent = len(lines[arg_to_remove_start_idx]) - len(
lines[arg_to_remove_start_idx].lstrip()
)
if arg_line_indent < args_block_start_indent:
logger.warning(
f"Argument '{arg_to_remove}' in tool description is not indented correctly. Not removing."
)
return description
end_idx: int | None = None
for i in range(arg_to_remove_start_idx + 1, len(lines)):
line = lines[i]
if line.strip() == "":
continue
line_indent = len(line) - len(line.lstrip())
if line_indent < arg_line_indent:
end_idx = i
break
if any(line.lstrip().startswith(arg + ":") for arg in arg_list):
end_idx = i
break
if end_idx is None:
end_idx = len(lines)
return "\n".join(lines[:arg_to_remove_start_idx] + lines[end_idx:])
def _unravel_exception_group(error: Exception) -> list[Any]:
"""Bubbles up all nested ExceptionGroups into a single list of exceptions.
If a non-ExceptionGroup exception is given, returns a list containing just
that exception."""
if isinstance(error, BaseExceptionGroup):
# Technically `error.exceptions` can be Any so the iterator is a little clunky
return list(
itertools.chain.from_iterable(
(_unravel_exception_group(e) if isinstance(e, Exception) else [e])
for e in error.exceptions
)
)
elif isinstance(error, Exception):
return [error]
else:
raise TypeError(f"Expected Exception or BaseExceptionGroup, got {type(error)}")
def _is_retryable_error(error: Exception) -> bool:
"""Check if an error is retryable (429, 529, or 504 HTTP status codes).
Handles both direct httpx.HTTPStatusError and errors within BaseExceptionGroup,
following the pattern from replay_solver.py.
Args:
error: The exception to check
Returns:
True if the error has a retryable HTTP status code, False otherwise
"""
retryable_codes = {429, 529, 504}
# In practice there is always just one exception, but it might be wrapped
# deeply in a nested group
for base_error in _unravel_exception_group(error):
if isinstance(base_error, httpx.HTTPStatusError) and hasattr(
base_error, "response"
):
return base_error.response.status_code in retryable_codes
elif isinstance(base_error, anyio.BrokenResourceError):
# 429s sometimes manifest as BrokenResourceError due to the
# internal MCP client async task (post_writer) dying from it; it
# doesn't happen consistently for some reason, though.
# Unfortunately this means a BrokenResourceError has a chance to be
# caused by a code that we normally wouldn't want to retry, so we
# have to make a blind assumption here.
return True
return False
def make_override_wrapper(
td: ToolDef,
arg_defaults: dict[str, Any] | None = None,
arg_overrides: dict[str, Any] | None = None,
):
if not arg_defaults:
arg_defaults = {}
if not arg_overrides:
arg_overrides = {}
# arg_defaults just sets defaults (that the caller can override)
for arg_name, default_value in arg_defaults.items():
td.parameters.properties[arg_name].default = default_value
if arg_name in td.parameters.required:
td.parameters.required.remove(arg_name)
# arg_overrides sets the value and removes the parameter so the caller can't use it
for arg_name, forced_value in arg_overrides.items():
td.description = _remove_arg_description_block(
arg_name,
td.description,
td.parameters.properties.keys(),
)
del td.parameters.properties[arg_name]
if arg_name in td.parameters.required:
td.parameters.required.remove(arg_name)
# Make sure we didn't break anything
td.parameters.model_validate(td.parameters)
origtool = td.tool
@functools.wraps(origtool)
async def wrapper(*args, **kwargs):
for arg_name, default_value in arg_defaults.items():
if kwargs.get(arg_name, inspect.Parameter.empty) is inspect.Parameter.empty:
kwargs[arg_name] = default_value
# Overwrite for forced values
kwargs.update(arg_overrides)
return await origtool(*args, **kwargs)
# Remove forced values from signature so that the LLM doesn't see them
original_signature = inspect.signature(origtool)
wrapper.__signature__ = original_signature.replace( # type: ignore[attr-defined]
parameters=[
p.replace(default=arg_defaults[p.name]) if p.name in arg_defaults else p
for p in original_signature.parameters.values()
if p.name not in arg_overrides
]
)
return wrapper
def make_retry_wrapper(
td: ToolDef,
max_retries: int = 10,
backoff_multiplier: float = 2.0,
base_delay: float = 1.0,
max_delay: float = 60.0,
) -> Callable[..., Awaitable[Any]]:
"""Wraps a tool to retry on retryable HTTP errors (429, 529, 504) with exponential backoff.
Args:
td: The ToolDef to modify
max_retries: Maximum number of retry attempts (default: 10)
backoff_multiplier: Multiplier for exponential backoff (default: 2.0)
base_delay: Base delay in seconds (default: 1.0)
max_delay: Maximum delay in seconds (default: 60.0)
"""
origtool = td.tool
@functools.wraps(origtool)
async def wrapper(*args, **kwargs):
for attempt in range(max_retries + 1): # +1 for initial attempt
try:
return await origtool(*args, **kwargs)
except Exception as e:
err_shortstring = "+".join(
[str(x) for x in _unravel_exception_group(e)]
)
if any(
isinstance(err, ToolError) for err in _unravel_exception_group(e)
):
# ToolErrors are not retryable and usually mean the LLM did
# a bad request. These will be shown to the LLM and show up
# in Inspect logs, so we needn't spam the console logs with
# the error.
raise
if not _is_retryable_error(e):
logger.exception(
f"Failed with non-retryable error {err_shortstring}"
)
raise
if attempt >= max_retries:
msg = f"Tool {td.name} failed after {max_retries} retries with retryable error: {err_shortstring}"
logger.warning(msg)
raise ToolError(msg) from e
delay = min(max_delay, base_delay * (backoff_multiplier**attempt))
logger.warning(
f"Tool {td.name} attempt {attempt + 1} failed with retryable error: {err_shortstring}. "
f"Retrying in {delay:.1f} seconds..."
)
await asyncio.sleep(delay)
return wrapper
def _wrap_title_match_not_found_as_empty(tool: ToolDef) -> None:
"""Convert title-search misses into an empty result."""
origtool = tool.tool
@functools.wraps(origtool)
async def _req_wrapper(*args, **kwargs):
try:
return await origtool(*args, **kwargs)
except Exception as exc:
flattened_errors = _unravel_exception_group(exc)
tool_errors = [
err for err in flattened_errors if isinstance(err, ToolError)
]
if any(
"title match not found" in err.message.lower() for err in tool_errors
):
return [ContentText(type="text", text=json.dumps({"data": []}))]
raise
tool.tool = _req_wrapper
async def async_make_asta_mcp_tools(
api_key: str | None = None, insertion_date: str | None = None
) -> list[Tool]:
"""Asta MCP tools, with some arg defaults and descriptions overridden to
better fit the Asta ecosystem where
- we already have a snippet_search
- we want to default to having some informative fields in paper searches,
rather than relying on models to specify all the fields every time."""
PAPER_FIELDS = METADATA_FIELDS
AUTHOR_FIELDS = "authorId,name,paperCount"
# We already have separate paper_search and snippet_search tools
# exclude_tools = {"snippet_search", "search_papers_by_relevance"}
exclude_tools: set[str] = set()
# Remaining tools should be included; we define this list just to help
# confirm that there are no surprises with unknown tools
include_tools = {
"snippet_search",
"search_papers_by_relevance",
"get_paper",
"get_paper_batch",
"get_citations",
"search_authors_by_name",
"get_author_papers",
"search_paper_by_title",
}
asta_tools = await make_asta_toolsource(api_key).tools()
found_tools = set()
new_tools = []
for t in asta_tools:
td = ToolDef(t)
if td.name in exclude_tools:
continue
if td.name not in include_tools:
raise ValueError(f"Unexpected tool {td.name} found in Asta MCP tools")
found_tools.add(td.name)
# Retry if we get rate-limited by Asta MCP
td.tool = make_retry_wrapper(td)
combined_arg_defaults: dict[str, Any] = {}
combined_arg_overrides: dict[str, Any] = {}
paper_fields_arg_name: str | None = None
if "fields" in td.parameters.properties:
paper_fields_arg_name = "fields"
elif "paper_fields" in td.parameters.properties:
paper_fields_arg_name = "paper_fields"
# Ordering invariant: the combined override wrapper must remain
# outermost relative to the date/subfield filters, since those filters
# inspect the effective kwargs (for example `fields`) and assume any
# benchmark-specified defaults/overrides have already been applied.
# So, the date/subfield filters are installed first and remain
# innermost at runtime.
if td.name == "search_authors_by_name":
if insertion_date is not None:
# Filter papers subfield; note: there is no `papers.citations` so
# we don't need to worry about nested filtering
_wrap_subfield_date_filter(
td,
insertion_date=insertion_date,
subfield_name="papers",
fields_arg_name="fields",
)
combined_arg_defaults["fields"] = AUTHOR_FIELDS
paper_fields_arg_name = None
elif td.name in ["get_paper", "get_paper_batch"]:
# For get_paper and get_paper_batch apply both top-level filtering
# in case llm requests a paper newer than the cutoff
if insertion_date is not None:
_wrap_toplevel_date_filter(
td,
insertion_date=insertion_date,
fields_arg_name="fields",
)
if paper_fields_arg_name is not None:
if insertion_date is not None:
# Everything with a paper-field specifier also allows
# a citations subfield that we need to filter
_wrap_subfield_date_filter(
td,
insertion_date=insertion_date,
subfield_name="citations",
fields_arg_name=paper_fields_arg_name,
)
combined_arg_defaults[paper_fields_arg_name] = PAPER_FIELDS
if "limit" in td.parameters.properties:
combined_arg_defaults["limit"] = DEFAULT_RESULTS_LIMIT
# Add publication date filtering for tools that support it
if insertion_date is not None:
if "inserted_before" in td.parameters.properties:
combined_arg_overrides["inserted_before"] = insertion_date
elif "publication_date_range" in td.parameters.properties:
combined_arg_overrides["publication_date_range"] = (
_format_publication_before(insertion_date)
)
if combined_arg_defaults or combined_arg_overrides:
td.tool = make_override_wrapper(
td,
arg_defaults=combined_arg_defaults,
arg_overrides=combined_arg_overrides,
)
if td.name == "search_paper_by_title":
_wrap_title_match_not_found_as_empty(td)
new_tools.append(td.as_tool())
if len(include_tools - found_tools) > 0:
raise ValueError(
f"Missing expected tools in Asta MCP: {include_tools - found_tools}"
)
return new_tools
def make_asta_mcp_tools(insertion_date: str | None = None) -> list[Tool]:
coro = async_make_asta_mcp_tools(insertion_date=insertion_date)
# Run in separate thread since it's async
fut = concurrent.futures.ThreadPoolExecutor().submit(asyncio.run, coro)
return fut.result()
async def main():
tools = await make_asta_toolsource().tools()
print(
f"Successfully connected to Semantic Scholar API and found {len(tools)} tools."
)
if __name__ == "__main__":
asyncio.run(main())