diff --git a/src/crewai/agent.py b/src/crewai/agent.py index 77fffc7e1b..9ed3392ff9 100644 --- a/src/crewai/agent.py +++ b/src/crewai/agent.py @@ -62,6 +62,7 @@ class Agent(BaseAgent): function_calling_llm: The language model that will handle the tool calling for this agent, it overrides the crew function_calling_llm. max_iter: Maximum number of iterations for an agent to execute a task. max_rpm: Maximum number of requests per minute for the agent execution to be respected. + max_tpm: Maximum number of tokens that should be used for the agent execution verbose: Whether the agent execution should be in verbose mode. allow_delegation: Whether the agent is allowed to delegate tasks to other agents. tools: Tools at agents disposal @@ -401,6 +402,9 @@ def execute_task( if self.max_rpm and self._rpm_controller: self._rpm_controller.stop_rpm_counter() + if self.max_tpm and self._tpm_controller: + self._tpm_controller.stop_tpm_counter() + # If there was any tool in self.tools_results that had result_as_answer # set to True, return the results of the last tool that had # result_as_answer set to True @@ -512,6 +516,9 @@ def create_agent_executor( request_within_rpm_limit=( self._rpm_controller.check_or_wait if self._rpm_controller else None ), + request_within_tpm_limit=( + self._tpm_controller.check_or_wait if self._tpm_controller else None + ), callbacks=[TokenCalcHandler(self._token_process)], ) diff --git a/src/crewai/agents/agent_builder/base_agent.py b/src/crewai/agents/agent_builder/base_agent.py index ba2596f632..be196f8725 100644 --- a/src/crewai/agents/agent_builder/base_agent.py +++ b/src/crewai/agents/agent_builder/base_agent.py @@ -23,7 +23,7 @@ from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource from crewai.security.security_config import SecurityConfig from crewai.tools.base_tool import BaseTool, Tool -from crewai.utilities import I18N, Logger, RPMController +from crewai.utilities import I18N, Logger, RPMController, TPMController from crewai.utilities.config import process_config from crewai.utilities.converter import Converter from crewai.utilities.string_utils import interpolate_only @@ -43,6 +43,7 @@ class BaseAgent(ABC, BaseModel): config (Optional[Dict[str, Any]]): Configuration for the agent. verbose (bool): Verbose mode for the Agent Execution. max_rpm (Optional[int]): Maximum number of requests per minute for the agent execution. + max_tpm (Optional[int]): Maximum number of tokens to ne used per minute for the agent execution. allow_delegation (bool): Allow delegation of tasks to agents. tools (Optional[List[Any]]): Tools at the agent's disposal. max_iter (int): Maximum iterations for an agent to execute a task. @@ -75,6 +76,8 @@ class BaseAgent(ABC, BaseModel): Create a copy of the agent. set_rpm_controller(rpm_controller: RPMController) -> None: Set the rpm controller for the agent. + set_tpm_controller(rpm_controller: RPMController) -> None: + Set the tpm controller for the agent. set_private_attrs() -> "BaseAgent": Set private attributes. """ @@ -83,6 +86,8 @@ class BaseAgent(ABC, BaseModel): _logger: Logger = PrivateAttr(default_factory=lambda: Logger(verbose=False)) _rpm_controller: Optional[RPMController] = PrivateAttr(default=None) _request_within_rpm_limit: Any = PrivateAttr(default=None) + _tpm_controller: Optional[TPMController] = PrivateAttr(default=None) + _request_within_tpm_limit: Any = PrivateAttr(default=None) _original_role: Optional[str] = PrivateAttr(default=None) _original_goal: Optional[str] = PrivateAttr(default=None) _original_backstory: Optional[str] = PrivateAttr(default=None) @@ -104,6 +109,10 @@ class BaseAgent(ABC, BaseModel): default=None, description="Maximum number of requests per minute for the agent execution to be respected.", ) + max_tpm: Optional[int] = Field( + default=None, + description="Maximum number of tokens per minute for the agent execution to be respected.", + ) allow_delegation: bool = Field( default=False, description="Enable agent to delegate and ask questions among each other.", @@ -213,6 +222,11 @@ def validate_and_set_attributes(self): if not self._token_process: self._token_process = TokenProcess() + if self.max_tpm and not self._tpm_controller: + self._tpm_controller = TPMController( + max_tpm=self.max_rpm, token_counter=self._token_process, logger=self._logger + ) + # Initialize security_config if not provided if self.security_config is None: self.security_config = SecurityConfig() @@ -237,6 +251,12 @@ def set_private_attrs(self): ) if not self._token_process: self._token_process = TokenProcess() + + if self.max_tpm and not self._tpm_controller: + self._tpm_controller = TPMController( + max_tpm=self.max_rpm, token_counter=self._token_process, logger=self._logger + ) + return self @property @@ -273,6 +293,8 @@ def copy(self: T) -> T: # type: ignore # Signature of "copy" incompatible with "_logger", "_rpm_controller", "_request_within_rpm_limit", + "_tpm_controller", + "_request_within_tpm_limit", "_token_process", "agent_executor", "tools", @@ -362,5 +384,14 @@ def set_rpm_controller(self, rpm_controller: RPMController) -> None: self._rpm_controller = rpm_controller self.create_agent_executor() + def set_tpm_controller(self, tpm_controller: TPMController) -> None: + """Set the tpm controller for the agent. + + Args: + tpm_controller: An instance of the TPMController class. + """ + if not self._tpm_controller: + self._tpm_controller = tpm_controller + def set_knowledge(self, crew_embedder: Optional[Dict[str, Any]] = None): pass diff --git a/src/crewai/agents/crew_agent_executor.py b/src/crewai/agents/crew_agent_executor.py index 914f837ee1..8338550aa2 100644 --- a/src/crewai/agents/crew_agent_executor.py +++ b/src/crewai/agents/crew_agent_executor.py @@ -17,6 +17,7 @@ from crewai.utilities import I18N, Printer from crewai.utilities.agent_utils import ( enforce_rpm_limit, + enforce_tpm_limit, format_message_for_llm, get_llm_response, handle_agent_action_core, @@ -26,6 +27,8 @@ handle_unknown_error, has_reached_max_iterations, is_context_length_exceeded, + is_token_limit_exceeded, + handle_exceeded_token_limits, process_llm_response, show_agent_logs, ) @@ -56,6 +59,7 @@ def __init__( function_calling_llm: Any = None, respect_context_window: bool = False, request_within_rpm_limit: Optional[Callable[[], bool]] = None, + request_within_tpm_limit: Optional[Callable[[], bool]] = None, callbacks: List[Any] = [], ): self._i18n: I18N = I18N() @@ -78,6 +82,7 @@ def __init__( self.function_calling_llm = function_calling_llm self.respect_context_window = respect_context_window self.request_within_rpm_limit = request_within_rpm_limit + self.request_within_tpm_limit = request_within_tpm_limit self.ask_for_human_input = False self.messages: List[Dict[str, str]] = [] self.iterations = 0 @@ -152,6 +157,8 @@ def _invoke_loop(self) -> AgentFinish: enforce_rpm_limit(self.request_within_rpm_limit) + enforce_tpm_limit(self.request_within_tpm_limit) + answer = get_llm_response( llm=self.llm, messages=self.messages, @@ -203,8 +210,12 @@ def _invoke_loop(self) -> AgentFinish: ) except Exception as e: + if is_token_limit_exceeded(e): + handle_exceeded_token_limits(self.request_within_tpm_limit) + continue + if e.__class__.__module__.startswith("litellm"): - # Do not retry on litellm errors + # Do not retry on other litellm errors raise e if is_context_length_exceeded(e): handle_context_length( diff --git a/src/crewai/utilities/__init__.py b/src/crewai/utilities/__init__.py index dd6d9fa44f..8b5ca02d68 100644 --- a/src/crewai/utilities/__init__.py +++ b/src/crewai/utilities/__init__.py @@ -7,6 +7,7 @@ from .printer import Printer from .prompts import Prompts from .rpm_controller import RPMController +from .tpm_controller import TPMController from .exceptions.context_window_exceeding_exception import ( LLMContextLengthExceededException, ) @@ -21,6 +22,7 @@ "Logger", "Printer", "Prompts", + "TPMController", "RPMController", "YamlParser", "LLMContextLengthExceededException", diff --git a/src/crewai/utilities/agent_utils.py b/src/crewai/utilities/agent_utils.py index e580c26ce7..ee7f932717 100644 --- a/src/crewai/utilities/agent_utils.py +++ b/src/crewai/utilities/agent_utils.py @@ -1,7 +1,7 @@ import json import re from typing import Any, Callable, Dict, List, Optional, Sequence, Union - +from litellm.exceptions import RateLimitError from crewai.agents.parser import ( FINAL_ANSWER_AND_PARSABLE_ACTION_ERROR_MESSAGE, AgentAction, @@ -15,8 +15,9 @@ from crewai.tools.base_tool import BaseTool from crewai.tools.structured_tool import CrewStructuredTool from crewai.tools.tool_types import ToolResult -from crewai.utilities import I18N, Printer +from crewai.utilities import I18N, Printer, TPMController from crewai.utilities.errors import AgentRepositoryError + from crewai.utilities.exceptions.context_window_exceeding_exception import ( LLMContextLengthExceededException, ) @@ -136,6 +137,12 @@ def enforce_rpm_limit( if request_within_rpm_limit: request_within_rpm_limit() +def enforce_tpm_limit( + request_within_tpm_limit: Optional[Callable[[], bool]] = None, +) -> None: + """Enforce the tokens used per minute (TPM) limit if applicable.""" + if request_within_tpm_limit: + request_within_tpm_limit() def get_llm_response( llm: Union[LLM, BaseLLM], @@ -322,6 +329,31 @@ def handle_context_length( "Context length exceeded and user opted not to summarize. Consider using smaller text or RAG tools from crewai_tools." ) + +def is_token_limit_exceeded(exception: Exception) -> bool: + """Check if the exception is due to exceeding token limit per minute. + + Args: + exception: The exception to check + + Returns: + bool: True if the exception is due to to exceeding token limit per minute. + """ + if isinstance(exception, RateLimitError): + return "Rate limit reached" in str(exception) or "rate_limit_exceeded" in str(exception) + return False + + +def handle_exceeded_token_limits( + tpm_controller: TPMController +) -> None: + """Handle token limit error by waiting. + + Args: + token_counter: Class with Sleep function + """ + tpm_controller(1) + def summarize_messages( messages: List[Dict[str, str]], diff --git a/src/crewai/utilities/tpm_controller.py b/src/crewai/utilities/tpm_controller.py new file mode 100644 index 0000000000..a53e8ff1ae --- /dev/null +++ b/src/crewai/utilities/tpm_controller.py @@ -0,0 +1,93 @@ +import threading +import time +from typing import Optional + +from pydantic import BaseModel, Field, PrivateAttr, model_validator, ConfigDict + +from crewai.utilities.logger import Logger +from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess + +"""Controls Token rate .""" + + +class TPMController(BaseModel): + """Manages Tokens per minute limiting.""" + + max_tpm: Optional[int] = Field(default=None) + logger: Logger = Field(default_factory=lambda: Logger(verbose=False)) + token_counter: TokenProcess + _current_tokens: int = PrivateAttr(default=0) + _timer: Optional[threading.Timer] = PrivateAttr(default=None) + _lock: Optional[threading.Lock] = PrivateAttr(default=None) + _shutdown_flag: bool = PrivateAttr(default=False) + + model_config = ConfigDict(arbitrary_types_allowed=True) + + @model_validator(mode="after") + def reset_counter(self): + if self.max_tpm is not None: + if not self._shutdown_flag: + self._lock = threading.Lock() + self._reset_request_count() + return self + + def check_or_wait(self, wait: int = 0): + if self.max_tpm is None: + return True + + def _check_and_increment(wait): + + if self.max_tpm is not None and self._current_tokens < self.max_tpm and not wait: + print("Tokens checked") + self._current_tokens += self.token_counter.total_tokens + print(f"Tokens increased: {self._current_tokens}") + + return True + elif self.max_tpm is not None: + self.logger.log( + "info", "Max TPM reached, waiting for next minute to start." + ) + self._wait_for_next_minute() + + return True + return True + + if self._lock: + with self._lock: + return _check_and_increment(wait) + else: + return _check_and_increment(wait) + + def stop_tpm_counter(self): + if self._timer: + self._timer.cancel() + self._timer = None + + def _wait_for_next_minute(self): + time.sleep(60) + self._current_tokens = 0 + + def external_wait_for_next_minute(self): + if self._lock: + with self._lock: + pass + else: + time.sleep(60) + self._current_tokens = 0 + + def _reset_request_count(self): + def _reset(): + self._current_tokens = 0 + if not self._shutdown_flag: + self._timer = threading.Timer(60.0, self._reset_request_count) + self._timer.start() + + if self._lock: + with self._lock: + _reset() + else: + _reset() + + if self._timer: + self._shutdown_flag = True + self._timer.cancel()