Skip to content

Commit c3d713c

Browse files
authored
Merge pull request #137 from JetBrains/kosst/prompt
Prompt changes and fix of notebook
2 parents 04e9050 + 1455e62 commit c3d713c

File tree

3 files changed

+663
-630
lines changed

3 files changed

+663
-630
lines changed

databao/executors/lighthouse/graph.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def get_query_ids_mapping(messages: list[BaseMessage]) -> dict[str, ToolMessage]
4040

4141

4242
class ExecuteSubmit:
43-
"""Simple graph with two tools: run_sql_query and submit_query_id.
43+
"""Simple graph with two tools: run_sql_query and submit_result.
4444
All context must be in the SystemMessage."""
4545

4646
MAX_TOOL_ROWS = 12
@@ -69,7 +69,7 @@ def get_result(self, state: AgentState) -> ExecutionResult:
6969
if last_ai_message is None:
7070
raise RuntimeError("No AI message found in message log")
7171
if len(last_ai_message.tool_calls) == 0:
72-
# Sometimes models don't call the submit_query_id tool, but we still want to return some dataframe.
72+
# Sometimes models don't call the submit_result tool, but we still want to return some dataframe.
7373
sql = state.get("sql", "")
7474
df = state.get("df") # Latest df result (usually from run_sql_query)
7575
visualization_prompt = state.get("visualization_prompt")
@@ -85,9 +85,9 @@ def get_result(self, state: AgentState) -> ExecutionResult:
8585
)
8686
elif len(last_ai_message.tool_calls) > 1:
8787
raise RuntimeError("Expected exactly one tool call in AI message")
88-
elif last_ai_message.tool_calls[0]["name"] != "submit_query_id":
88+
elif last_ai_message.tool_calls[0]["name"] != "submit_result":
8989
raise RuntimeError(
90-
f"Expected submit_query_id tool call in AI message, got {last_ai_message.tool_calls[0]['name']}"
90+
f"Expected submit_result tool call in AI message, got {last_ai_message.tool_calls[0]['name']}"
9191
)
9292
else:
9393
sql = state.get("sql", "")
@@ -130,7 +130,7 @@ def run_sql_query(sql: str, graph_state: Annotated[AgentState, InjectedState]) -
130130
return {"error": exception_to_string(e)}
131131

132132
@tool(parse_docstring=True)
133-
def submit_query_id(
133+
def submit_result(
134134
query_id: str,
135135
result_description: str,
136136
visualization_prompt: str,
@@ -149,7 +149,7 @@ def submit_query_id(
149149
"""
150150
return f"Query {query_id} submitted successfully. Your response is now visible to the user."
151151

152-
tools = [run_sql_query, submit_query_id]
152+
tools = [run_sql_query, submit_result]
153153
return tools
154154

155155
def compile(self, model_config: LLMConfig) -> CompiledStateGraph[Any]:
@@ -170,11 +170,11 @@ def tool_executor_node(state: AgentState) -> dict[str, Any]:
170170

171171
tool_calls = last_message.tool_calls
172172

173-
is_ready_for_user = any(tc["name"] == "submit_query_id" for tc in tool_calls)
173+
is_ready_for_user = any(tc["name"] == "submit_result" for tc in tool_calls)
174174
if is_ready_for_user:
175175
if len(tool_calls) > 1:
176176
tool_messages = [
177-
ToolMessage("submit_query_id must be the only tool call.", tool_call_id=tool_call["id"])
177+
ToolMessage("submit_result must be the only tool call.", tool_call_id=tool_call["id"])
178178
for tool_call in tool_calls
179179
]
180180
return {"messages": tool_messages, "ready_for_user": False}
@@ -244,14 +244,14 @@ def tool_executor_node(state: AgentState) -> dict[str, Any]:
244244
tool_call_id=tool_call_id,
245245
artifact=result,
246246
)
247-
elif name == "submit_query_id":
247+
elif name == "submit_result":
248248
content = str(result)
249249
query_id = tool_call["args"]["query_id"]
250250
visualization_prompt = tool_call["args"].get("visualization_prompt", "")
251251
sql = state["query_ids"][query_id].artifact["sql"]
252252
df = state["query_ids"][query_id].artifact["df"]
253253
tool_messages.append(ToolMessage(content=content, tool_call_id=tool_call_id, artifact=result))
254-
if name == "submit_query_id":
254+
if name == "submit_result":
255255
return {
256256
"messages": tool_messages,
257257
"sql": sql,
@@ -276,7 +276,7 @@ def should_continue(state: AgentState) -> Literal["tool_executor", "end"]:
276276
return "end"
277277

278278
def should_finish(state: AgentState) -> Literal["llm_node", "end"]:
279-
# Check if we just executed submit_query_id - if so, end the conversation
279+
# Check if we just executed submit_result - if so, end the conversation
280280
if state.get("ready_for_user", False):
281281
return "end"
282282
return "llm_node"

databao/executors/lighthouse/system_prompt.jinja

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
1-
You are an agent that has direct access to the database. You generate SQL requests, which are executed on a DB client with no changes.
1+
You are a "Databao" agent that has direct access to the database. You generate SQL requests, which are executed on a DB client with no changes.
2+
User can connect several databases and DataFrames to your internal DuckDB instance. DataFrames are available as tables with "temp.main" prefix.
23
The task is to request all necessary data and answer the user question.
4+
You can answer with
5+
- text (using plain text with no tool or result_description parameter of submit_result tool)
6+
- a table (using SQL requests and query_id parameter of submit_result tool). It will be visible as a DataFrame.
7+
- a plot (using visualization parameter of submit_result tool)
8+
or a combination of these.
39

410
Today's date is: {{ date }} (YYYY-MM-DD).
511

@@ -11,14 +17,13 @@ Today's date is: {{ date }} (YYYY-MM-DD).
1117
- Get DB schema in the 'Database schema' section. Don't waste tool call for it.
1218
- Pay attention to SQL dialect specific commands (DuckDB is used)
1319
- Cross joins are allowed only for tables that are guaranteed small (< 5 rows), such as enums or static dictionaries.
14-
- Use 'today()' instead of 'now()' to get current date
1520
- When calculating percentages like (a - b) / a * 100, you must make multiplication first to prevent number rounding. Use 100 * (a - b) / a.
1621
- When comparing an unfinished period like the current year to a finished one like last year, use the same date range. Never compare unfinished periods to finished one.
17-
- Make sure the submitted query answers the user's question and it is not-empty
18-
- Result description of submitted query should contain definitions being used, important decisions and analysis of resulting data
22+
- Make sure the submitted result answers the user's question and it is not-empty
23+
- Result description of submitted result should contain definitions being used, important decisions and analysis of resulting data
1924
- Leave visualization prompt empty if you don't want to visualize the result. Table with few values or table with heterogeneous data don't need visualization
2025
- Time series require visualization
21-
- The user will see only the submitted result of submit_query_id. The user will not see intermediate results
26+
- The user will see only the submitted result - final SQL and DataFrame. The user will not see intermediate results
2227

2328

2429
# Database schema

0 commit comments

Comments
 (0)