Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions massgen/api_params_handler/_api_params_handler_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,8 @@ def get_base_excluded_params(self) -> set[str]:
"fairness_enabled",
"fairness_lead_cap_answers",
"max_midstream_injections_per_round",
# WebSocket mode (transport control, not an API parameter)
"websocket_mode",
"defer_peer_updates_until_restart",
"allow_midstream_peer_updates_before_checklist_submit",
"max_checklist_calls_per_round",
Expand Down
13 changes: 9 additions & 4 deletions massgen/api_params_handler/_response_api_params_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ def get_excluded_params(self) -> set[str]:
"enable_audio_generation", # Internal flag for audio generation (used in system messages only)
"enable_video_generation", # Internal flag for video generation (used in system messages only)
"previous_response_id", # Handled explicitly above for reasoning continuity
"websocket_mode", # Transport control, not an API parameter
"base_url", # Client-constructor param, not a request-body param
"organization", # Client-constructor param, not a request-body param
},
)

Expand Down Expand Up @@ -71,10 +74,12 @@ async def build_api_params(
converted_messages = self.formatter.format_messages(messages)

# Response API uses 'input' instead of 'messages'
api_params = {
"input": converted_messages,
"stream": True,
}
websocket_mode = all_params.get("websocket_mode", False) # In WebSocket mode, stream/background are not used (transport handles streaming)
api_params = {"input": converted_messages}
if not websocket_mode:
api_params["stream"] = True
else:
all_params.pop("background", None)
Comment on lines +77 to +82
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Mutating all_params may cause unintended side effects.

The all_params.pop("background", None) call modifies the input dictionary in place. If the caller reuses all_params after this call, the background key will be unexpectedly missing. Consider operating on a copy or documenting this mutation.

🛡️ Suggested fix to avoid mutation
         websocket_mode = all_params.get("websocket_mode", False)  # In WebSocket mode, stream/background are not used (transport handles streaming)
         api_params = {"input": converted_messages}
         if not websocket_mode:
             api_params["stream"] = True
         else:
-            all_params.pop("background", None)
+            # Don't include 'background' in api_params for WebSocket mode
+            # (handled below via excluded params or explicit skip)

Alternatively, if background must be removed to prevent it from being added later in the loop, ensure this mutation is documented in the docstring.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@massgen/api_params_handler/_response_api_params_handler.py` around lines 77 -
82, The code mutates the input dict by calling all_params.pop("background",
None), which can cause unexpected side effects; change the logic to operate on a
shallow copy of all_params (e.g., work with a new dict variable before
modifying) or explicitly build a new params dict instead of mutating all_params
so websocket_mode, api_params, and the rest of the flow remain the same; update
the code paths that reference websocket_mode, api_params, and any later use of
all_params to use the new copy or constructed dict, or document the mutation in
the function docstring if mutation is intentional.


# Set default reasoning configuration for reasoning models (GPT-5, o-series)
# Per OpenAI docs, GPT-5.1 and GPT-5.2 default to reasoning=none, but GPT-5 defaults to medium
Expand Down
152 changes: 152 additions & 0 deletions massgen/backend/_websocket_transport.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
"""
WebSocket transport for the OpenAI Responses API.
Persistent connection for response.create events.
See https://developers.openai.com/api/docs/guides/websocket-mode/
"""

from __future__ import annotations

import asyncio
import json
from typing import Any

import websockets
from websockets.protocol import State as WSState

from ..logger_config import logger

DEFAULT_WS_URL = "wss://api.openai.com/v1/responses"
MAX_RECONNECT_ATTEMPTS = 3
RECONNECT_BASE_DELAY = 1.0


class WebSocketConnectionError(Exception):
"""Raised when a WebSocket connection cannot be established."""


def _extract_error_details(event: dict[str, Any]) -> tuple[str, str]:
"""Extract message and code from websocket error events."""
nested_error = event.get("error")
if isinstance(nested_error, dict):
return (
nested_error.get("message", "Unknown error"),
nested_error.get("code", ""),
)
return (event.get("message", "Unknown error"), event.get("code", ""))


class WebSocketResponseTransport:
"""Persistent WebSocket transport for the OpenAI Responses API."""

def __init__(
self,
api_key: str,
url: str = DEFAULT_WS_URL,
organization: str | None = None,
):
self.api_key = api_key
self.url = url
self.organization = organization
self._ws = None

def _build_headers(self) -> dict[str, str]:
headers = {
"Authorization": f"Bearer {self.api_key}",
}
if self.organization:
headers["OpenAI-Organization"] = self.organization
return headers

def _build_response_create_event(self, payload: dict[str, Any]) -> str:
"""Wrap an API params dict as a response.create WebSocket event."""
event = {"type": "response.create", **payload}
return json.dumps(event)

async def connect(self) -> None:
"""Establish the WebSocket connection with retry logic."""
headers = self._build_headers()
last_error: Exception | None = None

for attempt in range(MAX_RECONNECT_ATTEMPTS):
try:
self._ws = await websockets.connect(
self.url,
additional_headers=headers,
max_size=None,
ping_interval=30,
ping_timeout=10,
)
logger.info(
f"[WebSocket] Connected to {self.url} (attempt {attempt + 1})",
)
return
except Exception as e:
last_error = e
if attempt < MAX_RECONNECT_ATTEMPTS - 1:
delay = RECONNECT_BASE_DELAY * (2**attempt)
logger.warning(
f"[WebSocket] Connection attempt {attempt + 1} failed: {e}. Retrying in {delay}s...",
)
await asyncio.sleep(delay)

raise WebSocketConnectionError(
f"Failed to connect to {self.url} after {MAX_RECONNECT_ATTEMPTS} attempts: {last_error}",
)

async def send_and_receive(
self,
api_params: dict[str, Any],
):
"""Send a response.create event and yield parsed response events.

Args:
api_params: The API params dict (same as HTTP body, minus stream/background).

Yields:
Parsed event dicts with a "type" field matching the HTTP SSE event types
(e.g. "response.output_text.delta", "response.completed").
"""
if self._ws is None:
raise WebSocketConnectionError("Not connected. Call connect() first.")

message = self._build_response_create_event(api_params)
await self._ws.send(message)
logger.debug("[WebSocket] Sent response.create event")

async for raw_message in self._ws:
event = json.loads(raw_message)
event_type = event.get("type", "")
logger.debug(f"[WebSocket] Received event: {event_type}")

if event_type == "error":
error_msg, error_code = _extract_error_details(event)
logger.error(
f"[WebSocket] Server error: {error_msg} (code={error_code})",
)
raise WebSocketConnectionError(
f"WebSocket response.create failed: {error_msg}" + (f" (code={error_code})" if error_code else ""),
)

yield event

if event_type in (
"response.completed",
"response.incomplete",
"response.failed",
):
break

async def close(self) -> None:
"""Close the WebSocket connection."""
if self._ws is not None:
try:
await self._ws.close()
logger.info("[WebSocket] Connection closed")
except Exception as e:
logger.warning(f"[WebSocket] Error closing connection: {e}")
finally:
self._ws = None

@property
def is_connected(self) -> bool:
return self._ws is not None and self._ws.state == WSState.OPEN
2 changes: 2 additions & 0 deletions massgen/backend/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,8 @@ def get_base_excluded_config_params(cls) -> set:
"fairness_enabled",
"fairness_lead_cap_answers",
"max_midstream_injections_per_round",
# WebSocket mode (transport control, not an API parameter)
"websocket_mode",
"defer_peer_updates_until_restart",
"allow_midstream_peer_updates_before_checklist_submit",
"max_checklist_calls_per_round",
Expand Down
1 change: 1 addition & 0 deletions massgen/config_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,6 +558,7 @@ def _validate_backend(self, backend_config: dict[str, Any], location: str, resul
"enable_programmatic_flow",
"enable_tool_search",
"enable_strict_tool_use",
"websocket_mode",
]
for field_name in boolean_fields:
if field_name in backend_config:
Expand Down
13 changes: 13 additions & 0 deletions massgen/configs/providers/openai/gpt5_2_websocket.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# GPT-5.2 with WebSocket Mode
# Single agent using the latest GPT-5.2 model over persistent WebSocket.
agents:
- id: "gpt-5-2-ws"
backend:
type: "openai"
model: "gpt-5.2"
websocket_mode: true
enable_web_search: true
enable_code_interpreter: true
ui:
display_type: "textual_terminal"
logging_enabled: true
12 changes: 12 additions & 0 deletions massgen/configs/providers/openai/gpt5_nano_websocket.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# GPT-5-nano with WebSocket Mode
# Single agent with persistent WebSocket connection.
agents:
- id: "gpt-5-nano-ws"
backend:
type: "openai"
model: "gpt-5-nano"
websocket_mode: true
enable_code_interpreter: true
ui:
display_type: "textual_terminal"
logging_enabled: true
18 changes: 18 additions & 0 deletions massgen/configs/providers/openai/multi_model_websocket.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Multi-Model WebSocket Mode
# Two agents with different models, both using WebSocket transport.
agents:
- id: "gpt-5-2"
backend:
type: "openai"
model: "gpt-5.2"
websocket_mode: true
enable_code_interpreter: true
- id: "gpt-5-nano"
backend:
type: "openai"
model: "gpt-5-nano"
websocket_mode: true
enable_code_interpreter: true
ui:
display_type: "textual_terminal"
logging_enabled: true
6 changes: 4 additions & 2 deletions massgen/formatter/_response_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,10 @@ def format_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]
continue

if message.get("type") == "message":
# Response API output items have type="message" with content array
# containing output_text items. Convert to simple assistant message format.
# Messages with 'id' must keep identity for reasoning→message pairing on replay
if "id" in message:
converted_messages.append(message)
continue
role = message.get("role", "assistant")
content_items = message.get("content", [])
# Extract text from output_text items
Expand Down
Loading
Loading