Skip to content

working till just before agenerate #305

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 19 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14,740 changes: 14,740 additions & 0 deletions docs/package-lock.json

Large diffs are not rendered by default.

81 changes: 51 additions & 30 deletions examples/experimental/sotopia_original_replica/llm_agent_sotopia.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,11 @@
pass

# Configure logging
FORMAT = "%(asctime)s - %(levelname)s - %(name)s - %(message)s"
logging.basicConfig(
level=logging.WARNING,
format=FORMAT,
datefmt="[%X]",
handlers=[RichHandler()],
)
log = logging.getLogger("sotopia.llm_agent")
log.setLevel(logging.INFO)
# Prevent propagation to root logger
log.propagate = False
log.addHandler(RichHandler(rich_tracebacks=True, show_time=True))


@NodeFactory.register("llm_agent")
Expand All @@ -36,6 +34,7 @@ def __init__(
output_channel: str,
node_name: str,
model_name: str,
agents_str: str,
goal: str,
agent_name: str = "",
background: dict[str, Any] | None = None,
Expand All @@ -49,6 +48,7 @@ def __init__(
node_name,
)
self.output_channel = output_channel
self.agents_str: str = agents_str
self.count_ticks: int = 0
self.message_history: list[Observation] = []
self.goal: str = goal
Expand Down Expand Up @@ -83,7 +83,9 @@ async def aact(self, obs: Observation) -> AgentAction:
agent_name=self.name,
output_channel=self.output_channel,
action_type="none",
argument="",
argument=json.dumps(
{"pk": self.agent_profile_pk, "model_name": self.model_name}
),
)
args = json.loads(obs.last_turn)
self.set_profile(args["use_pk_value"])
Expand All @@ -104,7 +106,9 @@ async def aact(self, obs: Observation) -> AgentAction:
agent_name=self.name,
output_channel=self.output_channel,
action_type="none",
argument="",
argument=json.dumps(
{"pk": self.agent_profile_pk, "model_name": self.model_name}
),
)
elif len(obs.available_actions) == 1 and "leave" in obs.available_actions:
self.shutdown_event.set()
Expand All @@ -116,29 +120,46 @@ async def aact(self, obs: Observation) -> AgentAction:
)
else:
history = self._format_message_history(self.message_history)
action: str = await agenerate(
model_name=self.model_name,
template="Imagine that you are a friend of the other persons. Here is the "
"conversation between you and them.\n"
"You are {agent_name} in the conversation.\n"
"{message_history}\n"
"and you plan to {goal}.\n"
"You can choose to interrupt the other person "
"by saying something or not to interrupt by outputting notiong. What would you say? "
"Please only output a sentence or not outputting anything."
"{format_instructions}",
input_values={
"message_history": history,
"goal": self.goal,
"agent_name": self.name,
},
temperature=0.7,
output_parser=StrOutputParser(),
)

try:
to, response = await agenerate(
model_name=self.model_name,
template="Imagine that you are a friend of the other persons. Here is the "
"conversation between you and them.\n"
"List of people involved: {all_agents}"
"You are {agent_name} in the conversation.\n"
"{message_history}\n"
"and you plan to {goal}.\n"
"You can choose to interrupt the other person "
"by saying something or not interrupt by outputting nothing. What would you say? \n"
"If you choose to say something, before your output mention the person you want to talk to "
"or 'All' if you want to say something to everyone\n"
"For example, 'All: I am tired' if you want to address everyone or 'John Doe: I am tired' "
"if you want to address John Doe. If you choose to address a person, ensure they are in the list of "
"people involved\n"
"Please only output a sentence or do not output anything."
"{format_instructions}",
input_values={
"message_history": history,
"goal": self.goal,
"agent_name": self.name,
"all_agents": self.agents_str,
},
temperature=0.7,
output_parser=StrOutputParser(),
)
except Exception as e:
log.error(f"Error generating response: {e}")
response = ""
if not response:
return AgentAction(
agent_name=self.name,
output_channel=self.output_channel,
action_type="none",
argument=json.dumps({"action": "", "to": "all"}),
)
return AgentAction(
agent_name=self.name,
output_channel=self.output_channel,
action_type="speak",
argument=action,
argument=json.dumps({"action": response, "to": to}),
)
7 changes: 5 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,11 @@ dependencies = [
"together>=0.2.4,<1.5.0",
"pydantic>=2.5.0,<3.0.0",
"hiredis>=3.0.0",
"litellm>=1.65.0",
"aact"
"litellm>=0.1.1",
"aact",
"modal>=0.68.18",
"streamlit>=1.41.1",
"fastapi[standard]>=0.115.6",
]

[build-system]
Expand Down
3 changes: 2 additions & 1 deletion sotopia/api/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ modal deploy scripts/modal/modal_api_server.py

To run the FastAPI server, you can use the following command:
```bash
uv run fastapi run sotopia/api/fastapi_server.py --port 8080
uv run rq worker
uv run fastapi run sotopia/api/fastapi_server.py --workers 4 --port 8080
```

Here is also an example of using the FastAPI server:
Expand Down
176 changes: 141 additions & 35 deletions sotopia/api/fastapi_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@
EpisodeLLMEvaluator,
EvaluationForTwoAgents,
)
from sotopia.server import arun_one_episode
from sotopia.agents import LLMAgent, Agents
from sotopia.agents import Agents, LLMAgent
from fastapi import (
FastAPI,
WebSocket,
Expand All @@ -53,8 +52,10 @@
import logging
from fastapi.responses import Response

logger = logging.getLogger(__name__)

# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# app = FastAPI()

Expand Down Expand Up @@ -186,37 +187,107 @@ async def handle_client_message(
) -> bool:
try:
msg_type = message.get("type")
data = message.get("data", {})
if msg_type == WSMessageType.FINISH_SIM.value:
return True
# TODO handle other message types
if msg_type == WSMessageType.CLIENT_MSG.value:
# Check if this is a message with content
if "content" in data:
message_content = data.get("content", "")
receiver = data.get("to", "")
if not message_content:
await self.send_error(
websocket,
ErrorType.INVALID_MESSAGE,
"Message must include content",
)
return False

# Forward to simulator
await simulator.send_message(
{
"content": message_content,
"sender": "redis_agent",
"receiver": receiver,
}
)
return False
except Exception as e:
msg = f"Error handling client message: {e}"
logger.error(msg)
await self.send_error(websocket, ErrorType.INVALID_MESSAGE, msg)
error_msg = f"Error handling client message: {e}"
logger.error(error_msg)
await self.send_error(websocket, ErrorType.INVALID_MESSAGE, error_msg)
return False

async def run_simulation(
self, websocket: WebSocket, simulator: WebSocketSotopiaSimulator
) -> None:
"""Run the simulation and process client messages"""
try:
async for message in simulator.arun():
await self.send_message(websocket, WSMessageType.SERVER_MSG, message)
# Start the simulation tasks
sim_task = asyncio.create_task(
self._process_simulation(websocket, simulator)
)
client_task = asyncio.create_task(
self._process_client_messages(websocket, simulator)
)

# Wait for either task to complete
done, pending = await asyncio.wait(
[sim_task, client_task], return_when=asyncio.FIRST_COMPLETED
)

# Cancel the remaining task
for task in pending:
task.cancel()
try:
data = await asyncio.wait_for(websocket.receive_json(), timeout=0.1)
if await self.handle_client_message(websocket, simulator, data):
break
except asyncio.TimeoutError:
continue
await task
except asyncio.CancelledError:
pass

except Exception as e:
msg = f"Error running simulation: {e}"
logger.error(msg)
await self.send_error(websocket, ErrorType.SIMULATION_ISSUE, msg)
finally:
# Always send END_SIM message
await self.send_message(websocket, WSMessageType.END_SIM, {})

async def _process_simulation(
self,
websocket: WebSocket,
simulator: WebSocketSotopiaSimulator,
) -> None:
"""
Process the simulation by running either the simulator epilogs or redis epilogs

Args:
websocket: The WebSocket connection
simulator: The simulation manager
"""
try:
# Use the simulator's built-in arun generator to get messages
async for message in simulator.arun():
await self.send_message(websocket, WSMessageType.SERVER_MSG, message)
except Exception as e:
logger.error(f"Error in simulation processor: {e}")
raise

async def _process_client_messages(
self, websocket: WebSocket, simulator: WebSocketSotopiaSimulator
) -> None:
"""Process messages from the client"""
try:
while True:
data = await websocket.receive_json()
should_end = await self.handle_client_message(
websocket, simulator, data
)
if should_end:
break
except Exception as e:
logger.error(f"Error processing client messages: {e}")
raise

@staticmethod
async def send_message(
websocket: WebSocket, msg_type: WSMessageType, data: dict[str, Any]
Expand Down Expand Up @@ -294,6 +365,7 @@ async def nonstreaming_simulation(
),
}
)
from sotopia.server import arun_one_episode

await arun_one_episode(
env=env,
Expand Down Expand Up @@ -713,29 +785,63 @@ async def websocket_endpoint(websocket: WebSocket, token: str) -> None:
while True:
start_msg = await websocket.receive_json()
if start_msg.get("type") != WSMessageType.START_SIM.value:
await manager.send_error(
websocket,
ErrorType.INVALID_MESSAGE,
"First message must be of type START_SIM",
)
continue
sim_data = start_msg.get("data", {})
env_id = sim_data.get("env_id", "")
agent_ids = sim_data.get("agent_ids", [])

async with manager.state.start_simulation(token):
simulator = await manager.create_simulator(
env_id=start_msg["data"]["env_id"],
agent_ids=start_msg["data"]["agent_ids"],
agent_models=start_msg["data"].get(
"agent_models", ["gpt-4o-mini", "gpt-4o-mini"]
),
env_profile_dict=start_msg["data"].get(
"env_profile_dict", {}
),
agent_profile_dicts=start_msg["data"].get(
"agent_profile_dicts", []
),
evaluator_model=start_msg["data"].get(
"evaluator_model", "gpt-4o"
),
evaluation_dimension_list_name=start_msg["data"].get(
"evaluation_dimension_list_name", "sotopia"
),
max_turns=start_msg["data"].get("max_turns", 20),
)
await manager.run_simulation(websocket, simulator)
try:
# Create the simulator
simulator = await manager.create_simulator(
env_id=env_id,
agent_ids=agent_ids,
agent_models=sim_data.get(
"agent_models", ["gpt-4o-mini"] * len(agent_ids)
),
evaluator_model=sim_data.get(
"evaluator_model", "gpt-4o"
),
evaluation_dimension_list_name=sim_data.get(
"evaluation_dimension_list_name", "sotopia"
),
env_profile_dict=sim_data.get("env_profile_dict", {}),
agent_profile_dicts=sim_data.get(
"agent_profile_dicts", []
),
max_turns=sim_data.get("max_turns", 20),
)
await simulator.connect_to_redis()

# Initial message to client
await manager.send_message(
websocket,
WSMessageType.SERVER_MSG,
{
"status": "simulation_started",
"env_id": simulator.env_id,
"agent_ids": simulator.agent_ids,
"connection_id": simulator.connection_id,
},
)
logger.info(
"WebSocket start sim message confirmation sent."
)
# Run the simulation
await manager.run_simulation(websocket, simulator)

except Exception as e:
logger.error(f"Error creating or running simulator: {e}")
await manager.send_error(
websocket,
ErrorType.SIMULATION_ISSUE,
f"Error in simulation: {str(e)}",
)

except WebSocketDisconnect:
logger.info(f"Client disconnected: {token}")
Expand Down
Loading
Loading