Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions oxygent/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,12 @@ class Config:
"mcp_is_keep_alive": True,
"is_concurrent_init": True,
},
"rate_limiter": {
"enabled": False,
"default_rate": 1.0,
"default_capacity": 10,
"per_oxy_limits": {},
},
}

@classmethod
Expand Down Expand Up @@ -623,3 +629,45 @@ def set_tool_is_concurrent_init(cls, is_concurrent_init):
@classmethod
def get_tool_is_concurrent_init(cls):
return cls.get_module_config("tool", "is_concurrent_init")

""" rate_limiter """

@classmethod
def set_rate_limiter_config(cls, rate_limiter_config):
return cls.set_module_config("rate_limiter", rate_limiter_config)

@classmethod
def get_rate_limiter_config(cls):
return cls.get_module_config("rate_limiter")

@classmethod
def set_rate_limiter_enabled(cls, enabled: bool):
cls.set_module_config("rate_limiter", "enabled", enabled)

@classmethod
def get_rate_limiter_enabled(cls) -> bool:
return cls.get_module_config("rate_limiter", "enabled", False)

@classmethod
def set_rate_limiter_default_rate(cls, rate: float):
cls.set_module_config("rate_limiter", "default_rate", rate)

@classmethod
def get_rate_limiter_default_rate(cls) -> float:
return cls.get_module_config("rate_limiter", "default_rate", 1.0)

@classmethod
def set_rate_limiter_default_capacity(cls, capacity: int):
cls.set_module_config("rate_limiter", "default_capacity", capacity)

@classmethod
def get_rate_limiter_default_capacity(cls) -> int:
return cls.get_module_config("rate_limiter", "default_capacity", 10)

@classmethod
def set_rate_limiter_per_oxy_limits(cls, per_oxy_limits: dict):
cls.set_module_config("rate_limiter", "per_oxy_limits", per_oxy_limits)

@classmethod
def get_rate_limiter_per_oxy_limits(cls) -> dict:
return cls.get_module_config("rate_limiter", "per_oxy_limits", {})
83 changes: 83 additions & 0 deletions oxygent/mas.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from .db_factory import DBFactory
from .log_setup import setup_logging
from .oxy import Oxy
from .rate_limiter import get_rate_limit_manager, RateLimitManager
from .oxy.agents.base_agent import BaseAgent
from .oxy.agents.remote_agent import RemoteAgent
from .oxy.base_flow import BaseFlow
Expand Down Expand Up @@ -97,6 +98,10 @@ class MAS(BaseModel):
func_interceptor: Optional[Callable] = Field(
lambda x: None, exclude=True, description="interceptor function"
)

rate_limiter: Optional[RateLimitManager] = Field(
None, exclude=True, description="Rate limiter manager"
)

func_process_message: Optional[Callable] = Field(
lambda x, oxy_request: x, exclude=True, description="process message function"
Expand Down Expand Up @@ -127,6 +132,10 @@ def __init__(self, **kwargs):
Config.set_app_name(self.name)
else:
self.name = Config.get_app_name()

# Initialize rate limiter if enabled
if Config.get_rate_limiter_enabled():
self._init_rate_limiter()

async def __aenter__(self):
await self.init()
Expand Down Expand Up @@ -181,6 +190,14 @@ def add_oxy(self, oxy: Oxy):
if oxy.name in self.oxy_name_to_oxy:
raise Exception(f"oxy [{oxy.name}] already exists.")
self.oxy_name_to_oxy[oxy.name] = oxy

# Create rate limiter for the new oxy if rate limiting is enabled
if self.rate_limiter and Config.get_rate_limiter_enabled():
per_oxy_limits = Config.get_rate_limiter_per_oxy_limits()
oxy_limits = per_oxy_limits.get(oxy.name, {})
rate = oxy_limits.get("rate", Config.get_rate_limiter_default_rate())
capacity = oxy_limits.get("capacity", Config.get_rate_limiter_default_capacity())
self._create_oxy_limiter(oxy.name, rate, capacity)

def add_oxy_list(self, oxy_list: list[Oxy]):
"""Register a list of Oxy objects.
Expand All @@ -190,6 +207,72 @@ def add_oxy_list(self, oxy_list: list[Oxy]):
"""
for oxy in oxy_list:
self.add_oxy(oxy)

def _init_rate_limiter(self):
"""Initialize the rate limiter manager and create limiters for oxy instances."""
logger.info("Initializing rate limiter...")
self.rate_limiter = get_rate_limit_manager()
self.rate_limiter.enable()

# Create default limiter with configuration
default_rate = Config.get_rate_limiter_default_rate()
default_capacity = Config.get_rate_limiter_default_capacity()

# Create limiters for existing oxy instances
for oxy_name, oxy in self.oxy_name_to_oxy.items():
self._create_oxy_limiter(oxy_name, default_rate, default_capacity)

# Apply per-oxy limits from configuration
per_oxy_limits = Config.get_rate_limiter_per_oxy_limits()
for oxy_name, limits in per_oxy_limits.items():
if oxy_name in self.oxy_name_to_oxy:
rate = limits.get("rate", default_rate)
capacity = limits.get("capacity", default_capacity)
self._create_oxy_limiter(oxy_name, rate, capacity)

logger.info(f"Rate limiter initialized with {len(self.rate_limiter._limiters)} limiters")

def _create_oxy_limiter(self, oxy_name: str, rate: float, capacity: int):
"""Create a rate limiter for a specific oxy instance."""
if self.rate_limiter:
self.rate_limiter.create_limiter(oxy_name, rate, capacity)
logger.debug(f"Created rate limiter for oxy '{oxy_name}': rate={rate}, capacity={capacity}")

def check_rate_limit(self, oxy_name: str, tokens: int = 1) -> bool:
"""Check if rate limit allows the operation for an oxy instance.

Args:
oxy_name: Name of the oxy instance
tokens: Number of tokens to acquire

Returns:
True if operation is allowed, False otherwise
"""
if not self.rate_limiter:
return True
return self.rate_limiter.check_rate_limit(oxy_name, tokens)

async def check_rate_limit_async(self, oxy_name: str, tokens: int = 1) -> bool:
"""Async check if rate limit allows the operation for an oxy instance.

Args:
oxy_name: Name of the oxy instance
tokens: Number of tokens to acquire

Returns:
True if operation is allowed, False otherwise
"""
if not self.rate_limiter:
return True
return await self.rate_limiter.check_rate_limit_async(oxy_name, tokens)

def get_rate_limiter_manager(self) -> Optional[RateLimitManager]:
"""Get the rate limiter manager.

Returns:
Rate limiter manager if initialized, None otherwise
"""
return self.rate_limiter

async def init(self):
"""Initialize the MAS. This coroutine performs all necessary setup steps to
Expand Down
27 changes: 27 additions & 0 deletions oxygent/oxy/base_oxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,11 @@ class Oxy(BaseModel, ABC):
timeout: float = Field(3600, description="Timeout in seconds.")
retries: int = Field(2)
delay: float = Field(1.0)

rate_limiter_enabled: bool = Field(
default_factory=Config.get_rate_limiter_enabled,
description="Enable rate limiting for this oxy"
)

def __init__(self, **kwargs):
super().__init__(**kwargs)
Expand Down Expand Up @@ -409,6 +414,28 @@ async def _pre_send_message(self, oxy_request: OxyRequest):
)

async def _before_execute(self, oxy_request: OxyRequest) -> OxyRequest:
"""Check rate limit before execution."""
if (self.mas and
self.rate_limiter_enabled and
Config.get_rate_limiter_enabled()):
# Check rate limit for this oxy instance
allowed = await self.mas.check_rate_limit_async(self.name)
if not allowed:
logger.warning(
f"Rate limit exceeded for oxy {self.name}",
extra={
"trace_id": oxy_request.current_trace_id,
"node_id": oxy_request.node_id,
},
)
# Create a rate limited response
from ..schemas import OxyResponse, OxyState
rate_limited_response = OxyResponse(
state=OxyState.FAILED,
output=f"Rate limit exceeded for {self.name}. Please try again later.",
)
rate_limited_response.oxy_request = oxy_request
raise Exception(f"Rate limit exceeded for {self.name}")
return oxy_request

@abstractmethod
Expand Down
Loading