-
-
Notifications
You must be signed in to change notification settings - Fork 1.1k
fix: implement token tracking in RouterAgent for cost observability #1190
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -10,7 +10,8 @@ | |
| from typing import Dict, List, Optional, Any, Union | ||
| from .agent import Agent | ||
| from ..llm.model_router import ModelRouter | ||
| from ..llm import LLM | ||
| from ..llm import LLM, TokenUsage | ||
| from ..trace.protocol import get_default_emitter | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
@@ -213,8 +214,8 @@ def _execute_with_model( | |
| full_prompt = f"{context}\n\n{prompt}" | ||
|
|
||
| try: | ||
| # Execute with the selected model | ||
| response = llm_instance.get_response( | ||
| # Execute with the selected model, requesting token usage tracking | ||
| result = llm_instance.get_response( | ||
| prompt=full_prompt, | ||
| system_prompt=self._build_system_prompt(), | ||
| tools=tools, | ||
|
|
@@ -225,16 +226,45 @@ def _execute_with_model( | |
| agent_role=self.role, | ||
| agent_tools=[t.__name__ if hasattr(t, '__name__') else str(t) for t in (tools or [])], | ||
| execute_tool_fn=self.execute_tool if tools else None, | ||
| return_token_usage=True, # Request token usage information | ||
| **kwargs | ||
| ) | ||
|
|
||
| # Extract response and token usage | ||
| if isinstance(result, tuple): | ||
| response, token_usage = result | ||
| else: | ||
| # Fallback for backward compatibility | ||
| response = result | ||
| token_usage = TokenUsage() | ||
|
|
||
| # Update usage statistics | ||
| self.model_usage_stats[model_name]['calls'] += 1 | ||
| self.model_usage_stats[model_name]['tokens'] += token_usage.total_tokens | ||
|
|
||
| # Calculate and store cost estimate | ||
| model_info = self.model_router.get_model_info(model_name) | ||
| if model_info and token_usage.total_tokens > 0: | ||
| cost = self.model_router.estimate_cost(model_name, token_usage.total_tokens) | ||
| self.model_usage_stats[model_name]['cost'] += cost | ||
|
Comment on lines
+245
to
+249
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Emit per-decision cost in the trace event.
Suggested patch- model_info = self.model_router.get_model_info(model_name)
- if model_info and token_usage.total_tokens > 0:
- cost = self.model_router.estimate_cost(model_name, token_usage.total_tokens)
+ cost = 0.0
+ model_info = self.model_router.get_model_info(model_name)
+ if model_info and token_usage.total_tokens > 0:
+ cost = self.model_router.estimate_cost(model_name, token_usage.total_tokens)
self.model_usage_stats[model_name]['cost'] += cost- 'estimated_cost': self.model_usage_stats[model_name]['cost'],
+ 'estimated_cost': cost,
+ 'cumulative_estimated_cost': self.model_usage_stats[model_name]['cost'],Also applies to: 251-263 🤖 Prompt for AI Agents |
||
|
|
||
| # TODO: Implement token tracking when LLM.get_response() is updated to return token usage | ||
| # The LLM response currently returns only text, but litellm provides usage info in: | ||
| # response.get("usage") with prompt_tokens, completion_tokens, and total_tokens | ||
| # This would require modifying the LLM class to return both text and metadata | ||
| # Emit token usage via trace system for observability | ||
| try: | ||
| trace_emitter = get_default_emitter() | ||
| trace_emitter.output( | ||
| content=f"RouterAgent routing decision completed", | ||
| agent_name=self.name, | ||
| metadata={ | ||
| 'selected_model': model_name, | ||
| 'routing_strategy': self.routing_strategy, | ||
| 'token_usage': token_usage.to_dict(), | ||
| 'estimated_cost': self.model_usage_stats[model_name]['cost'], | ||
| 'total_calls': self.model_usage_stats[model_name]['calls'], | ||
| } | ||
| ) | ||
| except Exception as trace_error: | ||
| # Don't fail the request if tracing fails | ||
| logger.debug(f"Failed to emit trace event: {trace_error}") | ||
|
|
||
| return response | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,6 +4,7 @@ | |
| import re | ||
| import inspect | ||
| import asyncio | ||
| from dataclasses import dataclass | ||
| from typing import Any, Dict, List, Optional, Union, Literal, Callable, TYPE_CHECKING | ||
|
|
||
| if TYPE_CHECKING: | ||
|
|
@@ -90,6 +91,36 @@ def _is_context_limit_error(self, error_message: str) -> bool: | |
| ] | ||
| return any(phrase in error_message.lower() for phrase in context_limit_phrases) | ||
|
|
||
|
|
||
| @dataclass | ||
| class TokenUsage: | ||
| """ | ||
| Token usage information from LLM response. | ||
|
|
||
| This class provides structured access to token consumption data | ||
| returned by language models, enabling cost tracking and observability. | ||
| """ | ||
| prompt_tokens: int = 0 | ||
| completion_tokens: int = 0 | ||
| total_tokens: int = 0 | ||
| cached_tokens: int = 0 | ||
| reasoning_tokens: int = 0 | ||
| audio_input_tokens: int = 0 | ||
| audio_output_tokens: int = 0 | ||
|
|
||
| def to_dict(self) -> Dict[str, int]: | ||
| """Convert to dictionary format.""" | ||
| return { | ||
| 'prompt_tokens': self.prompt_tokens, | ||
| 'completion_tokens': self.completion_tokens, | ||
| 'total_tokens': self.total_tokens, | ||
| 'cached_tokens': self.cached_tokens, | ||
| 'reasoning_tokens': self.reasoning_tokens, | ||
| 'audio_input_tokens': self.audio_input_tokens, | ||
| 'audio_output_tokens': self.audio_output_tokens, | ||
| } | ||
|
|
||
|
|
||
| class LLM: | ||
| """ | ||
| Easy to use wrapper for language models. Supports multiple providers like OpenAI, | ||
|
|
@@ -1566,10 +1597,24 @@ def get_response( | |
| stream: bool = True, | ||
| stream_callback: Optional[Callable] = None, | ||
| emit_events: bool = False, | ||
| return_token_usage: bool = False, | ||
| **kwargs | ||
| ) -> str: | ||
| ) -> Union[str, tuple[str, TokenUsage]]: | ||
|
Comment on lines
+1600
to
+1602
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Mirror The sync API now exposes token usage, but the async counterpart still advertises 🤖 Prompt for AI Agents |
||
| """Enhanced get_response with all OpenAI-like features""" | ||
| logging.debug(f"Getting response from {self.model}") | ||
|
|
||
| # Variable to store final response for token usage extraction | ||
| _final_llm_response = None | ||
|
Comment on lines
+1606
to
+1607
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Most execution paths still drop the raw usage payload.
Please capture Also applies to: 1903-1903, 2134-2134, 2326-2326 🤖 Prompt for AI Agents |
||
|
|
||
| # Helper closure to return appropriate format based on return_token_usage | ||
| def _prepare_return_value(text: str) -> Union[str, tuple]: | ||
| if not return_token_usage: | ||
| return text | ||
| token_usage = self._extract_token_usage(_final_llm_response) if _final_llm_response else None | ||
| if token_usage is None: | ||
| token_usage = TokenUsage() | ||
| return text, token_usage | ||
|
|
||
| # Log all self values when in debug mode | ||
| self._log_llm_config( | ||
| 'LLM instance', | ||
|
|
@@ -1864,6 +1909,7 @@ def get_response( | |
| reasoning_content = resp["choices"][0]["message"].get("provider_specific_fields", {}).get("reasoning_content") | ||
| response_text = resp["choices"][0]["message"]["content"] | ||
| final_response = resp | ||
| _final_llm_response = resp # Store for token usage extraction | ||
|
|
||
| # Emit StreamEvent for reasoning content if callback provided | ||
| if _emit and reasoning_content: | ||
|
|
@@ -2094,6 +2140,7 @@ def get_response( | |
| **kwargs | ||
| ) | ||
| ) | ||
| _final_llm_response = final_response # Store for token usage extraction | ||
| # Handle None content from Gemini | ||
| response_content = final_response["choices"][0]["message"].get("content") | ||
| response_text = response_content if response_content is not None else "" | ||
|
|
@@ -2285,6 +2332,7 @@ def get_response( | |
| **kwargs | ||
| ) | ||
| ) | ||
| _final_llm_response = final_response # Store for token usage extraction | ||
| # Handle None content from Gemini | ||
| response_content = final_response["choices"][0]["message"].get("content") | ||
| response_text = response_content if response_content is not None else "" | ||
|
|
@@ -2698,7 +2746,7 @@ def get_response( | |
| task_id=task_id | ||
| ) | ||
| callback_executed = True | ||
| return final_response_text | ||
| return _prepare_return_value(final_response_text) | ||
|
|
||
| # No tool calls were made in this iteration, return the response | ||
| generation_time_val = time.time() - start_time | ||
|
|
@@ -2787,7 +2835,7 @@ def get_response( | |
| task_id=task_id | ||
| ) | ||
| callback_executed = True | ||
| return response_text | ||
| return _prepare_return_value(response_text) | ||
|
|
||
| if not self_reflect: | ||
| if verbose and not interaction_displayed: | ||
|
|
@@ -2816,8 +2864,8 @@ def get_response( | |
|
|
||
| # Return reasoning content if reasoning_steps is True | ||
| if reasoning_steps and stored_reasoning_content: | ||
| return stored_reasoning_content | ||
| return response_text | ||
| return _prepare_return_value(stored_reasoning_content) | ||
| return _prepare_return_value(response_text) | ||
|
|
||
| # Handle self-reflection loop | ||
| while reflection_count < max_reflect: | ||
|
|
@@ -2999,7 +3047,7 @@ def get_response( | |
| agent_name=agent_name, agent_role=agent_role, agent_tools=agent_tools, | ||
| task_name=task_name, task_description=task_description, task_id=task_id) | ||
| interaction_displayed = True | ||
| return response_text | ||
| return _prepare_return_value(response_text) | ||
| continue | ||
| except Exception as e: | ||
| _get_display_functions()['display_error'](f"Error in LLM response: {str(e)}") | ||
|
|
@@ -3010,12 +3058,12 @@ def get_response( | |
| _get_display_functions()['display_interaction'](prompt, response_text, markdown=markdown, | ||
| generation_time=time.time() - start_time, console=self.console) | ||
| interaction_displayed = True | ||
| return response_text | ||
| return _prepare_return_value(response_text) | ||
|
|
||
| except Exception as error: | ||
| _get_display_functions()['display_error'](f"Error in get_response: {str(error)}") | ||
| raise | ||
|
|
||
| # Log completion time if in debug mode | ||
| if logging.getLogger().getEffectiveLevel() == logging.DEBUG: | ||
| total_time = time.time() - start_time | ||
|
|
@@ -4192,6 +4240,49 @@ def _track_token_usage(self, response: Dict[str, Any], model: str) -> Optional[T | |
| logging.warning(f"Failed to track token usage: {e}") | ||
| return None | ||
|
|
||
| def _extract_token_usage(self, response: Union[Dict[str, Any], Any]) -> Optional[TokenUsage]: | ||
| """Extract token usage from LiteLLM response for public API.""" | ||
| try: | ||
| usage = None | ||
|
|
||
| # Handle both dict and ModelResponse object formats | ||
| if isinstance(response, dict): | ||
| usage = response.get("usage", {}) | ||
| else: | ||
| # ModelResponse object | ||
| usage = getattr(response, 'usage', None) | ||
|
|
||
| if not usage: | ||
| return None | ||
|
|
||
| # Extract token counts with support for both dict and object access | ||
| if isinstance(usage, dict): | ||
| return TokenUsage( | ||
| prompt_tokens=usage.get("prompt_tokens", 0), | ||
| completion_tokens=usage.get("completion_tokens", 0), | ||
| total_tokens=usage.get("total_tokens", 0), | ||
| cached_tokens=usage.get("cached_tokens", 0), | ||
| reasoning_tokens=usage.get("reasoning_tokens", 0), | ||
| audio_input_tokens=usage.get("audio_input_tokens", 0), | ||
| audio_output_tokens=usage.get("audio_output_tokens", 0), | ||
| ) | ||
| else: | ||
| # Object-style access | ||
| return TokenUsage( | ||
| prompt_tokens=getattr(usage, 'prompt_tokens', 0) or 0, | ||
| completion_tokens=getattr(usage, 'completion_tokens', 0) or 0, | ||
| total_tokens=getattr(usage, 'total_tokens', 0) or 0, | ||
| cached_tokens=getattr(usage, 'cached_tokens', 0) or 0, | ||
| reasoning_tokens=getattr(usage, 'reasoning_tokens', 0) or 0, | ||
| audio_input_tokens=getattr(usage, 'audio_input_tokens', 0) or 0, | ||
| audio_output_tokens=getattr(usage, 'audio_output_tokens', 0) or 0, | ||
| ) | ||
|
|
||
| except Exception as e: | ||
| if self.verbose: | ||
| logging.warning(f"Failed to extract token usage: {e}") | ||
| return None | ||
|
Comment on lines
+4243
to
+4284
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧩 Analysis chain🏁 Script executed: find . -name "llm.py" -path "*/praisonai-agents/*" | head -5Repository: MervinPraison/PraisonAI Length of output: 116 🏁 Script executed: wc -l ./src/praisonai-agents/praisonaiagents/llm/llm.pyRepository: MervinPraison/PraisonAI Length of output: 121 🏁 Script executed: sed -n '4245,4286p' ./src/praisonai-agents/praisonaiagents/llm/llm.pyRepository: MervinPraison/PraisonAI Length of output: 2140 🏁 Script executed: sed -n '1750,1760p' ./src/praisonai-agents/praisonaiagents/llm/llm.pyRepository: MervinPraison/PraisonAI Length of output: 794 🏁 Script executed: rg "RouterAgent" ./src/praisonai-agents/praisonaiagents/ -lRepository: MervinPraison/PraisonAI Length of output: 184 🏁 Script executed: cat -n ./src/praisonai-agents/praisonaiagents/agent/router_agent.py | head -100Repository: MervinPraison/PraisonAI Length of output: 4645 🏁 Script executed: rg "cost" ./src/praisonai-agents/praisonaiagents/agent/router_agent.py -A 3 -B 3Repository: MervinPraison/PraisonAI Length of output: 3576 🏁 Script executed: rg "class TokenUsage" ./src/praisonai-agents/praisonaiagents/ -A 20Repository: MervinPraison/PraisonAI Length of output: 1811 🏁 Script executed: rg "_extract_token_usage" ./src/praisonai-agents/praisonaiagents/ -B 5 -A 5Repository: MervinPraison/PraisonAI Length of output: 2215 🏁 Script executed: rg "estimate_cost" ./src/praisonai-agents/praisonaiagents/llm/model_router.py -B 2 -A 10Repository: MervinPraison/PraisonAI Length of output: 541 🏁 Script executed: sed -n '170,185p' ./src/praisonai-agents/praisonaiagents/agent/router_agent.pyRepository: MervinPraison/PraisonAI Length of output: 573 🏁 Script executed: sed -n '190,240p' ./src/praisonai-agents/praisonaiagents/agent/router_agent.pyRepository: MervinPraison/PraisonAI Length of output: 2043 🏁 Script executed: sed -n '240,280p' ./src/praisonai-agents/praisonaiagents/agent/router_agent.pyRepository: MervinPraison/PraisonAI Length of output: 2072 🏁 Script executed: rg "total_tokens" ./src/praisonai-agents/praisonaiagents/llm/model_router.py -B 2 -A 2Repository: MervinPraison/PraisonAI Length of output: 49 🏁 Script executed: sed -n '1750,1770p' ./src/praisonai-agents/praisonaiagents/llm/llm.pyRepository: MervinPraison/PraisonAI Length of output: 1319 🏁 Script executed: rg "input_tokens|output_tokens" ./src/praisonai-agents/praisonaiagents/llm/llm.py | head -20Repository: MervinPraison/PraisonAI Length of output: 1069 Add fallback to Responses API token field names in The method at lines 4245-4286 only looks for Add fallback logic to check Suggested patch if isinstance(usage, dict):
+ prompt_tokens = usage.get("prompt_tokens", usage.get("input_tokens", 0))
+ completion_tokens = usage.get("completion_tokens", usage.get("output_tokens", 0))
return TokenUsage(
- prompt_tokens=usage.get("prompt_tokens", 0),
- completion_tokens=usage.get("completion_tokens", 0),
- total_tokens=usage.get("total_tokens", 0),
+ prompt_tokens=prompt_tokens,
+ completion_tokens=completion_tokens,
+ total_tokens=usage.get("total_tokens", prompt_tokens + completion_tokens),
cached_tokens=usage.get("cached_tokens", 0),
reasoning_tokens=usage.get("reasoning_tokens", 0),
audio_input_tokens=usage.get("audio_input_tokens", 0),
audio_output_tokens=usage.get("audio_output_tokens", 0),
)
else:
+ prompt_tokens = getattr(usage, "prompt_tokens", None)
+ if prompt_tokens is None:
+ prompt_tokens = getattr(usage, "input_tokens", 0) or 0
+ completion_tokens = getattr(usage, "completion_tokens", None)
+ if completion_tokens is None:
+ completion_tokens = getattr(usage, "output_tokens", 0) or 0
return TokenUsage(
- prompt_tokens=getattr(usage, 'prompt_tokens', 0) or 0,
- completion_tokens=getattr(usage, 'completion_tokens', 0) or 0,
- total_tokens=getattr(usage, 'total_tokens', 0) or 0,
+ prompt_tokens=prompt_tokens,
+ completion_tokens=completion_tokens,
+ total_tokens=getattr(usage, 'total_tokens', prompt_tokens + completion_tokens) or (prompt_tokens + completion_tokens),
cached_tokens=getattr(usage, 'cached_tokens', 0) or 0,
reasoning_tokens=getattr(usage, 'reasoning_tokens', 0) or 0,
audio_input_tokens=getattr(usage, 'audio_input_tokens', 0) or 0,
audio_output_tokens=getattr(usage, 'audio_output_tokens', 0) or 0,
)🧰 Tools🪛 Ruff (0.15.7)[warning] 4283-4283: Do not catch blind exception: (BLE001) 🤖 Prompt for AI Agents |
||
|
|
||
| def set_current_agent(self, agent_name: Optional[str]): | ||
| """Set the current agent name for token tracking.""" | ||
| self.current_agent_name = agent_name | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2. Routeragent token_usage not persisted
📎 Requirement gap✧ QualityAgent Prompt
ⓘ Copy this prompt and use it to remediate the issue with your preferred AI generation tools