Skip to content

Add TPM Controller to Respect LLM API Token Rate Limits #2841

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/crewai/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class Agent(BaseAgent):
function_calling_llm: The language model that will handle the tool calling for this agent, it overrides the crew function_calling_llm.
max_iter: Maximum number of iterations for an agent to execute a task.
max_rpm: Maximum number of requests per minute for the agent execution to be respected.
max_tpm: Maximum number of tokens that should be used for the agent execution
verbose: Whether the agent execution should be in verbose mode.
allow_delegation: Whether the agent is allowed to delegate tasks to other agents.
tools: Tools at agents disposal
Expand Down Expand Up @@ -401,6 +402,9 @@ def execute_task(
if self.max_rpm and self._rpm_controller:
self._rpm_controller.stop_rpm_counter()

if self.max_tpm and self._tpm_controller:
self._tpm_controller.stop_tpm_counter()

# If there was any tool in self.tools_results that had result_as_answer
# set to True, return the results of the last tool that had
# result_as_answer set to True
Expand Down Expand Up @@ -512,6 +516,9 @@ def create_agent_executor(
request_within_rpm_limit=(
self._rpm_controller.check_or_wait if self._rpm_controller else None
),
request_within_tpm_limit=(
self._tpm_controller.check_or_wait if self._tpm_controller else None
),
callbacks=[TokenCalcHandler(self._token_process)],
)

Expand Down
33 changes: 32 additions & 1 deletion src/crewai/agents/agent_builder/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
from crewai.security.security_config import SecurityConfig
from crewai.tools.base_tool import BaseTool, Tool
from crewai.utilities import I18N, Logger, RPMController
from crewai.utilities import I18N, Logger, RPMController, TPMController
from crewai.utilities.config import process_config
from crewai.utilities.converter import Converter
from crewai.utilities.string_utils import interpolate_only
Expand All @@ -43,6 +43,7 @@ class BaseAgent(ABC, BaseModel):
config (Optional[Dict[str, Any]]): Configuration for the agent.
verbose (bool): Verbose mode for the Agent Execution.
max_rpm (Optional[int]): Maximum number of requests per minute for the agent execution.
max_tpm (Optional[int]): Maximum number of tokens to ne used per minute for the agent execution.
allow_delegation (bool): Allow delegation of tasks to agents.
tools (Optional[List[Any]]): Tools at the agent's disposal.
max_iter (int): Maximum iterations for an agent to execute a task.
Expand Down Expand Up @@ -75,6 +76,8 @@ class BaseAgent(ABC, BaseModel):
Create a copy of the agent.
set_rpm_controller(rpm_controller: RPMController) -> None:
Set the rpm controller for the agent.
set_tpm_controller(rpm_controller: RPMController) -> None:
Set the tpm controller for the agent.
set_private_attrs() -> "BaseAgent":
Set private attributes.
"""
Expand All @@ -83,6 +86,8 @@ class BaseAgent(ABC, BaseModel):
_logger: Logger = PrivateAttr(default_factory=lambda: Logger(verbose=False))
_rpm_controller: Optional[RPMController] = PrivateAttr(default=None)
_request_within_rpm_limit: Any = PrivateAttr(default=None)
_tpm_controller: Optional[TPMController] = PrivateAttr(default=None)
_request_within_tpm_limit: Any = PrivateAttr(default=None)
_original_role: Optional[str] = PrivateAttr(default=None)
_original_goal: Optional[str] = PrivateAttr(default=None)
_original_backstory: Optional[str] = PrivateAttr(default=None)
Expand All @@ -104,6 +109,10 @@ class BaseAgent(ABC, BaseModel):
default=None,
description="Maximum number of requests per minute for the agent execution to be respected.",
)
max_tpm: Optional[int] = Field(
default=None,
description="Maximum number of tokens per minute for the agent execution to be respected.",
)
allow_delegation: bool = Field(
default=False,
description="Enable agent to delegate and ask questions among each other.",
Expand Down Expand Up @@ -213,6 +222,11 @@ def validate_and_set_attributes(self):
if not self._token_process:
self._token_process = TokenProcess()

if self.max_tpm and not self._tpm_controller:
self._tpm_controller = TPMController(
max_tpm=self.max_rpm, token_counter=self._token_process, logger=self._logger
)

# Initialize security_config if not provided
if self.security_config is None:
self.security_config = SecurityConfig()
Expand All @@ -237,6 +251,12 @@ def set_private_attrs(self):
)
if not self._token_process:
self._token_process = TokenProcess()

if self.max_tpm and not self._tpm_controller:
self._tpm_controller = TPMController(
max_tpm=self.max_rpm, token_counter=self._token_process, logger=self._logger
)

return self

@property
Expand Down Expand Up @@ -273,6 +293,8 @@ def copy(self: T) -> T: # type: ignore # Signature of "copy" incompatible with
"_logger",
"_rpm_controller",
"_request_within_rpm_limit",
"_tpm_controller",
"_request_within_tpm_limit",
"_token_process",
"agent_executor",
"tools",
Expand Down Expand Up @@ -362,5 +384,14 @@ def set_rpm_controller(self, rpm_controller: RPMController) -> None:
self._rpm_controller = rpm_controller
self.create_agent_executor()

def set_tpm_controller(self, tpm_controller: TPMController) -> None:
"""Set the tpm controller for the agent.

Args:
tpm_controller: An instance of the TPMController class.
"""
if not self._tpm_controller:
self._tpm_controller = tpm_controller

def set_knowledge(self, crew_embedder: Optional[Dict[str, Any]] = None):
pass
13 changes: 12 additions & 1 deletion src/crewai/agents/crew_agent_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from crewai.utilities import I18N, Printer
from crewai.utilities.agent_utils import (
enforce_rpm_limit,
enforce_tpm_limit,
format_message_for_llm,
get_llm_response,
handle_agent_action_core,
Expand All @@ -26,6 +27,8 @@
handle_unknown_error,
has_reached_max_iterations,
is_context_length_exceeded,
is_token_limit_exceeded,
handle_exceeded_token_limits,
process_llm_response,
show_agent_logs,
)
Expand Down Expand Up @@ -56,6 +59,7 @@ def __init__(
function_calling_llm: Any = None,
respect_context_window: bool = False,
request_within_rpm_limit: Optional[Callable[[], bool]] = None,
request_within_tpm_limit: Optional[Callable[[], bool]] = None,
callbacks: List[Any] = [],
):
self._i18n: I18N = I18N()
Expand All @@ -78,6 +82,7 @@ def __init__(
self.function_calling_llm = function_calling_llm
self.respect_context_window = respect_context_window
self.request_within_rpm_limit = request_within_rpm_limit
self.request_within_tpm_limit = request_within_tpm_limit
self.ask_for_human_input = False
self.messages: List[Dict[str, str]] = []
self.iterations = 0
Expand Down Expand Up @@ -152,6 +157,8 @@ def _invoke_loop(self) -> AgentFinish:

enforce_rpm_limit(self.request_within_rpm_limit)

enforce_tpm_limit(self.request_within_tpm_limit)

answer = get_llm_response(
llm=self.llm,
messages=self.messages,
Expand Down Expand Up @@ -203,8 +210,12 @@ def _invoke_loop(self) -> AgentFinish:
)

except Exception as e:
if is_token_limit_exceeded(e):
handle_exceeded_token_limits(self.request_within_tpm_limit)
continue

if e.__class__.__module__.startswith("litellm"):
# Do not retry on litellm errors
# Do not retry on other litellm errors
raise e
if is_context_length_exceeded(e):
handle_context_length(
Expand Down
2 changes: 2 additions & 0 deletions src/crewai/utilities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .printer import Printer
from .prompts import Prompts
from .rpm_controller import RPMController
from .tpm_controller import TPMController
from .exceptions.context_window_exceeding_exception import (
LLMContextLengthExceededException,
)
Expand All @@ -21,6 +22,7 @@
"Logger",
"Printer",
"Prompts",
"TPMController",
"RPMController",
"YamlParser",
"LLMContextLengthExceededException",
Expand Down
36 changes: 34 additions & 2 deletions src/crewai/utilities/agent_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
import re
from typing import Any, Callable, Dict, List, Optional, Sequence, Union

from litellm.exceptions import RateLimitError
from crewai.agents.parser import (
FINAL_ANSWER_AND_PARSABLE_ACTION_ERROR_MESSAGE,
AgentAction,
Expand All @@ -15,8 +15,9 @@
from crewai.tools.base_tool import BaseTool
from crewai.tools.structured_tool import CrewStructuredTool
from crewai.tools.tool_types import ToolResult
from crewai.utilities import I18N, Printer
from crewai.utilities import I18N, Printer, TPMController
from crewai.utilities.errors import AgentRepositoryError

from crewai.utilities.exceptions.context_window_exceeding_exception import (
LLMContextLengthExceededException,
)
Expand Down Expand Up @@ -136,6 +137,12 @@ def enforce_rpm_limit(
if request_within_rpm_limit:
request_within_rpm_limit()

def enforce_tpm_limit(
request_within_tpm_limit: Optional[Callable[[], bool]] = None,
) -> None:
"""Enforce the tokens used per minute (TPM) limit if applicable."""
if request_within_tpm_limit:
request_within_tpm_limit()

def get_llm_response(
llm: Union[LLM, BaseLLM],
Expand Down Expand Up @@ -322,6 +329,31 @@ def handle_context_length(
"Context length exceeded and user opted not to summarize. Consider using smaller text or RAG tools from crewai_tools."
)


def is_token_limit_exceeded(exception: Exception) -> bool:
"""Check if the exception is due to exceeding token limit per minute.

Args:
exception: The exception to check

Returns:
bool: True if the exception is due to to exceeding token limit per minute.
"""
if isinstance(exception, RateLimitError):
return "Rate limit reached" in str(exception) or "rate_limit_exceeded" in str(exception)
return False


def handle_exceeded_token_limits(
tpm_controller: TPMController
) -> None:
"""Handle token limit error by waiting.

Args:
token_counter: Class with Sleep function
"""
tpm_controller(1)


def summarize_messages(
messages: List[Dict[str, str]],
Expand Down
93 changes: 93 additions & 0 deletions src/crewai/utilities/tpm_controller.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import threading
import time
from typing import Optional

from pydantic import BaseModel, Field, PrivateAttr, model_validator, ConfigDict

from crewai.utilities.logger import Logger
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess

"""Controls Token rate ."""


class TPMController(BaseModel):
"""Manages Tokens per minute limiting."""

max_tpm: Optional[int] = Field(default=None)
logger: Logger = Field(default_factory=lambda: Logger(verbose=False))
token_counter: TokenProcess
_current_tokens: int = PrivateAttr(default=0)
_timer: Optional[threading.Timer] = PrivateAttr(default=None)
_lock: Optional[threading.Lock] = PrivateAttr(default=None)
_shutdown_flag: bool = PrivateAttr(default=False)

model_config = ConfigDict(arbitrary_types_allowed=True)

@model_validator(mode="after")
def reset_counter(self):
if self.max_tpm is not None:
if not self._shutdown_flag:
self._lock = threading.Lock()
self._reset_request_count()
return self

def check_or_wait(self, wait: int = 0):
if self.max_tpm is None:
return True

def _check_and_increment(wait):

if self.max_tpm is not None and self._current_tokens < self.max_tpm and not wait:
print("Tokens checked")
self._current_tokens += self.token_counter.total_tokens
print(f"Tokens increased: {self._current_tokens}")

return True
elif self.max_tpm is not None:
self.logger.log(
"info", "Max TPM reached, waiting for next minute to start."
)
self._wait_for_next_minute()

return True
return True

if self._lock:
with self._lock:
return _check_and_increment(wait)
else:
return _check_and_increment(wait)

def stop_tpm_counter(self):
if self._timer:
self._timer.cancel()
self._timer = None

def _wait_for_next_minute(self):
time.sleep(60)
self._current_tokens = 0

def external_wait_for_next_minute(self):
if self._lock:
with self._lock:
pass
else:
time.sleep(60)
self._current_tokens = 0

def _reset_request_count(self):
def _reset():
self._current_tokens = 0
if not self._shutdown_flag:
self._timer = threading.Timer(60.0, self._reset_request_count)
self._timer.start()

if self._lock:
with self._lock:
_reset()
else:
_reset()

if self._timer:
self._shutdown_flag = True
self._timer.cancel()