|
| 1 | +import pytest |
| 2 | +import logging |
| 3 | + |
| 4 | +from mindsdb_sdk.utils.agents import MindsDBSQLStreamParser |
| 5 | + |
| 6 | +@pytest.fixture |
| 7 | +def parser(): |
| 8 | + return MindsDBSQLStreamParser(log_level=logging.INFO) |
| 9 | + |
| 10 | +def test_initialization(parser): |
| 11 | + assert isinstance(parser, MindsDBSQLStreamParser) |
| 12 | + assert parser.logger.level == logging.INFO |
| 13 | + |
| 14 | +def test_stream_and_parse_sql_query_with_dict(parser): |
| 15 | + mock_stream = [ |
| 16 | + {'output': 'Test output', 'type': 'text'}, |
| 17 | + {'type': 'sql', 'content': 'SELECT * FROM table'}, |
| 18 | + {'output': 'More output'} |
| 19 | + ] |
| 20 | + |
| 21 | + generator = parser.stream_and_parse_sql_query(iter(mock_stream)) |
| 22 | + results = list(generator) |
| 23 | + |
| 24 | + assert len(results) == 3 |
| 25 | + assert results[0] == {'output': 'Test output', 'sql_query': None} |
| 26 | + assert results[1] == {'output': '', 'sql_query': 'SELECT * FROM table'} |
| 27 | + assert results[2] == {'output': 'More output', 'sql_query': None} |
| 28 | + |
| 29 | +def test_stream_and_parse_sql_query_with_string(parser): |
| 30 | + mock_stream = ['String chunk 1', 'String chunk 2'] |
| 31 | + |
| 32 | + generator = parser.stream_and_parse_sql_query(iter(mock_stream)) |
| 33 | + results = list(generator) |
| 34 | + |
| 35 | + assert len(results) == 2 |
| 36 | + assert results[0] == {'output': 'String chunk 1', 'sql_query': None} |
| 37 | + assert results[1] == {'output': 'String chunk 2', 'sql_query': None} |
| 38 | + |
| 39 | + |
| 40 | +def test_process_stream(parser, caplog): |
| 41 | + mock_stream = [ |
| 42 | + {'output':'First output'}, |
| 43 | + {'type':'sql', 'content':'SELECT * FROM users'}, |
| 44 | + {'output':'Second output'} |
| 45 | + ] |
| 46 | + |
| 47 | + with caplog.at_level(logging.INFO): |
| 48 | + full_response, sql_query = parser.process_stream(iter(mock_stream)) |
| 49 | + |
| 50 | + assert full_response == 'First outputSecond output' |
| 51 | + assert sql_query == 'SELECT * FROM users' |
| 52 | + |
| 53 | + # Check for specific log messages |
| 54 | + assert 'Starting to process completion stream...' in caplog.text |
| 55 | + assert 'Output: First output' in caplog.text |
| 56 | + assert 'Extracted SQL Query: SELECT * FROM users' in caplog.text |
| 57 | + assert 'Output: Second output' in caplog.text |
| 58 | + assert f'Full Response: {full_response}' in caplog.text |
| 59 | + assert f'Final SQL Query: {sql_query}' in caplog.text |
| 60 | + |
| 61 | +def test_process_stream_no_sql(parser): |
| 62 | + mock_stream = [ |
| 63 | + {'output': 'First output'}, |
| 64 | + {'output': 'Second output'} |
| 65 | + ] |
| 66 | + |
| 67 | + full_response, sql_query = parser.process_stream(iter(mock_stream)) |
| 68 | + |
| 69 | + assert full_response == 'First outputSecond output' |
| 70 | + assert sql_query is None |
0 commit comments