-
Notifications
You must be signed in to change notification settings - Fork 4.6k
Add TPM Controller to Respect LLM API Token Rate Limits #2841
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,7 +23,7 @@ | |
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource | ||
from crewai.security.security_config import SecurityConfig | ||
from crewai.tools.base_tool import BaseTool, Tool | ||
from crewai.utilities import I18N, Logger, RPMController | ||
from crewai.utilities import I18N, Logger, RPMController, TPMController | ||
from crewai.utilities.config import process_config | ||
from crewai.utilities.converter import Converter | ||
from crewai.utilities.string_utils import interpolate_only | ||
|
@@ -43,6 +43,7 @@ class BaseAgent(ABC, BaseModel): | |
config (Optional[Dict[str, Any]]): Configuration for the agent. | ||
verbose (bool): Verbose mode for the Agent Execution. | ||
max_rpm (Optional[int]): Maximum number of requests per minute for the agent execution. | ||
max_tpm (Optional[int]): Maximum number of tokens to ne used per minute for the agent execution. | ||
allow_delegation (bool): Allow delegation of tasks to agents. | ||
tools (Optional[List[Any]]): Tools at the agent's disposal. | ||
max_iter (int): Maximum iterations for an agent to execute a task. | ||
|
@@ -75,6 +76,8 @@ class BaseAgent(ABC, BaseModel): | |
Create a copy of the agent. | ||
set_rpm_controller(rpm_controller: RPMController) -> None: | ||
Set the rpm controller for the agent. | ||
set_tpm_controller(rpm_controller: RPMController) -> None: | ||
Set the tpm controller for the agent. | ||
set_private_attrs() -> "BaseAgent": | ||
Set private attributes. | ||
""" | ||
|
@@ -83,6 +86,8 @@ class BaseAgent(ABC, BaseModel): | |
_logger: Logger = PrivateAttr(default_factory=lambda: Logger(verbose=False)) | ||
_rpm_controller: Optional[RPMController] = PrivateAttr(default=None) | ||
_request_within_rpm_limit: Any = PrivateAttr(default=None) | ||
_tpm_controller: Optional[TPMController] = PrivateAttr(default=None) | ||
_request_within_tpm_limit: Any = PrivateAttr(default=None) | ||
_original_role: Optional[str] = PrivateAttr(default=None) | ||
_original_goal: Optional[str] = PrivateAttr(default=None) | ||
_original_backstory: Optional[str] = PrivateAttr(default=None) | ||
|
@@ -104,6 +109,10 @@ class BaseAgent(ABC, BaseModel): | |
default=None, | ||
description="Maximum number of requests per minute for the agent execution to be respected.", | ||
) | ||
max_tpm: Optional[int] = Field( | ||
default=None, | ||
description="Maximum number of tokens per minute for the agent execution to be respected.", | ||
) | ||
allow_delegation: bool = Field( | ||
default=False, | ||
description="Enable agent to delegate and ask questions among each other.", | ||
|
@@ -213,6 +222,11 @@ def validate_and_set_attributes(self): | |
if not self._token_process: | ||
self._token_process = TokenProcess() | ||
|
||
if self.max_tpm and not self._tpm_controller: | ||
self._tpm_controller = TPMController( | ||
max_tpm=self.max_rpm, token_counter=self._token_process, logger=self._logger | ||
) | ||
|
||
# Initialize security_config if not provided | ||
if self.security_config is None: | ||
self.security_config = SecurityConfig() | ||
|
@@ -237,6 +251,12 @@ def set_private_attrs(self): | |
) | ||
if not self._token_process: | ||
self._token_process = TokenProcess() | ||
|
||
if self.max_tpm and not self._tpm_controller: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Any reason to try re-initialize _tpm_controller` here? |
||
self._tpm_controller = TPMController( | ||
max_tpm=self.max_rpm, token_counter=self._token_process, logger=self._logger | ||
) | ||
|
||
return self | ||
|
||
@property | ||
|
@@ -273,6 +293,8 @@ def copy(self: T) -> T: # type: ignore # Signature of "copy" incompatible with | |
"_logger", | ||
"_rpm_controller", | ||
"_request_within_rpm_limit", | ||
"_tpm_controller", | ||
"_request_within_tpm_limit", | ||
"_token_process", | ||
"agent_executor", | ||
"tools", | ||
|
@@ -362,5 +384,14 @@ def set_rpm_controller(self, rpm_controller: RPMController) -> None: | |
self._rpm_controller = rpm_controller | ||
self.create_agent_executor() | ||
|
||
def set_tpm_controller(self, tpm_controller: TPMController) -> None: | ||
"""Set the tpm controller for the agent. | ||
|
||
Args: | ||
tpm_controller: An instance of the TPMController class. | ||
""" | ||
if not self._tpm_controller: | ||
self._tpm_controller = tpm_controller | ||
|
||
def set_knowledge(self, crew_embedder: Optional[Dict[str, Any]] = None): | ||
pass |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,6 +17,7 @@ | |
from crewai.utilities import I18N, Printer | ||
from crewai.utilities.agent_utils import ( | ||
enforce_rpm_limit, | ||
enforce_tpm_limit, | ||
format_message_for_llm, | ||
get_llm_response, | ||
handle_agent_action_core, | ||
|
@@ -26,6 +27,8 @@ | |
handle_unknown_error, | ||
has_reached_max_iterations, | ||
is_context_length_exceeded, | ||
is_token_limit_exceeded, | ||
handle_exceeded_token_limits, | ||
process_llm_response, | ||
show_agent_logs, | ||
) | ||
|
@@ -56,6 +59,7 @@ def __init__( | |
function_calling_llm: Any = None, | ||
respect_context_window: bool = False, | ||
request_within_rpm_limit: Optional[Callable[[], bool]] = None, | ||
request_within_tpm_limit: Optional[Callable[[], bool]] = None, | ||
callbacks: List[Any] = [], | ||
): | ||
self._i18n: I18N = I18N() | ||
|
@@ -78,6 +82,7 @@ def __init__( | |
self.function_calling_llm = function_calling_llm | ||
self.respect_context_window = respect_context_window | ||
self.request_within_rpm_limit = request_within_rpm_limit | ||
self.request_within_tpm_limit = request_within_tpm_limit | ||
self.ask_for_human_input = False | ||
self.messages: List[Dict[str, str]] = [] | ||
self.iterations = 0 | ||
|
@@ -152,6 +157,8 @@ def _invoke_loop(self) -> AgentFinish: | |
|
||
enforce_rpm_limit(self.request_within_rpm_limit) | ||
|
||
enforce_tpm_limit(self.request_within_tpm_limit) | ||
|
||
answer = get_llm_response( | ||
llm=self.llm, | ||
messages=self.messages, | ||
|
@@ -203,8 +210,12 @@ def _invoke_loop(self) -> AgentFinish: | |
) | ||
|
||
except Exception as e: | ||
if is_token_limit_exceeded(e): | ||
handle_exceeded_token_limits(self.request_within_tpm_limit) | ||
continue | ||
Comment on lines
+213
to
+215
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we handle this error after |
||
|
||
if e.__class__.__module__.startswith("litellm"): | ||
# Do not retry on litellm errors | ||
# Do not retry on other litellm errors | ||
raise e | ||
if is_context_length_exceeded(e): | ||
handle_context_length( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,7 @@ | ||
import json | ||
import re | ||
from typing import Any, Callable, Dict, List, Optional, Sequence, Union | ||
|
||
from litellm.exceptions import RateLimitError | ||
from crewai.agents.parser import ( | ||
FINAL_ANSWER_AND_PARSABLE_ACTION_ERROR_MESSAGE, | ||
AgentAction, | ||
|
@@ -15,8 +15,9 @@ | |
from crewai.tools.base_tool import BaseTool | ||
from crewai.tools.structured_tool import CrewStructuredTool | ||
from crewai.tools.tool_types import ToolResult | ||
from crewai.utilities import I18N, Printer | ||
from crewai.utilities import I18N, Printer, TPMController | ||
from crewai.utilities.errors import AgentRepositoryError | ||
|
||
from crewai.utilities.exceptions.context_window_exceeding_exception import ( | ||
LLMContextLengthExceededException, | ||
) | ||
|
@@ -136,6 +137,12 @@ def enforce_rpm_limit( | |
if request_within_rpm_limit: | ||
request_within_rpm_limit() | ||
|
||
def enforce_tpm_limit( | ||
request_within_tpm_limit: Optional[Callable[[], bool]] = None, | ||
) -> None: | ||
"""Enforce the tokens used per minute (TPM) limit if applicable.""" | ||
if request_within_tpm_limit: | ||
request_within_tpm_limit() | ||
|
||
def get_llm_response( | ||
llm: Union[LLM, BaseLLM], | ||
|
@@ -322,6 +329,31 @@ def handle_context_length( | |
"Context length exceeded and user opted not to summarize. Consider using smaller text or RAG tools from crewai_tools." | ||
) | ||
|
||
|
||
def is_token_limit_exceeded(exception: Exception) -> bool: | ||
"""Check if the exception is due to exceeding token limit per minute. | ||
|
||
Args: | ||
exception: The exception to check | ||
|
||
Returns: | ||
bool: True if the exception is due to to exceeding token limit per minute. | ||
""" | ||
if isinstance(exception, RateLimitError): | ||
return "Rate limit reached" in str(exception) or "rate_limit_exceeded" in str(exception) | ||
return False | ||
|
||
|
||
def handle_exceeded_token_limits( | ||
tpm_controller: TPMController | ||
) -> None: | ||
"""Handle token limit error by waiting. | ||
|
||
Args: | ||
token_counter: Class with Sleep function | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The args is |
||
""" | ||
tpm_controller(1) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I’m assuming 1 is meant to represent |
||
|
||
|
||
def summarize_messages( | ||
messages: List[Dict[str, str]], | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
import threading | ||
import time | ||
from typing import Optional | ||
|
||
from pydantic import BaseModel, Field, PrivateAttr, model_validator, ConfigDict | ||
|
||
from crewai.utilities.logger import Logger | ||
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess | ||
|
||
"""Controls Token rate .""" | ||
|
||
|
||
class TPMController(BaseModel): | ||
"""Manages Tokens per minute limiting.""" | ||
|
||
max_tpm: Optional[int] = Field(default=None) | ||
logger: Logger = Field(default_factory=lambda: Logger(verbose=False)) | ||
token_counter: TokenProcess | ||
_current_tokens: int = PrivateAttr(default=0) | ||
_timer: Optional[threading.Timer] = PrivateAttr(default=None) | ||
_lock: Optional[threading.Lock] = PrivateAttr(default=None) | ||
_shutdown_flag: bool = PrivateAttr(default=False) | ||
|
||
model_config = ConfigDict(arbitrary_types_allowed=True) | ||
|
||
@model_validator(mode="after") | ||
def reset_counter(self): | ||
if self.max_tpm is not None: | ||
if not self._shutdown_flag: | ||
self._lock = threading.Lock() | ||
self._reset_request_count() | ||
return self | ||
|
||
def check_or_wait(self, wait: int = 0): | ||
if self.max_tpm is None: | ||
return True | ||
|
||
def _check_and_increment(wait): | ||
|
||
if self.max_tpm is not None and self._current_tokens < self.max_tpm and not wait: | ||
print("Tokens checked") | ||
self._current_tokens += self.token_counter.total_tokens | ||
print(f"Tokens increased: {self._current_tokens}") | ||
|
||
return True | ||
elif self.max_tpm is not None: | ||
self.logger.log( | ||
"info", "Max TPM reached, waiting for next minute to start." | ||
) | ||
self._wait_for_next_minute() | ||
|
||
return True | ||
return True | ||
|
||
if self._lock: | ||
with self._lock: | ||
return _check_and_increment(wait) | ||
else: | ||
return _check_and_increment(wait) | ||
|
||
def stop_tpm_counter(self): | ||
if self._timer: | ||
self._timer.cancel() | ||
self._timer = None | ||
|
||
def _wait_for_next_minute(self): | ||
time.sleep(60) | ||
self._current_tokens = 0 | ||
|
||
def external_wait_for_next_minute(self): | ||
if self._lock: | ||
with self._lock: | ||
pass | ||
else: | ||
time.sleep(60) | ||
self._current_tokens = 0 | ||
|
||
def _reset_request_count(self): | ||
def _reset(): | ||
self._current_tokens = 0 | ||
if not self._shutdown_flag: | ||
self._timer = threading.Timer(60.0, self._reset_request_count) | ||
self._timer.start() | ||
|
||
if self._lock: | ||
with self._lock: | ||
_reset() | ||
else: | ||
_reset() | ||
|
||
if self._timer: | ||
self._shutdown_flag = True | ||
self._timer.cancel() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typo here