Skip to content

Commit b10ba3f

Browse files
fixes
1 parent c972c7c commit b10ba3f

File tree

5 files changed

+844
-99
lines changed

5 files changed

+844
-99
lines changed

backend/src/core/query_utils.py

Lines changed: 306 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,306 @@
1+
"""
2+
Core utilities for server-side pagination, sorting, and filtering.
3+
4+
This module provides reusable functions to apply pagination, sorting, and filtering
5+
to SQLAlchemy queries in a type-safe and flexible manner.
6+
"""
7+
8+
from enum import Enum
9+
from typing import Any
10+
11+
from pydantic import BaseModel, Field, field_validator
12+
from sqlalchemy import Select, asc, desc, func, select
13+
from sqlalchemy.ext.asyncio import AsyncSession
14+
from sqlalchemy.orm import DeclarativeMeta
15+
16+
17+
class SortOrder(str, Enum):
18+
"""Sort order enumeration."""
19+
20+
ASC = "asc"
21+
DESC = "desc"
22+
23+
24+
class FilterOperator(str, Enum):
25+
"""Filter operators for different comparison types."""
26+
27+
EQUALS = "eq"
28+
NOT_EQUALS = "ne"
29+
GREATER_THAN = "gt"
30+
GREATER_THAN_OR_EQUAL = "gte"
31+
LESS_THAN = "lt"
32+
LESS_THAN_OR_EQUAL = "lte"
33+
CONTAINS = "contains" # For string fields (case-insensitive LIKE)
34+
IN = "in" # For checking if value is in a list
35+
NOT_IN = "not_in"
36+
IS_NULL = "is_null"
37+
IS_NOT_NULL = "is_not_null"
38+
39+
40+
class PaginationParams(BaseModel):
41+
"""Parameters for pagination."""
42+
43+
page_number: int = Field(default=1, ge=1, description="Page number (1-indexed)")
44+
page_size: int | None = Field(
45+
default=None, ge=1, le=100, description="Items per page (None = all items)"
46+
)
47+
48+
@property
49+
def skip(self) -> int:
50+
"""Calculate offset from page_number and page_size."""
51+
if self.page_size is None:
52+
return 0
53+
return (self.page_number - 1) * self.page_size
54+
55+
@property
56+
def limit(self) -> int | None:
57+
"""Alias for page_size."""
58+
return self.page_size
59+
60+
61+
class SortParam(BaseModel):
62+
"""Single sorting parameter."""
63+
64+
field: str = Field(..., description="Field name to sort by")
65+
order: SortOrder = Field(default=SortOrder.ASC, description="Sort order")
66+
67+
68+
class FilterParam(BaseModel):
69+
"""Single filter parameter."""
70+
71+
field: str = Field(..., description="Field name to filter on")
72+
operator: FilterOperator = Field(..., description="Comparison operator")
73+
value: Any | None = Field(default=None, description="Value to compare against")
74+
75+
@field_validator("value")
76+
@classmethod
77+
def validate_value_for_operator(cls, v: Any, info: Any) -> Any:
78+
"""Validate that value is appropriate for the operator."""
79+
operator = info.data.get("operator")
80+
81+
if operator in [FilterOperator.IS_NULL, FilterOperator.IS_NOT_NULL]:
82+
# These operators don't need a value
83+
return None
84+
85+
if operator in [FilterOperator.IN, FilterOperator.NOT_IN] and not isinstance(v, list):
86+
# These operators need a list
87+
raise ValueError(f"Operator {operator} requires a list value")
88+
89+
return v
90+
91+
92+
class QueryParams(BaseModel):
93+
"""Combined parameters for pagination, sorting, and filtering."""
94+
95+
pagination: PaginationParams | None = Field(default=None)
96+
sort: list[SortParam] | None = Field(default=None)
97+
filters: list[FilterParam] | None = Field(default=None)
98+
99+
100+
def apply_pagination(query: Select, params: PaginationParams | None = None) -> Select:
101+
"""
102+
Apply pagination to a SQLAlchemy query.
103+
104+
Args:
105+
query: The base SQLAlchemy query
106+
params: Pagination parameters (page_number, page_size)
107+
108+
Returns:
109+
Modified query with LIMIT and OFFSET applied
110+
"""
111+
if params is None:
112+
return query
113+
114+
query = query.offset(params.skip)
115+
if params.limit is not None:
116+
query = query.limit(params.limit)
117+
118+
return query
119+
120+
121+
def apply_sorting[ModelType: DeclarativeMeta](
122+
query: Select,
123+
model: type[ModelType],
124+
sort_params: list[SortParam] | None = None,
125+
allowed_fields: list[str] | None = None,
126+
) -> Select:
127+
"""
128+
Apply sorting to a SQLAlchemy query.
129+
130+
Args:
131+
query: The base SQLAlchemy query
132+
model: The SQLAlchemy model class
133+
sort_params: List of sorting parameters
134+
allowed_fields: Optional list of fields that can be sorted on
135+
136+
Returns:
137+
Modified query with ORDER BY applied
138+
139+
Raises:
140+
ValueError: If attempting to sort on a disallowed or non-existent field
141+
"""
142+
if not sort_params:
143+
return query
144+
145+
for sort_param in sort_params:
146+
# Validate field exists and is allowed
147+
if not hasattr(model, sort_param.field):
148+
raise ValueError(f"Field '{sort_param.field}' does not exist on model")
149+
150+
if allowed_fields and sort_param.field not in allowed_fields:
151+
raise ValueError(f"Sorting on field '{sort_param.field}' is not allowed")
152+
153+
# Get the model attribute
154+
field = getattr(model, sort_param.field)
155+
156+
# Apply sort order
157+
if sort_param.order == SortOrder.DESC:
158+
query = query.order_by(desc(field))
159+
else:
160+
query = query.order_by(asc(field))
161+
162+
return query
163+
164+
165+
def apply_filters[ModelType: DeclarativeMeta](
166+
query: Select,
167+
model: type[ModelType],
168+
filter_params: list[FilterParam] | None = None,
169+
allowed_fields: list[str] | None = None,
170+
) -> Select:
171+
"""
172+
Apply filters to a SQLAlchemy query.
173+
174+
Args:
175+
query: The base SQLAlchemy query
176+
model: The SQLAlchemy model class
177+
filter_params: List of filter parameters
178+
allowed_fields: Optional list of fields that can be filtered on
179+
180+
Returns:
181+
Modified query with WHERE clauses applied
182+
183+
Raises:
184+
ValueError: If attempting to filter on a disallowed or non-existent field
185+
"""
186+
if not filter_params:
187+
return query
188+
189+
for filter_param in filter_params:
190+
# Validate field exists and is allowed
191+
if not hasattr(model, filter_param.field):
192+
raise ValueError(f"Field '{filter_param.field}' does not exist on model")
193+
194+
if allowed_fields and filter_param.field not in allowed_fields:
195+
raise ValueError(f"Filtering on field '{filter_param.field}' is not allowed")
196+
197+
# Get the model attribute
198+
field = getattr(model, filter_param.field)
199+
200+
# Apply the appropriate filter based on operator
201+
if filter_param.operator == FilterOperator.EQUALS:
202+
query = query.where(field == filter_param.value)
203+
elif filter_param.operator == FilterOperator.NOT_EQUALS:
204+
query = query.where(field != filter_param.value)
205+
elif filter_param.operator == FilterOperator.GREATER_THAN:
206+
query = query.where(field > filter_param.value)
207+
elif filter_param.operator == FilterOperator.GREATER_THAN_OR_EQUAL:
208+
query = query.where(field >= filter_param.value)
209+
elif filter_param.operator == FilterOperator.LESS_THAN:
210+
query = query.where(field < filter_param.value)
211+
elif filter_param.operator == FilterOperator.LESS_THAN_OR_EQUAL:
212+
query = query.where(field <= filter_param.value)
213+
elif filter_param.operator == FilterOperator.CONTAINS:
214+
# Case-insensitive LIKE for string fields
215+
query = query.where(field.ilike(f"%{filter_param.value}%"))
216+
elif filter_param.operator == FilterOperator.IN:
217+
query = query.where(field.in_(filter_param.value))
218+
elif filter_param.operator == FilterOperator.NOT_IN:
219+
query = query.where(~field.in_(filter_param.value))
220+
elif filter_param.operator == FilterOperator.IS_NULL:
221+
query = query.where(field.is_(None))
222+
elif filter_param.operator == FilterOperator.IS_NOT_NULL:
223+
query = query.where(field.is_not(None))
224+
225+
return query
226+
227+
228+
def apply_query_params[ModelType: DeclarativeMeta](
229+
query: Select,
230+
model: type[ModelType],
231+
params: QueryParams | None = None,
232+
allowed_sort_fields: list[str] | None = None,
233+
allowed_filter_fields: list[str] | None = None,
234+
) -> tuple[Select, PaginationParams | None]:
235+
"""
236+
Apply all query parameters (filtering, sorting, pagination) to a query.
237+
238+
This is the main utility function that combines all operations.
239+
240+
Args:
241+
query: The base SQLAlchemy query
242+
model: The SQLAlchemy model class
243+
params: Combined query parameters
244+
allowed_sort_fields: Optional list of fields that can be sorted on
245+
allowed_filter_fields: Optional list of fields that can be filtered on
246+
247+
Returns:
248+
Tuple of (modified query, pagination params used)
249+
"""
250+
if params is None:
251+
return query, None
252+
253+
# Apply filters first (narrows down the dataset)
254+
if params.filters:
255+
query = apply_filters(query, model, params.filters, allowed_filter_fields)
256+
257+
# Apply sorting (before pagination to ensure consistent ordering)
258+
if params.sort:
259+
query = apply_sorting(query, model, params.sort, allowed_sort_fields)
260+
261+
# Apply pagination last
262+
pagination_params = params.pagination
263+
if pagination_params:
264+
query = apply_pagination(query, pagination_params)
265+
266+
return query, pagination_params
267+
268+
269+
async def get_total_count(
270+
session: AsyncSession,
271+
base_query: Select,
272+
) -> int:
273+
"""
274+
Get the total count of results for a query (before pagination).
275+
276+
Args:
277+
session: SQLAlchemy async session
278+
base_query: The base query (with filters but before pagination)
279+
280+
Returns:
281+
Total number of results
282+
"""
283+
# Use select(func.count()).select_from(base_query.subquery()) for proper counting
284+
count_query = select(func.count()).select_from(base_query.subquery())
285+
result = await session.execute(count_query)
286+
return result.scalar() or 0
287+
288+
289+
class PaginatedResponse[ModelType](BaseModel):
290+
"""Generic paginated response wrapper matching existing PaginatedResponse."""
291+
292+
items: list[ModelType]
293+
total_records: int
294+
page_size: int
295+
page_number: int
296+
total_pages: int
297+
298+
@property
299+
def has_next(self) -> bool:
300+
"""Check if there's a next page."""
301+
return self.page_number < self.total_pages
302+
303+
@property
304+
def has_prev(self) -> bool:
305+
"""Check if there's a previous page."""
306+
return self.page_number > 1

backend/src/modules/party/party_router.py

Lines changed: 33 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -68,52 +68,50 @@ async def create_party(
6868
async def list_parties(
6969
page_number: int = Query(1, ge=1, description="Page number (1-indexed)"),
7070
page_size: int | None = Query(None, ge=1, le=100, description="Items per page (default: all)"),
71+
sort_by: str | None = Query(None, description="Field to sort by (e.g., 'party_datetime')"),
72+
sort_order: str = Query("asc", pattern="^(asc|desc)$", description="Sort order: asc or desc"),
73+
location_id: int | None = Query(None, description="Filter by location ID"),
74+
contact_one_id: int | None = Query(None, description="Filter by contact one (student) ID"),
7175
party_service: PartyService = Depends(),
7276
_=Depends(authenticate_by_role("admin", "staff", "police")),
7377
) -> PaginatedPartiesResponse:
7478
"""
75-
Returns all party registrations in the database with optional pagination.
79+
Returns all party registrations in the database with optional pagination, sorting, and filtering
7680
7781
Query Parameters:
78-
- page_number: The page number to retrieve (1-indexed)
82+
- page_number: The page number to retrieve (1-indexed, default: 1)
7983
- page_size: Number of items per page (max 100, default: returns all parties)
84+
- sort_by: Field to sort by (allowed: party_datetime, location_id, contact_one_id, id)
85+
- sort_order: Sort order (asc or desc, default: asc)
86+
- location_id: Filter by location ID (optional)
87+
- contact_one_id: Filter by contact one (student) ID (optional)
88+
89+
Features:
90+
- **Opt-in**: All features have sensible defaults - no parameters returns all parties
91+
- **Server-side**: All sorting, filtering, and pagination happens in the database
92+
- **Performant**: Scales well with large datasets
8093
8194
Returns:
82-
- items: List of party registrations
83-
- total_records: Total number of records in the database
84-
- page_size: Requested page size (or total_records if not specified)
85-
- page_number: Requested page number
86-
- total_pages: Total number of pages based on page size
95+
- items: List of party registrations for the current page
96+
- total_records: Total number of records matching filters (not just current page)
97+
- page_size: Items per page (equals total_records when page_size is None)
98+
- page_number: Current page number
99+
- total_pages: Total number of pages based on page size and total records
100+
101+
Examples:
102+
- Get all parties: GET /api/parties/
103+
- Get first page of 10: GET /api/parties/?page_size=10
104+
- Sort by date descending: GET /api/parties/?sort_by=party_datetime&sort_order=desc
105+
- Filter by location: GET /api/parties/?location_id=5
106+
- Combined: GET /api/parties/?location_id=5&sort_by=party_datetime&page_size=20
87107
"""
88-
# Get total count first
89-
total_records = await party_service.get_party_count()
90-
91-
# If page_size is None, return all parties
92-
if page_size is None:
93-
parties = await party_service.get_parties(skip=0, limit=None)
94-
return PaginatedPartiesResponse(
95-
items=parties,
96-
total_records=total_records,
97-
page_size=total_records,
98-
page_number=1,
99-
total_pages=1,
100-
)
101-
102-
# Calculate skip and limit for pagination
103-
skip = (page_number - 1) * page_size
104-
105-
# Get parties with pagination
106-
parties = await party_service.get_parties(skip=skip, limit=page_size)
107-
108-
# Calculate total pages (ceiling division)
109-
total_pages = (total_records + page_size - 1) // page_size if total_records > 0 else 0
110-
111-
return PaginatedPartiesResponse(
112-
items=parties,
113-
total_records=total_records,
114-
page_size=page_size,
108+
return await party_service.get_parties_paginated(
115109
page_number=page_number,
116-
total_pages=total_pages,
110+
page_size=page_size,
111+
sort_by=sort_by,
112+
sort_order=sort_order,
113+
location_id=location_id,
114+
contact_one_id=contact_one_id,
117115
)
118116

119117

0 commit comments

Comments
 (0)