Skip to content

Commit ad3fde5

Browse files
lc5211The tunix Authors
authored andcommitted
[Tunix] Update ModelAgent to accept observations with a "prompts" key rather than "question".
PiperOrigin-RevId: 875785435
1 parent 791d90c commit ad3fde5

File tree

1 file changed

+26
-3
lines changed

1 file changed

+26
-3
lines changed

tunix/rl/agentic/agents/model_agent.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"""Agent implementation for single-turn interactions."""
1616

1717
import copy
18+
from typing import Any, Dict
1819

1920
from tunix.rl.agentic.agents import agent_types
2021
from tunix.rl.agentic.agents import base_agent
@@ -26,9 +27,31 @@ class ModelAgent(base_agent.ConversationAgentBase):
2627
def __init__(self, system_prompt: str):
2728
super().__init__(system_prompt=system_prompt)
2829

29-
# If you want to handle observations in a special way, you can override
30-
# _observation_to_messages. Here, we stick to the default behavior of
31-
# ConversationAgentBase.
30+
def _observation_to_messages(
31+
self, observation: Any, reward: float, done: bool, info: Dict[str, Any]
32+
) -> None:
33+
"""Convert environment observation into chat messages.
34+
35+
Default behavior:
36+
* If observation is a dict containing "question", use it as user content.
37+
* If observation is a string, append as a user message.
38+
* Otherwise, do nothing.
39+
40+
Subclasses can override this to handle richer observation formats.
41+
42+
Args:
43+
observation: The observation from the environment.
44+
reward: The reward from the environment.
45+
done: Whether the episode is done.
46+
info: Additional information from the environment.
47+
"""
48+
del reward, done, info # Unused in default implementation.
49+
if isinstance(observation, dict) and "prompts" in observation:
50+
self._messages.append(
51+
{"role": "user", "content": observation["prompts"]}
52+
)
53+
elif isinstance(observation, str):
54+
self._messages.append({"role": "user", "content": observation})
3255

3356
def update_from_model(self, response: str, **kwargs) -> agent_types.Action:
3457
"""Receive model response and return it as the final action."""

0 commit comments

Comments
 (0)