-
Notifications
You must be signed in to change notification settings - Fork 15
Expand file tree
/
Copy pathstate.py
More file actions
81 lines (61 loc) · 2.91 KB
/
state.py
File metadata and controls
81 lines (61 loc) · 2.91 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import copy
from logging import getLogger
from typing import Any, Callable
from inspect_ai.agent import as_solver, bridge
from inspect_ai.solver import Generate, Solver, TaskState, solver
from inspect_ai.tool import Tool, ToolDef
logger = getLogger(__name__)
@solver
def merge_tools_with_state(
tools: list[Tool],
prefer_given_tools: bool = False,
select_fn: Callable[[ToolDef], bool] | None = None,
) -> Solver:
"""Merge the given tools into the task state, resolving conflicts when
tools have the same name. If `prefer_given_tools`, the conflicts will
be resolved preferring the tools given. Otherwise, the default is to
prefer the tools already in the task state (in case it comes with e.g.
date restriction).
If `select_fn` is provided, it will be used to filter the tools (from both
the given tools and the state tools). If it returns `True`, the tool will
be included in the final state, but will be excluded if it returns `False`.
"""
select_fn = select_fn or (lambda td: True)
async def solve(state: TaskState, generate: Generate) -> TaskState:
# Default is to prefer the tools already in the state; we swap the
# lists if prefer_given_tools is set
preferred_tools = [ToolDef(t) for t in state.tools]
other_tools = [ToolDef(t) for t in tools]
if prefer_given_tools:
preferred_tools, other_tools = other_tools, preferred_tools
preferred_tool_names = set(tt.name for tt in preferred_tools)
other_tools = [td for td in other_tools if td.name not in preferred_tool_names]
state.tools = [
td.as_tool() for td in preferred_tools + other_tools if select_fn(td)
]
return state
return solve
@solver
def full_state_bridge(func: Solver) -> Solver:
"""
Decorator that allows a solver to get the full `TaskState` but also get the
openai patch from InspectAI's `bridge()`. Think of it just like `bridge()`
but instead of defining the inner fn to take only `{"input": ...}`, it can
now be a `Solver` that takes a full `TaskState`.
"""
async def outer_solver(state: TaskState, generate: Generate) -> TaskState:
final_state: TaskState = state
async def inner_solver(sample: dict[str, Any]) -> dict[str, Any]:
nonlocal final_state
# Call the original solver function (final_state is passed by
# reference); we do deepcopy just in case `bridge` messes with the
# state variable after we return.
final_state = copy.deepcopy(await func(state, generate))
# Return doesn't matter; we won't use it
return {"output": ""}
bridged_solver = as_solver(bridge(inner_solver))
# We don't care about the output of inner_solver since we're getting
# final_state directly
await bridged_solver(state, generate)
return final_state
return outer_solver