|
1 | 1 | import os
|
2 | 2 | from typing import Optional
|
3 |
| -from unittest.mock import MagicMock, Mock, patch |
| 3 | +from unittest.mock import MagicMock, Mock, mock_open, patch |
4 | 4 |
|
5 | 5 | import pandas as pd
|
6 | 6 | import pytest
|
7 | 7 |
|
| 8 | +from pandasai import DatasetLoader, VirtualDataFrame |
8 | 9 | from pandasai.agent.base import Agent
|
9 | 10 | from pandasai.config import Config, ConfigManager
|
| 11 | +from pandasai.data_loader.semantic_layer_schema import SemanticLayerSchema |
10 | 12 | from pandasai.dataframe.base import DataFrame
|
11 | 13 | from pandasai.exceptions import CodeExecutionError
|
12 | 14 | from pandasai.llm.fake import FakeLLM
|
|
15 | 17 | class TestAgent:
|
16 | 18 | "Unit tests for Agent class"
|
17 | 19 |
|
| 20 | + @pytest.fixture |
| 21 | + def mysql_schema(self): |
| 22 | + raw_schema = { |
| 23 | + "name": "countries", |
| 24 | + "source": { |
| 25 | + "type": "mysql", |
| 26 | + "connection": { |
| 27 | + "host": "localhost", |
| 28 | + "port": 3306, |
| 29 | + "database": "test_db", |
| 30 | + "user": "test_user", |
| 31 | + "password": "test_password", |
| 32 | + }, |
| 33 | + "table": "countries", |
| 34 | + }, |
| 35 | + } |
| 36 | + return SemanticLayerSchema(**raw_schema) |
| 37 | + |
18 | 38 | @pytest.fixture
|
19 | 39 | def sample_df(self) -> DataFrame:
|
20 | 40 | return DataFrame(
|
@@ -429,3 +449,52 @@ def test_train_method_with_code_but_no_queries(self, agent):
|
429 | 449 | codes = ["code1", "code2"]
|
430 | 450 | with pytest.raises(ValueError):
|
431 | 451 | agent.train(codes)
|
| 452 | + |
| 453 | + def test_execute_local_sql_query_success(self, agent): |
| 454 | + query = "SELECT count(*) as total from countries;" |
| 455 | + expected_result = pd.DataFrame({"total": [4]}) |
| 456 | + result = agent._execute_local_sql_query(query) |
| 457 | + pd.testing.assert_frame_equal(result, expected_result) |
| 458 | + |
| 459 | + def test_execute_local_sql_query_failure(self, agent): |
| 460 | + with pytest.raises(RuntimeError, match="SQL execution failed"): |
| 461 | + agent._execute_local_sql_query("wrong query;") |
| 462 | + |
| 463 | + def test_execute_sql_query_success_local(self, agent): |
| 464 | + query = "SELECT count(*) as total from countries;" |
| 465 | + expected_result = pd.DataFrame({"total": [4]}) |
| 466 | + result = agent._execute_sql_query(query) |
| 467 | + pd.testing.assert_frame_equal(result, expected_result) |
| 468 | + |
| 469 | + @patch("os.path.exists", return_value=True) |
| 470 | + def test_execute_sql_query_success_virtual_dataframe( |
| 471 | + self, mock_exists, agent, mysql_schema, sample_df |
| 472 | + ): |
| 473 | + query = "SELECT count(*) as total from countries;" |
| 474 | + loader = DatasetLoader() |
| 475 | + expected_result = pd.DataFrame({"total": [4]}) |
| 476 | + |
| 477 | + with patch( |
| 478 | + "builtins.open", mock_open(read_data=str(mysql_schema.to_yaml())) |
| 479 | + ), patch( |
| 480 | + "pandasai.data_loader.loader.DatasetLoader.execute_query" |
| 481 | + ) as mock_query: |
| 482 | + # Set up the mock for both the sample data and the query result |
| 483 | + mock_query.side_effect = [sample_df, expected_result] |
| 484 | + |
| 485 | + virtual_dataframe = loader.load("test/users") |
| 486 | + agent._state.dfs = [virtual_dataframe] |
| 487 | + |
| 488 | + pd.testing.assert_frame_equal(virtual_dataframe.head(), sample_df) |
| 489 | + result = agent._execute_sql_query(query) |
| 490 | + pd.testing.assert_frame_equal(result, expected_result) |
| 491 | + |
| 492 | + # Verify execute_query was called appropriately |
| 493 | + assert mock_query.call_count == 2 # Once for head(), once for the SQL query |
| 494 | + |
| 495 | + def test_execute_sql_query_error_no_dataframe(self, agent): |
| 496 | + query = "SELECT count(*) as total from countries;" |
| 497 | + agent._state.dfs = None |
| 498 | + |
| 499 | + with pytest.raises(ValueError, match="No DataFrames available"): |
| 500 | + agent._execute_sql_query(query) |
0 commit comments