Skip to content

Commit 9741640

Browse files
authored
fix(SemanticAgent): join data to be fixed (#1239)
1 parent e321aa3 commit 9741640

3 files changed

Lines changed: 82 additions & 2 deletions

File tree

pandasai/ee/helpers/query_builder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,7 @@ def _build_from_clause(self, main_table_entry):
296296

297297
def _build_joins_clause(self, main_table_entry, referenced_tables):
298298
sql = ""
299-
main_table = main_table_entry["table"]
299+
main_table = main_table_entry["name"]
300300

301301
for table_name in referenced_tables:
302302
if table_name != main_table:

tests/unit_tests/ee/helpers/schema.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,3 +43,46 @@
4343
],
4444
}
4545
]
46+
47+
48+
MULTI_JOIN_SCHEMA = [
49+
{
50+
"name": "Sales",
51+
"table": "sales",
52+
"measures": [
53+
{"name": "total_revenue", "type": "sum", "sql": "revenue"},
54+
{"name": "total_sales", "type": "count", "sql": "id"},
55+
],
56+
"dimensions": [
57+
{"name": "product", "type": "string", "sql": "product"},
58+
{"name": "region", "type": "string", "sql": "region"},
59+
{"name": "sales_date", "type": "date", "sql": "sales_date"},
60+
{"name": "id", "type": "string", "sql": "id"},
61+
],
62+
"joins": [
63+
{
64+
"name": "Engagement",
65+
"join_type": "left",
66+
"sql": "${Sales.id} = ${Engagement.id}",
67+
}
68+
],
69+
},
70+
{
71+
"name": "Engagement",
72+
"table": "engagement",
73+
"measures": [{"name": "total_duration", "type": "sum", "sql": "duration"}],
74+
"dimensions": [
75+
{"name": "id", "type": "string", "sql": "id"},
76+
{"name": "user_id", "type": "string", "sql": "user_id"},
77+
{"name": "activity_type", "type": "string", "sql": "activity_type"},
78+
{"name": "engagement_date", "type": "date", "sql": "engagement_date"},
79+
],
80+
"joins": [
81+
{
82+
"name": "Sales",
83+
"join_type": "right",
84+
"sql": "${Engagement.id} = ${Sales.id}",
85+
}
86+
],
87+
},
88+
]

tests/unit_tests/ee/helpers/test_query_builder.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import unittest
22

33
from pandasai.ee.helpers.query_builder import QueryBuilder
4-
from tests.unit_tests.ee.helpers.schema import VIZ_QUERY_SCHEMA
4+
from tests.unit_tests.ee.helpers.schema import MULTI_JOIN_SCHEMA, VIZ_QUERY_SCHEMA
55

66

77
class TestQueryBuilder(unittest.TestCase):
@@ -191,3 +191,40 @@ def test_sql_with_filters_with_set_filter(self):
191191
"SELECT SUM(`orders`.`freight`) AS total_freight, `orders`.`ship_country` AS ship_country FROM `orders` GROUP BY ship_country HAVING SUM(`orders`.`freight`) IS NOT NULL ORDER BY total_freight asc",
192192
"SELECT `orders`.`ship_country` AS ship_country, SUM(`orders`.`freight`) AS total_freight FROM `orders` GROUP BY ship_country HAVING SUM(`orders`.`freight`) IS NOT NULL ORDER BY total_freight asc",
193193
]
194+
195+
def test_sql_with_filters_with_join(self):
196+
query_builder = QueryBuilder(MULTI_JOIN_SCHEMA)
197+
198+
json_str = {
199+
"type": "bar",
200+
"dimensions": ["Engagement.activity_type"],
201+
"measures": ["Sales.total_revenue"],
202+
"timeDimensions": [],
203+
"options": {
204+
"xLabel": "Activity Type",
205+
"yLabel": "Total Revenue",
206+
"title": "Total Revenue Generated from Users who Logged in Before Purchase",
207+
"legend": {"display": True, "position": "top"},
208+
},
209+
"joins": [
210+
{
211+
"name": "Engagement",
212+
"join_type": "right",
213+
"sql": "${Sales.id} = ${Engagement.id}",
214+
}
215+
],
216+
"filters": [
217+
{
218+
"member": "Engagement.engagement_date",
219+
"operator": "beforeDate",
220+
"values": ["${Sales.sales_date}"],
221+
}
222+
],
223+
"order": [{"id": "Sales.total_revenue", "direction": "asc"}],
224+
}
225+
sql_query = query_builder.generate_sql(json_str)
226+
227+
assert (
228+
sql_query
229+
== "SELECT `engagement`.`activity_type` AS activity_type, SUM(`sales`.`revenue`) AS total_revenue FROM `sales` RIGHT JOIN `engagement` ON `engagement`.`id` = `sales`.`id` WHERE `engagement`.`engagement_date` < '${Sales.sales_date}' GROUP BY activity_type ORDER BY total_revenue asc"
230+
)

0 commit comments

Comments
 (0)