Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions examples/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import portus

logging.basicConfig(level=logging.INFO)
logging.basicConfig(level=logging.WARNING)

engine = create_engine(
"postgresql://readonly_role:>sU9y95R([email protected]/netflix?options=endpoint%3Dep-young-breeze-a5cq8xns&sslmode=require"
Expand Down Expand Up @@ -34,8 +34,9 @@
thread = session.thread()
thread.ask("count cancelled shows by directors")
print(thread.text())
print(thread.code)
print(thread.df())
print(f"\n```\n{thread.code}\n```\n")
df = thread.df()
print(f"\n{df.to_markdown() if df is not None else df}\n")

plot = thread.plot()
print(plot.text)
67 changes: 66 additions & 1 deletion portus/agents/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
from typing import Any

from langchain_core.messages import HumanMessage
from langchain_core.runnables import RunnableConfig
from langgraph.graph.state import CompiledStateGraph

from portus.agents.frontend.text_frontend import TextStreamFrontend
from portus.configs.llm import LLMConfig
from portus.core import Executor, Opa, Session

Expand All @@ -16,7 +19,7 @@

class AgentExecutor(Executor):
"""
Base class for agents that execute with a DuckDB connection and LLM configuration.
Base class for LangGraph agents that execute with a DuckDB connection and LLM configuration.
Provides common functionality for graph caching, message handling, and OPA processing.
"""

Expand All @@ -25,6 +28,7 @@ def __init__(self) -> None:
self._cached_compiled_graph: Any | None = None
self._cached_connection_id: int | None = None
self._cached_llm_config_id: int | None = None
self._graph_recursion_limit = 50

def _get_data_connection(self, session: Session) -> Any:
"""Get DuckDB connection from session."""
Expand Down Expand Up @@ -132,3 +136,64 @@ def _update_message_history(self, session: Session, cache_scope: str, final_mess
"""Update message history in cache with final messages from graph execution."""
if final_messages:
self._set_messages(session, cache_scope, final_messages)

def _invoke_graph(
self,
compiled_graph: CompiledStateGraph[Any],
start_state: dict[str, Any],
*,
config: RunnableConfig | None = None,
stream: bool = True,
**kwargs: Any,
) -> Any:
"""Invoke the graph with the given start state and return the output state."""
if stream:
return self._execute_stream_sync(compiled_graph, start_state, config=config, **kwargs)
else:
return compiled_graph.invoke(start_state, config=config)

@staticmethod
async def _execute_stream(
compiled_graph: CompiledStateGraph[Any],
start_state: dict[str, Any],
*,
config: RunnableConfig | None = None,
**kwargs: Any,
) -> Any:
writer = TextStreamFrontend(start_state)
last_state = None
async for mode, chunk in compiled_graph.astream(
start_state,
stream_mode=["values", "messages"],
config=config,
**kwargs,
):
writer.write_stream_chunk(mode, chunk)
if mode == "values":
last_state = chunk
writer.end()
assert last_state is not None
return last_state

@staticmethod
def _execute_stream_sync(
compiled_graph: CompiledStateGraph[Any],
start_state: dict[str, Any],
*,
config: RunnableConfig | None = None,
**kwargs: Any,
) -> Any:
writer = TextStreamFrontend(start_state)
last_state = None
for mode, chunk in compiled_graph.stream(
start_state,
stream_mode=["values", "messages"],
config=config,
**kwargs,
):
writer.write_stream_chunk(mode, chunk)
if mode == "values":
last_state = chunk
writer.end()
assert last_state is not None
return last_state
Empty file.
35 changes: 35 additions & 0 deletions portus/agents/frontend/messages.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from langchain_core.messages import AIMessage, AIMessageChunk, ToolCall


def get_tool_call_sql(tool_call: ToolCall) -> str | None:
args = tool_call["args"]
# Currently, there is only run_sql_query with an sql param
if "sql" in args:
sql = args["sql"]
assert isinstance(sql, str), f"Expected SQL to be a string, got {type(sql)}"
return sql
return None


def get_reasoning_content(message: AIMessage | AIMessageChunk) -> str:
# Assume only one of the reasoning parts is present, so there will be no duplication.
reasoning_text = ""

# OpenAI output_version: v0
reasoning_chunk = message.additional_kwargs.get("reasoning", {})
reasoning_summary_chunks = reasoning_chunk.get("summary", [])
for reasoning_summary_chunk in reasoning_summary_chunks:
reasoning_text += reasoning_summary_chunk.get("text", "")

# "Qwen" style reasoning:
reasoning_text += message.additional_kwargs.get("reasoning_content", "")

# OpenAI output_version: responses/v1
blocks = message.content if isinstance(message.content, list) else [message.content]
for block in blocks:
if isinstance(block, dict) and block.get("type", "text") == "reasoning":
for summary in block["summary"]:
reasoning_text += summary["text"]

assert isinstance(reasoning_text, str), f"Expected a string, got {type(reasoning_text)}"
return reasoning_text
124 changes: 124 additions & 0 deletions portus/agents/frontend/text_frontend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import re
from typing import Any, TextIO

import pandas as pd
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage, BaseMessageChunk, ToolMessage

from portus.agents.frontend.messages import get_reasoning_content, get_tool_call_sql


class TextStreamFrontend:
"""Helper for streaming LangGraph LLM outputs to a text stream (stdout, stderr, a file, etc.)."""

def __init__(
self,
start_state: dict[str, Any],
*,
writer: TextIO | None = None,
escape_markdown: bool = False,
show_headers: bool = True,
pretty_sql: bool = False,
):
self._writer = writer # Use io.Writer type in Python 3.14
self._escape_markdown = escape_markdown
self._show_headers = show_headers
self._message_count = len(start_state.get("messages", []))
self._started = False
self._is_tool_calling = False
self._pretty_sql = pretty_sql

def write(self, text: str) -> None:
if not self._started:
self.start()
print(text, end="", flush=True, file=self._writer)

def write_dataframe(self, df: pd.DataFrame) -> None:
self.write(df.to_markdown())

def write_message_chunk(self, chunk: BaseMessageChunk) -> None:
if not isinstance(chunk, AIMessageChunk):
return # Handle ToolMessage results in add_state_chunk

reasoning_text = get_reasoning_content(chunk)
text = reasoning_text + chunk.text()
if self._escape_markdown:
text = escape_markdown_text(text)
self.write(text)

if len(chunk.tool_call_chunks) > 0:
# N.B. LangChain sometimes waits for the whole string to complete before yielding chunks
# That's why long "sql" tool calls take some time to show up and then the whole sql is shown in a batch
if not self._is_tool_calling:
self.write("\n\n```\n") # Open code block
self._is_tool_calling = True
for tool_call_chunk in chunk.tool_call_chunks:
if tool_call_chunk["args"] is not None:
self.write(tool_call_chunk["args"])
elif self._is_tool_calling:
self.write("\n```\n\n") # Close code block
self._is_tool_calling = False

def write_state_chunk(self, state_chunk: dict[str, Any]) -> None:
"""The state chunk is assumed to contain a "messages" key."""
if self._is_tool_calling:
self.write("\n```\n\n") # Close code block
self._is_tool_calling = False

# Loop through new messages only.
# We could either force the caller of the frontend to provide new messages only,
# but for ease of use we assume the state contains a list of messages and do it here.
messages: list[BaseMessage] = state_chunk.get("messages", [])
new_messages = messages[self._message_count :]
self._message_count += len(new_messages)

for message in new_messages:
if isinstance(message, ToolMessage):
if message.artifact is not None:
if "df" in message.artifact and message.artifact["df"] is not None:
self.write_dataframe(message.artifact["df"])
else:
self.write(f"\n```\n{message.content}\n```\n\n")
else:
self.write(f"\n```\n{message.content}\n```\n\n")
elif self._pretty_sql and isinstance(message, AIMessage):
# During tool calling we show raw JSON chunks, but for SQL we also want pretty formatting.
for tool_call in message.tool_calls:
sql = get_tool_call_sql(tool_call)
if sql is not None:
self.write(f"\n```sql\n{sql}\n```\n\n")

def write_stream_chunk(self, mode: str, chunk: Any) -> None:
if mode == "messages":
token_chunk, _token_metadata = chunk
self.write_message_chunk(token_chunk)
elif mode == "values":
if isinstance(chunk, dict):
self.write_state_chunk(chunk)
else:
raise ValueError(f"Unexpected chunk type: {type(chunk)}")

def start(self) -> None:
self._started = True
if self._show_headers:
self.write("=" * 8 + " <THINKING> " + "=" * 8 + "\n\n")

def end(self) -> None:
if self._show_headers:
self.write("\n" + "=" * 8 + " </THINKING> " + "=" * 8 + "\n\n")
self._started = False


def escape_currency_dollar_signs(text: str) -> str:
"""Escapes dollar signs in a string to prevent MathJax interpretation in markdown environments."""
return re.sub(r"\$(\d+)", r"\$\1", text)


def escape_strikethrough(text: str) -> str:
"""Prevents aggressive markdown strikethrough formatting."""
return re.sub(r"~(.?\d+)", r"\~\1", text)


def escape_markdown_text(text: str) -> str:
text = escape_strikethrough(text)
text = escape_currency_dollar_signs(text)
return text
25 changes: 11 additions & 14 deletions portus/agents/lighthouse/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,13 @@ def _get_graph_and_compiled(self, session: Session) -> tuple[Any, ExecuteSubmit,
return data_connection, self._cached_graph, compiled_graph

def execute(
self, session: Session, opa: Opa, *, rows_limit: int = 100, cache_scope: str = "common_cache"
self,
session: Session,
opa: Opa,
*,
rows_limit: int = 100,
cache_scope: str = "common_cache",
stream: bool = True,
) -> ExecutionResult:
# TODO rows_limit is ignored

Expand All @@ -69,23 +75,14 @@ def execute(
]

init_state = graph.init_state(messages_with_system)
last_state: dict[str, Any] | None = None
try:
for chunk in compiled_graph.stream(
init_state,
stream_mode="values",
config=RunnableConfig(recursion_limit=50),
):
assert isinstance(chunk, dict)
last_state = chunk
except Exception as e:
return ExecutionResult(text=str(e), meta={"messages": messages_with_system})
assert last_state is not None
invoke_config = RunnableConfig(recursion_limit=self._graph_recursion_limit)
last_state = self._invoke_graph(compiled_graph, init_state, config=invoke_config, stream=stream)
execution_result = graph.get_result(last_state)

# Update message history (excluding system message which we add dynamically)
final_messages = last_state.get("messages", [])
if final_messages:
messages_without_system = [msg for msg in final_messages if msg.type != "system"]
self._update_message_history(session, cache_scope, messages_without_system)

return graph.get_result(last_state)
return execution_result
18 changes: 14 additions & 4 deletions portus/agents/react_duckdb/agent.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import logging
from typing import Any

from langchain_core.runnables import RunnableConfig

from portus.agents.base import AgentExecutor
from portus.configs.llm import LLMConfig
from portus.core import ExecutionResult, Opa, Session
Expand All @@ -15,7 +17,13 @@ def _create_graph(self, data_connection: Any, llm_config: LLMConfig) -> Any:
return make_react_duckdb_agent(data_connection, llm_config.chat_model)

def execute(
self, session: Session, opa: Opa, *, rows_limit: int = 100, cache_scope: str = "common_cache"
self,
session: Session,
opa: Opa,
*,
rows_limit: int = 100,
cache_scope: str = "common_cache",
stream: bool = True,
) -> ExecutionResult:
# Get or create graph (cached after first use)
data_connection, compiled_graph = self._get_or_create_cached_graph(session)
Expand All @@ -24,13 +32,15 @@ def execute(
messages = self._process_opa(session, opa, cache_scope)

# Execute the graph
state = compiled_graph.invoke({"messages": messages})
answer: AgentResponse = state["structured_response"]
init_state = {"messages": messages}
invoke_config = RunnableConfig(recursion_limit=self._graph_recursion_limit)
last_state = self._invoke_graph(compiled_graph, init_state, config=invoke_config, stream=stream)
answer: AgentResponse = last_state["structured_response"]
logger.info("Generated query: %s", answer.sql)
df = data_connection.execute(f"SELECT * FROM ({sql_strip(answer.sql)}) t LIMIT {rows_limit}").df()

# Update message history
final_messages = state.get("messages", [])
final_messages = last_state.get("messages", [])
self._update_message_history(session, cache_scope, final_messages)

return ExecutionResult(text=answer.explanation, code=answer.sql, df=df, meta={})
1 change: 1 addition & 0 deletions portus/core/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,5 +27,6 @@ def execute(
*,
rows_limit: int = 100,
cache_scope: str = "common_cache",
stream: bool = True,
) -> ExecutionResult:
pass
Loading