Skip to content

Commit c159671

Browse files
authored
fix(Agent): adding test to improve coverage (#1528)
* fix(test): fixing tests for when PANDABI_API_KEY is present in .env * fix(Agent): fixing chatting with multiple local dataframes * fix(Agent): adding test to improve coverage
1 parent 2d54457 commit c159671

File tree

3 files changed

+72
-7
lines changed

3 files changed

+72
-7
lines changed

pandasai/agent/base.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,7 @@ def execute_code(self, code: str) -> dict:
106106
"""Execute the generated code."""
107107
self._state.logger.log(f"Executing code: {code}")
108108
code_executor = CodeExecutor(self._state.config)
109-
code_executor.add_to_env("execute_sql_query", self.execute_sql_query)
110-
109+
code_executor.add_to_env("execute_sql_query", self._execute_sql_query)
111110
return code_executor.execute_and_return_result(code)
112111

113112
def _execute_local_sql_query(self, query: str) -> pd.DataFrame:
@@ -125,7 +124,7 @@ def _execute_local_sql_query(self, query: str) -> pd.DataFrame:
125124
except duckdb.Error as e:
126125
raise RuntimeError(f"SQL execution failed: {e}") from e
127126

128-
def execute_sql_query(self, query: str) -> pd.DataFrame:
127+
def _execute_sql_query(self, query: str) -> pd.DataFrame:
129128
"""
130129
Executes an SQL query on registered DataFrames.
131130

tests/unit_tests/agent/test_agent.py

+70-1
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
import os
22
from typing import Optional
3-
from unittest.mock import MagicMock, Mock, patch
3+
from unittest.mock import MagicMock, Mock, mock_open, patch
44

55
import pandas as pd
66
import pytest
77

8+
from pandasai import DatasetLoader, VirtualDataFrame
89
from pandasai.agent.base import Agent
910
from pandasai.config import Config, ConfigManager
11+
from pandasai.data_loader.semantic_layer_schema import SemanticLayerSchema
1012
from pandasai.dataframe.base import DataFrame
1113
from pandasai.exceptions import CodeExecutionError
1214
from pandasai.llm.fake import FakeLLM
@@ -15,6 +17,24 @@
1517
class TestAgent:
1618
"Unit tests for Agent class"
1719

20+
@pytest.fixture
21+
def mysql_schema(self):
22+
raw_schema = {
23+
"name": "countries",
24+
"source": {
25+
"type": "mysql",
26+
"connection": {
27+
"host": "localhost",
28+
"port": 3306,
29+
"database": "test_db",
30+
"user": "test_user",
31+
"password": "test_password",
32+
},
33+
"table": "countries",
34+
},
35+
}
36+
return SemanticLayerSchema(**raw_schema)
37+
1838
@pytest.fixture
1939
def sample_df(self) -> DataFrame:
2040
return DataFrame(
@@ -429,3 +449,52 @@ def test_train_method_with_code_but_no_queries(self, agent):
429449
codes = ["code1", "code2"]
430450
with pytest.raises(ValueError):
431451
agent.train(codes)
452+
453+
def test_execute_local_sql_query_success(self, agent):
454+
query = "SELECT count(*) as total from countries;"
455+
expected_result = pd.DataFrame({"total": [4]})
456+
result = agent._execute_local_sql_query(query)
457+
pd.testing.assert_frame_equal(result, expected_result)
458+
459+
def test_execute_local_sql_query_failure(self, agent):
460+
with pytest.raises(RuntimeError, match="SQL execution failed"):
461+
agent._execute_local_sql_query("wrong query;")
462+
463+
def test_execute_sql_query_success_local(self, agent):
464+
query = "SELECT count(*) as total from countries;"
465+
expected_result = pd.DataFrame({"total": [4]})
466+
result = agent._execute_sql_query(query)
467+
pd.testing.assert_frame_equal(result, expected_result)
468+
469+
@patch("os.path.exists", return_value=True)
470+
def test_execute_sql_query_success_virtual_dataframe(
471+
self, mock_exists, agent, mysql_schema, sample_df
472+
):
473+
query = "SELECT count(*) as total from countries;"
474+
loader = DatasetLoader()
475+
expected_result = pd.DataFrame({"total": [4]})
476+
477+
with patch(
478+
"builtins.open", mock_open(read_data=str(mysql_schema.to_yaml()))
479+
), patch(
480+
"pandasai.data_loader.loader.DatasetLoader.execute_query"
481+
) as mock_query:
482+
# Set up the mock for both the sample data and the query result
483+
mock_query.side_effect = [sample_df, expected_result]
484+
485+
virtual_dataframe = loader.load("test/users")
486+
agent._state.dfs = [virtual_dataframe]
487+
488+
pd.testing.assert_frame_equal(virtual_dataframe.head(), sample_df)
489+
result = agent._execute_sql_query(query)
490+
pd.testing.assert_frame_equal(result, expected_result)
491+
492+
# Verify execute_query was called appropriately
493+
assert mock_query.call_count == 2 # Once for head(), once for the SQL query
494+
495+
def test_execute_sql_query_error_no_dataframe(self, agent):
496+
query = "SELECT count(*) as total from countries;"
497+
agent._state.dfs = None
498+
499+
with pytest.raises(ValueError, match="No DataFrames available"):
500+
agent._execute_sql_query(query)

tests/unit_tests/dataframe/test_loader.py

-3
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,8 @@
11
import logging
2-
import sys
3-
from datetime import datetime, timedelta
42
from unittest.mock import mock_open, patch
53

64
import pandas as pd
75
import pytest
8-
import yaml
96

107
from pandasai.data_loader.loader import DatasetLoader
118
from pandasai.data_loader.semantic_layer_schema import SemanticLayerSchema

0 commit comments

Comments
 (0)