Skip to content

Commit 4a6f725

Browse files
authored
Merge pull request #158 from mindsdb/extract-query-from-text2sql-agent
Extract query from text2sql agent
2 parents d587a0f + 341bc38 commit 4a6f725

File tree

4 files changed

+208
-18
lines changed

4 files changed

+208
-18
lines changed

examples/using_agents_with_text2sql.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@
55
con = mindsdb_sdk.connect()
66

77
open_ai_key = os.getenv('OPENAI_API_KEY')
8-
model_name = 'gpt-4'
8+
model_name = 'gpt-4o'
99

1010
# Now create an agent that will use the model we just created.
1111
agent = con.agents.create(name=f'mindsdb_sql_agent_{model_name}_{uuid4().hex}',
12-
model='gpt-4')
12+
model=model_name)
1313

1414

1515
# Set up a Postgres data source with our new agent.
Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,19 @@
1+
import logging
2+
13
import mindsdb_sdk
24
from uuid import uuid4
35
import os
46

7+
from mindsdb_sdk.utils.agents import MindsDBSQLStreamParser
8+
59
con = mindsdb_sdk.connect()
610

711
open_ai_key = os.getenv('OPENAI_API_KEY')
8-
model_name = 'gpt-4'
12+
model_name = 'gpt-4o'
913

1014
# Now create an agent that will use the model we just created.
1115
agent = con.agents.create(name=f'mindsdb_sql_agent_{model_name}_{uuid4().hex}',
12-
model='gpt-4')
13-
16+
model=model_name)
1417

1518
# Set up a Postgres data source with our new agent.
1619
data_source = 'postgres'
@@ -32,23 +35,13 @@
3235
# Actually connect the agent to the datasource.
3336
agent.add_database(database.name, [], description)
3437

35-
3638
question = 'How many three-bedroom houses were sold in 2008?'
3739

3840
completion_stream = agent.completion_stream([{'question': question, 'answer': None}])
3941

40-
# Process the streaming response
41-
full_response = ""
42-
for chunk in completion_stream:
43-
print(chunk) # Print the entire chunk for debugging
44-
if isinstance(chunk, dict):
45-
if 'output' in chunk:
46-
full_response += chunk['output']
47-
elif isinstance(chunk, str):
48-
full_response += chunk
49-
50-
print("\n\nFull response:")
51-
print(full_response)
42+
#default logging level is set to INFO, we can change it to DEBUG to see more detailed logs and get full agent steps
43+
mdb_parser = MindsDBSQLStreamParser()
44+
full_response, sql_query = mdb_parser.process_stream(completion_stream)
5245

5346
con.databases.drop(database.name)
5447
con.agents.drop(agent.name)

mindsdb_sdk/utils/agents.py

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
import re
2+
import json
3+
import logging
4+
from typing import Dict, Any, Generator, Optional, Tuple
5+
6+
7+
class MindsDBSQLStreamParser:
8+
"""
9+
A utility class for parsing SQL queries from MindsDB completion streams.
10+
11+
This class provides methods to process completion streams, extract SQL queries,
12+
and accumulate full responses.
13+
14+
Attributes:
15+
logger (logging.Logger): The logger instance for this class.
16+
"""
17+
18+
def __init__(self, log_level: int = logging.INFO):
19+
"""
20+
Initialize the MindsDBSQLStreamParser.
21+
22+
Args:
23+
log_level (int, optional): The logging level to use. Defaults to logging.INFO.
24+
"""
25+
self.logger = logging.getLogger(__name__)
26+
self.logger.setLevel(log_level)
27+
28+
# Create a console handler and set its level
29+
ch = logging.StreamHandler()
30+
ch.setLevel(log_level)
31+
32+
# Create a formatter
33+
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
34+
35+
# Add the formatter to the handler
36+
ch.setFormatter(formatter)
37+
38+
# Add the handler to the logger
39+
self.logger.addHandler(ch)
40+
41+
def stream_and_parse_sql_query(self, completion_stream: Generator[Dict[str, Any], None, None]) -> Generator[
42+
Dict[str, Optional[str]], None, None]:
43+
"""
44+
Stream and parse the completion stream, yielding output and SQL queries.
45+
46+
This generator function processes each chunk of the completion stream,
47+
extracts any output and SQL queries, and yields the results.
48+
49+
Args:
50+
completion_stream (Generator[Dict[str, Any], None, None]): The input completion stream.
51+
52+
Yields:
53+
Dict[str, Optional[str]]: A dictionary containing 'output' and 'sql_query' keys.
54+
- 'output': The extracted output string from the chunk, if any.
55+
- 'sql_query': The extracted SQL query string, if found in the chunk.
56+
57+
Note:
58+
This function will only yield the first SQL query it finds in the stream.
59+
"""
60+
sql_query_found = False
61+
62+
for chunk in completion_stream:
63+
output = ""
64+
sql_query = None
65+
66+
# Log full chunk at DEBUG level
67+
self.logger.debug(f"Processing chunk: {json.dumps(chunk, indent=2)}")
68+
69+
# Log important info at INFO level
70+
if isinstance(chunk, dict):
71+
if 'quick_response' in chunk:
72+
self.logger.info(f"Quick response received: {json.dumps(chunk)}")
73+
74+
output = chunk.get('output', '')
75+
if output:
76+
self.logger.info(f"Chunk output: {output}")
77+
78+
if 'messages' in chunk:
79+
for message in chunk['messages']:
80+
if message.get('role') == 'assistant':
81+
self.logger.info(f"Assistant message: {message.get('content', '')}")
82+
if chunk.get('type') == 'sql':
83+
sql_query = chunk['content']
84+
self.logger.info(f"Generated SQL: {sql_query}")
85+
86+
elif isinstance(chunk, str):
87+
output = chunk
88+
self.logger.info(f"String chunk received: {chunk}")
89+
90+
yield {
91+
'output':output,
92+
'sql_query':sql_query
93+
}
94+
95+
def process_stream(self, completion_stream: Generator[Dict[str, Any], None, None]) -> Tuple[str, Optional[str]]:
96+
"""
97+
Process the completion stream and extract the SQL query.
98+
99+
This method iterates through the stream, accumulates the full response,
100+
logs outputs, and extracts the SQL query when found.
101+
102+
Args:
103+
completion_stream (Generator[Dict[str, Any], None, None]): The input completion stream.
104+
105+
Returns:
106+
Tuple[str, Optional[str]]: A tuple containing:
107+
- The full accumulated response as a string.
108+
- The extracted SQL query as a string, or None if no query was found.
109+
"""
110+
full_response = ""
111+
sql_query = None
112+
113+
self.logger.info("Starting to process completion stream...")
114+
115+
for result in self.stream_and_parse_sql_query(completion_stream):
116+
if result['output']:
117+
self.logger.info(f"Output: {result['output']}")
118+
full_response += result['output']
119+
120+
if result['sql_query'] and sql_query is None:
121+
sql_query = result['sql_query']
122+
self.logger.info(f"Extracted SQL Query: {sql_query}")
123+
124+
self.logger.info(f"Full Response: {full_response}")
125+
self.logger.info(f"Final SQL Query: {sql_query}")
126+
127+
return full_response, sql_query

tests/test_agent_stream_process.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import pytest
2+
import logging
3+
4+
from mindsdb_sdk.utils.agents import MindsDBSQLStreamParser
5+
6+
@pytest.fixture
7+
def parser():
8+
return MindsDBSQLStreamParser(log_level=logging.INFO)
9+
10+
def test_initialization(parser):
11+
assert isinstance(parser, MindsDBSQLStreamParser)
12+
assert parser.logger.level == logging.INFO
13+
14+
def test_stream_and_parse_sql_query_with_dict(parser):
15+
mock_stream = [
16+
{'output': 'Test output', 'type': 'text'},
17+
{'type': 'sql', 'content': 'SELECT * FROM table'},
18+
{'output': 'More output'}
19+
]
20+
21+
generator = parser.stream_and_parse_sql_query(iter(mock_stream))
22+
results = list(generator)
23+
24+
assert len(results) == 3
25+
assert results[0] == {'output': 'Test output', 'sql_query': None}
26+
assert results[1] == {'output': '', 'sql_query': 'SELECT * FROM table'}
27+
assert results[2] == {'output': 'More output', 'sql_query': None}
28+
29+
def test_stream_and_parse_sql_query_with_string(parser):
30+
mock_stream = ['String chunk 1', 'String chunk 2']
31+
32+
generator = parser.stream_and_parse_sql_query(iter(mock_stream))
33+
results = list(generator)
34+
35+
assert len(results) == 2
36+
assert results[0] == {'output': 'String chunk 1', 'sql_query': None}
37+
assert results[1] == {'output': 'String chunk 2', 'sql_query': None}
38+
39+
40+
def test_process_stream(parser, caplog):
41+
mock_stream = [
42+
{'output':'First output'},
43+
{'type':'sql', 'content':'SELECT * FROM users'},
44+
{'output':'Second output'}
45+
]
46+
47+
with caplog.at_level(logging.INFO):
48+
full_response, sql_query = parser.process_stream(iter(mock_stream))
49+
50+
assert full_response == 'First outputSecond output'
51+
assert sql_query == 'SELECT * FROM users'
52+
53+
# Check for specific log messages
54+
assert 'Starting to process completion stream...' in caplog.text
55+
assert 'Output: First output' in caplog.text
56+
assert 'Extracted SQL Query: SELECT * FROM users' in caplog.text
57+
assert 'Output: Second output' in caplog.text
58+
assert f'Full Response: {full_response}' in caplog.text
59+
assert f'Final SQL Query: {sql_query}' in caplog.text
60+
61+
def test_process_stream_no_sql(parser):
62+
mock_stream = [
63+
{'output': 'First output'},
64+
{'output': 'Second output'}
65+
]
66+
67+
full_response, sql_query = parser.process_stream(iter(mock_stream))
68+
69+
assert full_response == 'First outputSecond output'
70+
assert sql_query is None

0 commit comments

Comments
 (0)