Skip to content

Commit 3943012

Browse files
[autofix.ci] apply automated fixes
1 parent 3643e68 commit 3943012

File tree

1 file changed

+75
-30
lines changed

1 file changed

+75
-30
lines changed

tests/api/test_turntaking_atomic.py

Lines changed: 75 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,20 @@
11
# test_turntaking_atomic.py
22

33
import pytest
4-
import asyncio
54
from typing import List, Dict, Any, Optional, Union
65
import logging
76
from pydantic import BaseModel
87
from sotopia.api.websocket_utils import (
98
build_observation,
109
get_env_agents,
11-
WSMessageType,
1210
)
1311
from sotopia.agents import LLMAgent, Agents
1412
from sotopia.messages import Observation, AgentAction
15-
from sotopia.envs.parallel import ParallelSotopiaEnv
16-
from sotopia.envs.evaluators import (
17-
RuleBasedTerminatedEvaluator,
18-
EpisodeLLMEvaluator,
19-
EvaluationForTwoAgents,
20-
SotopiaDimensions,
13+
from sotopia.database import (
14+
EnvironmentProfile,
15+
AgentProfile,
16+
EvaluationDimensionBuilder,
2117
)
22-
from sotopia.database import EnvironmentProfile, AgentProfile, EvaluationDimensionBuilder
2318
from sotopia.database.persistent_profile import RelationshipType
2419

2520

@@ -35,6 +30,7 @@
3530
# Dummy classes to simulate database objects
3631
# =============================================================================
3732

33+
3834
# Define DummyAgentProfile as a subclass of AgentProfile so that it is compatible.
3935
class DummyAgentProfile(AgentProfile):
4036
@classmethod
@@ -67,6 +63,7 @@ def construct_dummy(
6763
}
6864
return cls.model_validate(data) # Using model_validate() for pydantic models
6965

66+
7067
# Define DummyEnvProfile with all required fields.
7168
class DummyEnvProfile:
7269
def __init__(self, pk: str) -> None:
@@ -75,32 +72,42 @@ def __init__(self, pk: str) -> None:
7572
# Force source to be a string (not None)
7673
self.source: str = ""
7774
# Use a non-None scenario.
78-
self.scenario: str = "A concrete scenario description that meets the guidelines."
75+
self.scenario: str = (
76+
"A concrete scenario description that meets the guidelines."
77+
)
7978
# For test consistency, use lowercase goals if tests expect them.
8079
self.agent_goals: List[str] = ["goal1", "goal2"]
81-
self.relationship: RelationshipType = RelationshipType.stranger
80+
self.relationship: RelationshipType = RelationshipType.stranger
8281
# For fields that might be str but can be None in model, ensure a string is provided.
8382
self.age_constraint: str = ""
8483
self.occupation_constraint: str = ""
8584
self.agent_constraint: List[List[str]] = []
8685
self.tag: str = "test_tag"
8786

87+
8888
def fake_agent_get(agent_id: str) -> DummyAgentProfile:
8989
dummy: Dict[str, DummyAgentProfile] = {
90-
"agent1": DummyAgentProfile.construct_dummy("agent1", first_name="John", last_name="Doe", age=30),
91-
"agent2": DummyAgentProfile.construct_dummy("agent2", first_name="Jane", last_name="Doe", age=25),
90+
"agent1": DummyAgentProfile.construct_dummy(
91+
"agent1", first_name="John", last_name="Doe", age=30
92+
),
93+
"agent2": DummyAgentProfile.construct_dummy(
94+
"agent2", first_name="Jane", last_name="Doe", age=25
95+
),
9296
}
9397
if agent_id in dummy:
9498
return dummy[agent_id]
9599
raise Exception(f"AgentProfile with id {agent_id} not found")
96100

101+
97102
def fake_env_get(env_id: str) -> DummyEnvProfile:
98103
return DummyEnvProfile(env_id)
99104

105+
100106
# =============================================================================
101107
# Fixture for monkeypatching
102108
# =============================================================================
103109

110+
104111
@pytest.fixture
105112
def mp(monkeypatch: pytest.MonkeyPatch) -> pytest.MonkeyPatch:
106113
monkeypatch.setattr(EnvironmentProfile, "get", fake_env_get)
@@ -122,22 +129,27 @@ def patched_init(
122129
) -> None:
123130
if agent_name is None and agent_profile is not None:
124131
agent_name = agent_profile.pk # This sets the agent name to its pk.
125-
original_init(self, agent_name, uuid_str, agent_profile, model_name, script_like)
132+
original_init(
133+
self, agent_name, uuid_str, agent_profile, model_name, script_like
134+
)
135+
126136
monkeypatch.setattr(LLMAgent, "__init__", patched_init)
127137
return monkeypatch
128138

139+
129140
# =============================================================================
130141
# Test for build_observation
131142
# =============================================================================
132143

144+
133145
def test_build_observation() -> None:
134146
"""
135147
Test that build_observation returns an Observation with the correct
136148
last_turn, turn_number and available_actions.
137149
"""
138150
conversation_history: List[Dict[str, str]] = [
139151
{"role": "client", "content": "Hello"},
140-
{"role": "agent", "content": "Hi, how may I help you?"}
152+
{"role": "agent", "content": "Hi, how may I help you?"},
141153
]
142154
turn_number: int = len(conversation_history)
143155
obs: Observation = build_observation(turn_number, conversation_history)
@@ -151,16 +163,20 @@ def test_build_observation() -> None:
151163
]
152164
assert obs.available_actions == expected_actions
153165

166+
154167
# =============================================================================
155168
# Test for get_env_agents
156169
# =============================================================================
157170

171+
158172
def test_get_env_agents(mp: pytest.MonkeyPatch) -> None:
159173
"""
160174
Test that get_env_agents returns an environment, an agents dictionary keyed by
161175
agent names (which should be the dummy profile pks) and non-empty environment messages.
162176
"""
163-
env, agents, env_msgs = get_env_agents("env1", ["agent1", "agent2"], ["model1", "model2"], "eval_model", "dummy_list")
177+
env, agents, env_msgs = get_env_agents(
178+
"env1", ["agent1", "agent2"], ["model1", "model2"], "eval_model", "dummy_list"
179+
)
164180
# Because our patched LLMAgent.__init__ uses agent_profile.pk as agent_name,
165181
# the keys should be "agent1" and "agent2".
166182
assert set(agents.keys()) == {"Jane Doe", "John Doe"}
@@ -169,60 +185,87 @@ def test_get_env_agents(mp: pytest.MonkeyPatch) -> None:
169185
assert agent.goal in ["goal1", "goal2"]
170186
assert isinstance(env_msgs, dict)
171187

188+
172189
# =============================================================================
173190
# Atomic test for process_turn
174191
# =============================================================================
175192

193+
176194
# Create a dummy agent by subclassing LLMAgent that returns a fixed AgentAction.
177195
class DummyAgent(LLMAgent):
178196
async def aact(self, obs: Observation) -> AgentAction:
179197
return AgentAction(action_type="speak", argument="dummy response")
180198

199+
181200
@pytest.fixture
182201
def dummy_simulator(mp: pytest.MonkeyPatch) -> Any:
183202
"""
184203
Create a dummy simulator that mimics WebSocketSotopiaSimulator.
185204
It sets up a dummy Agents dictionary with one DummyAgent,
186205
a dummy environment with one goal, and conversation history seeded with an initial message.
187206
"""
188-
agents_instance: Agents = Agents({
189-
"agent1": DummyAgent(agent_name="agent1", agent_profile=None, model_name="dummy")
190-
})
207+
agents_instance: Agents = Agents(
208+
{
209+
"agent1": DummyAgent(
210+
agent_name="agent1", agent_profile=None, model_name="dummy"
211+
)
212+
}
213+
)
214+
191215
class DummyEnv:
192216
def __init__(self, goals: List[str]) -> None:
193217
self.agents: List[str] = list(agents_instance.keys())
194-
self.profile = type("DummyProfile", (), {"agent_goals": goals, "pk": "env1"})
218+
self.profile = type(
219+
"DummyProfile", (), {"agent_goals": goals, "pk": "env1"}
220+
)
221+
195222
dummy_env = DummyEnv(["goal1"])
196223
dummy_msgs: Dict[str, Any] = {
197-
"agent1": type("DummyObs", (), {"to_natural_language": lambda self: "initial message"})()
224+
"agent1": type(
225+
"DummyObs", (), {"to_natural_language": lambda self: "initial message"}
226+
)()
198227
}
228+
199229
class DummySimulator:
200230
def __init__(self) -> None:
201231
self.env: DummyEnv = dummy_env
202232
self.agents: Agents = agents_instance
203233
self.environment_messages: Dict[str, Any] = dummy_msgs
204-
self.conversation_history: List[Dict[str, str]] = [{
205-
"role": "environment",
206-
"agent": "agent1",
207-
"content": dummy_msgs["agent1"].to_natural_language()
208-
}]
209-
async def process_turn(self, client_data: Dict[str, str]) -> Dict[str, Union[int, str]]:
210-
self.conversation_history.append({"role": "client", "content": client_data.get("content", "")})
234+
self.conversation_history: List[Dict[str, str]] = [
235+
{
236+
"role": "environment",
237+
"agent": "agent1",
238+
"content": dummy_msgs["agent1"].to_natural_language(),
239+
}
240+
]
241+
242+
async def process_turn(
243+
self, client_data: Dict[str, str]
244+
) -> Dict[str, Union[int, str]]:
245+
self.conversation_history.append(
246+
{"role": "client", "content": client_data.get("content", "")}
247+
)
211248
agent_id: str | None = client_data.get("agent_id")
212249
if agent_id not in self.agents:
213250
raise ValueError(f"Agent with id {agent_id} not found")
214-
obs: Observation = build_observation(len(self.conversation_history), self.conversation_history)
251+
obs: Observation = build_observation(
252+
len(self.conversation_history), self.conversation_history
253+
)
215254
agent = self.agents[agent_id]
216255
agent_action: AgentAction = await agent.aact(obs)
217-
self.conversation_history.append({"role": "agent", "content": agent_action.argument})
256+
self.conversation_history.append(
257+
{"role": "agent", "content": agent_action.argument}
258+
)
218259
return {
219260
"turn": len(self.conversation_history),
220261
"agent_id": agent_id,
221262
"agent_response": agent_action.argument,
222263
"action_type": agent_action.action_type,
223264
}
265+
224266
return DummySimulator()
225267

268+
226269
@pytest.mark.asyncio
227270
async def test_process_turn_success(dummy_simulator: Any) -> None:
228271
"""
@@ -237,6 +280,7 @@ async def test_process_turn_success(dummy_simulator: Any) -> None:
237280
assert result["agent_response"] == "dummy response"
238281
assert result["action_type"] == "speak"
239282

283+
240284
@pytest.mark.asyncio
241285
async def test_process_turn_invalid_agent(dummy_simulator: Any) -> None:
242286
"""
@@ -248,6 +292,7 @@ async def test_process_turn_invalid_agent(dummy_simulator: Any) -> None:
248292
await simulator.process_turn(client_data)
249293
assert "Agent with id nonexistent not found" in str(excinfo.value)
250294

295+
251296
@pytest.mark.asyncio
252297
async def test_multiple_turns_accumulate_history(dummy_simulator: Any) -> None:
253298
"""

0 commit comments

Comments
 (0)