|
2 | 2 | from unittest.mock import patch |
3 | 3 |
|
4 | 4 | from langchain_core.runnables import RunnableConfig, RunnableLambda |
| 5 | +from parameterized import parameterized |
5 | 6 |
|
6 | | -from posthog.schema import ArtifactContentType, ArtifactSource, HumanMessage |
| 7 | +from posthog.schema import ArtifactContentType, ArtifactSource, AssistantToolCallMessage, FailureMessage, HumanMessage |
7 | 8 |
|
8 | 9 | from products.posthog_ai.backend.models.assistant import Conversation |
9 | 10 |
|
| 11 | +from ee.hogai.chat_agent.schema_generator.nodes import SchemaGenerationException |
10 | 12 | from ee.hogai.chat_agent.sql.nodes import SQLGeneratorNode |
11 | 13 | from ee.hogai.utils.types import AssistantState |
12 | 14 | from ee.hogai.utils.types.base import ArtifactRefMessage |
13 | 15 |
|
14 | 16 |
|
15 | 17 | class TestSQLGeneratorNode(NonAtomicBaseTest): |
16 | 18 | maxDiff = None |
| 19 | + # NonAtomicBaseTest truncates all tables (RESTART IDENTITY) after each test, so class-level |
| 20 | + # test data created once in setUpClass is gone by the second test. Recreate it per test. |
| 21 | + CLASS_DATA_LEVEL_SETUP = False |
17 | 22 |
|
18 | 23 | def setUp(self): |
19 | 24 | super().setUp() |
@@ -51,3 +56,43 @@ async def test_node_runs(self): |
51 | 56 | self.assertIsNone(new_state.intermediate_steps) |
52 | 57 | self.assertIsNone(new_state.plan) |
53 | 58 | self.assertIsNone(new_state.rag_context) |
| 59 | + |
| 60 | + @parameterized.expand( |
| 61 | + [ |
| 62 | + ("with_tool_call", "tool_123", AssistantToolCallMessage), |
| 63 | + ("without_tool_call", None, FailureMessage), |
| 64 | + ] |
| 65 | + ) |
| 66 | + async def test_node_handles_retry_exhaustion_gracefully(self, _name, root_tool_call_id, expected_message_type): |
| 67 | + node = SQLGeneratorNode(self.team, self.user) |
| 68 | + config = RunnableConfig(configurable={"thread_id": str(self.conversation.id)}) |
| 69 | + |
| 70 | + async def _raise(*args, **kwargs): |
| 71 | + raise SchemaGenerationException( |
| 72 | + "WITH date_end AS toDate(now()) SELECT 1", |
| 73 | + "HogQL parsing error: this query isn't valid HogQL.", |
| 74 | + ) |
| 75 | + |
| 76 | + with patch("ee.hogai.chat_agent.schema_generator.nodes.SchemaGeneratorNode._run_with_prompt", new=_raise): |
| 77 | + new_state = await node( |
| 78 | + AssistantState( |
| 79 | + messages=[HumanMessage(content="Text")], |
| 80 | + plan="Plan", |
| 81 | + root_tool_call_id=root_tool_call_id, |
| 82 | + root_tool_insight_plan="question", |
| 83 | + ), |
| 84 | + config, |
| 85 | + ) |
| 86 | + |
| 87 | + assert new_state is not None |
| 88 | + self.assertEqual(len(new_state.messages), 1) |
| 89 | + msg = new_state.messages[0] |
| 90 | + self.assertIsInstance(msg, expected_message_type) |
| 91 | + assert isinstance(msg, AssistantToolCallMessage | FailureMessage) |
| 92 | + assert msg.content is not None |
| 93 | + self.assertIn("valid SQL query", msg.content) |
| 94 | + if isinstance(msg, AssistantToolCallMessage): |
| 95 | + self.assertEqual(msg.tool_call_id, root_tool_call_id) |
| 96 | + # Node ends gracefully and clears the tool call so the run terminates |
| 97 | + self.assertIsNone(new_state.root_tool_call_id) |
| 98 | + self.assertIsNone(new_state.intermediate_steps) |
0 commit comments