Skip to content
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
File renamed without changes.
2 changes: 1 addition & 1 deletion massgen/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@
from .message_templates import MessageTemplates, get_templates
from .orchestrator import Orchestrator, create_orchestrator

__version__ = "0.1.71"
__version__ = "0.1.72"
__author__ = "MassGen Contributors"


Expand Down
19 changes: 17 additions & 2 deletions massgen/agent_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,19 +674,32 @@ def create_claude_config(
return cls(backend_params=backend_params)

@classmethod
def create_grok_config(cls, model: str = "grok-2-1212", enable_web_search: bool = False, **kwargs) -> "AgentConfig":
def create_grok_config(
cls,
model: str = "grok-2-1212",
enable_web_search: bool = False,
enable_x_search: bool = False,
enable_code_execution: bool = False,
**kwargs,
) -> "AgentConfig":
"""Create xAI Grok configuration.

Args:
model: Grok model name
enable_web_search: Enable Live Search feature
enable_web_search: Enable xAI web search
enable_x_search: Enable xAI X search
enable_code_execution: Enable xAI code execution
**kwargs: Additional backend parameters
"""
backend_params = {"model": model, **kwargs}

# Add tool enablement to backend_params
if enable_web_search:
backend_params["enable_web_search"] = True
if enable_x_search:
backend_params["enable_x_search"] = True
if enable_code_execution:
backend_params["enable_code_execution"] = True

return cls(backend_params=backend_params)

Expand Down Expand Up @@ -945,6 +958,8 @@ def for_computational_task(cls, model: str = "gpt-4o", backend: str = "openai")
"""
if backend == "openai":
return cls.create_openai_config(model, enable_code_interpreter=True)
elif backend == "grok":
return cls.create_grok_config(model, enable_code_execution=True)
elif backend == "claude":
return cls.create_claude_config(model, enable_code_execution=True)
elif backend == "gemini":
Expand Down
18 changes: 15 additions & 3 deletions massgen/api_params_handler/_response_api_params_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ def get_excluded_params(self) -> set[str]:
return self.get_base_excluded_params().union(
{
"enable_web_search",
"enable_x_search",
"enable_code_execution",
"enable_code_interpreter",
"allowed_tools",
"exclude_tools",
Expand All @@ -38,11 +40,15 @@ def get_excluded_params(self) -> set[str]:
def get_provider_tools(self, all_params: dict[str, Any]) -> list[dict[str, Any]]:
"""Get provider tools for Response API format."""
provider_tools = []
provider_name = self.backend.get_provider_name()

if all_params.get("enable_web_search", False):
provider_tools.append({"type": "web_search"})

if all_params.get("enable_code_interpreter", False):
if provider_name == "Grok" and all_params.get("enable_x_search", False):
provider_tools.append({"type": "x_search"})

if all_params.get("enable_code_interpreter", False) or all_params.get("enable_code_execution", False):
provider_tools.append(
{
"type": "code_interpreter",
Expand Down Expand Up @@ -116,12 +122,18 @@ async def build_api_params(
logger.debug(f"Using previous_response_id for reasoning continuity: {previous_response_id}")

# Handle parallel_tool_calls with built-in tools constraint
builtin_flags = ("enable_web_search", "enable_code_interpreter", "_has_file_search_files")
builtin_flags = (
"enable_web_search",
"enable_x_search",
"enable_code_execution",
"enable_code_interpreter",
"_has_file_search_files",
)
if any(all_params.get(f, False) for f in builtin_flags):
# Built-in tools present - MUST disable parallel calling
if all_params.get("parallel_tool_calls") is True:
logger.warning(
"parallel_tool_calls=true is not supported with built-in tools " "(web_search, code_interpreter, file_search). " "Setting parallel_tool_calls=false.",
"parallel_tool_calls=true is not supported with built-in tools " "(web_search, x_search, code_interpreter, file_search). " "Setting parallel_tool_calls=false.",
)
api_params["parallel_tool_calls"] = False
elif "parallel_tool_calls" in all_params:
Expand Down
23 changes: 21 additions & 2 deletions massgen/backend/capabilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ class Capability(Enum):
"""Enumeration of all possible backend capabilities."""

WEB_SEARCH = "web_search"
X_SEARCH = "x_search"
CODE_EXECUTION = "code_execution"
BASH = "bash"
MULTIMODAL = "multimodal" # Legacy - being phased out
Expand Down Expand Up @@ -386,10 +387,12 @@ class BackendCapabilities:
provider_name="Grok",
supported_capabilities={
"web_search",
"x_search",
"code_execution",
"mcp",
"image_understanding",
},
builtin_tools=["web_search"],
builtin_tools=["web_search", "x_search", "code_execution"],
filesystem_support="mcp",
models=[
"grok-4.20-0309-reasoning",
Expand All @@ -403,7 +406,7 @@ class BackendCapabilities:
],
default_model="grok-4.20-0309-reasoning",
env_var="XAI_API_KEY",
notes="Web search includes real-time data access. Image understanding capabilities.",
notes=("Uses xAI's Responses API tooling surface. " "Supports web_search, x_search, and code execution. " "Legacy Chat Completions search_parameters are not supported."),
model_release_dates={
"grok-4.20-0309-reasoning": "2026-03",
"grok-4-1-fast-reasoning": "2025-11",
Expand All @@ -414,6 +417,7 @@ class BackendCapabilities:
"grok-3": "2025-02",
"grok-3-mini": "2025-05",
},
base_url="https://api.x.ai/v1",
),
"azure_openai": BackendCapabilities(
backend_type="azure_openai",
Expand Down Expand Up @@ -854,6 +858,14 @@ def validate_backend_config(backend_type: str, config: dict) -> list[str]:
if "web_search" not in caps.supported_capabilities:
errors.append(f"{backend_type} does not support web_search")

if config.get("enable_x_search"):
if backend_type != "grok":
errors.append(
f"enable_x_search is only supported by Grok backend, not {backend_type}",
)
elif "x_search" not in caps.supported_capabilities:
errors.append(f"{backend_type} does not support x_search")

if "enable_code_execution" in config and config["enable_code_execution"]:
if "code_execution" not in caps.supported_capabilities:
errors.append(f"{backend_type} does not support code_execution")
Expand All @@ -880,6 +892,13 @@ def validate_backend_config(backend_type: str, config: dict) -> list[str]:
if "mcp" not in caps.supported_capabilities:
errors.append(f"{backend_type} does not support MCP")

if backend_type == "grok":
extra_body = config.get("extra_body")
if isinstance(extra_body, dict) and "search_parameters" in extra_body:
errors.append(
"Grok no longer supports extra_body.search_parameters. " "Use enable_web_search and/or enable_x_search instead.",
)

# Check for deprecated system prompt parameters (standardized across all backends)
if "append_system_prompt" in config:
errors.append(
Expand Down
42 changes: 40 additions & 2 deletions massgen/backend/chat_completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@
CustomToolChunk,
ToolExecutionConfig,
)
from .llm_circuit_breaker import (
CircuitBreakerOpenError,
LLMCircuitBreaker,
LLMCircuitBreakerConfig,
)


class ChatCompletionsBackend(StreamingBufferMixin, CustomToolAndMCPBackend):
Expand All @@ -58,6 +63,8 @@ class ChatCompletionsBackend(StreamingBufferMixin, CustomToolAndMCPBackend):
"""

def __init__(self, api_key: str | None = None, **kwargs):
# Extract circuit breaker config before passing to super
cb_config = self._build_circuit_breaker_config(kwargs)
super().__init__(api_key, **kwargs)
# Backend name is already set in MCPBackend, but we may need to override it
self.backend_name = self.get_provider_name()
Expand All @@ -72,6 +79,27 @@ def __init__(self, api_key: str | None = None, **kwargs):
self._stream_usage_received: bool = True # True = no pending estimation needed
# Track reasoning state for streaming (needed for reasoning_done transition)
self._reasoning_active: bool = False
self.circuit_breaker = LLMCircuitBreaker(
config=cb_config,
backend_name=self.get_provider_name(),
)

@staticmethod
def _build_circuit_breaker_config(
kwargs: dict[str, Any],
) -> LLMCircuitBreakerConfig:
"""Extract circuit breaker settings from kwargs and build config."""
cb_kwargs: dict[str, Any] = {}
prefix = "llm_circuit_breaker_"
keys_to_pop: list[str] = []
for key in kwargs:
if key.startswith(prefix):
param = key[len(prefix) :]
cb_kwargs[param] = kwargs[key]
keys_to_pop.append(key)
for key in keys_to_pop:
kwargs.pop(key)
return LLMCircuitBreakerConfig(**cb_kwargs)

def finalize_token_tracking(self) -> None:
"""Finalize token tracking by estimating tokens for any interrupted streams.
Expand Down Expand Up @@ -276,9 +304,19 @@ async def _stream_with_custom_and_mcp_tools(
model=model,
operation="stream",
) as llm_span:
# Start streaming - wrap in try/except for context length errors
# Start streaming - wrap with circuit breaker + context length handling
try:
stream = await client.chat.completions.create(**api_params)

async def _make_api_call():
return await client.chat.completions.create(**api_params)

stream = await self.circuit_breaker.call_with_retry(
_make_api_call,
agent_id=agent_id,
)
except CircuitBreakerOpenError:
self.end_api_call_timing(success=False, error="circuit_breaker_open")
raise
except Exception as e:
if is_context_length_error(e) and not _compression_retry:
# Context length exceeded on initial request - compress and retry
Expand Down
52 changes: 51 additions & 1 deletion massgen/backend/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,11 @@
PostEvaluationResponse,
VoteOnlyCoordinationResponse,
)
from .llm_circuit_breaker import (
CircuitBreakerOpenError,
LLMCircuitBreaker,
LLMCircuitBreakerConfig,
)
from .rate_limiter import GlobalRateLimiter


Expand Down Expand Up @@ -247,6 +252,9 @@ def __init__(self, api_key: str | None = None, **kwargs):
# Store Gemini-specific API key before calling parent init
gemini_api_key = api_key or os.getenv("GOOGLE_API_KEY") or os.getenv("GEMINI_API_KEY")

# Extract circuit breaker config before other kwargs processing
cb_config = self._build_circuit_breaker_config(kwargs)

# Extract and remove enable_rate_limit and backoff config
enable_rate_limit = kwargs.pop("enable_rate_limit", False)
model_name = kwargs.get("model", "")
Expand Down Expand Up @@ -293,6 +301,12 @@ def __init__(self, api_key: str | None = None, **kwargs):
self.backoff_retry_count = 0
self.backoff_total_delay = 0.0

# LLM circuit breaker (opt-in, default disabled)
self.circuit_breaker = LLMCircuitBreaker(
config=cb_config,
backend_name="gemini",
)

# Initialize multi-dimensional rate limiter for Gemini API
# Supports RPM (Requests Per Minute), TPM (Tokens Per Minute), RPD (Requests Per Day)
# Configuration loaded from massgen/config/rate_limits.yaml
Expand Down Expand Up @@ -335,6 +349,23 @@ def __init__(self, api_key: str | None = None, **kwargs):
self.rate_limiter = None
logger.info(f"[Gemini] Rate limiting disabled for '{model_name}'")

@staticmethod
def _build_circuit_breaker_config(
kwargs: dict[str, Any],
) -> LLMCircuitBreakerConfig:
"""Extract circuit breaker settings from kwargs and build config."""
cb_kwargs: dict[str, Any] = {}
prefix = "llm_circuit_breaker_"
keys_to_pop: list[str] = []
for key in kwargs:
if key.startswith(prefix):
param = key[len(prefix) :]
cb_kwargs[param] = kwargs[key]
keys_to_pop.append(key)
for key in keys_to_pop:
kwargs.pop(key)
return LLMCircuitBreakerConfig(**cb_kwargs)

def _normalize_and_resolve_tool_name(self, tool_name: str) -> str:
"""Normalize Gemini tool names and resolve MCP aliases.

Expand Down Expand Up @@ -777,6 +808,11 @@ async def stream_with_tools(self, messages: list[dict[str, Any]], tools: list[di
last_response_with_candidates = None

cfg = self.backoff_config

# Circuit breaker gate
if self.circuit_breaker.should_block():
raise CircuitBreakerOpenError("Circuit breaker is open for gemini")

first_token_recorded = False
for stream_attempt in range(1, cfg.max_attempts + 1):
try:
Expand Down Expand Up @@ -863,6 +899,7 @@ async def stream_with_tools(self, messages: list[dict[str, Any]], tools: list[di

# End API call timing on successful completion
self.end_api_call_timing(success=True)
self.circuit_breaker.record_success()
break

except Exception as stream_exc:
Expand All @@ -873,6 +910,10 @@ async def stream_with_tools(self, messages: list[dict[str, Any]], tools: list[di

if not is_retryable or stream_attempt >= cfg.max_attempts:
if is_retryable:
self.circuit_breaker.record_failure(
error_type=f"exhausted_{status_code or 'unknown'}",
error_message=f"Max retries exhausted: {error_msg[:200]}",
)
yield StreamChunk(
type="error",
error=f"⚠️ Rate limit exceeded after {cfg.max_attempts} retries. Please try again later.",
Expand Down Expand Up @@ -1443,6 +1484,10 @@ def tool_config_for_call(call: dict[str, Any]) -> ToolExecutionConfig:
cont_first_token_recorded = False

# Retry for continuation with backoff
# Circuit breaker gate
if self.circuit_breaker.should_block():
raise CircuitBreakerOpenError("Circuit breaker is open for gemini")

for cont_attempt in range(1, cfg.max_attempts + 1):
try:
# Start API call timing for continuation
Expand Down Expand Up @@ -1519,16 +1564,21 @@ def tool_config_for_call(call: dict[str, Any]) -> ToolExecutionConfig:

# End API call timing on successful completion
self.end_api_call_timing(success=True)
self.circuit_breaker.record_success()
break

except Exception as cont_exc:
# End API call timing with failure
self.end_api_call_timing(success=False, error=str(cont_exc))
is_retryable, status_code, _ = _is_retryable_gemini_error(cont_exc, cfg.retry_statuses)
is_retryable, status_code, error_msg = _is_retryable_gemini_error(cont_exc, cfg.retry_statuses)

if not is_retryable or cont_attempt >= cfg.max_attempts:
# Yield user-friendly error before raising
if is_retryable:
self.circuit_breaker.record_failure(
error_type=f"exhausted_{status_code or 'unknown'}",
error_message=f"Max retries exhausted: {error_msg[:200]}",
)
yield StreamChunk(
type="error",
error=f"⚠️ Rate limit exceeded after {cfg.max_attempts} retries. Please try again later.",
Expand Down
Loading
Loading