1- from typing import Any , Dict , List , Optional , Union
2- import logging
1+ # test_turntaking_atomic.py
2+
33import pytest
4+ import asyncio
5+ from typing import List , Dict , Any , Optional , Union
6+ import logging
47from pydantic import BaseModel
5- from sotopia .database .persistent_profile import RelationshipType
6- from sotopia .api .websocket_utils import build_observation , get_env_agents
7- from sotopia .agents import LLMAgent , Agents
8- from sotopia .messages import Observation , AgentAction
9- from sotopia .database import EnvironmentProfile , AgentProfile , EvaluationDimensionBuilder
108
119logging .basicConfig (
1210 level = logging .DEBUG ,
1311 format = "%(asctime)s - %(levelname)s - %(message)s" ,
1412 force = True ,
1513)
1614
17- # =============================================================================
18- # Test for build_observation
19- # =============================================================================
20-
21- def test_build_observation () -> None :
22- """
23- Test that build_observation returns an Observation with the correct
24- last_turn, turn_number and available_actions.
25- """
26- conversation_history : List [Dict [str , str ]] = [
27- {"role" : "client" , "content" : "Hello" },
28- {"role" : "agent" , "content" : "Hi, how may I help you?" },
29- ]
30- turn_number : int = len (conversation_history )
31- obs : Observation = build_observation (turn_number , conversation_history )
32- assert obs .last_turn == "Hi, how may I help you?"
33- assert obs .turn_number == turn_number
34- expected_actions : List [str ] = [
35- "speak" ,
36- "non-verbal communication" ,
37- "action" ,
38- "leave" ,
39- ]
40- assert obs .available_actions == expected_actions
15+ # Import required components from the Sotopia codebase.
16+ from sotopia .api .websocket_utils import (
17+ build_observation ,
18+ get_env_agents ,
19+ WSMessageType ,
20+ )
21+ from sotopia .agents import LLMAgent , Agents
22+ from sotopia .messages import Observation , AgentAction
23+ from sotopia .envs .parallel import ParallelSotopiaEnv
24+ from sotopia .envs .evaluators import (
25+ RuleBasedTerminatedEvaluator ,
26+ EpisodeLLMEvaluator ,
27+ EvaluationForTwoAgents ,
28+ SotopiaDimensions ,
29+ )
30+ from sotopia .database import EnvironmentProfile , AgentProfile , EvaluationDimensionBuilder
31+ from sotopia .database .persistent_profile import RelationshipType
4132
4233# =============================================================================
4334# Dummy classes to simulate database objects
4435# =============================================================================
4536
46- class DummyAgentProfile :
47- def __init__ (
48- self ,
37+ # Define DummyAgentProfile as a subclass of AgentProfile so that it is compatible.
38+ class DummyAgentProfile (AgentProfile ):
39+ @classmethod
40+ def construct_dummy (
41+ cls ,
4942 pk : str ,
5043 first_name : str = "John" ,
5144 last_name : str = "Doe" ,
5245 age : int = 30 ,
53- occupation : str = "tester" ,
54- gender : str = "male" ,
55- gender_pronoun : str = "he" ,
56- public_info : str = "" ,
57- big_five : str = "" ,
58- moral_values : Optional [List [str ]] = None ,
59- schwartz_personal_values : Optional [List [str ]] = None ,
60- personality_and_values : str = "" ,
61- decision_making_style : str = "" ,
62- secret : str = "" ,
63- model_id : str = "" ,
64- mbti : str = "" ,
65- tag : str = "test_tag" ,
66- ) -> None :
67- self .pk = pk
68- self .first_name = first_name
69- self .last_name = last_name
70- self .age = age
71- self .occupation = occupation
72- self .gender = gender
73- self .gender_pronoun = gender_pronoun
74- self .public_info = public_info
75- self .big_five = big_five
76- # Use "Optional[List[str]]" for defaults and set empty list if None.
77- self .moral_values : List [str ] = moral_values or []
78- self .schwartz_personal_values : List [str ] = schwartz_personal_values or []
79- self .personality_and_values = personality_and_values
80- self .decision_making_style = decision_making_style
81- self .secret = secret
82- self .model_id = model_id
83- self .mbti = mbti
84- self .tag = tag
46+ ) -> "DummyAgentProfile" :
47+ # Create a dict that matches BaseAgentProfile's schema.
48+ data = {
49+ "pk" : pk ,
50+ "first_name" : first_name ,
51+ "last_name" : last_name ,
52+ "age" : age ,
53+ "occupation" : "tester" ,
54+ "gender" : "male" ,
55+ "gender_pronoun" : "he" ,
56+ "public_info" : "" ,
57+ "big_five" : "" ,
58+ "moral_values" : [],
59+ "schwartz_personal_values" : [],
60+ "personality_and_values" : "" ,
61+ "decision_making_style" : "" ,
62+ "secret" : "" ,
63+ "model_id" : "" ,
64+ "mbti" : "" ,
65+ "tag" : "test_tag" ,
66+ }
67+ return cls .model_validate (data ) # Using model_validate() for pydantic models
8568
69+ # Define DummyEnvProfile with all required fields.
8670class DummyEnvProfile :
8771 def __init__ (self , pk : str ) -> None :
8872 self .pk = pk
8973 self .codename : str = "test_codename"
74+ # Force source to be a string (not None)
9075 self .source : str = ""
76+ # Use a non-None scenario.
9177 self .scenario : str = "A concrete scenario description that meets the guidelines."
92- self .agent_goals : List [str ] = ["Goal1" , "Goal2" ]
78+ # For test consistency, use lowercase goals if tests expect them.
79+ self .agent_goals : List [str ] = ["goal1" , "goal2" ]
9380 self .relationship : RelationshipType = RelationshipType .stranger
94- self .age_constraint = None
95- self .occupation_constraint = None
96- self .agent_constraint = None
97- self .tag = "test_tag"
81+ # For fields that might be str but can be None in model, ensure a string is provided.
82+ self .age_constraint : str = ""
83+ self .occupation_constraint : str = ""
84+ self .agent_constraint : List [List [str ]] = []
85+ self .tag : str = "test_tag"
9886
9987def fake_agent_get (agent_id : str ) -> DummyAgentProfile :
10088 dummy : Dict [str , DummyAgentProfile ] = {
101- "agent1" : DummyAgentProfile ("agent1" , first_name = "John" , last_name = "Doe" , age = 30 ),
102- "agent2" : DummyAgentProfile ("agent2" , first_name = "Jane" , last_name = "Doe" , age = 25 ),
89+ "agent1" : DummyAgentProfile . construct_dummy ("agent1" , first_name = "John" , last_name = "Doe" , age = 30 ),
90+ "agent2" : DummyAgentProfile . construct_dummy ("agent2" , first_name = "Jane" , last_name = "Doe" , age = 25 ),
10391 }
10492 if agent_id in dummy :
10593 return dummy [agent_id ]
@@ -113,7 +101,7 @@ def fake_env_get(env_id: str) -> DummyEnvProfile:
113101# =============================================================================
114102
115103@pytest .fixture
116- def mp (monkeypatch : Any ) -> Any :
104+ def mp (monkeypatch : pytest . MonkeyPatch ) -> pytest . MonkeyPatch :
117105 monkeypatch .setattr (EnvironmentProfile , "get" , fake_env_get )
118106 monkeypatch .setattr (AgentProfile , "get" , fake_agent_get )
119107 monkeypatch .setattr (
@@ -127,55 +115,86 @@ def patched_init(
127115 self : LLMAgent ,
128116 agent_name : Optional [str ] = None ,
129117 uuid_str : Optional [str ] = None ,
130- agent_profile : Optional [DummyAgentProfile ] = None ,
118+ agent_profile : Optional [AgentProfile ] = None ,
131119 model_name : str = "gpt-4o-mini" ,
132120 script_like : bool = False ,
133121 ) -> None :
134122 if agent_name is None and agent_profile is not None :
135- agent_name = agent_profile .pk
123+ agent_name = agent_profile .pk # This sets the agent name to its pk.
136124 original_init (self , agent_name , uuid_str , agent_profile , model_name , script_like )
137-
138125 monkeypatch .setattr (LLMAgent , "__init__" , patched_init )
139126 return monkeypatch
140127
141- def test_get_env_agents (mp : Any ) -> None :
128+ # =============================================================================
129+ # Test for build_observation
130+ # =============================================================================
131+
132+ def test_build_observation () -> None :
133+ """
134+ Test that build_observation returns an Observation with the correct
135+ last_turn, turn_number and available_actions.
136+ """
137+ conversation_history : List [Dict [str , str ]] = [
138+ {"role" : "client" , "content" : "Hello" },
139+ {"role" : "agent" , "content" : "Hi, how may I help you?" }
140+ ]
141+ turn_number : int = len (conversation_history )
142+ obs : Observation = build_observation (turn_number , conversation_history )
143+ assert obs .last_turn == "Hi, how may I help you?"
144+ assert obs .turn_number == turn_number
145+ expected_actions : List [str ] = [
146+ "speak" ,
147+ "non-verbal communication" ,
148+ "action" ,
149+ "leave" ,
150+ ]
151+ assert obs .available_actions == expected_actions
152+
153+ # =============================================================================
154+ # Test for get_env_agents
155+ # =============================================================================
156+
157+ def test_get_env_agents (mp : pytest .MonkeyPatch ) -> None :
142158 """
143159 Test that get_env_agents returns an environment, an agents dictionary keyed by
144- agent names (which should be the pks) and non-empty environment messages.
160+ agent names (which should be the dummy profile pks) and non-empty environment messages.
145161 """
146162 env , agents , env_msgs = get_env_agents ("env1" , ["agent1" , "agent2" ], ["model1" , "model2" ], "eval_model" , "dummy_list" )
147- assert set (agents .keys ()) == {"John Doe" , "Jane Doe" }
148-
163+ # Because our patched LLMAgent.__init__ uses agent_profile.pk as agent_name,
164+ # the keys should be "agent1" and "agent2".
165+ assert set (agents .keys ()) == {"Jane Doe" , "John Doe" }
149166 for agent in agents .values ():
150- assert agent .goal in [ "Goal1" , "Goal2" ]
151-
167+ # Using our DummyEnvProfile, agent.goal should be one of the dummy goals.
168+ assert agent . goal in [ "goal1" , "goal2" ]
152169 assert isinstance (env_msgs , dict )
153170
154171# =============================================================================
155- # DummyAgent and simulator for process_turn tests
172+ # Atomic test for process_turn
156173# =============================================================================
157174
175+ # Create a dummy agent by subclassing LLMAgent that returns a fixed AgentAction.
158176class DummyAgent (LLMAgent ):
159177 async def aact (self , obs : Observation ) -> AgentAction :
160- # Always return a known action.
161178 return AgentAction (action_type = "speak" , argument = "dummy response" )
162179
163180@pytest .fixture
164- def dummy_simulator (mp : Any ) -> Any :
181+ def dummy_simulator (mp : pytest .MonkeyPatch ) -> Any :
182+ """
183+ Create a dummy simulator that mimics WebSocketSotopiaSimulator.
184+ It sets up a dummy Agents dictionary with one DummyAgent,
185+ a dummy environment with one goal, and conversation history seeded with an initial message.
186+ """
165187 agents_instance : Agents = Agents ({
166188 "agent1" : DummyAgent (agent_name = "agent1" , agent_profile = None , model_name = "dummy" )
167189 })
168-
169190 class DummyEnv :
170191 def __init__ (self , goals : List [str ]) -> None :
171192 self .agents : List [str ] = list (agents_instance .keys ())
172193 self .profile = type ("DummyProfile" , (), {"agent_goals" : goals , "pk" : "env1" })
173-
174194 dummy_env = DummyEnv (["goal1" ])
175195 dummy_msgs : Dict [str , Any ] = {
176196 "agent1" : type ("DummyObs" , (), {"to_natural_language" : lambda self : "initial message" })()
177197 }
178-
179198 class DummySimulator :
180199 def __init__ (self ) -> None :
181200 self .env : DummyEnv = dummy_env
@@ -186,10 +205,10 @@ def __init__(self) -> None:
186205 "agent" : "agent1" ,
187206 "content" : dummy_msgs ["agent1" ].to_natural_language ()
188207 }]
189-
190208 async def process_turn (self , client_data : Dict [str , str ]) -> Dict [str , Union [int , str ]]:
209+ from sotopia .api .websocket_utils import build_observation
191210 self .conversation_history .append ({"role" : "client" , "content" : client_data .get ("content" , "" )})
192- agent_id : str = client_data .get ("agent_id" ) # Assumed to be str
211+ agent_id : str = client_data .get ("agent_id" )
193212 if agent_id not in self .agents :
194213 raise ValueError (f"Agent with id { agent_id } not found" )
195214 obs : Observation = build_observation (len (self .conversation_history ), self .conversation_history )
@@ -202,7 +221,6 @@ async def process_turn(self, client_data: Dict[str, str]) -> Dict[str, Union[int
202221 "agent_response" : agent_action .argument ,
203222 "action_type" : agent_action .action_type ,
204223 }
205-
206224 return DummySimulator ()
207225
208226@pytest .mark .asyncio
@@ -213,7 +231,7 @@ async def test_process_turn_success(dummy_simulator: Any) -> None:
213231 simulator = dummy_simulator
214232 client_data : Dict [str , str ] = {"agent_id" : "agent1" , "content" : "Hello!" }
215233 result : Dict [str , Union [int , str ]] = await simulator .process_turn (client_data )
216- # Initially: environment message (1 msg). After processing turn → add client and agent messages, so total is 3.
234+ # Initial message count is 1, then client and agent turns make 3.
217235 assert result ["turn" ] == 3
218236 assert result ["agent_id" ] == "agent1"
219237 assert result ["agent_response" ] == "dummy response"
@@ -239,5 +257,5 @@ async def test_multiple_turns_accumulate_history(dummy_simulator: Any) -> None:
239257 initial_length : int = len (simulator .conversation_history )
240258 await simulator .process_turn ({"agent_id" : "agent1" , "content" : "Turn one" })
241259 await simulator .process_turn ({"agent_id" : "agent1" , "content" : "Turn two" })
242- # There should be: 1 (initial env message) + 2 turns × 2 messages each = initial_length + 4 messages.
260+ # Each call to process_turn adds 2 messages.
243261 assert len (simulator .conversation_history ) == initial_length + 4
0 commit comments