Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .devcontainer/post_create.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ npm i --verbose

cd ../backend
python -m script.create_db
python -m script.create_test_db
python -m script.reset_dev
3 changes: 2 additions & 1 deletion backend/pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,6 @@ python_files = *_test.py
python_classes = Test*
python_functions = test_*
asyncio_mode = auto
asyncio_default_fixture_loop_scope = function
asyncio_default_fixture_loop_scope = session
asyncio_default_test_loop_scope = session
Comment on lines +8 to +9
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Necessary for async session fixture shenanigans

addopts = --color=yes
14 changes: 10 additions & 4 deletions backend/script/create_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,13 @@
engine = create_engine(server_url(sync=True), isolation_level="AUTOCOMMIT")

with engine.connect() as connection:
print("Creating database...")
connection.execute(text(f"CREATE DATABASE {env.POSTGRES_DATABASE}"))

print("Database successfully created")
# Check if database already exists
result = connection.execute(
text(f"SELECT 1 FROM pg_database WHERE datname = '{env.POSTGRES_DATABASE}'")
)
if result.fetchone():
print(f"Database '{env.POSTGRES_DATABASE}' already exists")
else:
print(f"Creating database '{env.POSTGRES_DATABASE}'...")
connection.execute(text(f"CREATE DATABASE {env.POSTGRES_DATABASE}"))
print("Database successfully created")
16 changes: 16 additions & 0 deletions backend/script/create_test_db.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from src.core.database import server_url
from sqlalchemy import create_engine, text

engine = create_engine(server_url(sync=True), isolation_level="AUTOCOMMIT")

with engine.connect() as connection:
# Check if test database already exists
result = connection.execute(
text("SELECT 1 FROM pg_database WHERE datname = 'ocsl_test'")
)
if result.fetchone():
print("Test database 'ocsl_test' already exists")
else:
print("Creating test database 'ocsl_test'...")
connection.execute(text("CREATE DATABASE ocsl_test"))
print("Test database successfully created")
5 changes: 5 additions & 0 deletions backend/src/core/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@ def __init__(self, detail: str):
super().__init__(status_code=400, detail=detail)


class UnprocessableEntityException(HTTPException):
def __init__(self, detail: str):
super().__init__(status_code=422, detail=detail)


class CredentialsException(HTTPException):
def __init__(self):
super().__init__(
Expand Down
2 changes: 1 addition & 1 deletion backend/src/modules/account/account_entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class AccountEntity(MappedAsDataclass, EntityBase):
pid: Mapped[str] = mapped_column(
String(9),
CheckConstraint(
"length(pid) = 9",
"length(pid) = 9 AND pid ~ '^[0-9]{9}$'",
name="check_pid_format",
),
nullable=False,
Expand Down
2 changes: 1 addition & 1 deletion backend/src/modules/complaint/complaint_entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class ComplaintEntity(MappedAsDataclass, EntityBase):
location_id: Mapped[int] = mapped_column(
Integer, ForeignKey("locations.id", ondelete="CASCADE"), nullable=False
)
complaint_datetime: Mapped[datetime] = mapped_column(DateTime, nullable=False)
complaint_datetime: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
description: Mapped[str] = mapped_column(String, nullable=False, default="")

# Relationships
Expand Down
58 changes: 19 additions & 39 deletions backend/src/modules/party/party_router.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from datetime import datetime
from datetime import datetime, timezone

from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi import APIRouter, Depends, Query
from fastapi.responses import Response
from src.core.authentication import (
authenticate_admin,
Expand All @@ -9,7 +9,7 @@
authenticate_staff_or_admin,
authenticate_user,
)
from src.core.exceptions import BadRequestException, ForbiddenException
from src.core.exceptions import BadRequestException, ForbiddenException, UnprocessableEntityException
from src.modules.account.account_model import Account, AccountRole
from src.modules.location.location_service import LocationService

Expand Down Expand Up @@ -52,9 +52,7 @@ async def create_party(
return await party_service.create_party_from_student_dto(party_data, user.id)
elif isinstance(party_data, AdminCreatePartyDTO):
if user.role != AccountRole.ADMIN:
raise ForbiddenException(
detail="Only admins can use the admin party creation endpoint"
)
raise ForbiddenException(detail="Only admins can use the admin party creation endpoint")
return await party_service.create_party_from_admin_dto(party_data)
else:
raise ForbiddenException(detail="Invalid request type")
Expand All @@ -63,9 +61,7 @@ async def create_party(
@party_router.get("/")
async def list_parties(
page_number: int = Query(1, ge=1, description="Page number (1-indexed)"),
page_size: int | None = Query(
None, ge=1, le=100, description="Items per page (default: all)"
),
page_size: int | None = Query(None, ge=1, le=100, description="Items per page (default: all)"),
party_service: PartyService = Depends(),
_=Depends(authenticate_by_role("admin", "staff", "police")),
) -> PaginatedPartiesResponse:
Expand Down Expand Up @@ -104,9 +100,7 @@ async def list_parties(
parties = await party_service.get_parties(skip=skip, limit=page_size)

# Calculate total pages (ceiling division)
total_pages = (
(total_records + page_size - 1) // page_size if total_records > 0 else 0
)
total_pages = (total_records + page_size - 1) // page_size if total_records > 0 else 0

return PaginatedPartiesResponse(
items=parties,
Expand All @@ -120,8 +114,8 @@ async def list_parties(
@party_router.get("/nearby")
async def get_parties_nearby(
place_id: str = Query(..., description="Google Maps place ID"),
start_date: str = Query(..., description="Start date (YYYY-MM-DD format)"),
end_date: str = Query(..., description="End date (YYYY-MM-DD format)"),
start_date: str = Query(..., pattern=r"^\d{4}-\d{2}-\d{2}$", description="Start date (YYYY-MM-DD format)"),
end_date: str = Query(..., pattern=r"^\d{4}-\d{2}-\d{2}$", description="End date (YYYY-MM-DD format)"),
party_service: PartyService = Depends(),
location_service: LocationService = Depends(),
_=Depends(authenticate_police_or_admin),
Expand All @@ -145,12 +139,12 @@ async def get_parties_nearby(
"""
# Parse date strings to datetime objects
try:
start_datetime = datetime.strptime(start_date, "%Y-%m-%d")
end_datetime = datetime.strptime(end_date, "%Y-%m-%d")
start_datetime = datetime.strptime(start_date, "%Y-%m-%d").replace(tzinfo=timezone.utc)
end_datetime = datetime.strptime(end_date, "%Y-%m-%d").replace(tzinfo=timezone.utc)
# Set end_datetime to end of day (23:59:59)
end_datetime = end_datetime.replace(hour=23, minute=59, second=59)
except ValueError as e:
raise BadRequestException(f"Invalid date format. Expected YYYY-MM-DD: {str(e)}")
raise UnprocessableEntityException(f"Invalid date format. Expected YYYY-MM-DD: {str(e)}")

# Validate that start_date is not greater than end_date
if start_datetime > end_datetime:
Expand All @@ -172,8 +166,8 @@ async def get_parties_nearby(

@party_router.get("/csv")
async def get_parties_csv(
start_date: str = Query(..., description="Start date in YYYY-MM-DD format"),
end_date: str = Query(..., description="End date in YYYY-MM-DD format"),
start_date: str = Query(..., pattern=r"^\d{4}-\d{2}-\d{2}$", description="Start date in YYYY-MM-DD format"),
end_date: str = Query(..., pattern=r"^\d{4}-\d{2}-\d{2}$", description="End date in YYYY-MM-DD format"),
party_service: PartyService = Depends(),
_=Depends(authenticate_admin),
) -> Response:
Expand All @@ -194,25 +188,15 @@ async def get_parties_csv(
start_datetime = datetime.strptime(start_date, "%Y-%m-%d")
end_datetime = datetime.strptime(end_date, "%Y-%m-%d")

end_datetime = end_datetime.replace(
hour=23, minute=59, second=59, microsecond=999999
)
end_datetime = end_datetime.replace(hour=23, minute=59, second=59, microsecond=999999)
except ValueError:
raise HTTPException(
status_code=400,
detail="Invalid date format. Use YYYY-MM-DD format for dates.",
)
raise UnprocessableEntityException("Invalid date format. Use YYYY-MM-DD format for dates.")

# Validate that start_date is not greater than end_date
if start_datetime > end_datetime:
raise HTTPException(
status_code=400,
detail="Start date must be less than or equal to end date",
)
raise BadRequestException("Start date must be less than or equal to end date")

parties = await party_service.get_parties_by_date_range(
start_datetime, end_datetime
)
parties = await party_service.get_parties_by_date_range(start_datetime, end_datetime)
csv_content = await party_service.export_parties_to_csv(parties)

return Response(
Expand Down Expand Up @@ -247,14 +231,10 @@ async def update_party(
raise ForbiddenException(
detail="Only students can use the student party update endpoint"
)
return await party_service.update_party_from_student_dto(
party_id, party_data, user.id
)
return await party_service.update_party_from_student_dto(party_id, party_data, user.id)
elif isinstance(party_data, AdminCreatePartyDTO):
if user.role != AccountRole.ADMIN:
raise ForbiddenException(
detail="Only admins can use the admin party update endpoint"
)
raise ForbiddenException(detail="Only admins can use the admin party update endpoint")
return await party_service.update_party_from_admin_dto(party_id, party_data)
else:
raise ForbiddenException(detail="Invalid request type")
Expand Down
40 changes: 34 additions & 6 deletions backend/test/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import os

from sqlalchemy import text

os.environ["GOOGLE_MAPS_API_KEY"] = "invalid_google_maps_api_key_for_tests"

from typing import Any, AsyncGenerator, Callable
Expand All @@ -12,8 +14,9 @@
import src.modules # Ensure all modules are imported so their entities are registered # noqa: F401
from httpx import ASGITransport, AsyncClient
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.ext.asyncio.engine import AsyncEngine
from src.core.authentication import StringRole
from src.core.database import EntityBase, get_session
from src.core.database import EntityBase, database_url, get_session
from src.main import app
from src.modules.account.account_service import AccountService
from src.modules.complaint.complaint_service import ComplaintService
Expand All @@ -29,20 +32,45 @@
from test.modules.police.police_utils import PoliceTestUtils
from test.modules.student.student_utils import StudentTestUtils

DATABASE_URL = "sqlite+aiosqlite:///:memory:"
DATABASE_URL = database_url("ocsl_test")

# =================================== Database ======================================


@pytest_asyncio.fixture(scope="function")
async def test_session():
@pytest_asyncio.fixture(autouse=True, scope="session", loop_scope="session")
async def test_engine():
"""Create engine and tables once per test session."""
engine = create_async_engine(DATABASE_URL, echo=False)
async with engine.begin() as conn:
await conn.run_sync(EntityBase.metadata.drop_all)
await conn.run_sync(EntityBase.metadata.create_all)
TestAsyncSessionLocal = async_sessionmaker(engine, expire_on_commit=False, class_=AsyncSession)
yield engine
async with engine.begin() as conn:
await conn.run_sync(EntityBase.metadata.drop_all)
await engine.dispose()


@pytest_asyncio.fixture(scope="function")
async def test_session(test_engine: AsyncEngine):
"""Create a new session and truncate all tables after each test."""
TestAsyncSessionLocal = async_sessionmaker(
bind=test_engine,
expire_on_commit=False,
class_=AsyncSession,
)

async with TestAsyncSessionLocal() as session:
yield session
await engine.dispose()

# Clean up: truncate all tables and reset sequences
async with test_engine.begin() as conn:
tables = [table.name for table in EntityBase.metadata.sorted_tables]
if tables:
# Disable foreign key checks, truncate, and reset sequences
await conn.execute(text("SET session_replication_role = 'replica';"))
for table in tables:
await conn.execute(text(f"TRUNCATE TABLE {table} RESTART IDENTITY CASCADE;"))
await conn.execute(text("SET session_replication_role = 'origin';"))
Comment on lines +65 to +73
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Truncate instead of drop all tables so that tests don't take a million years



# =================================== Clients =======================================
Expand Down
58 changes: 34 additions & 24 deletions backend/test/modules/party/party_router_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,31 +390,22 @@ async def test_get_parties_nearby_with_date_range(self):
assert data[0].id == party_valid.id

@pytest.mark.asyncio
async def test_get_parties_nearby_validation_errors(self):
@pytest.mark.parametrize(
"params",
[
{"start_date": "2024-01-01", "end_date": "2024-01-02"}, # missing place_id
{"place_id": "ChIJtest123", "end_date": "2024-01-02"}, # missing start_date
{"place_id": "ChIJtest123", "start_date": "2024-01-01"}, # missing end_date
{"place_id": "ChIJtest123", "start_date": "01-01-2024", "end_date": "2024-01-02"}, # MM-DD-YYYY
{"place_id": "ChIJtest123", "start_date": "2024-01-01", "end_date": "01/02/2024"}, # MM/DD/YYYY
{"place_id": "ChIJtest123", "start_date": "2024/01/01", "end_date": "2024-01-02"}, # slashes
{"place_id": "ChIJtest123", "start_date": "2024-1-1", "end_date": "2024-01-02"}, # no leading zeros
{"place_id": "ChIJtest123", "start_date": "2024-01-01", "end_date": "24-01-02"}, # 2-digit year
{"place_id": "ChIJtest123", "start_date": "not-a-date", "end_date": "2024-01-02"}, # invalid string
],
)
async def test_get_parties_nearby_validation_errors(self, params: dict[str, str]):
"""Test validation errors for nearby search."""
now = datetime.now(timezone.utc)

# Missing place_id
params = {
"start_date": now.strftime("%Y-%m-%d"),
"end_date": (now + timedelta(days=1)).strftime("%Y-%m-%d"),
}
response = await self.admin_client.get("/api/parties/nearby", params=params)
assert_res_validation_error(response)

# Missing start_date
params = {
"place_id": "ChIJtest123",
"end_date": (now + timedelta(days=1)).strftime("%Y-%m-%d"),
}
response = await self.admin_client.get("/api/parties/nearby", params=params)
assert_res_validation_error(response)

# Missing end_date
params = {
"place_id": "ChIJtest123",
"start_date": now.strftime("%Y-%m-%d"),
}
response = await self.admin_client.get("/api/parties/nearby", params=params)
assert_res_validation_error(response)

Expand Down Expand Up @@ -470,3 +461,22 @@ async def test_get_parties_csv_with_data(self):
# Verify party IDs are in CSV
for party in parties:
assert str(party.id) in csv_content

@pytest.mark.asyncio
@pytest.mark.parametrize(
"params",
[
{"end_date": "2024-01-02"}, # missing start_date
{"start_date": "2024-01-01"}, # missing end_date
{"start_date": "01-01-2024", "end_date": "2024-01-02"}, # MM-DD-YYYY
{"start_date": "2024-01-01", "end_date": "01/02/2024"}, # MM/DD/YYYY
{"start_date": "2024/01/01", "end_date": "2024-01-02"}, # slashes
{"start_date": "2024-1-1", "end_date": "2024-01-02"}, # no leading zeros
{"start_date": "2024-01-01", "end_date": "24-01-02"}, # 2-digit year
{"start_date": "not-a-date", "end_date": "2024-01-02"}, # invalid string
],
)
async def test_get_parties_csv_validation_errors(self, params: dict[str, str]):
"""Test validation errors for CSV export."""
response = await self.admin_client.get("/api/parties/csv", params=params)
assert_res_validation_error(response)