-
Notifications
You must be signed in to change notification settings - Fork 2
Add stdout token streaming #56
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
9897cf5
Add basic text streaming
mare5x cc46ba9
Add tests for escape_markdown_text
mare5x bd936b7
Prettier demo.py output
mare5x 3c6cf88
More structured printing
mare5x ac0906d
Stream new messages only
mare5x 7fa3a06
More TextStreamFrontend formatting
mare5x c40d851
Remove ToolCallSQL object
mare5x File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,7 +5,7 @@ | |
|
|
||
| import portus | ||
|
|
||
| logging.basicConfig(level=logging.INFO) | ||
| logging.basicConfig(level=logging.WARNING) | ||
|
|
||
| engine = create_engine( | ||
| "postgresql://readonly_role:>sU9y95R([email protected]/netflix?options=endpoint%3Dep-young-breeze-a5cq8xns&sslmode=require" | ||
|
|
@@ -34,8 +34,9 @@ | |
| thread = session.thread() | ||
| thread.ask("count cancelled shows by directors") | ||
| print(thread.text()) | ||
| print(thread.code) | ||
| print(thread.df()) | ||
| print(f"\n```\n{thread.code}\n```\n") | ||
| df = thread.df() | ||
| print(f"\n{df.to_markdown() if df is not None else df}\n") | ||
|
|
||
| plot = thread.plot() | ||
| print(plot.text) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,35 @@ | ||
| from langchain_core.messages import AIMessage, AIMessageChunk, ToolCall | ||
|
|
||
|
|
||
| def get_tool_call_sql(tool_call: ToolCall) -> str | None: | ||
| args = tool_call["args"] | ||
| # Currently, there is only run_sql_query with an sql param | ||
| if "sql" in args: | ||
| sql = args["sql"] | ||
| assert isinstance(sql, str), f"Expected SQL to be a string, got {type(sql)}" | ||
| return sql | ||
| return None | ||
|
|
||
|
|
||
| def get_reasoning_content(message: AIMessage | AIMessageChunk) -> str: | ||
| # Assume only one of the reasoning parts is present, so there will be no duplication. | ||
| reasoning_text = "" | ||
|
|
||
| # OpenAI output_version: v0 | ||
| reasoning_chunk = message.additional_kwargs.get("reasoning", {}) | ||
| reasoning_summary_chunks = reasoning_chunk.get("summary", []) | ||
| for reasoning_summary_chunk in reasoning_summary_chunks: | ||
| reasoning_text += reasoning_summary_chunk.get("text", "") | ||
|
|
||
| # "Qwen" style reasoning: | ||
| reasoning_text += message.additional_kwargs.get("reasoning_content", "") | ||
|
|
||
| # OpenAI output_version: responses/v1 | ||
| blocks = message.content if isinstance(message.content, list) else [message.content] | ||
| for block in blocks: | ||
| if isinstance(block, dict) and block.get("type", "text") == "reasoning": | ||
| for summary in block["summary"]: | ||
| reasoning_text += summary["text"] | ||
|
|
||
| assert isinstance(reasoning_text, str), f"Expected a string, got {type(reasoning_text)}" | ||
| return reasoning_text |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,124 @@ | ||
| import re | ||
| from typing import Any, TextIO | ||
|
|
||
| import pandas as pd | ||
| from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage, BaseMessageChunk, ToolMessage | ||
|
|
||
| from portus.agents.frontend.messages import get_reasoning_content, get_tool_call_sql | ||
|
|
||
|
|
||
| class TextStreamFrontend: | ||
| """Helper for streaming LangGraph LLM outputs to a text stream (stdout, stderr, a file, etc.).""" | ||
|
|
||
| def __init__( | ||
| self, | ||
| start_state: dict[str, Any], | ||
| *, | ||
| writer: TextIO | None = None, | ||
| escape_markdown: bool = False, | ||
| show_headers: bool = True, | ||
| pretty_sql: bool = False, | ||
| ): | ||
| self._writer = writer # Use io.Writer type in Python 3.14 | ||
| self._escape_markdown = escape_markdown | ||
| self._show_headers = show_headers | ||
| self._message_count = len(start_state.get("messages", [])) | ||
| self._started = False | ||
| self._is_tool_calling = False | ||
| self._pretty_sql = pretty_sql | ||
|
|
||
| def write(self, text: str) -> None: | ||
| if not self._started: | ||
| self.start() | ||
| print(text, end="", flush=True, file=self._writer) | ||
|
|
||
| def write_dataframe(self, df: pd.DataFrame) -> None: | ||
| self.write(df.to_markdown()) | ||
|
|
||
| def write_message_chunk(self, chunk: BaseMessageChunk) -> None: | ||
| if not isinstance(chunk, AIMessageChunk): | ||
| return # Handle ToolMessage results in add_state_chunk | ||
|
|
||
| reasoning_text = get_reasoning_content(chunk) | ||
| text = reasoning_text + chunk.text() | ||
| if self._escape_markdown: | ||
| text = escape_markdown_text(text) | ||
| self.write(text) | ||
|
|
||
| if len(chunk.tool_call_chunks) > 0: | ||
| # N.B. LangChain sometimes waits for the whole string to complete before yielding chunks | ||
| # That's why long "sql" tool calls take some time to show up and then the whole sql is shown in a batch | ||
| if not self._is_tool_calling: | ||
| self.write("\n\n```\n") # Open code block | ||
| self._is_tool_calling = True | ||
| for tool_call_chunk in chunk.tool_call_chunks: | ||
| if tool_call_chunk["args"] is not None: | ||
| self.write(tool_call_chunk["args"]) | ||
| elif self._is_tool_calling: | ||
| self.write("\n```\n\n") # Close code block | ||
| self._is_tool_calling = False | ||
|
|
||
| def write_state_chunk(self, state_chunk: dict[str, Any]) -> None: | ||
| """The state chunk is assumed to contain a "messages" key.""" | ||
| if self._is_tool_calling: | ||
| self.write("\n```\n\n") # Close code block | ||
| self._is_tool_calling = False | ||
|
|
||
| # Loop through new messages only. | ||
| # We could either force the caller of the frontend to provide new messages only, | ||
| # but for ease of use we assume the state contains a list of messages and do it here. | ||
| messages: list[BaseMessage] = state_chunk.get("messages", []) | ||
| new_messages = messages[self._message_count :] | ||
| self._message_count += len(new_messages) | ||
|
|
||
| for message in new_messages: | ||
| if isinstance(message, ToolMessage): | ||
| if message.artifact is not None: | ||
| if "df" in message.artifact and message.artifact["df"] is not None: | ||
| self.write_dataframe(message.artifact["df"]) | ||
| else: | ||
| self.write(f"\n```\n{message.content}\n```\n\n") | ||
| else: | ||
| self.write(f"\n```\n{message.content}\n```\n\n") | ||
| elif self._pretty_sql and isinstance(message, AIMessage): | ||
| # During tool calling we show raw JSON chunks, but for SQL we also want pretty formatting. | ||
| for tool_call in message.tool_calls: | ||
| sql = get_tool_call_sql(tool_call) | ||
| if sql is not None: | ||
| self.write(f"\n```sql\n{sql}\n```\n\n") | ||
|
|
||
| def write_stream_chunk(self, mode: str, chunk: Any) -> None: | ||
| if mode == "messages": | ||
| token_chunk, _token_metadata = chunk | ||
| self.write_message_chunk(token_chunk) | ||
| elif mode == "values": | ||
| if isinstance(chunk, dict): | ||
| self.write_state_chunk(chunk) | ||
| else: | ||
| raise ValueError(f"Unexpected chunk type: {type(chunk)}") | ||
|
|
||
| def start(self) -> None: | ||
| self._started = True | ||
| if self._show_headers: | ||
| self.write("=" * 8 + " <THINKING> " + "=" * 8 + "\n\n") | ||
|
|
||
| def end(self) -> None: | ||
| if self._show_headers: | ||
| self.write("\n" + "=" * 8 + " </THINKING> " + "=" * 8 + "\n\n") | ||
| self._started = False | ||
|
|
||
|
|
||
| def escape_currency_dollar_signs(text: str) -> str: | ||
| """Escapes dollar signs in a string to prevent MathJax interpretation in markdown environments.""" | ||
| return re.sub(r"\$(\d+)", r"\$\1", text) | ||
|
|
||
|
|
||
| def escape_strikethrough(text: str) -> str: | ||
| """Prevents aggressive markdown strikethrough formatting.""" | ||
| return re.sub(r"~(.?\d+)", r"\~\1", text) | ||
|
|
||
|
|
||
| def escape_markdown_text(text: str) -> str: | ||
| text = escape_strikethrough(text) | ||
| text = escape_currency_dollar_signs(text) | ||
| return text |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.