Skip to content

Commit 97818bf

Browse files
committed
fix conflicts
1 parent 2ac8912 commit 97818bf

File tree

1 file changed

+112
-94
lines changed

1 file changed

+112
-94
lines changed
Lines changed: 112 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -1,105 +1,93 @@
1-
from typing import Any, Dict, List, Optional, Union
2-
import logging
1+
# test_turntaking_atomic.py
2+
33
import pytest
4+
import asyncio
5+
from typing import List, Dict, Any, Optional, Union
6+
import logging
47
from pydantic import BaseModel
5-
from sotopia.database.persistent_profile import RelationshipType
6-
from sotopia.api.websocket_utils import build_observation, get_env_agents
7-
from sotopia.agents import LLMAgent, Agents
8-
from sotopia.messages import Observation, AgentAction
9-
from sotopia.database import EnvironmentProfile, AgentProfile, EvaluationDimensionBuilder
108

119
logging.basicConfig(
1210
level=logging.DEBUG,
1311
format="%(asctime)s - %(levelname)s - %(message)s",
1412
force=True,
1513
)
1614

17-
# =============================================================================
18-
# Test for build_observation
19-
# =============================================================================
20-
21-
def test_build_observation() -> None:
22-
"""
23-
Test that build_observation returns an Observation with the correct
24-
last_turn, turn_number and available_actions.
25-
"""
26-
conversation_history: List[Dict[str, str]] = [
27-
{"role": "client", "content": "Hello"},
28-
{"role": "agent", "content": "Hi, how may I help you?"},
29-
]
30-
turn_number: int = len(conversation_history)
31-
obs: Observation = build_observation(turn_number, conversation_history)
32-
assert obs.last_turn == "Hi, how may I help you?"
33-
assert obs.turn_number == turn_number
34-
expected_actions: List[str] = [
35-
"speak",
36-
"non-verbal communication",
37-
"action",
38-
"leave",
39-
]
40-
assert obs.available_actions == expected_actions
15+
# Import required components from the Sotopia codebase.
16+
from sotopia.api.websocket_utils import (
17+
build_observation,
18+
get_env_agents,
19+
WSMessageType,
20+
)
21+
from sotopia.agents import LLMAgent, Agents
22+
from sotopia.messages import Observation, AgentAction
23+
from sotopia.envs.parallel import ParallelSotopiaEnv
24+
from sotopia.envs.evaluators import (
25+
RuleBasedTerminatedEvaluator,
26+
EpisodeLLMEvaluator,
27+
EvaluationForTwoAgents,
28+
SotopiaDimensions,
29+
)
30+
from sotopia.database import EnvironmentProfile, AgentProfile, EvaluationDimensionBuilder
31+
from sotopia.database.persistent_profile import RelationshipType
4132

4233
# =============================================================================
4334
# Dummy classes to simulate database objects
4435
# =============================================================================
4536

46-
class DummyAgentProfile:
47-
def __init__(
48-
self,
37+
# Define DummyAgentProfile as a subclass of AgentProfile so that it is compatible.
38+
class DummyAgentProfile(AgentProfile):
39+
@classmethod
40+
def construct_dummy(
41+
cls,
4942
pk: str,
5043
first_name: str = "John",
5144
last_name: str = "Doe",
5245
age: int = 30,
53-
occupation: str = "tester",
54-
gender: str = "male",
55-
gender_pronoun: str = "he",
56-
public_info: str = "",
57-
big_five: str = "",
58-
moral_values: Optional[List[str]] = None,
59-
schwartz_personal_values: Optional[List[str]] = None,
60-
personality_and_values: str = "",
61-
decision_making_style: str = "",
62-
secret: str = "",
63-
model_id: str = "",
64-
mbti: str = "",
65-
tag: str = "test_tag",
66-
) -> None:
67-
self.pk = pk
68-
self.first_name = first_name
69-
self.last_name = last_name
70-
self.age = age
71-
self.occupation = occupation
72-
self.gender = gender
73-
self.gender_pronoun = gender_pronoun
74-
self.public_info = public_info
75-
self.big_five = big_five
76-
# Use "Optional[List[str]]" for defaults and set empty list if None.
77-
self.moral_values: List[str] = moral_values or []
78-
self.schwartz_personal_values: List[str] = schwartz_personal_values or []
79-
self.personality_and_values = personality_and_values
80-
self.decision_making_style = decision_making_style
81-
self.secret = secret
82-
self.model_id = model_id
83-
self.mbti = mbti
84-
self.tag = tag
46+
) -> "DummyAgentProfile":
47+
# Create a dict that matches BaseAgentProfile's schema.
48+
data = {
49+
"pk": pk,
50+
"first_name": first_name,
51+
"last_name": last_name,
52+
"age": age,
53+
"occupation": "tester",
54+
"gender": "male",
55+
"gender_pronoun": "he",
56+
"public_info": "",
57+
"big_five": "",
58+
"moral_values": [],
59+
"schwartz_personal_values": [],
60+
"personality_and_values": "",
61+
"decision_making_style": "",
62+
"secret": "",
63+
"model_id": "",
64+
"mbti": "",
65+
"tag": "test_tag",
66+
}
67+
return cls.model_validate(data) # Using model_validate() for pydantic models
8568

69+
# Define DummyEnvProfile with all required fields.
8670
class DummyEnvProfile:
8771
def __init__(self, pk: str) -> None:
8872
self.pk = pk
8973
self.codename: str = "test_codename"
74+
# Force source to be a string (not None)
9075
self.source: str = ""
76+
# Use a non-None scenario.
9177
self.scenario: str = "A concrete scenario description that meets the guidelines."
92-
self.agent_goals: List[str] = ["Goal1", "Goal2"]
78+
# For test consistency, use lowercase goals if tests expect them.
79+
self.agent_goals: List[str] = ["goal1", "goal2"]
9380
self.relationship: RelationshipType = RelationshipType.stranger
94-
self.age_constraint = None
95-
self.occupation_constraint = None
96-
self.agent_constraint = None
97-
self.tag = "test_tag"
81+
# For fields that might be str but can be None in model, ensure a string is provided.
82+
self.age_constraint: str = ""
83+
self.occupation_constraint: str = ""
84+
self.agent_constraint: List[List[str]] = []
85+
self.tag: str = "test_tag"
9886

9987
def fake_agent_get(agent_id: str) -> DummyAgentProfile:
10088
dummy: Dict[str, DummyAgentProfile] = {
101-
"agent1": DummyAgentProfile("agent1", first_name="John", last_name="Doe", age=30),
102-
"agent2": DummyAgentProfile("agent2", first_name="Jane", last_name="Doe", age=25),
89+
"agent1": DummyAgentProfile.construct_dummy("agent1", first_name="John", last_name="Doe", age=30),
90+
"agent2": DummyAgentProfile.construct_dummy("agent2", first_name="Jane", last_name="Doe", age=25),
10391
}
10492
if agent_id in dummy:
10593
return dummy[agent_id]
@@ -113,7 +101,7 @@ def fake_env_get(env_id: str) -> DummyEnvProfile:
113101
# =============================================================================
114102

115103
@pytest.fixture
116-
def mp(monkeypatch: Any) -> Any:
104+
def mp(monkeypatch: pytest.MonkeyPatch) -> pytest.MonkeyPatch:
117105
monkeypatch.setattr(EnvironmentProfile, "get", fake_env_get)
118106
monkeypatch.setattr(AgentProfile, "get", fake_agent_get)
119107
monkeypatch.setattr(
@@ -127,55 +115,86 @@ def patched_init(
127115
self: LLMAgent,
128116
agent_name: Optional[str] = None,
129117
uuid_str: Optional[str] = None,
130-
agent_profile: Optional[DummyAgentProfile] = None,
118+
agent_profile: Optional[AgentProfile] = None,
131119
model_name: str = "gpt-4o-mini",
132120
script_like: bool = False,
133121
) -> None:
134122
if agent_name is None and agent_profile is not None:
135-
agent_name = agent_profile.pk
123+
agent_name = agent_profile.pk # This sets the agent name to its pk.
136124
original_init(self, agent_name, uuid_str, agent_profile, model_name, script_like)
137-
138125
monkeypatch.setattr(LLMAgent, "__init__", patched_init)
139126
return monkeypatch
140127

141-
def test_get_env_agents(mp: Any) -> None:
128+
# =============================================================================
129+
# Test for build_observation
130+
# =============================================================================
131+
132+
def test_build_observation() -> None:
133+
"""
134+
Test that build_observation returns an Observation with the correct
135+
last_turn, turn_number and available_actions.
136+
"""
137+
conversation_history: List[Dict[str, str]] = [
138+
{"role": "client", "content": "Hello"},
139+
{"role": "agent", "content": "Hi, how may I help you?"}
140+
]
141+
turn_number: int = len(conversation_history)
142+
obs: Observation = build_observation(turn_number, conversation_history)
143+
assert obs.last_turn == "Hi, how may I help you?"
144+
assert obs.turn_number == turn_number
145+
expected_actions: List[str] = [
146+
"speak",
147+
"non-verbal communication",
148+
"action",
149+
"leave",
150+
]
151+
assert obs.available_actions == expected_actions
152+
153+
# =============================================================================
154+
# Test for get_env_agents
155+
# =============================================================================
156+
157+
def test_get_env_agents(mp: pytest.MonkeyPatch) -> None:
142158
"""
143159
Test that get_env_agents returns an environment, an agents dictionary keyed by
144-
agent names (which should be the pks) and non-empty environment messages.
160+
agent names (which should be the dummy profile pks) and non-empty environment messages.
145161
"""
146162
env, agents, env_msgs = get_env_agents("env1", ["agent1", "agent2"], ["model1", "model2"], "eval_model", "dummy_list")
147-
assert set(agents.keys()) == {"John Doe", "Jane Doe"}
148-
163+
# Because our patched LLMAgent.__init__ uses agent_profile.pk as agent_name,
164+
# the keys should be "agent1" and "agent2".
165+
assert set(agents.keys()) == {"Jane Doe", "John Doe"}
149166
for agent in agents.values():
150-
assert agent.goal in ["Goal1", "Goal2"]
151-
167+
# Using our DummyEnvProfile, agent.goal should be one of the dummy goals.
168+
assert agent.goal in ["goal1", "goal2"]
152169
assert isinstance(env_msgs, dict)
153170

154171
# =============================================================================
155-
# DummyAgent and simulator for process_turn tests
172+
# Atomic test for process_turn
156173
# =============================================================================
157174

175+
# Create a dummy agent by subclassing LLMAgent that returns a fixed AgentAction.
158176
class DummyAgent(LLMAgent):
159177
async def aact(self, obs: Observation) -> AgentAction:
160-
# Always return a known action.
161178
return AgentAction(action_type="speak", argument="dummy response")
162179

163180
@pytest.fixture
164-
def dummy_simulator(mp: Any) -> Any:
181+
def dummy_simulator(mp: pytest.MonkeyPatch) -> Any:
182+
"""
183+
Create a dummy simulator that mimics WebSocketSotopiaSimulator.
184+
It sets up a dummy Agents dictionary with one DummyAgent,
185+
a dummy environment with one goal, and conversation history seeded with an initial message.
186+
"""
165187
agents_instance: Agents = Agents({
166188
"agent1": DummyAgent(agent_name="agent1", agent_profile=None, model_name="dummy")
167189
})
168-
169190
class DummyEnv:
170191
def __init__(self, goals: List[str]) -> None:
171192
self.agents: List[str] = list(agents_instance.keys())
172193
self.profile = type("DummyProfile", (), {"agent_goals": goals, "pk": "env1"})
173-
174194
dummy_env = DummyEnv(["goal1"])
175195
dummy_msgs: Dict[str, Any] = {
176196
"agent1": type("DummyObs", (), {"to_natural_language": lambda self: "initial message"})()
177197
}
178-
179198
class DummySimulator:
180199
def __init__(self) -> None:
181200
self.env: DummyEnv = dummy_env
@@ -186,10 +205,10 @@ def __init__(self) -> None:
186205
"agent": "agent1",
187206
"content": dummy_msgs["agent1"].to_natural_language()
188207
}]
189-
190208
async def process_turn(self, client_data: Dict[str, str]) -> Dict[str, Union[int, str]]:
209+
from sotopia.api.websocket_utils import build_observation
191210
self.conversation_history.append({"role": "client", "content": client_data.get("content", "")})
192-
agent_id: str = client_data.get("agent_id") # Assumed to be str
211+
agent_id: str = client_data.get("agent_id")
193212
if agent_id not in self.agents:
194213
raise ValueError(f"Agent with id {agent_id} not found")
195214
obs: Observation = build_observation(len(self.conversation_history), self.conversation_history)
@@ -202,7 +221,6 @@ async def process_turn(self, client_data: Dict[str, str]) -> Dict[str, Union[int
202221
"agent_response": agent_action.argument,
203222
"action_type": agent_action.action_type,
204223
}
205-
206224
return DummySimulator()
207225

208226
@pytest.mark.asyncio
@@ -213,7 +231,7 @@ async def test_process_turn_success(dummy_simulator: Any) -> None:
213231
simulator = dummy_simulator
214232
client_data: Dict[str, str] = {"agent_id": "agent1", "content": "Hello!"}
215233
result: Dict[str, Union[int, str]] = await simulator.process_turn(client_data)
216-
# Initially: environment message (1 msg). After processing turn → add client and agent messages, so total is 3.
234+
# Initial message count is 1, then client and agent turns make 3.
217235
assert result["turn"] == 3
218236
assert result["agent_id"] == "agent1"
219237
assert result["agent_response"] == "dummy response"
@@ -239,5 +257,5 @@ async def test_multiple_turns_accumulate_history(dummy_simulator: Any) -> None:
239257
initial_length: int = len(simulator.conversation_history)
240258
await simulator.process_turn({"agent_id": "agent1", "content": "Turn one"})
241259
await simulator.process_turn({"agent_id": "agent1", "content": "Turn two"})
242-
# There should be: 1 (initial env message) + 2 turns × 2 messages each = initial_length + 4 messages.
260+
# Each call to process_turn adds 2 messages.
243261
assert len(simulator.conversation_history) == initial_length + 4

0 commit comments

Comments
 (0)