Skip to content

Commit f594212

Browse files
authored
Merge pull request #1233 from Sinaptik-AI/semantic_schema_timestamp
fix: error prompt generation and extra validation on schema generation
2 parents 2bbbb92 + 4f1e527 commit f594212

4 files changed

Lines changed: 12 additions & 8 deletions

File tree

pandasai/ee/agents/semantic_agent/pipeline/code_generator.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -79,14 +79,15 @@ def execute(self, input_data: Any, **kwargs) -> Any:
7979

8080
traceback_errors = traceback.format_exc()
8181

82-
input_data = self.on_failure(input, traceback_errors)
82+
input_data = self.on_failure(input_data, traceback_errors)
8383

8484
retry_count += 1
8585

8686
def _get_type(self, input: dict) -> bool:
8787
return (
8888
"plot"
89-
if input["type"] in ["bar", "line", "histogram", "pie", "scatter"]
89+
if input["type"]
90+
in ["bar", "line", "histogram", "pie", "scatter", "boxplot"]
9091
else input["type"]
9192
)
9293

@@ -99,7 +100,7 @@ def _generate_code(self, type, query):
99100
"""
100101
elif type == "dataframe":
101102
return """
102-
result = {{"type": "dataframe","value": data}}
103+
result = {"type": "dataframe","value": data}
103104
"""
104105
else:
105106
code = self.generate_matplotlib_code(query)
@@ -119,8 +120,8 @@ def _generate_code_for_number(self, query: dict) -> str:
119120

120121
def generate_matplotlib_code(self, query: dict) -> str:
121122
chart_type = query["type"]
122-
x_label = query["options"].get("xLabel", None)
123-
y_label = query["options"].get("yLabel", None)
123+
x_label = query.get("options", {}).get("xLabel", None)
124+
y_label = query.get("options", {}).get("yLabel", None)
124125
title = query["options"].get("title", None)
125126
legend_display = {"display": True}
126127
legend_position = "best"

pandasai/ee/agents/semantic_agent/pipeline/error_correction_pipeline/error_correction_pipeline.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def __init__(
5858
on_code_generation=on_code_generation,
5959
on_prompt_generation=on_prompt_generation,
6060
)
61+
self.query_exec_tracker = query_exec_tracker
6162

6263
self._context = context
6364
self._logger = logger

pandasai/ee/agents/semantic_agent/prompts/generate_df_schema.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,16 @@ def __init__(self, **kwargs):
3131
def validate(self, output: str) -> bool:
3232
try:
3333
json_data = json.loads(output.replace("# SAMPLE SCHEMA", ""))
34+
context = self.props["context"]
3435
if isinstance(json_data, dict):
3536
json_data = [json_data]
3637
if isinstance(json_data, list):
3738
for record in json_data:
3839
if not all(key in record for key in ("name", "table")):
3940
return False
40-
return True
41+
42+
return len(context.dfs) == len(json_data)
43+
4144
except json.JSONDecodeError:
4245
pass
4346
return False

tests/unit_tests/ee/semantic_agent/test__semantic_code_generator.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -307,11 +307,10 @@ def test_generate_matplolib_boxplot_chart_code(
307307

308308
logic_unit = code_gen.execute(json_str, context=context, logger=logger)
309309
assert isinstance(logic_unit, LogicUnitOutput)
310-
print(logic_unit.output)
311310
assert (
312311
logic_unit.output
313312
== """
314-
313+
import matplotlib.pyplot as plt
315314
import pandas as pd
316315
317316
sql_query="SELECT `orders`.`ship_country` AS ship_country, SUM(`orders`.`freight`) AS total_freight FROM `orders` GROUP BY ship_country"

0 commit comments

Comments
 (0)