Skip to content

Commit 1b0e333

Browse files
authored
fix(max-ai): handle SQL generation retry exhaustion gracefully (#64818)
Co-authored-by: posthog[bot] <206114724+posthog[bot]@users.noreply.github.com>
1 parent 8a757c2 commit 1b0e333

5 files changed

Lines changed: 96 additions & 7 deletions

File tree

ee/hogai/chat_agent/query_executor/nodes.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,9 @@ async def _extract_artifact(self, state: AssistantState) -> ArtifactMessage | No
5959
if isinstance(last_message, FailureMessage):
6060
return None # Exit early - something failed earlier
6161

62+
if isinstance(last_message, AssistantToolCallMessage):
63+
return None # Exit early - a generator already produced a terminal tool response (e.g. graceful failure)
64+
6265
if not isinstance(last_message, ArtifactRefMessage):
6366
raise ValueError(f"Expected an ArtifactRefMessage, found {type(last_message)}")
6467

ee/hogai/chat_agent/sql/mixins.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -170,10 +170,16 @@ def _validate_hogql_query_sync(self, query: str) -> AssistantHogQLQuery:
170170
err_msg = str(err)
171171
# Both the antlr-based cpp parser and the hand-rolled rust-py parser produce
172172
# terse low-level error wording on syntax failures ("no viable alternative…",
173-
# "trailing tokens after expression…", "unexpected token in expression…").
174-
# Replace any of them with a single human/LLM-friendly message.
173+
# "trailing tokens after expression…", "unexpected token in expression…",
174+
# "mismatched input … expecting …"). Replace any of them with a single
175+
# human/LLM-friendly message.
175176
if err_msg.startswith(
176-
("no viable alternative", "trailing tokens after expression", "unexpected token in expression")
177+
(
178+
"no viable alternative",
179+
"trailing tokens after expression",
180+
"unexpected token in expression",
181+
"mismatched input",
182+
)
177183
):
178184
err_msg = "HogQL parsing error: this query isn't valid HogQL."
179185
raise PydanticOutputParserException(llm_output=cleaned_query, validation_message=err_msg)

ee/hogai/chat_agent/sql/nodes.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
1+
from uuid import uuid4
2+
13
from langchain_core.runnables import RunnableConfig
24

3-
from posthog.schema import DataVisualizationNode
5+
from posthog.schema import AssistantToolCallMessage, DataVisualizationNode, FailureMessage
46

57
from posthog.hogql.context import HogQLContext
68

79
from ee.hogai.utils.types import AssistantState, PartialAssistantState
810

9-
from ..schema_generator.nodes import SchemaGeneratorNode, SchemaGeneratorToolsNode
11+
from ..schema_generator.nodes import SchemaGenerationException, SchemaGeneratorNode, SchemaGeneratorToolsNode
1012
from .mixins import HogQLGeneratorMixin, SQLSchemaGeneratorOutput
13+
from .prompts import SQL_GENERATION_FAILURE_MESSAGE
1114
from .toolkit import SQL_SCHEMA
1215

1316

@@ -20,7 +23,35 @@ class SQLGeneratorNode(HogQLGeneratorMixin, SchemaGeneratorNode[DataVisualizatio
2023

2124
async def arun(self, state: AssistantState, config: RunnableConfig) -> PartialAssistantState:
2225
prompt = await self._construct_system_prompt()
23-
return await super()._run_with_prompt(state, prompt, config=config)
26+
try:
27+
return await super()._run_with_prompt(state, prompt, config=config)
28+
except SchemaGenerationException as e:
29+
# The LLM exhausted its retries on invalid HogQL. Surface this as a graceful tool
30+
# response so the calling agent can recover, instead of letting it bubble up to the
31+
# runner's generic handler and be captured as an unhandled application error.
32+
return self._handle_generation_failure(state, e)
33+
34+
def _handle_generation_failure(
35+
self, state: AssistantState, error: SchemaGenerationException
36+
) -> PartialAssistantState:
37+
tool_call_id = state.root_tool_call_id
38+
content = SQL_GENERATION_FAILURE_MESSAGE.format(error_message=error.validation_message)
39+
# Respond to the calling agent when there's a tool call to answer; otherwise emit a
40+
# FailureMessage so the run still terminates cleanly via the query executor.
41+
message = (
42+
AssistantToolCallMessage(content=content, id=str(uuid4()), tool_call_id=tool_call_id)
43+
if tool_call_id
44+
else FailureMessage(content=content, id=str(uuid4()))
45+
)
46+
return PartialAssistantState(
47+
messages=[message],
48+
intermediate_steps=None,
49+
plan=None,
50+
rag_context=None,
51+
root_tool_call_id=None,
52+
root_tool_insight_plan=None,
53+
root_tool_insight_type=None,
54+
)
2455

2556

2657
class SQLGeneratorToolsNode(SchemaGeneratorToolsNode):

ee/hogai/chat_agent/sql/prompts.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
SQL_GENERATION_FAILURE_MESSAGE = (
2+
"I wasn't able to generate a valid SQL query for this request after several attempts. Error: {error_message}"
3+
)
4+
15
HOGQL_GENERATOR_SYSTEM_PROMPT = """
26
You are an expert in writing HogQL. HogQL is PostHog's variant of SQL that supports most of ClickHouse SQL. We're going to use terms "HogQL" and "SQL" interchangeably.
37
You write HogQL based on a prompt. You don't help with other knowledge. You are provided with the current HogQL query that the user is editing. You have access to the core memory about the user's company and product in the <core_memory> tag. Use this memory in your responses.

ee/hogai/chat_agent/sql/test/test_nodes.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,23 @@
22
from unittest.mock import patch
33

44
from langchain_core.runnables import RunnableConfig, RunnableLambda
5+
from parameterized import parameterized
56

6-
from posthog.schema import ArtifactContentType, ArtifactSource, HumanMessage
7+
from posthog.schema import ArtifactContentType, ArtifactSource, AssistantToolCallMessage, FailureMessage, HumanMessage
78

89
from products.posthog_ai.backend.models.assistant import Conversation
910

11+
from ee.hogai.chat_agent.schema_generator.nodes import SchemaGenerationException
1012
from ee.hogai.chat_agent.sql.nodes import SQLGeneratorNode
1113
from ee.hogai.utils.types import AssistantState
1214
from ee.hogai.utils.types.base import ArtifactRefMessage
1315

1416

1517
class TestSQLGeneratorNode(NonAtomicBaseTest):
1618
maxDiff = None
19+
# NonAtomicBaseTest truncates all tables (RESTART IDENTITY) after each test, so class-level
20+
# test data created once in setUpClass is gone by the second test. Recreate it per test.
21+
CLASS_DATA_LEVEL_SETUP = False
1722

1823
def setUp(self):
1924
super().setUp()
@@ -51,3 +56,43 @@ async def test_node_runs(self):
5156
self.assertIsNone(new_state.intermediate_steps)
5257
self.assertIsNone(new_state.plan)
5358
self.assertIsNone(new_state.rag_context)
59+
60+
@parameterized.expand(
61+
[
62+
("with_tool_call", "tool_123", AssistantToolCallMessage),
63+
("without_tool_call", None, FailureMessage),
64+
]
65+
)
66+
async def test_node_handles_retry_exhaustion_gracefully(self, _name, root_tool_call_id, expected_message_type):
67+
node = SQLGeneratorNode(self.team, self.user)
68+
config = RunnableConfig(configurable={"thread_id": str(self.conversation.id)})
69+
70+
async def _raise(*args, **kwargs):
71+
raise SchemaGenerationException(
72+
"WITH date_end AS toDate(now()) SELECT 1",
73+
"HogQL parsing error: this query isn't valid HogQL.",
74+
)
75+
76+
with patch("ee.hogai.chat_agent.schema_generator.nodes.SchemaGeneratorNode._run_with_prompt", new=_raise):
77+
new_state = await node(
78+
AssistantState(
79+
messages=[HumanMessage(content="Text")],
80+
plan="Plan",
81+
root_tool_call_id=root_tool_call_id,
82+
root_tool_insight_plan="question",
83+
),
84+
config,
85+
)
86+
87+
assert new_state is not None
88+
self.assertEqual(len(new_state.messages), 1)
89+
msg = new_state.messages[0]
90+
self.assertIsInstance(msg, expected_message_type)
91+
assert isinstance(msg, AssistantToolCallMessage | FailureMessage)
92+
assert msg.content is not None
93+
self.assertIn("valid SQL query", msg.content)
94+
if isinstance(msg, AssistantToolCallMessage):
95+
self.assertEqual(msg.tool_call_id, root_tool_call_id)
96+
# Node ends gracefully and clears the tool call so the run terminates
97+
self.assertIsNone(new_state.root_tool_call_id)
98+
self.assertIsNone(new_state.intermediate_steps)

0 commit comments

Comments
 (0)