Skip to content

Implement defined queries #1144

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: gated
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions lumen/ai/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def _add_step(self, title: str = "", **kwargs):
return nullcontext(self._null_step) if self.interface is None else self.interface.add_step(title=title, **kwargs)

async def _gather_prompt_context(self, prompt_name: str, messages: list[Message], **context):
context["actor_name"] = self.name
context["memory"] = self._memory
context["actor_name"] = self.name
return context
Expand Down
11 changes: 11 additions & 0 deletions lumen/ai/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,7 @@ class SQLAgent(LumenBaseAgent):
"main": {
"response_model": Sql,
"template": PROMPTS_DIR / "SQLAgent" / "main.jinja2",
"tools": []
},
"find_tables": {
"response_model": make_find_tables_model,
Expand All @@ -521,6 +522,16 @@ class SQLAgent(LumenBaseAgent):

requires = param.List(default=["source", "table_sql_metaset"], readonly=True)

# def __init__(self, **params):
# super().__init__(**params)
# # Dynamically add QueryLookup to tools if defined_queries is populated
# if "defined_queries" in self._memory and self._memory["defined_queries"]:
# from .tools import QueryLookup
# if QueryLookup not in self.prompts['main'].get('tools', []):
# if 'tools' not in self.prompts['main']:
# self.prompts['main']['tools'] = []
# self.prompts['main']['tools'].append(QueryLookup)

_extensions = ('codeeditor', 'tabulator',)

_output_type = SQLOutput
Expand Down
220 changes: 218 additions & 2 deletions lumen/ai/coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from panel import bind
from panel.chat import ChatFeed, ChatInterface, ChatStep
from panel.layout import (
Card, Column, FlexBox, Tabs,
Card, Column, Divider, FlexBox, Tabs,
)
from panel.pane import HTML
from panel.viewable import Viewable, Viewer
Expand All @@ -29,7 +29,7 @@
from .logs import ChatLogs
from .models import YesNo, make_agent_model, make_plan_models
from .tools import (
IterativeTableLookup, TableLookup, Tool, ToolUser,
IterativeTableLookup, QueryLookup, TableLookup, Tool, ToolUser,
)
from .utils import (
fuse_messages, log_debug, mutate_user_message, retry_llm_output,
Expand Down Expand Up @@ -114,6 +114,11 @@ class Coordinator(Viewer, ToolUser):
demo_inputs = param.List(default=DEMO_MESSAGES, doc="""
List of instructions to demo the Coordinator.""")

defined_queries = param.Dict(default={}, doc="""
Dictionary of predefined queries that can be run individually.
The keys are the names of the queries and the values are
dictionaries containing 'sql' and 'description' keys.""")

history = param.Integer(default=3, doc="""
Number of previous user-assistant interactions to include in the chat history.""")

Expand Down Expand Up @@ -232,6 +237,10 @@ def on_rerun(instance, _):
}
self._add_suggestions_to_footer(self.suggestions)

# Add defined query buttons if any exist
if self.defined_queries:
self._add_defined_query_buttons()

if "source" in self._memory and "sources" not in self._memory:
self._memory["sources"] = [self._memory["source"]]
elif "source" not in self._memory and self._memory.get("sources"):
Expand Down Expand Up @@ -325,6 +334,75 @@ async def run_demo(event):
self.interface.param.watch(hide_suggestions, "objects")
return message

def _add_defined_query_buttons(self):
"""Add buttons for defined queries to the interface."""
async def run_defined_query(event):
button = event.obj
query_name = button.name.replace("Query ", "")
with button.param.update(loading=True), self.interface.active_widget.param.update(loading=True):
if query_name in self.defined_queries:
await self._fast_track_sql_visualization(query_name, f"Run query: {query_name}")

async def run_all_queries(event):
button = event.obj
with button.param.update(loading=True), self.interface.active_widget.param.update(loading=True):
if event.new > 1: # prevent double clicks
return
for query_name in self.defined_queries.keys():
while self.interface.disabled:
await asyncio.sleep(1.25)
await self._fast_track_sql_visualization(query_name, f"Run query: {query_name}")
await asyncio.sleep(2)

query_buttons = FlexBox(
*[
Button(
name=f"Query {query_name}",
button_style="outline",
button_type="success",
on_click=run_defined_query,
margin=5,
disabled=self.interface.param.loading
)
for query_name in self.defined_queries
],
margin=(5, 5),
)

# Add the Run All Queries button
run_all_button = Button(
name="Run all queries",
button_type="primary",
on_click=run_all_queries,
margin=5,
disabled=self.interface.param.loading
)

# Disable button temporarily after click to prevent double-clicks
disable_js = "cb_obj.origin.disabled = true; setTimeout(() => cb_obj.origin.disabled = false, 3000)"
for b in query_buttons:
b.js_on_click(code=disable_js)
run_all_button.js_on_click(code=disable_js)

# Add the Run All button to the flexbox
query_buttons.append(run_all_button)

# Add a title above the buttons
container = Column(Divider(), query_buttons)

# Add the buttons to the existing message
message = self.interface.objects[-1]
if not hasattr(message, 'footer_objects') or not message.footer_objects:
message.footer_objects = [container]
# If footer already exists, append to it if it's a Column
elif isinstance(message.footer_objects[0], Column):
message.footer_objects[0].append(container)
else:
# Otherwise create a new Column with existing footer and our container
existing_footer = message.footer_objects[0]
message.footer_objects = [Column(existing_footer, container)]
return message

async def _add_analysis_suggestions(self):
pipeline = self._memory["pipeline"]
current_analysis = self._memory.get("analysis")
Expand All @@ -343,8 +421,135 @@ async def _add_analysis_suggestions(self):

async def _chat_invoke(self, contents: list | str, user: str, instance: ChatInterface):
log_debug(f"New Message: \033[91m{contents!r}\033[0m", show_sep="above")

# Fast track for "Run query: ..." messages
if isinstance(contents, str) and contents.startswith("Run query:"):
query_name = contents.replace("Run query:", "").strip()
if query_name in self.defined_queries:
log_debug(f"Fast-tracking query: {query_name}")
await self._fast_track_sql_visualization(query_name, contents)
return

await self.respond(contents)

async def _fast_track_sql_visualization(self, query_name: str, original_message: str):
"""Fast-track a predefined query directly to SQLAgent then VegaLiteAgent."""
# Get the query data
query_data = self.defined_queries[query_name]
sql = query_data.get('sql', '')
description = query_data.get('description', query_name)
table_name = None

# Extract possible table name from SQL
# TODO: use sqlglot
sql_lower = sql.lower()
for pattern in [r'from\s+([\w._]+)', r'join\s+([\w._]+)']:
matches = re.findall(pattern, sql_lower)
if matches:
table_name = matches[0]
break

# Construct a simple message for the agents
messages = [{"role": "user", "content": f"Execute this SQL query and visualize the results: {sql}"}]

# Create our execution graph manually
agents = {agent.name[:-5]: agent for agent in self.agents}
tools = {tool.name[:-5]: tool for tool in self._tools.get("main", [])}
execution_graph = [
ExecutionNode(
actor=QueryLookup(),
instruction=sql,
title="Finding closest defined query",
render_output=False
)
]

# Check if we need to set up SQL dependencies first
lookup_tool = None
for tool_name in ["IterativeTableLookup", "TableLookup"]:
if tool_name in tools:
lookup_tool = tools[tool_name]
break

# If we don't have table_sql_metaset, add a lookup tool to the execution graph
if lookup_tool and "table_sql_metaset" not in self._memory:
# Create message with table hint if we have one
lookup_message = messages[0]["content"]
if table_name:
lookup_message += f" This query is specifically about the {table_name} table."

# Add lookup tool node
instruction = f"Identify relevant tables for query: {query_name}"
if table_name:
instruction += f" (specifically looking for table: {table_name})"

execution_graph.append(
ExecutionNode(
actor=lookup_tool,
provides=["tables_vector_metaset", "table_sql_metaset"],
instruction=instruction,
title=f"Identifying tables for query: {query_name}",
render_output=False
)
)

# Add SQLAgent node
if 'SQLAgent' in agents:
sql_instruction = f"Execute this predefined query '{query_name}': {sql}"
if table_name:
sql_instruction += f" This query specifically targets the {table_name} table."

execution_graph.append(
ExecutionNode(
actor=agents['SQLAgent'],
provides=["table", "sql", "pipeline", "data"],
instruction=sql_instruction,
title=f"Executing query: {query_name}",
render_output=True
)
)

# Add VegaLiteAgent
if 'VegaLiteAgent' in agents:
viz_instruction = f"Create a visualization for the results of query '{query_name}': {description}"
execution_graph.append(
ExecutionNode(
actor=agents['VegaLiteAgent'],
provides=["view"],
instruction=viz_instruction,
title=f"Visualizing query results: {query_name}",
render_output=True
)
)

if not execution_graph:
self.interface.stream("Could not find required agents for fast-track execution.", user='Lumen')
return

# Execute the graph
self._memory["agent_tool_contexts"] = {}
with self.interface.param.update(loading=True):
# Send the original message to the chat
self.interface.send(original_message, user="User", respond=False)

messages = fuse_messages(
self.interface.serialize(custom_serializer=self._serialize, limit=10),
max_user_messages=self.history
)

for i, node in enumerate(execution_graph):
succeeded = await self._execute_graph_node(node, messages, execution_graph=execution_graph)
if not succeeded:
break
if "pipeline" in self._memory:
await self._add_analysis_suggestions()
log_debug("\033[92mCompleted: Fast-track SQL Visualization\033[0m", show_sep="below")

for message_obj in self.interface.objects[::-1]:
if isinstance(message_obj.object, Card):
message_obj.object.collapsed = True
break

@retry_llm_output()
async def _fill_model(self, messages, system, agent_model, errors=None):
if errors:
Expand Down Expand Up @@ -682,10 +887,21 @@ class Planner(Coordinator):
"template": PROMPTS_DIR / "Planner" / "main.jinja2",
"response_model": make_plan_models,
"tools": [TableLookup, IterativeTableLookup]
# QueryLookup will be added if defined_queries is populated
},
}
)

def __init__(self, **params):
super().__init__(**params)
# Dynamically add QueryLookup to tools if defined_queries is populated
tools = self.prompts['main'].get("tools", [])
if self.defined_queries and 'QueryLookup' not in [t.__name__ for t in tools]:
self._memory["defined_queries"] = self.defined_queries
tools.append(QueryLookup)
if "tools" not in self.prompts["main"]:
self.prompts["main"]["tools"] = tools

async def _make_plan(
self,
messages: list[Message],
Expand Down
Loading
Loading