Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions src/rosa/rosa.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, AsyncIterable, Dict, Literal, Optional, Union

from rich.console import Console
from langchain.agents import AgentExecutor, create_tool_calling_agent
from langchain.prompts import MessagesPlaceholder
from langchain_community.callbacks import get_openai_callback
Expand Down Expand Up @@ -216,7 +217,7 @@ async def astream(self, query: str) -> AsyncIterable[Dict[str, Any]]:
# Extract the content from the event and yield it
content = event["data"]["chunk"].content
if content:
final_output += f" {content}"
final_output += content
yield {"type": "token", "content": content}

# Handle tool start events
Expand Down Expand Up @@ -328,9 +329,9 @@ def _print_usage(self, cb):
"""Print the token usage if show_token_usage is enabled."""
if cb is None or not self.__show_token_usage:
return
print(f"[bold]Prompt Tokens:[/bold] {cb.prompt_tokens}")
print(f"[bold]Completion Tokens:[/bold] {cb.completion_tokens}")
print(f"[bold]Total Cost (USD):[/bold] ${cb.total_cost}")
Console().print(f"[bold]Prompt Tokens:[/bold] {cb.prompt_tokens}")
Console().print(f"[bold]Completion Tokens:[/bold] {cb.completion_tokens}")
Console().print(f"[bold]Total Cost (USD):[/bold] ${cb.total_cost}")

def _record_chat_history(self, query: str, response: str):
"""Record the chat history if accumulation is enabled."""
Expand Down