Skip to content

Commit 357c97c

Browse files
feature(paginator): adding paginator functionality (#1644)
* feature(paginator): adding paginator functionality * Update pandasai/helpers/sql_sanitizer.py Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com> * Update pandasai/query_builders/sql_parser.py Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com> * feature(paginator): docstring * feature(paginator): format issues --------- Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>
1 parent fc1fefc commit 357c97c

6 files changed

Lines changed: 518 additions & 3 deletions

File tree

pandasai/helpers/sql_sanitizer.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,3 +97,26 @@ def is_sql_query_safe(query: str, dialect: str = "postgres") -> bool:
9797

9898
except sqlglot.errors.ParseError:
9999
return False
100+
101+
102+
def is_sql_query(query: str) -> bool:
103+
# Define SQL patterns with context to avoid standalone keyword matches
104+
sql_patterns = [
105+
r"\bSELECT\b.*\bFROM\b",
106+
r"\bINSERT\b.*\bINTO\b",
107+
r"\bUPDATE\b.*\bSET\b",
108+
r"\bDELETE\b.*\bFROM\b",
109+
r"\bDROP\b.*\b(TABLE|DATABASE)\b",
110+
r"\bCREATE\b.*\b(DATABASE|TABLE)\b",
111+
r"\bALTER\b.*\bTABLE\b",
112+
r"\bJOIN\b.*\bON\b",
113+
r"\bWHERE\b",
114+
]
115+
116+
# Combine all patterns into a single regex
117+
sql_regex = re.compile("|".join(sql_patterns), re.IGNORECASE)
118+
119+
# If the query matches any SQL pattern, it's considered a SQL query
120+
if sql_regex.search(query):
121+
return True
122+
return False

pandasai/query_builders/base_query_builder.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import sqlglot
44
from sqlglot import select
5-
from sqlglot.expressions import false
65
from sqlglot.optimizer.normalize_identifiers import normalize_identifiers
76

87
from pandasai.data_loader.semantic_layer_schema import SemanticLayerSchema, Source
Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
import datetime
2+
import json
3+
import uuid
4+
from typing import List, Optional, Tuple
5+
6+
import sqlglot
7+
from pydantic import BaseModel, Field, field_validator
8+
9+
from pandasai.helpers.sql_sanitizer import is_sql_query
10+
11+
12+
class PaginationParams(BaseModel):
13+
"""Parameters for pagination requests"""
14+
15+
page: int = Field(ge=1, description="Page number, starting from 1")
16+
page_size: int = Field(
17+
ge=1, le=100, description="Number of items per page, maximum 100"
18+
)
19+
search: Optional[str] = Field(
20+
None, description="Search term to filter across all fields"
21+
)
22+
sort_by: Optional[str] = Field(None, description="Column to sort by")
23+
sort_order: Optional[str] = Field(
24+
None, pattern="^(asc|desc)$", description="Sort order (asc or desc)"
25+
)
26+
filters: Optional[str] = Field(None, description="Filters to apply to the data")
27+
28+
@field_validator("search", "filters", "sort_by", "sort_order")
29+
@classmethod
30+
def not_sql(cls, field):
31+
if is_sql_query(str(field)):
32+
raise ValueError(
33+
f"SQL queries are not allowed in pagination parameters: {field}"
34+
)
35+
return field
36+
37+
38+
class DatasetPaginator:
39+
@staticmethod
40+
def is_float(value: str) -> bool:
41+
try:
42+
# Try to cast the value to a number
43+
float(value)
44+
return True
45+
except (ValueError, TypeError):
46+
# If it fails, it's not a number
47+
return False
48+
49+
@staticmethod
50+
def is_valid_boolean(value):
51+
"""Check if the value is a valid boolean."""
52+
return (
53+
value.lower() in ["true", "false"]
54+
if isinstance(value, str)
55+
else isinstance(value, bool)
56+
)
57+
58+
@staticmethod
59+
def is_valid_uuid(value):
60+
try:
61+
uuid.UUID(value)
62+
return True
63+
except ValueError:
64+
return False
65+
66+
@staticmethod
67+
def is_valid_datetime(value: str) -> bool:
68+
try:
69+
datetime.datetime.strptime(value, "%Y-%m-%d %H:%M:%S")
70+
return True
71+
except ValueError:
72+
return False
73+
74+
@staticmethod
75+
def apply_pagination(
76+
query: str,
77+
columns: List[dict],
78+
pagination: Optional[PaginationParams],
79+
target_dialect: str = "postgres",
80+
) -> Tuple[str, List]:
81+
"""
82+
Apply pagination to a SQL query.
83+
84+
Args:
85+
query (str): The SQL query to apply pagination to
86+
columns (List[dict]): A list of dictionaries containing
87+
information about the columns in the result set. Each
88+
dictionary should have the following structure:
89+
{
90+
"name": str,
91+
"type": str
92+
}
93+
The type should be one of: "string", "number", "integer", "float",
94+
"boolean", "datetime"
95+
pagination (Optional[PaginationParams]): The pagination parameters
96+
to apply to the query. If None, the query is returned unchanged
97+
target_dialect (str): The SQL dialect to generate the query for.
98+
Defaults to "postgres".
99+
100+
Returns:
101+
Tuple[str, List]: A tuple containing the modified SQL query and a
102+
list of parameters to pass to the query.
103+
"""
104+
105+
params = []
106+
107+
if not pagination:
108+
return query, params
109+
110+
filtering_query = f"SELECT * FROM ({query}) AS filtered_data"
111+
conditions = []
112+
113+
# Handle search functionality
114+
if pagination.search:
115+
search_conditions = []
116+
for column in columns:
117+
column_name = column["name"]
118+
column_type = column["type"]
119+
120+
if column_type == "string":
121+
search_conditions.append(f"{column_name} ILIKE %s")
122+
params.append(f"%{pagination.search}%")
123+
124+
elif column_type == "float" and DatasetPaginator.is_float(
125+
pagination.search
126+
):
127+
search_conditions.append(f"{column_name} = %s")
128+
params.append(pagination.search)
129+
130+
elif (
131+
column_type in ["number", "integer"]
132+
and pagination.search.isnumeric()
133+
):
134+
search_conditions.append(f"{column_name} = %s")
135+
params.append(pagination.search)
136+
137+
elif column_type == "datetime" and DatasetPaginator.is_valid_datetime(
138+
pagination.search
139+
):
140+
search_conditions.append(f"{column_name} = %s")
141+
params.append(
142+
datetime.datetime.strptime(
143+
pagination.search, "%Y-%m-%d %H:%M:%S"
144+
)
145+
)
146+
147+
elif column_type == "boolean" and DatasetPaginator.is_valid_boolean(
148+
pagination.search
149+
):
150+
search_conditions.append(f"{column_name} = %s")
151+
params.append(pagination.search)
152+
153+
elif column_type == "uuid" and DatasetPaginator.is_valid_uuid(
154+
pagination.search
155+
):
156+
search_conditions.append(f"{column_name}::TEXT = %s")
157+
params.append(pagination.search)
158+
159+
if search_conditions:
160+
conditions.append(" OR ".join(search_conditions))
161+
162+
# Handle filters
163+
if pagination.filters:
164+
try:
165+
filters = (
166+
json.loads(pagination.filters)
167+
if isinstance(pagination.filters, str)
168+
else pagination.filters
169+
)
170+
for column, values in filters.items():
171+
if not isinstance(values, list):
172+
values = [values]
173+
placeholders = ", ".join(["%s"] * len(values))
174+
conditions.append(f"{column} IN ({placeholders})")
175+
params.extend(values)
176+
except json.JSONDecodeError as e:
177+
raise ValueError(f"Invalid filters format: {e}")
178+
179+
# Add WHERE clause if conditions exist
180+
if conditions:
181+
filtering_query += " WHERE " + " AND ".join(conditions)
182+
183+
# Handle sorting
184+
if pagination.sort_by and pagination.sort_order:
185+
if not any(pagination.sort_by == column["name"] for column in columns):
186+
raise ValueError(
187+
f"Sort column '{pagination.sort_by}' not found in available columns"
188+
)
189+
190+
filtering_query += (
191+
f" ORDER BY {pagination.sort_by} {pagination.sort_order.upper()}"
192+
)
193+
194+
# Handle page and page_size
195+
if pagination.page and pagination.page_size:
196+
filtering_query += " LIMIT %s OFFSET %s"
197+
params.extend(
198+
[pagination.page_size, (pagination.page - 1) * pagination.page_size]
199+
)
200+
201+
# Replace placeholders for target dialect
202+
placeholder = "___PLACEHOLDER___"
203+
temp_query = filtering_query.replace("%s", placeholder)
204+
transpiled_query = sqlglot.transpile(
205+
temp_query, read="postgres", write=target_dialect
206+
)[0]
207+
final_query = transpiled_query.replace(placeholder, "%s")
208+
209+
return final_query, params

pandasai/query_builders/sql_parser.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,11 +57,16 @@ def transform_node(node):
5757
return transformed.sql(pretty=True)
5858

5959
@staticmethod
60-
def transpile_sql_dialect(query, to_dialect, from_dialect=None):
60+
def transpile_sql_dialect(
61+
query: str, to_dialect: str, from_dialect: Optional[str] = None
62+
):
63+
placeholder = "___PLACEHOLDER___"
64+
query = query.replace("%s", placeholder)
6165
query = (
6266
parse_one(query, read=from_dialect) if from_dialect else parse_one(query)
6367
)
64-
return query.sql(dialect=to_dialect, pretty=True)
68+
result = query.sql(dialect=to_dialect, pretty=True)
69+
return result.replace(placeholder, "%s")
6570

6671
@staticmethod
6772
def extract_table_names(sql_query: str, dialect: str = "postgres") -> List[str]:

tests/unit_tests/helpers/test_sql_sanitizer.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from pandasai.helpers.sql_sanitizer import (
2+
is_sql_query,
23
is_sql_query_safe,
34
sanitize_file_name,
45
sanitize_view_column_name,
@@ -94,3 +95,37 @@ def test_safe_query_with_subquery(self):
9495
def test_safe_query_with_query_params(self):
9596
query = "SELECT * FROM (SELECT * FROM heart_data) AS filtered_data LIMIT %s OFFSET %s"
9697
assert is_sql_query_safe(query)
98+
99+
def test_plain_text(self):
100+
"""Test with plain text input that is not a SQL query."""
101+
assert not is_sql_query("Hello, how are you?")
102+
assert not is_sql_query("This is just some text.")
103+
104+
def test_sql_queries(self):
105+
"""Test with typical SQL queries."""
106+
assert is_sql_query("SELECT * FROM users")
107+
assert is_sql_query("insert into users values ('john', 25)")
108+
assert is_sql_query("delete from orders where id=10")
109+
assert is_sql_query("DROP TABLE users")
110+
assert is_sql_query("update products set price=100 where id=1")
111+
112+
def test_case_insensitivity(self):
113+
"""Test with queries in different cases."""
114+
assert is_sql_query("select id from users")
115+
assert is_sql_query("SeLeCt id FROM users")
116+
assert is_sql_query("DROP table orders")
117+
assert is_sql_query("cReAtE DATABASE testdb")
118+
119+
def test_edge_cases(self):
120+
"""Test with edge cases like empty strings and special characters."""
121+
assert not is_sql_query("")
122+
assert not is_sql_query(" ")
123+
assert not is_sql_query("1234567890")
124+
assert not is_sql_query("#$%^&*()")
125+
assert not is_sql_query("JOIN the party") # Not SQL context
126+
127+
def test_mixed_input(self):
128+
"""Test with mixed input containing SQL keywords in non-SQL contexts."""
129+
assert not is_sql_query("Let's SELECT a movie to watch")
130+
assert not is_sql_query("CREATE a new painting")
131+
assert not is_sql_query("DROP by my house later")

0 commit comments

Comments
 (0)