1
1
# test_turntaking_atomic.py
2
2
3
3
import pytest
4
- import asyncio
5
4
from typing import List , Dict , Any , Optional , Union
6
5
import logging
7
6
from pydantic import BaseModel
8
7
from sotopia .api .websocket_utils import (
9
8
build_observation ,
10
9
get_env_agents ,
11
- WSMessageType ,
12
10
)
13
11
from sotopia .agents import LLMAgent , Agents
14
12
from sotopia .messages import Observation , AgentAction
15
- from sotopia .envs .parallel import ParallelSotopiaEnv
16
- from sotopia .envs .evaluators import (
17
- RuleBasedTerminatedEvaluator ,
18
- EpisodeLLMEvaluator ,
19
- EvaluationForTwoAgents ,
20
- SotopiaDimensions ,
13
+ from sotopia .database import (
14
+ EnvironmentProfile ,
15
+ AgentProfile ,
16
+ EvaluationDimensionBuilder ,
21
17
)
22
- from sotopia .database import EnvironmentProfile , AgentProfile , EvaluationDimensionBuilder
23
18
from sotopia .database .persistent_profile import RelationshipType
24
19
25
20
35
30
# Dummy classes to simulate database objects
36
31
# =============================================================================
37
32
33
+
38
34
# Define DummyAgentProfile as a subclass of AgentProfile so that it is compatible.
39
35
class DummyAgentProfile (AgentProfile ):
40
36
@classmethod
@@ -67,6 +63,7 @@ def construct_dummy(
67
63
}
68
64
return cls .model_validate (data ) # Using model_validate() for pydantic models
69
65
66
+
70
67
# Define DummyEnvProfile with all required fields.
71
68
class DummyEnvProfile :
72
69
def __init__ (self , pk : str ) -> None :
@@ -75,32 +72,42 @@ def __init__(self, pk: str) -> None:
75
72
# Force source to be a string (not None)
76
73
self .source : str = ""
77
74
# Use a non-None scenario.
78
- self .scenario : str = "A concrete scenario description that meets the guidelines."
75
+ self .scenario : str = (
76
+ "A concrete scenario description that meets the guidelines."
77
+ )
79
78
# For test consistency, use lowercase goals if tests expect them.
80
79
self .agent_goals : List [str ] = ["goal1" , "goal2" ]
81
- self .relationship : RelationshipType = RelationshipType .stranger
80
+ self .relationship : RelationshipType = RelationshipType .stranger
82
81
# For fields that might be str but can be None in model, ensure a string is provided.
83
82
self .age_constraint : str = ""
84
83
self .occupation_constraint : str = ""
85
84
self .agent_constraint : List [List [str ]] = []
86
85
self .tag : str = "test_tag"
87
86
87
+
88
88
def fake_agent_get (agent_id : str ) -> DummyAgentProfile :
89
89
dummy : Dict [str , DummyAgentProfile ] = {
90
- "agent1" : DummyAgentProfile .construct_dummy ("agent1" , first_name = "John" , last_name = "Doe" , age = 30 ),
91
- "agent2" : DummyAgentProfile .construct_dummy ("agent2" , first_name = "Jane" , last_name = "Doe" , age = 25 ),
90
+ "agent1" : DummyAgentProfile .construct_dummy (
91
+ "agent1" , first_name = "John" , last_name = "Doe" , age = 30
92
+ ),
93
+ "agent2" : DummyAgentProfile .construct_dummy (
94
+ "agent2" , first_name = "Jane" , last_name = "Doe" , age = 25
95
+ ),
92
96
}
93
97
if agent_id in dummy :
94
98
return dummy [agent_id ]
95
99
raise Exception (f"AgentProfile with id { agent_id } not found" )
96
100
101
+
97
102
def fake_env_get (env_id : str ) -> DummyEnvProfile :
98
103
return DummyEnvProfile (env_id )
99
104
105
+
100
106
# =============================================================================
101
107
# Fixture for monkeypatching
102
108
# =============================================================================
103
109
110
+
104
111
@pytest .fixture
105
112
def mp (monkeypatch : pytest .MonkeyPatch ) -> pytest .MonkeyPatch :
106
113
monkeypatch .setattr (EnvironmentProfile , "get" , fake_env_get )
@@ -122,22 +129,27 @@ def patched_init(
122
129
) -> None :
123
130
if agent_name is None and agent_profile is not None :
124
131
agent_name = agent_profile .pk # This sets the agent name to its pk.
125
- original_init (self , agent_name , uuid_str , agent_profile , model_name , script_like )
132
+ original_init (
133
+ self , agent_name , uuid_str , agent_profile , model_name , script_like
134
+ )
135
+
126
136
monkeypatch .setattr (LLMAgent , "__init__" , patched_init )
127
137
return monkeypatch
128
138
139
+
129
140
# =============================================================================
130
141
# Test for build_observation
131
142
# =============================================================================
132
143
144
+
133
145
def test_build_observation () -> None :
134
146
"""
135
147
Test that build_observation returns an Observation with the correct
136
148
last_turn, turn_number and available_actions.
137
149
"""
138
150
conversation_history : List [Dict [str , str ]] = [
139
151
{"role" : "client" , "content" : "Hello" },
140
- {"role" : "agent" , "content" : "Hi, how may I help you?" }
152
+ {"role" : "agent" , "content" : "Hi, how may I help you?" },
141
153
]
142
154
turn_number : int = len (conversation_history )
143
155
obs : Observation = build_observation (turn_number , conversation_history )
@@ -151,16 +163,20 @@ def test_build_observation() -> None:
151
163
]
152
164
assert obs .available_actions == expected_actions
153
165
166
+
154
167
# =============================================================================
155
168
# Test for get_env_agents
156
169
# =============================================================================
157
170
171
+
158
172
def test_get_env_agents (mp : pytest .MonkeyPatch ) -> None :
159
173
"""
160
174
Test that get_env_agents returns an environment, an agents dictionary keyed by
161
175
agent names (which should be the dummy profile pks) and non-empty environment messages.
162
176
"""
163
- env , agents , env_msgs = get_env_agents ("env1" , ["agent1" , "agent2" ], ["model1" , "model2" ], "eval_model" , "dummy_list" )
177
+ env , agents , env_msgs = get_env_agents (
178
+ "env1" , ["agent1" , "agent2" ], ["model1" , "model2" ], "eval_model" , "dummy_list"
179
+ )
164
180
# Because our patched LLMAgent.__init__ uses agent_profile.pk as agent_name,
165
181
# the keys should be "agent1" and "agent2".
166
182
assert set (agents .keys ()) == {"Jane Doe" , "John Doe" }
@@ -169,60 +185,87 @@ def test_get_env_agents(mp: pytest.MonkeyPatch) -> None:
169
185
assert agent .goal in ["goal1" , "goal2" ]
170
186
assert isinstance (env_msgs , dict )
171
187
188
+
172
189
# =============================================================================
173
190
# Atomic test for process_turn
174
191
# =============================================================================
175
192
193
+
176
194
# Create a dummy agent by subclassing LLMAgent that returns a fixed AgentAction.
177
195
class DummyAgent (LLMAgent ):
178
196
async def aact (self , obs : Observation ) -> AgentAction :
179
197
return AgentAction (action_type = "speak" , argument = "dummy response" )
180
198
199
+
181
200
@pytest .fixture
182
201
def dummy_simulator (mp : pytest .MonkeyPatch ) -> Any :
183
202
"""
184
203
Create a dummy simulator that mimics WebSocketSotopiaSimulator.
185
204
It sets up a dummy Agents dictionary with one DummyAgent,
186
205
a dummy environment with one goal, and conversation history seeded with an initial message.
187
206
"""
188
- agents_instance : Agents = Agents ({
189
- "agent1" : DummyAgent (agent_name = "agent1" , agent_profile = None , model_name = "dummy" )
190
- })
207
+ agents_instance : Agents = Agents (
208
+ {
209
+ "agent1" : DummyAgent (
210
+ agent_name = "agent1" , agent_profile = None , model_name = "dummy"
211
+ )
212
+ }
213
+ )
214
+
191
215
class DummyEnv :
192
216
def __init__ (self , goals : List [str ]) -> None :
193
217
self .agents : List [str ] = list (agents_instance .keys ())
194
- self .profile = type ("DummyProfile" , (), {"agent_goals" : goals , "pk" : "env1" })
218
+ self .profile = type (
219
+ "DummyProfile" , (), {"agent_goals" : goals , "pk" : "env1" }
220
+ )
221
+
195
222
dummy_env = DummyEnv (["goal1" ])
196
223
dummy_msgs : Dict [str , Any ] = {
197
- "agent1" : type ("DummyObs" , (), {"to_natural_language" : lambda self : "initial message" })()
224
+ "agent1" : type (
225
+ "DummyObs" , (), {"to_natural_language" : lambda self : "initial message" }
226
+ )()
198
227
}
228
+
199
229
class DummySimulator :
200
230
def __init__ (self ) -> None :
201
231
self .env : DummyEnv = dummy_env
202
232
self .agents : Agents = agents_instance
203
233
self .environment_messages : Dict [str , Any ] = dummy_msgs
204
- self .conversation_history : List [Dict [str , str ]] = [{
205
- "role" : "environment" ,
206
- "agent" : "agent1" ,
207
- "content" : dummy_msgs ["agent1" ].to_natural_language ()
208
- }]
209
- async def process_turn (self , client_data : Dict [str , str ]) -> Dict [str , Union [int , str ]]:
210
- self .conversation_history .append ({"role" : "client" , "content" : client_data .get ("content" , "" )})
234
+ self .conversation_history : List [Dict [str , str ]] = [
235
+ {
236
+ "role" : "environment" ,
237
+ "agent" : "agent1" ,
238
+ "content" : dummy_msgs ["agent1" ].to_natural_language (),
239
+ }
240
+ ]
241
+
242
+ async def process_turn (
243
+ self , client_data : Dict [str , str ]
244
+ ) -> Dict [str , Union [int , str ]]:
245
+ self .conversation_history .append (
246
+ {"role" : "client" , "content" : client_data .get ("content" , "" )}
247
+ )
211
248
agent_id : str | None = client_data .get ("agent_id" )
212
249
if agent_id not in self .agents :
213
250
raise ValueError (f"Agent with id { agent_id } not found" )
214
- obs : Observation = build_observation (len (self .conversation_history ), self .conversation_history )
251
+ obs : Observation = build_observation (
252
+ len (self .conversation_history ), self .conversation_history
253
+ )
215
254
agent = self .agents [agent_id ]
216
255
agent_action : AgentAction = await agent .aact (obs )
217
- self .conversation_history .append ({"role" : "agent" , "content" : agent_action .argument })
256
+ self .conversation_history .append (
257
+ {"role" : "agent" , "content" : agent_action .argument }
258
+ )
218
259
return {
219
260
"turn" : len (self .conversation_history ),
220
261
"agent_id" : agent_id ,
221
262
"agent_response" : agent_action .argument ,
222
263
"action_type" : agent_action .action_type ,
223
264
}
265
+
224
266
return DummySimulator ()
225
267
268
+
226
269
@pytest .mark .asyncio
227
270
async def test_process_turn_success (dummy_simulator : Any ) -> None :
228
271
"""
@@ -237,6 +280,7 @@ async def test_process_turn_success(dummy_simulator: Any) -> None:
237
280
assert result ["agent_response" ] == "dummy response"
238
281
assert result ["action_type" ] == "speak"
239
282
283
+
240
284
@pytest .mark .asyncio
241
285
async def test_process_turn_invalid_agent (dummy_simulator : Any ) -> None :
242
286
"""
@@ -248,6 +292,7 @@ async def test_process_turn_invalid_agent(dummy_simulator: Any) -> None:
248
292
await simulator .process_turn (client_data )
249
293
assert "Agent with id nonexistent not found" in str (excinfo .value )
250
294
295
+
251
296
@pytest .mark .asyncio
252
297
async def test_multiple_turns_accumulate_history (dummy_simulator : Any ) -> None :
253
298
"""
0 commit comments