Skip to content

Commit 431c3d5

Browse files
authored
chore(SemanticAgent): add samples in schema and support back-tick json load (#1241)
* fix(SemanticAgent): join data to be fixed * fix(semantic_agent): json load to also look for json in backtick
1 parent e10510a commit 431c3d5

5 files changed

Lines changed: 113 additions & 80 deletions

File tree

pandasai/ee/agents/semantic_agent/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from pandasai.ee.agents.semantic_agent.prompts.generate_df_schema import (
1616
GenerateDFSchemaPrompt,
1717
)
18+
from pandasai.ee.helpers.json_helper import extract_json_from_json_str
1819
from pandasai.exceptions import InvalidConfigError, InvalidSchemaJson, InvalidTrainJson
1920
from pandasai.helpers.cache import Cache
2021
from pandasai.helpers.memory import Memory
@@ -186,7 +187,7 @@ def _create_schema(self):
186187
"""
187188
)
188189
self._schema = result.replace("# SAMPLE SCHEMA", "")
189-
schema_data = json.loads(result.replace("# SAMPLE SCHEMA", ""))
190+
schema_data = extract_json_from_json_str(result.replace("# SAMPLE SCHEMA", ""))
190191
if isinstance(schema_data, dict):
191192
schema_data = [schema_data]
192193

pandasai/ee/agents/semantic_agent/pipeline/llm_call.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
import json
21
from typing import Any
32

3+
from pandasai.ee.helpers.json_helper import extract_json_from_json_str
44
from pandasai.helpers.logger import Logger
55
from pandasai.pipelines.base_logic_unit import BaseLogicUnit
66
from pandasai.pipelines.logic_unit_output import LogicUnitOutput
@@ -42,7 +42,7 @@ def execute(self, input: Any, **kwargs) -> Any:
4242
)
4343
try:
4444
# Validate is valid Json
45-
response_json = json.loads(response)
45+
response_json = extract_json_from_json_str(response)
4646

4747
pipeline_context.add("llm_call", response)
4848

pandasai/ee/agents/semantic_agent/prompts/generate_df_schema.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from jinja2 import Environment, FileSystemLoader
55

6+
from pandasai.ee.helpers.json_helper import extract_json_from_json_str
67
from pandasai.prompts.base import BasePrompt
78

89

@@ -30,7 +31,9 @@ def __init__(self, **kwargs):
3031

3132
def validate(self, output: str) -> bool:
3233
try:
33-
json_data = json.loads(output.replace("# SAMPLE SCHEMA", ""))
34+
json_data = extract_json_from_json_str(
35+
output.replace("# SAMPLE SCHEMA", "")
36+
)
3437
context = self.props["context"]
3538
if isinstance(json_data, dict):
3639
json_data = [json_data]

pandasai/ee/agents/semantic_agent/prompts/templates/generate_df_schema.tmpl

Lines changed: 91 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -1,132 +1,147 @@
11
# SAMPLE SCHEMA
22
[
33
{
4-
"name":"Contracts",
5-
"table":"contracts",
6-
"measures":[
4+
"name": "Contracts",
5+
"table": "contracts",
6+
"measures": [
77
{
8-
"name":"contract_count",
9-
"type":"count",
10-
"sql":"store_id"
8+
"name": "contract_count",
9+
"type": "count",
10+
"sql": "store_id"
1111
},
1212
{
13-
"name":"contract_duration",
14-
"type":"number",
15-
"sql":"${contract_end_date} - ${contract_start_date}"
13+
"name": "contract_duration",
14+
"type": "number",
15+
"sql": "${contract_end_date} - ${contract_start_date}"
1616
},
1717
{
18-
"name":"contract_avg_duration",
19-
"type":"avg",
20-
"sql":"${contract_duration}"
18+
"name": "contract_avg_duration",
19+
"type": "avg",
20+
"sql": "${contract_duration}"
2121
}
2222
],
23-
"dimensions":[
23+
"dimensions": [
2424
{
25-
"name":"contract_code",
26-
"type":"string",
27-
"sql":"contract_code"
25+
"name": "contract_code",
26+
"type": "string",
27+
"sql": "contract_code",
28+
"samples": ["C12345", "C67890"]
2829
},
2930
{
30-
"name":"store_id",
31-
"type":"string",
32-
"sql":"store_id"
31+
"name": "store_id",
32+
"type": "string",
33+
"sql": "store_id",
34+
"samples": ["S12345", "S67890"]
3335
},
3436
{
35-
"name":"tenant_code",
36-
"type":"string",
37-
"sql":"tenant_code"
37+
"name": "tenant_code",
38+
"type": "string",
39+
"sql": "tenant_code",
40+
"samples": ["T12345", "T67890"]
3841
},
3942
{
40-
"name":"tenant_name",
41-
"type":"string",
42-
"sql":"tenant_name"
43+
"name": "tenant_name",
44+
"type": "string",
45+
"sql": "tenant_name",
46+
"samples": ["Tenant A", "Tenant B"]
4347
},
4448
{
45-
"name":"store_brand",
46-
"type":"string",
47-
"sql":"store_brand"
49+
"name": "store_brand",
50+
"type": "string",
51+
"sql": "store_brand",
52+
"samples": ["Brand X", "Brand Y"]
4853
},
4954
{
50-
"name":"branch_segment_1",
51-
"type":"string",
52-
"sql":"branch_segment_1"
55+
"name": "branch_segment_1",
56+
"type": "string",
57+
"sql": "branch_segment_1",
58+
"samples": ["Segment 1", "Segment 2"]
5359
},
5460
{
55-
"name":"branch_segment_2",
56-
"type":"string",
57-
"sql":"branch_segment_2"
61+
"name": "branch_segment_2",
62+
"type": "string",
63+
"sql": "branch_segment_2",
64+
"samples": ["Segment A", "Segment B"]
5865
},
5966
{
60-
"name":"contract_start_date",
61-
"type":"date",
62-
"sql":"contract_start_date"
67+
"name": "contract_start_date",
68+
"type": "date",
69+
"sql": "contract_start_date",
70+
"samples": ["2023-01-01", "2023-02-01"]
6371
},
6472
{
65-
"name":"contract_end_date",
66-
"type":"date",
67-
"sql":"contract_end_date"
73+
"name": "contract_end_date",
74+
"type": "date",
75+
"sql": "contract_end_date",
76+
"samples": ["2024-01-01", "2024-02-01"]
6877
}
6978
],
70-
"joins":[
79+
"joins": [
7180
{
72-
"name":"corrispettivi",
73-
"join_type":"left",
74-
"sql":"${Contracts.contract_code} = ${Fees.contract_id}"
81+
"name": "Fee",
82+
"join_type": "left",
83+
"sql": "${Contracts.contract_code} = ${Fees.contract_id}"
7584
}
7685
]
7786
},
7887
{
79-
"name":"Fees",
80-
"table":"fees",
81-
"measures":[
88+
"name": "Fees",
89+
"table": "fees",
90+
"measures": [
8291
{
83-
"name":"total_taxable",
84-
"type":"sum",
85-
"sql":"imponibile_tot"
92+
"name": "total_taxable",
93+
"type": "sum",
94+
"sql": "imponibile_tot"
8695
},
8796
{
88-
"name":"total_revenue",
89-
"type":"sum",
90-
"sql":"totale_tot"
97+
"name": "total_revenue",
98+
"type": "sum",
99+
"sql": "totale_tot"
91100
}
92101
],
93-
"dimensions":[
102+
"dimensions": [
94103
{
95-
"name":"contract_id",
96-
"type":"string",
97-
"sql":"contract_id"
104+
"name": "contract_id",
105+
"type": "string",
106+
"sql": "contract_id",
107+
"samples": ["C12345", "C67890"]
98108
},
99109
{
100-
"name":"code",
101-
"type":"string",
102-
"sql":"code"
110+
"name": "code",
111+
"type": "string",
112+
"sql": "code",
113+
"samples": ["F12345", "F67890"]
103114
},
104115
{
105-
"name":"station",
106-
"type":"string",
107-
"sql":"station"
116+
"name": "station",
117+
"type": "string",
118+
"sql": "station",
119+
"samples": ["Station X", "Station Y"]
108120
},
109121
{
110-
"name":"tenant_id",
111-
"type":"string",
112-
"sql":"tenant_id"
122+
"name": "tenant_id",
123+
"type": "string",
124+
"sql": "tenant_id",
125+
"samples": ["T12345", "T67890"]
113126
},
114127
{
115-
"name":"day",
116-
"type":"date",
117-
"sql":"day"
128+
"name": "day",
129+
"type": "date",
130+
"sql": "day",
131+
"samples": ["2023-01-01", "2023-02-01"]
118132
},
119133
{
120-
"name":"store_id",
121-
"type":"string",
122-
"sql":"store_id"
134+
"name": "store_id",
135+
"type": "string",
136+
"sql": "store_id",
137+
"samples": ["S12345", "S67890"]
123138
}
124139
],
125-
"joins":[
140+
"joins": [
126141
{
127-
"name":"contracts",
128-
"join_type":"right",
129-
"sql":"${Fees.contract_id} = ${Fees.contract_code}"
142+
"name": "Contracts",
143+
"join_type": "right",
144+
"sql": "${Fees.contract_id} = ${Contracts.contract_code}"
130145
}
131146
]
132147
}

pandasai/ee/helpers/json_helper.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import json
2+
3+
4+
def extract_json_from_json_str(json_str):
5+
start_index = json_str.find("```json")
6+
7+
end_index = json_str.find("```", start_index)
8+
9+
if start_index == -1:
10+
return json.loads(json_str)
11+
12+
json_data = json_str[(start_index + len("```json")) : end_index].strip()
13+
14+
return json.loads(json_data)

0 commit comments

Comments
 (0)