Skip to content

Commit c0933a2

Browse files
authored
fix(View): fixing aliases in view (#1614)
* fix(views): transformation using raw sql * fix(Views): multiple joins * fix(View): correct error message for missing dependencies in view * fix(View): correct error message for missing dependencies in view * fix(View): redoundant code
1 parent a1b36f4 commit c0933a2

3 files changed

Lines changed: 307 additions & 186 deletions

File tree

pandasai/data_loader/view_loader.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,16 @@ def _get_dependencies_datasets(self) -> set[str]:
4848
} or {self.schema.columns[0].name.split(".")[0]}
4949

5050
def _get_dependencies_schemas(self) -> dict[str, DatasetLoader]:
51-
dependency_dict = {
52-
dep: DatasetLoader.create_loader_from_path(f"{self.org_name}/{dep}")
53-
for dep in self.dependencies_datasets
54-
}
51+
dependency_dict = {}
52+
for dep in self.dependencies_datasets:
53+
try:
54+
dependency_dict[dep] = DatasetLoader.create_loader_from_path(
55+
f"{self.org_name}/{dep}"
56+
)
57+
except FileNotFoundError:
58+
raise FileNotFoundError(
59+
f"View failed to load. Missing required dataset: '{dep}'. Try pulling the dataset to resolve the issue."
60+
)
5561

5662
loaders = list(dependency_dict.values())
5763

pandasai/query_builders/view_query_builder.py

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,16 @@ def _get_group_by_columns(self) -> list[str]:
4040
group_by_cols.append(self.normalize_view_column_alias(col))
4141
return group_by_cols
4242

43+
def _get_aliases(self) -> list[str]:
44+
return [
45+
col.alias or self.normalize_view_column_alias(col.name)
46+
for col in self.schema.columns
47+
]
48+
4349
def _get_columns(self) -> list[str]:
4450
columns = []
45-
for col in self.schema.columns:
51+
aliases = self._get_aliases()
52+
for i, col in enumerate(self.schema.columns):
4653
if col.expression:
4754
# Pre-process the expression to handle hyphens between letters
4855
expr = re.sub(r"([a-zA-Z])-([a-zA-Z])", r"\1_\2", col.expression)
@@ -51,9 +58,7 @@ def _get_columns(self) -> list[str]:
5158
else:
5259
column_expr = self.normalize_view_column_alias(col.name)
5360

54-
alias = (
55-
col.alias if col.alias else self.normalize_view_column_alias(col.name)
56-
)
61+
alias = aliases[i]
5762
column_expr = f"{column_expr} AS {alias}"
5863

5964
columns.append(column_expr)
@@ -62,32 +67,17 @@ def _get_columns(self) -> list[str]:
6267

6368
def build_query(self) -> str:
6469
"""Build the SQL query with proper group by column aliasing."""
65-
query = select(*self._get_columns()).from_(self._get_table_expression())
66-
67-
if self.schema.group_by:
68-
query = query.group_by(
69-
*[normalize_identifiers(col) for col in self._get_group_by_columns()]
70-
)
71-
70+
query = select(*self._get_aliases()).from_(self._get_table_expression())
7271
if self.schema.order_by:
7372
query = query.order_by(*self.schema.order_by)
74-
7573
if self.schema.limit:
7674
query = query.limit(self.schema.limit)
77-
7875
return query.sql(pretty=True)
7976

8077
def get_head_query(self, n=5):
8178
"""Get the head query with proper group by column aliasing."""
82-
query = select(*self._get_columns()).from_(self._get_table_expression())
83-
84-
if self.schema.group_by:
85-
query = query.group_by(
86-
*[normalize_identifiers(col) for col in self._get_group_by_columns()]
87-
)
88-
79+
query = select(*self._get_aliases()).from_(self._get_table_expression())
8980
query = query.limit(n)
90-
9181
return query.sql(pretty=True)
9282

9383
def _get_sub_query_from_loader(self, loader: DatasetLoader) -> Subquery:
@@ -132,4 +122,14 @@ def _get_table_expression(self) -> str:
132122
append=True,
133123
)
134124
alias = normalize_identifiers(self.schema.name).sql()
135-
return exp.Subquery(this=query, alias=alias).sql(pretty=True)
125+
126+
subquery = exp.Subquery(this=query).sql(pretty=True)
127+
128+
final_query = select(*self._get_columns()).from_(subquery)
129+
130+
if self.schema.group_by:
131+
final_query = final_query.group_by(
132+
*[normalize_identifiers(col) for col in self._get_group_by_columns()]
133+
)
134+
135+
return exp.Subquery(this=final_query, alias=alias).sql(pretty=True)

0 commit comments

Comments
 (0)