Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
File renamed without changes.
2 changes: 1 addition & 1 deletion massgen/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@
from .message_templates import MessageTemplates, get_templates
from .orchestrator import Orchestrator, create_orchestrator

__version__ = "0.1.71"
__version__ = "0.1.72"
__author__ = "MassGen Contributors"


Expand Down
42 changes: 40 additions & 2 deletions massgen/backend/chat_completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@
CustomToolChunk,
ToolExecutionConfig,
)
from .llm_circuit_breaker import (
CircuitBreakerOpenError,
LLMCircuitBreaker,
LLMCircuitBreakerConfig,
)


class ChatCompletionsBackend(StreamingBufferMixin, CustomToolAndMCPBackend):
Expand All @@ -58,6 +63,8 @@ class ChatCompletionsBackend(StreamingBufferMixin, CustomToolAndMCPBackend):
"""

def __init__(self, api_key: str | None = None, **kwargs):
# Extract circuit breaker config before passing to super
cb_config = self._build_circuit_breaker_config(kwargs)
super().__init__(api_key, **kwargs)
# Backend name is already set in MCPBackend, but we may need to override it
self.backend_name = self.get_provider_name()
Expand All @@ -72,6 +79,27 @@ def __init__(self, api_key: str | None = None, **kwargs):
self._stream_usage_received: bool = True # True = no pending estimation needed
# Track reasoning state for streaming (needed for reasoning_done transition)
self._reasoning_active: bool = False
self.circuit_breaker = LLMCircuitBreaker(
config=cb_config,
backend_name=self.get_provider_name(),
)

@staticmethod
def _build_circuit_breaker_config(
kwargs: dict[str, Any],
) -> LLMCircuitBreakerConfig:
"""Extract circuit breaker settings from kwargs and build config."""
cb_kwargs: dict[str, Any] = {}
prefix = "llm_circuit_breaker_"
keys_to_pop: list[str] = []
for key in kwargs:
if key.startswith(prefix):
param = key[len(prefix) :]
cb_kwargs[param] = kwargs[key]
keys_to_pop.append(key)
for key in keys_to_pop:
kwargs.pop(key)
return LLMCircuitBreakerConfig(**cb_kwargs)

def finalize_token_tracking(self) -> None:
"""Finalize token tracking by estimating tokens for any interrupted streams.
Expand Down Expand Up @@ -276,9 +304,19 @@ async def _stream_with_custom_and_mcp_tools(
model=model,
operation="stream",
) as llm_span:
# Start streaming - wrap in try/except for context length errors
# Start streaming - wrap with circuit breaker + context length handling
try:
stream = await client.chat.completions.create(**api_params)

async def _make_api_call():
return await client.chat.completions.create(**api_params)

stream = await self.circuit_breaker.call_with_retry(
_make_api_call,
agent_id=agent_id,
)
except CircuitBreakerOpenError:
self.end_api_call_timing(success=False, error="circuit_breaker_open")
raise
except Exception as e:
if is_context_length_error(e) and not _compression_retry:
# Context length exceeded on initial request - compress and retry
Expand Down
52 changes: 51 additions & 1 deletion massgen/backend/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,11 @@
PostEvaluationResponse,
VoteOnlyCoordinationResponse,
)
from .llm_circuit_breaker import (
CircuitBreakerOpenError,
LLMCircuitBreaker,
LLMCircuitBreakerConfig,
)
from .rate_limiter import GlobalRateLimiter


Expand Down Expand Up @@ -247,6 +252,9 @@ def __init__(self, api_key: str | None = None, **kwargs):
# Store Gemini-specific API key before calling parent init
gemini_api_key = api_key or os.getenv("GOOGLE_API_KEY") or os.getenv("GEMINI_API_KEY")

# Extract circuit breaker config before other kwargs processing
cb_config = self._build_circuit_breaker_config(kwargs)

# Extract and remove enable_rate_limit and backoff config
enable_rate_limit = kwargs.pop("enable_rate_limit", False)
model_name = kwargs.get("model", "")
Expand Down Expand Up @@ -293,6 +301,12 @@ def __init__(self, api_key: str | None = None, **kwargs):
self.backoff_retry_count = 0
self.backoff_total_delay = 0.0

# LLM circuit breaker (opt-in, default disabled)
self.circuit_breaker = LLMCircuitBreaker(
config=cb_config,
backend_name="gemini",
)

# Initialize multi-dimensional rate limiter for Gemini API
# Supports RPM (Requests Per Minute), TPM (Tokens Per Minute), RPD (Requests Per Day)
# Configuration loaded from massgen/config/rate_limits.yaml
Expand Down Expand Up @@ -335,6 +349,23 @@ def __init__(self, api_key: str | None = None, **kwargs):
self.rate_limiter = None
logger.info(f"[Gemini] Rate limiting disabled for '{model_name}'")

@staticmethod
def _build_circuit_breaker_config(
kwargs: dict[str, Any],
) -> LLMCircuitBreakerConfig:
"""Extract circuit breaker settings from kwargs and build config."""
cb_kwargs: dict[str, Any] = {}
prefix = "llm_circuit_breaker_"
keys_to_pop: list[str] = []
for key in kwargs:
if key.startswith(prefix):
param = key[len(prefix) :]
cb_kwargs[param] = kwargs[key]
keys_to_pop.append(key)
for key in keys_to_pop:
kwargs.pop(key)
return LLMCircuitBreakerConfig(**cb_kwargs)

def _normalize_and_resolve_tool_name(self, tool_name: str) -> str:
"""Normalize Gemini tool names and resolve MCP aliases.

Expand Down Expand Up @@ -777,6 +808,11 @@ async def stream_with_tools(self, messages: list[dict[str, Any]], tools: list[di
last_response_with_candidates = None

cfg = self.backoff_config

# Circuit breaker gate
if self.circuit_breaker.should_block():
raise CircuitBreakerOpenError("Circuit breaker is open for gemini")

first_token_recorded = False
for stream_attempt in range(1, cfg.max_attempts + 1):
try:
Expand Down Expand Up @@ -863,6 +899,7 @@ async def stream_with_tools(self, messages: list[dict[str, Any]], tools: list[di

# End API call timing on successful completion
self.end_api_call_timing(success=True)
self.circuit_breaker.record_success()
break

except Exception as stream_exc:
Expand All @@ -873,6 +910,10 @@ async def stream_with_tools(self, messages: list[dict[str, Any]], tools: list[di

if not is_retryable or stream_attempt >= cfg.max_attempts:
if is_retryable:
self.circuit_breaker.record_failure(
error_type=f"exhausted_{status_code or 'unknown'}",
error_message=f"Max retries exhausted: {error_msg[:200]}",
)
yield StreamChunk(
type="error",
error=f"⚠️ Rate limit exceeded after {cfg.max_attempts} retries. Please try again later.",
Expand Down Expand Up @@ -1443,6 +1484,10 @@ def tool_config_for_call(call: dict[str, Any]) -> ToolExecutionConfig:
cont_first_token_recorded = False

# Retry for continuation with backoff
# Circuit breaker gate
if self.circuit_breaker.should_block():
raise CircuitBreakerOpenError("Circuit breaker is open for gemini")

for cont_attempt in range(1, cfg.max_attempts + 1):
try:
# Start API call timing for continuation
Expand Down Expand Up @@ -1519,16 +1564,21 @@ def tool_config_for_call(call: dict[str, Any]) -> ToolExecutionConfig:

# End API call timing on successful completion
self.end_api_call_timing(success=True)
self.circuit_breaker.record_success()
break

except Exception as cont_exc:
# End API call timing with failure
self.end_api_call_timing(success=False, error=str(cont_exc))
is_retryable, status_code, _ = _is_retryable_gemini_error(cont_exc, cfg.retry_statuses)
is_retryable, status_code, error_msg = _is_retryable_gemini_error(cont_exc, cfg.retry_statuses)

if not is_retryable or cont_attempt >= cfg.max_attempts:
# Yield user-friendly error before raising
if is_retryable:
self.circuit_breaker.record_failure(
error_type=f"exhausted_{status_code or 'unknown'}",
error_message=f"Max retries exhausted: {error_msg[:200]}",
)
yield StreamChunk(
type="error",
error=f"⚠️ Rate limit exceeded after {cfg.max_attempts} retries. Please try again later.",
Expand Down
55 changes: 53 additions & 2 deletions massgen/backend/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@
ToolExecutionConfig,
UploadFileError,
)
from .llm_circuit_breaker import (
CircuitBreakerOpenError,
LLMCircuitBreaker,
LLMCircuitBreakerConfig,
)


class _WSEvent:
Expand Down Expand Up @@ -85,6 +90,8 @@ class ResponseBackend(StreamingBufferMixin, CustomToolAndMCPBackend):
"""Backend using the standard Response API format with multimodal support."""

def __init__(self, api_key: str | None = None, **kwargs):
# Extract circuit breaker config before passing to super
cb_config = self._build_circuit_breaker_config(kwargs)
super().__init__(api_key, **kwargs)
self.api_key = api_key or os.getenv("OPENAI_API_KEY")
self.formatter = ResponseFormatter()
Expand All @@ -107,6 +114,27 @@ def __init__(self, api_key: str | None = None, **kwargs):
self._uploaded_file_ids: list[str] = []

# Note: _streaming_buffer is provided by StreamingBufferMixin
self.circuit_breaker = LLMCircuitBreaker(
config=cb_config,
backend_name="response_api",
)

@staticmethod
def _build_circuit_breaker_config(
kwargs: dict[str, Any],
) -> LLMCircuitBreakerConfig:
"""Extract circuit breaker settings from kwargs and build config."""
cb_kwargs: dict[str, Any] = {}
prefix = "llm_circuit_breaker_"
keys_to_pop: list[str] = []
for key in kwargs:
if key.startswith(prefix):
param = key[len(prefix) :]
cb_kwargs[param] = kwargs[key]
keys_to_pop.append(key)
for key in keys_to_pop:
kwargs.pop(key)
return LLMCircuitBreakerConfig(**cb_kwargs)

def supports_upload_files(self) -> bool:
return True
Expand Down Expand Up @@ -244,12 +272,20 @@ async def _stream_without_custom_and_mcp_tools(
_compression_retry = kwargs.get("_compression_retry", False)
ws_transport = kwargs.get("_ws_transport")

# Start API call timing for non-MCP path
model = api_params.get("model", "unknown")
self.start_api_call_timing(model)

try:
stream = await self._create_response_stream(
api_params,
client,
ws_transport,
agent_id=agent_id,
)
except CircuitBreakerOpenError:
self.end_api_call_timing(success=False, error="circuit_breaker_open")
raise
except Exception as e:
# Debug: Catch input[N].content format errors and print the problematic message
error_str = str(e)
Expand All @@ -271,6 +307,7 @@ async def _stream_without_custom_and_mcp_tools(
from ._context_errors import is_context_length_error

if is_context_length_error(e) and not _compression_retry:
self.end_api_call_timing(success=False, error=str(e))
logger.warning(
f"[{self.get_provider_name()}] Context length exceeded, " f"attempting compression recovery...",
)
Expand Down Expand Up @@ -307,6 +344,7 @@ async def _stream_without_custom_and_mcp_tools(
api_params,
client,
ws_transport,
agent_id=agent_id,
)

# Notify user that compression succeeded
Expand All @@ -323,6 +361,7 @@ async def _stream_without_custom_and_mcp_tools(
f"[{self.get_provider_name()}] Compression recovery successful via summarization " f"({input_count} items)",
)
else:
self.end_api_call_timing(success=False, error=str(e))
raise

async for chunk in self._process_stream(stream, all_params, agent_id):
Expand Down Expand Up @@ -471,7 +510,11 @@ async def _stream_with_custom_and_mcp_tools(
api_params,
client,
ws_transport,
agent_id=agent_id,
)
except CircuitBreakerOpenError:
self.end_api_call_timing(success=False, error="circuit_breaker_open")
raise
except Exception as e:
# Debug: Catch input[N].content format errors and print the problematic message
error_str = str(e)
Expand Down Expand Up @@ -533,6 +576,7 @@ async def _stream_with_custom_and_mcp_tools(
api_params,
client,
ws_transport,
agent_id=agent_id,
)

# Notify user that compression succeeded
Expand Down Expand Up @@ -1758,12 +1802,19 @@ def extract_tool_result_content(self, tool_result_message: dict[str, Any]) -> st
"""Extract content from OpenAI Responses API tool result message."""
return tool_result_message.get("output", "")

async def _create_response_stream(self, api_params, client, ws_transport=None):
async def _create_response_stream(self, api_params, client, ws_transport=None, agent_id=None):
"""Create a response stream via HTTP or websocket transport."""
if ws_transport is not None and ws_transport.is_connected:
logger.debug("[WebSocket] Sending response.create via WebSocket")
return self._ws_event_stream(ws_transport, api_params)
return await client.responses.create(**api_params)

async def _make_api_call():
return await client.responses.create(**api_params)

return await self.circuit_breaker.call_with_retry(
_make_api_call,
agent_id=agent_id,
)

async def _ws_event_stream(self, ws_transport, api_params):
"""Wrap websocket JSON events as objects matching SDK stream chunks."""
Expand Down
Loading
Loading