Skip to content

Commit e6bfae8

Browse files
committed
fixing the submit answer tool and the search tool schema
1 parent 5c481b0 commit e6bfae8

3 files changed

Lines changed: 42 additions & 33 deletions

File tree

mdcrow/ldp_env/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
11
from .state import MDCrowState
2+
from .environment import MDCrowEnv,MySimpleAgent
3+
from .utils import PathRegistry
4+
__all__ = ["MDCrowState",
5+
"MDCrowEnv",
6+
"PathRegistry",
7+
"MySimpleAgent"]
28

3-
__all__ = ["MDCrowState"]

mdcrow/ldp_env/environment.py

Lines changed: 29 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,24 @@
66
import os
77
from typing import Any, Literal
88

9-
from analysis_tools import (
9+
from aviary.core import (
10+
Environment,
11+
Message,
12+
Messages,
13+
Tool,
14+
ToolRequestMessage,
15+
ToolResponseMessage,
16+
)
17+
18+
# get secrets from environment variables
19+
from dotenv import load_dotenv
20+
from pydantic import BaseModel, ConfigDict, Field
21+
22+
from ldp.agent import Agent
23+
from ldp.graph import LLMCallOp, OpResult, compute_graph
24+
from mdcrow.ldp_env.state import MDCrowState
25+
26+
from .analysis_tools import (
1027
compute_bond_angles,
1128
compute_contacts,
1229
compute_distance,
@@ -24,18 +41,7 @@
2441
perform_pca_analysis,
2542
summarize_protein_structure,
2643
)
27-
from aviary.core import (
28-
Environment,
29-
Message,
30-
Messages,
31-
Tool,
32-
ToolRequestMessage,
33-
ToolResponseMessage,
34-
)
35-
36-
# get secrets from environment variables
37-
from dotenv import load_dotenv
38-
from preprocess_tools import (
44+
from .preprocess_tools import (
3945
GetActiveSites,
4046
GetAllKnownSites,
4147
GetAllSequences,
@@ -57,13 +63,8 @@
5763
get_small_molecule_PDB,
5864
pack_molecules,
5965
)
60-
from pydantic import BaseModel, ConfigDict, Field
61-
from simulation_tools import modify_simulation_script, setup_and_run_simulation
62-
from util_tools import ListRegistryPaths, MapPath2Name, scholar2result_llm
63-
64-
from ldp.agent import Agent
65-
from ldp.graph import LLMCallOp, OpResult, compute_graph
66-
from mdcrow.ldp_env.state import MDCrowState
66+
from .simulation_tools import modify_simulation_script, setup_and_run_simulation
67+
from .util_tools import ListRegistryPaths, MapPath2Name, scholar2result_llm
6768

6869
load_dotenv()
6970

@@ -312,30 +313,29 @@ async def step(
312313
False,
313314
)
314315

315-
def submit_answer(self, answer: str) -> tuple[bool, float, Literal[True]]:
316+
def submit_answer(
317+
self, answer: str, finished: bool
318+
) -> tuple[str, float, Literal[True]]:
316319
"""Submit the proposed answer and check if it is correct. This action is terminal.
317320
318321
Args:
319322
answer: Proposed answer.
323+
finished: Whether the task is finished.
320324
321325
Returns:
322326
Three-tuple of if correct, associated reward (correct_reward if correct,
323327
tool_failure_reward if tool failure, otherwise incorrect_reward), and
324328
True indicating done.
325329
"""
326330
try:
327-
correct: bool = (
328-
abs(float(answer) - self.answer)
329-
/ (abs(self.answer) + self.config.rel_tol)
330-
< self.config.rel_tol
331-
)
331+
answer = answer.strip()
332332
reward = (
333-
self.config.correct_reward if correct else self.config.incorrect_reward
333+
self.config.correct_reward if finished else self.config.incorrect_reward
334334
)
335335
except ValueError:
336-
return False, self.config.tool_failure_reward, True
336+
return answer, self.config.tool_failure_reward, True
337337
else:
338-
return correct, reward, True
338+
return answer, reward, True
339339

340340
def calculator(self, expr: str) -> tuple[float | str, float, bool]:
341341
"""Calculate a mathematical expression.

mdcrow/ldp_env/util_tools/search_tools.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,19 @@
11
import paperqa
22

3+
from mdcrow.ldp_env.state import MDCrowState
34
llm_model_args = {
45
"name": "gpt-4o-2024-08-06",
56
"temperature": 0.5,
67
}
78

89

9-
def scholar2result_llm(state, query):
10+
def scholar2result_llm(
11+
state:MDCrowState,
12+
query:str
13+
):
1014
"""
11-
Useful to answer questions that may be found in literature.
12-
Ask a specific question as the input.
15+
Useful to answer questions that may be found in literature. Ask a
16+
specific question as the input.
1317
1418
Args:
1519
state (MDCrowState): The state of the MDCrow environment.

0 commit comments

Comments
 (0)