diff --git a/src/praisonai-agents/praisonaiagents/agent/router_agent.py b/src/praisonai-agents/praisonaiagents/agent/router_agent.py index 4237842f3..bbcdd9764 100644 --- a/src/praisonai-agents/praisonaiagents/agent/router_agent.py +++ b/src/praisonai-agents/praisonaiagents/agent/router_agent.py @@ -10,7 +10,8 @@ from typing import Dict, List, Optional, Any, Union from .agent import Agent from ..llm.model_router import ModelRouter -from ..llm import LLM +from ..llm import LLM, TokenUsage +from ..trace.protocol import get_default_emitter logger = logging.getLogger(__name__) @@ -213,8 +214,8 @@ def _execute_with_model( full_prompt = f"{context}\n\n{prompt}" try: - # Execute with the selected model - response = llm_instance.get_response( + # Execute with the selected model, requesting token usage tracking + result = llm_instance.get_response( prompt=full_prompt, system_prompt=self._build_system_prompt(), tools=tools, @@ -225,16 +226,45 @@ def _execute_with_model( agent_role=self.role, agent_tools=[t.__name__ if hasattr(t, '__name__') else str(t) for t in (tools or [])], execute_tool_fn=self.execute_tool if tools else None, + return_token_usage=True, # Request token usage information **kwargs ) + # Extract response and token usage + if isinstance(result, tuple): + response, token_usage = result + else: + # Fallback for backward compatibility + response = result + token_usage = TokenUsage() + # Update usage statistics self.model_usage_stats[model_name]['calls'] += 1 + self.model_usage_stats[model_name]['tokens'] += token_usage.total_tokens + + # Calculate and store cost estimate + model_info = self.model_router.get_model_info(model_name) + if model_info and token_usage.total_tokens > 0: + cost = self.model_router.estimate_cost(model_name, token_usage.total_tokens) + self.model_usage_stats[model_name]['cost'] += cost - # TODO: Implement token tracking when LLM.get_response() is updated to return token usage - # The LLM response currently returns only text, but litellm provides usage info in: - # response.get("usage") with prompt_tokens, completion_tokens, and total_tokens - # This would require modifying the LLM class to return both text and metadata + # Emit token usage via trace system for observability + try: + trace_emitter = get_default_emitter() + trace_emitter.output( + content=f"RouterAgent routing decision completed", + agent_name=self.name, + metadata={ + 'selected_model': model_name, + 'routing_strategy': self.routing_strategy, + 'token_usage': token_usage.to_dict(), + 'estimated_cost': self.model_usage_stats[model_name]['cost'], + 'total_calls': self.model_usage_stats[model_name]['calls'], + } + ) + except Exception as trace_error: + # Don't fail the request if tracing fails + logger.debug(f"Failed to emit trace event: {trace_error}") return response diff --git a/src/praisonai-agents/praisonaiagents/llm/__init__.py b/src/praisonai-agents/praisonaiagents/llm/__init__.py index 803e469c1..708ed36ed 100644 --- a/src/praisonai-agents/praisonaiagents/llm/__init__.py +++ b/src/praisonai-agents/praisonaiagents/llm/__init__.py @@ -94,6 +94,10 @@ def __getattr__(name): from .rate_limiter import RateLimiter _lazy_cache[name] = RateLimiter return RateLimiter + elif name == "TokenUsage": + from .llm import TokenUsage + _lazy_cache[name] = TokenUsage + return TokenUsage raise AttributeError(f"module {__name__!r} has no attribute {name!r}") @@ -117,5 +121,6 @@ def __getattr__(name): "ModelProfile", "TaskComplexity", "create_routing_agent", - "RateLimiter" + "RateLimiter", + "TokenUsage" ] diff --git a/src/praisonai-agents/praisonaiagents/llm/llm.py b/src/praisonai-agents/praisonaiagents/llm/llm.py index 3890d428c..e816f5775 100644 --- a/src/praisonai-agents/praisonaiagents/llm/llm.py +++ b/src/praisonai-agents/praisonaiagents/llm/llm.py @@ -4,6 +4,7 @@ import re import inspect import asyncio +from dataclasses import dataclass from typing import Any, Dict, List, Optional, Union, Literal, Callable, TYPE_CHECKING if TYPE_CHECKING: @@ -90,6 +91,36 @@ def _is_context_limit_error(self, error_message: str) -> bool: ] return any(phrase in error_message.lower() for phrase in context_limit_phrases) + +@dataclass +class TokenUsage: + """ + Token usage information from LLM response. + + This class provides structured access to token consumption data + returned by language models, enabling cost tracking and observability. + """ + prompt_tokens: int = 0 + completion_tokens: int = 0 + total_tokens: int = 0 + cached_tokens: int = 0 + reasoning_tokens: int = 0 + audio_input_tokens: int = 0 + audio_output_tokens: int = 0 + + def to_dict(self) -> Dict[str, int]: + """Convert to dictionary format.""" + return { + 'prompt_tokens': self.prompt_tokens, + 'completion_tokens': self.completion_tokens, + 'total_tokens': self.total_tokens, + 'cached_tokens': self.cached_tokens, + 'reasoning_tokens': self.reasoning_tokens, + 'audio_input_tokens': self.audio_input_tokens, + 'audio_output_tokens': self.audio_output_tokens, + } + + class LLM: """ Easy to use wrapper for language models. Supports multiple providers like OpenAI, @@ -1566,10 +1597,24 @@ def get_response( stream: bool = True, stream_callback: Optional[Callable] = None, emit_events: bool = False, + return_token_usage: bool = False, **kwargs - ) -> str: + ) -> Union[str, tuple[str, TokenUsage]]: """Enhanced get_response with all OpenAI-like features""" logging.debug(f"Getting response from {self.model}") + + # Variable to store final response for token usage extraction + _final_llm_response = None + + # Helper closure to return appropriate format based on return_token_usage + def _prepare_return_value(text: str) -> Union[str, tuple]: + if not return_token_usage: + return text + token_usage = self._extract_token_usage(_final_llm_response) if _final_llm_response else None + if token_usage is None: + token_usage = TokenUsage() + return text, token_usage + # Log all self values when in debug mode self._log_llm_config( 'LLM instance', @@ -1864,6 +1909,7 @@ def get_response( reasoning_content = resp["choices"][0]["message"].get("provider_specific_fields", {}).get("reasoning_content") response_text = resp["choices"][0]["message"]["content"] final_response = resp + _final_llm_response = resp # Store for token usage extraction # Emit StreamEvent for reasoning content if callback provided if _emit and reasoning_content: @@ -2094,6 +2140,7 @@ def get_response( **kwargs ) ) + _final_llm_response = final_response # Store for token usage extraction # Handle None content from Gemini response_content = final_response["choices"][0]["message"].get("content") response_text = response_content if response_content is not None else "" @@ -2285,6 +2332,7 @@ def get_response( **kwargs ) ) + _final_llm_response = final_response # Store for token usage extraction # Handle None content from Gemini response_content = final_response["choices"][0]["message"].get("content") response_text = response_content if response_content is not None else "" @@ -2698,7 +2746,7 @@ def get_response( task_id=task_id ) callback_executed = True - return final_response_text + return _prepare_return_value(final_response_text) # No tool calls were made in this iteration, return the response generation_time_val = time.time() - start_time @@ -2787,7 +2835,7 @@ def get_response( task_id=task_id ) callback_executed = True - return response_text + return _prepare_return_value(response_text) if not self_reflect: if verbose and not interaction_displayed: @@ -2816,8 +2864,8 @@ def get_response( # Return reasoning content if reasoning_steps is True if reasoning_steps and stored_reasoning_content: - return stored_reasoning_content - return response_text + return _prepare_return_value(stored_reasoning_content) + return _prepare_return_value(response_text) # Handle self-reflection loop while reflection_count < max_reflect: @@ -2999,7 +3047,7 @@ def get_response( agent_name=agent_name, agent_role=agent_role, agent_tools=agent_tools, task_name=task_name, task_description=task_description, task_id=task_id) interaction_displayed = True - return response_text + return _prepare_return_value(response_text) continue except Exception as e: _get_display_functions()['display_error'](f"Error in LLM response: {str(e)}") @@ -3010,12 +3058,12 @@ def get_response( _get_display_functions()['display_interaction'](prompt, response_text, markdown=markdown, generation_time=time.time() - start_time, console=self.console) interaction_displayed = True - return response_text + return _prepare_return_value(response_text) except Exception as error: _get_display_functions()['display_error'](f"Error in get_response: {str(error)}") raise - + # Log completion time if in debug mode if logging.getLogger().getEffectiveLevel() == logging.DEBUG: total_time = time.time() - start_time @@ -4192,6 +4240,49 @@ def _track_token_usage(self, response: Dict[str, Any], model: str) -> Optional[T logging.warning(f"Failed to track token usage: {e}") return None + def _extract_token_usage(self, response: Union[Dict[str, Any], Any]) -> Optional[TokenUsage]: + """Extract token usage from LiteLLM response for public API.""" + try: + usage = None + + # Handle both dict and ModelResponse object formats + if isinstance(response, dict): + usage = response.get("usage", {}) + else: + # ModelResponse object + usage = getattr(response, 'usage', None) + + if not usage: + return None + + # Extract token counts with support for both dict and object access + if isinstance(usage, dict): + return TokenUsage( + prompt_tokens=usage.get("prompt_tokens", 0), + completion_tokens=usage.get("completion_tokens", 0), + total_tokens=usage.get("total_tokens", 0), + cached_tokens=usage.get("cached_tokens", 0), + reasoning_tokens=usage.get("reasoning_tokens", 0), + audio_input_tokens=usage.get("audio_input_tokens", 0), + audio_output_tokens=usage.get("audio_output_tokens", 0), + ) + else: + # Object-style access + return TokenUsage( + prompt_tokens=getattr(usage, 'prompt_tokens', 0) or 0, + completion_tokens=getattr(usage, 'completion_tokens', 0) or 0, + total_tokens=getattr(usage, 'total_tokens', 0) or 0, + cached_tokens=getattr(usage, 'cached_tokens', 0) or 0, + reasoning_tokens=getattr(usage, 'reasoning_tokens', 0) or 0, + audio_input_tokens=getattr(usage, 'audio_input_tokens', 0) or 0, + audio_output_tokens=getattr(usage, 'audio_output_tokens', 0) or 0, + ) + + except Exception as e: + if self.verbose: + logging.warning(f"Failed to extract token usage: {e}") + return None + def set_current_agent(self, agent_name: Optional[str]): """Set the current agent name for token tracking.""" self.current_agent_name = agent_name