|
6 | 6 | import os |
7 | 7 | from typing import Any, Literal |
8 | 8 |
|
9 | | -from analysis_tools import ( |
| 9 | +from aviary.core import ( |
| 10 | + Environment, |
| 11 | + Message, |
| 12 | + Messages, |
| 13 | + Tool, |
| 14 | + ToolRequestMessage, |
| 15 | + ToolResponseMessage, |
| 16 | +) |
| 17 | + |
| 18 | +# get secrets from environment variables |
| 19 | +from dotenv import load_dotenv |
| 20 | +from pydantic import BaseModel, ConfigDict, Field |
| 21 | + |
| 22 | +from ldp.agent import Agent |
| 23 | +from ldp.graph import LLMCallOp, OpResult, compute_graph |
| 24 | +from mdcrow.ldp_env.state import MDCrowState |
| 25 | + |
| 26 | +from .analysis_tools import ( |
10 | 27 | compute_bond_angles, |
11 | 28 | compute_contacts, |
12 | 29 | compute_distance, |
|
24 | 41 | perform_pca_analysis, |
25 | 42 | summarize_protein_structure, |
26 | 43 | ) |
27 | | -from aviary.core import ( |
28 | | - Environment, |
29 | | - Message, |
30 | | - Messages, |
31 | | - Tool, |
32 | | - ToolRequestMessage, |
33 | | - ToolResponseMessage, |
34 | | -) |
35 | | - |
36 | | -# get secrets from environment variables |
37 | | -from dotenv import load_dotenv |
38 | | -from preprocess_tools import ( |
| 44 | +from .preprocess_tools import ( |
39 | 45 | GetActiveSites, |
40 | 46 | GetAllKnownSites, |
41 | 47 | GetAllSequences, |
|
57 | 63 | get_small_molecule_PDB, |
58 | 64 | pack_molecules, |
59 | 65 | ) |
60 | | -from pydantic import BaseModel, ConfigDict, Field |
61 | | -from simulation_tools import modify_simulation_script, setup_and_run_simulation |
62 | | -from util_tools import ListRegistryPaths, MapPath2Name, scholar2result_llm |
63 | | - |
64 | | -from ldp.agent import Agent |
65 | | -from ldp.graph import LLMCallOp, OpResult, compute_graph |
66 | | -from mdcrow.ldp_env.state import MDCrowState |
| 66 | +from .simulation_tools import modify_simulation_script, setup_and_run_simulation |
| 67 | +from .util_tools import ListRegistryPaths, MapPath2Name, scholar2result_llm |
67 | 68 |
|
68 | 69 | load_dotenv() |
69 | 70 |
|
@@ -312,30 +313,29 @@ async def step( |
312 | 313 | False, |
313 | 314 | ) |
314 | 315 |
|
315 | | - def submit_answer(self, answer: str) -> tuple[bool, float, Literal[True]]: |
| 316 | + def submit_answer( |
| 317 | + self, answer: str, finished: bool |
| 318 | + ) -> tuple[str, float, Literal[True]]: |
316 | 319 | """Submit the proposed answer and check if it is correct. This action is terminal. |
317 | 320 |
|
318 | 321 | Args: |
319 | 322 | answer: Proposed answer. |
| 323 | + finished: Whether the task is finished. |
320 | 324 |
|
321 | 325 | Returns: |
322 | 326 | Three-tuple of if correct, associated reward (correct_reward if correct, |
323 | 327 | tool_failure_reward if tool failure, otherwise incorrect_reward), and |
324 | 328 | True indicating done. |
325 | 329 | """ |
326 | 330 | try: |
327 | | - correct: bool = ( |
328 | | - abs(float(answer) - self.answer) |
329 | | - / (abs(self.answer) + self.config.rel_tol) |
330 | | - < self.config.rel_tol |
331 | | - ) |
| 331 | + answer = answer.strip() |
332 | 332 | reward = ( |
333 | | - self.config.correct_reward if correct else self.config.incorrect_reward |
| 333 | + self.config.correct_reward if finished else self.config.incorrect_reward |
334 | 334 | ) |
335 | 335 | except ValueError: |
336 | | - return False, self.config.tool_failure_reward, True |
| 336 | + return answer, self.config.tool_failure_reward, True |
337 | 337 | else: |
338 | | - return correct, reward, True |
| 338 | + return answer, reward, True |
339 | 339 |
|
340 | 340 | def calculator(self, expr: str) -> tuple[float | str, float, bool]: |
341 | 341 | """Calculate a mathematical expression. |
|
0 commit comments