diff --git a/.devcontainer/post_create.sh b/.devcontainer/post_create.sh index 6f92fd6..c1f5433 100755 --- a/.devcontainer/post_create.sh +++ b/.devcontainer/post_create.sh @@ -5,4 +5,5 @@ npm i --verbose cd ../backend python -m script.create_db +python -m script.create_test_db python -m script.reset_dev diff --git a/.github/workflows/backend-test.yml b/.github/workflows/backend-test.yml index 2763e80..ef65c69 100644 --- a/.github/workflows/backend-test.yml +++ b/.github/workflows/backend-test.yml @@ -2,7 +2,9 @@ name: Run Backend Tests on: pull_request: - branches: [main] + paths: + - "backend/**" + - ".github/workflows/backend-test.yml" push: branches: [main] @@ -10,6 +12,21 @@ jobs: backend-test: runs-on: ubuntu-latest + services: + postgres: + image: postgres:latest + env: + POSTGRES_USER: postgres + POSTGRES_PASSWORD: postgres + POSTGRES_DB: ocsl_test + ports: + - 5432:5432 + options: >- + --health-cmd pg_isready + --health-interval 10s + --health-timeout 5s + --health-retries 5 + steps: - name: Checkout code uses: actions/checkout@v4 @@ -32,6 +49,11 @@ jobs: PYTHONPATH=backend:backend/src pytest backend/test -v --tb=long --junitxml=backend/test-results/junit.xml env: GOOGLE_MAPS_API_KEY: "test_api_key_not_used" + POSTGRES_USER: postgres + POSTGRES_PASSWORD: postgres + POSTGRES_HOST: localhost + POSTGRES_PORT: 5432 + POSTGRES_DB: ocsl_test - name: Publish Test Results uses: EnricoMi/publish-unit-test-result-action@v2 diff --git a/backend/pytest.ini b/backend/pytest.ini index 055a309..66e8e3e 100644 --- a/backend/pytest.ini +++ b/backend/pytest.ini @@ -5,5 +5,6 @@ python_files = *_test.py python_classes = Test* python_functions = test_* asyncio_mode = auto -asyncio_default_fixture_loop_scope = function +asyncio_default_fixture_loop_scope = session +asyncio_default_test_loop_scope = session addopts = --color=yes diff --git a/backend/script/create_db.py b/backend/script/create_db.py index 5bc28c7..cf1babe 100644 --- a/backend/script/create_db.py +++ b/backend/script/create_db.py @@ -5,7 +5,13 @@ engine = create_engine(server_url(sync=True), isolation_level="AUTOCOMMIT") with engine.connect() as connection: - print("Creating database...") - connection.execute(text(f"CREATE DATABASE {env.POSTGRES_DATABASE}")) - -print("Database successfully created") + # Check if database already exists + result = connection.execute( + text(f"SELECT 1 FROM pg_database WHERE datname = '{env.POSTGRES_DATABASE}'") + ) + if result.fetchone(): + print(f"Database '{env.POSTGRES_DATABASE}' already exists") + else: + print(f"Creating database '{env.POSTGRES_DATABASE}'...") + connection.execute(text(f"CREATE DATABASE {env.POSTGRES_DATABASE}")) + print("Database successfully created") diff --git a/backend/script/create_test_db.py b/backend/script/create_test_db.py new file mode 100644 index 0000000..2846ff3 --- /dev/null +++ b/backend/script/create_test_db.py @@ -0,0 +1,16 @@ +from src.core.database import server_url +from sqlalchemy import create_engine, text + +engine = create_engine(server_url(sync=True), isolation_level="AUTOCOMMIT") + +with engine.connect() as connection: + # Check if test database already exists + result = connection.execute( + text("SELECT 1 FROM pg_database WHERE datname = 'ocsl_test'") + ) + if result.fetchone(): + print("Test database 'ocsl_test' already exists") + else: + print("Creating test database 'ocsl_test'...") + connection.execute(text("CREATE DATABASE ocsl_test")) + print("Test database successfully created") diff --git a/backend/src/core/exceptions.py b/backend/src/core/exceptions.py index 18e3628..044b68a 100644 --- a/backend/src/core/exceptions.py +++ b/backend/src/core/exceptions.py @@ -27,6 +27,11 @@ def __init__(self, detail: str): super().__init__(status_code=400, detail=detail) +class UnprocessableEntityException(HTTPException): + def __init__(self, detail: str): + super().__init__(status_code=422, detail=detail) + + class CredentialsException(HTTPException): def __init__(self): super().__init__( diff --git a/backend/src/modules/account/account_entity.py b/backend/src/modules/account/account_entity.py index c77cc40..cda3ffe 100644 --- a/backend/src/modules/account/account_entity.py +++ b/backend/src/modules/account/account_entity.py @@ -16,7 +16,7 @@ class AccountEntity(MappedAsDataclass, EntityBase): pid: Mapped[str] = mapped_column( String(9), CheckConstraint( - "length(pid) = 9", + "length(pid) = 9 AND pid ~ '^[0-9]{9}$'", name="check_pid_format", ), nullable=False, diff --git a/backend/src/modules/complaint/complaint_entity.py b/backend/src/modules/complaint/complaint_entity.py index b77a2c8..c5a1a62 100644 --- a/backend/src/modules/complaint/complaint_entity.py +++ b/backend/src/modules/complaint/complaint_entity.py @@ -17,7 +17,7 @@ class ComplaintEntity(MappedAsDataclass, EntityBase): location_id: Mapped[int] = mapped_column( Integer, ForeignKey("locations.id", ondelete="CASCADE"), nullable=False ) - complaint_datetime: Mapped[datetime] = mapped_column(DateTime, nullable=False) + complaint_datetime: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False) description: Mapped[str] = mapped_column(String, nullable=False, default="") # Relationships diff --git a/backend/src/modules/party/party_router.py b/backend/src/modules/party/party_router.py index afad34d..59e717c 100644 --- a/backend/src/modules/party/party_router.py +++ b/backend/src/modules/party/party_router.py @@ -1,6 +1,6 @@ -from datetime import datetime +from datetime import datetime, timezone -from fastapi import APIRouter, Depends, HTTPException, Query +from fastapi import APIRouter, Depends, Query from fastapi.responses import Response from src.core.authentication import ( authenticate_admin, @@ -9,7 +9,7 @@ authenticate_staff_or_admin, authenticate_user, ) -from src.core.exceptions import BadRequestException, ForbiddenException +from src.core.exceptions import BadRequestException, ForbiddenException, UnprocessableEntityException from src.modules.account.account_model import Account, AccountRole from src.modules.location.location_service import LocationService @@ -52,9 +52,7 @@ async def create_party( return await party_service.create_party_from_student_dto(party_data, user.id) elif isinstance(party_data, AdminCreatePartyDTO): if user.role != AccountRole.ADMIN: - raise ForbiddenException( - detail="Only admins can use the admin party creation endpoint" - ) + raise ForbiddenException(detail="Only admins can use the admin party creation endpoint") return await party_service.create_party_from_admin_dto(party_data) else: raise ForbiddenException(detail="Invalid request type") @@ -63,9 +61,7 @@ async def create_party( @party_router.get("/") async def list_parties( page_number: int = Query(1, ge=1, description="Page number (1-indexed)"), - page_size: int | None = Query( - None, ge=1, le=100, description="Items per page (default: all)" - ), + page_size: int | None = Query(None, ge=1, le=100, description="Items per page (default: all)"), party_service: PartyService = Depends(), _=Depends(authenticate_by_role("admin", "staff", "police")), ) -> PaginatedPartiesResponse: @@ -104,9 +100,7 @@ async def list_parties( parties = await party_service.get_parties(skip=skip, limit=page_size) # Calculate total pages (ceiling division) - total_pages = ( - (total_records + page_size - 1) // page_size if total_records > 0 else 0 - ) + total_pages = (total_records + page_size - 1) // page_size if total_records > 0 else 0 return PaginatedPartiesResponse( items=parties, @@ -120,8 +114,8 @@ async def list_parties( @party_router.get("/nearby") async def get_parties_nearby( place_id: str = Query(..., description="Google Maps place ID"), - start_date: str = Query(..., description="Start date (YYYY-MM-DD format)"), - end_date: str = Query(..., description="End date (YYYY-MM-DD format)"), + start_date: str = Query(..., pattern=r"^\d{4}-\d{2}-\d{2}$", description="Start date (YYYY-MM-DD format)"), + end_date: str = Query(..., pattern=r"^\d{4}-\d{2}-\d{2}$", description="End date (YYYY-MM-DD format)"), party_service: PartyService = Depends(), location_service: LocationService = Depends(), _=Depends(authenticate_police_or_admin), @@ -145,12 +139,12 @@ async def get_parties_nearby( """ # Parse date strings to datetime objects try: - start_datetime = datetime.strptime(start_date, "%Y-%m-%d") - end_datetime = datetime.strptime(end_date, "%Y-%m-%d") + start_datetime = datetime.strptime(start_date, "%Y-%m-%d").replace(tzinfo=timezone.utc) + end_datetime = datetime.strptime(end_date, "%Y-%m-%d").replace(tzinfo=timezone.utc) # Set end_datetime to end of day (23:59:59) end_datetime = end_datetime.replace(hour=23, minute=59, second=59) except ValueError as e: - raise BadRequestException(f"Invalid date format. Expected YYYY-MM-DD: {str(e)}") + raise UnprocessableEntityException(f"Invalid date format. Expected YYYY-MM-DD: {str(e)}") # Validate that start_date is not greater than end_date if start_datetime > end_datetime: @@ -172,8 +166,8 @@ async def get_parties_nearby( @party_router.get("/csv") async def get_parties_csv( - start_date: str = Query(..., description="Start date in YYYY-MM-DD format"), - end_date: str = Query(..., description="End date in YYYY-MM-DD format"), + start_date: str = Query(..., pattern=r"^\d{4}-\d{2}-\d{2}$", description="Start date in YYYY-MM-DD format"), + end_date: str = Query(..., pattern=r"^\d{4}-\d{2}-\d{2}$", description="End date in YYYY-MM-DD format"), party_service: PartyService = Depends(), _=Depends(authenticate_admin), ) -> Response: @@ -194,25 +188,15 @@ async def get_parties_csv( start_datetime = datetime.strptime(start_date, "%Y-%m-%d") end_datetime = datetime.strptime(end_date, "%Y-%m-%d") - end_datetime = end_datetime.replace( - hour=23, minute=59, second=59, microsecond=999999 - ) + end_datetime = end_datetime.replace(hour=23, minute=59, second=59, microsecond=999999) except ValueError: - raise HTTPException( - status_code=400, - detail="Invalid date format. Use YYYY-MM-DD format for dates.", - ) + raise UnprocessableEntityException("Invalid date format. Use YYYY-MM-DD format for dates.") # Validate that start_date is not greater than end_date if start_datetime > end_datetime: - raise HTTPException( - status_code=400, - detail="Start date must be less than or equal to end date", - ) + raise BadRequestException("Start date must be less than or equal to end date") - parties = await party_service.get_parties_by_date_range( - start_datetime, end_datetime - ) + parties = await party_service.get_parties_by_date_range(start_datetime, end_datetime) csv_content = await party_service.export_parties_to_csv(parties) return Response( @@ -247,14 +231,10 @@ async def update_party( raise ForbiddenException( detail="Only students can use the student party update endpoint" ) - return await party_service.update_party_from_student_dto( - party_id, party_data, user.id - ) + return await party_service.update_party_from_student_dto(party_id, party_data, user.id) elif isinstance(party_data, AdminCreatePartyDTO): if user.role != AccountRole.ADMIN: - raise ForbiddenException( - detail="Only admins can use the admin party update endpoint" - ) + raise ForbiddenException(detail="Only admins can use the admin party update endpoint") return await party_service.update_party_from_admin_dto(party_id, party_data) else: raise ForbiddenException(detail="Invalid request type") diff --git a/backend/test/conftest.py b/backend/test/conftest.py index 6b4dcd4..425bbe6 100644 --- a/backend/test/conftest.py +++ b/backend/test/conftest.py @@ -1,5 +1,7 @@ import os +from sqlalchemy import text + os.environ["GOOGLE_MAPS_API_KEY"] = "invalid_google_maps_api_key_for_tests" from typing import Any, AsyncGenerator, Callable @@ -12,8 +14,9 @@ import src.modules # Ensure all modules are imported so their entities are registered # noqa: F401 from httpx import ASGITransport, AsyncClient from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine +from sqlalchemy.ext.asyncio.engine import AsyncEngine from src.core.authentication import StringRole -from src.core.database import EntityBase, get_session +from src.core.database import EntityBase, database_url, get_session from src.main import app from src.modules.account.account_service import AccountService from src.modules.complaint.complaint_service import ComplaintService @@ -29,20 +32,45 @@ from test.modules.police.police_utils import PoliceTestUtils from test.modules.student.student_utils import StudentTestUtils -DATABASE_URL = "sqlite+aiosqlite:///:memory:" +DATABASE_URL = database_url("ocsl_test") # =================================== Database ====================================== -@pytest_asyncio.fixture(scope="function") -async def test_session(): +@pytest_asyncio.fixture(autouse=True, scope="session", loop_scope="session") +async def test_engine(): + """Create engine and tables once per test session.""" engine = create_async_engine(DATABASE_URL, echo=False) async with engine.begin() as conn: + await conn.run_sync(EntityBase.metadata.drop_all) await conn.run_sync(EntityBase.metadata.create_all) - TestAsyncSessionLocal = async_sessionmaker(engine, expire_on_commit=False, class_=AsyncSession) + yield engine + async with engine.begin() as conn: + await conn.run_sync(EntityBase.metadata.drop_all) + await engine.dispose() + + +@pytest_asyncio.fixture(scope="function") +async def test_session(test_engine: AsyncEngine): + """Create a new session and truncate all tables after each test.""" + TestAsyncSessionLocal = async_sessionmaker( + bind=test_engine, + expire_on_commit=False, + class_=AsyncSession, + ) + async with TestAsyncSessionLocal() as session: yield session - await engine.dispose() + + # Clean up: truncate all tables and reset sequences + async with test_engine.begin() as conn: + tables = [table.name for table in EntityBase.metadata.sorted_tables] + if tables: + # Disable foreign key checks, truncate, and reset sequences + await conn.execute(text("SET session_replication_role = 'replica';")) + for table in tables: + await conn.execute(text(f"TRUNCATE TABLE {table} RESTART IDENTITY CASCADE;")) + await conn.execute(text("SET session_replication_role = 'origin';")) # =================================== Clients ======================================= diff --git a/backend/test/modules/party/party_router_test.py b/backend/test/modules/party/party_router_test.py index 29b4f37..95498da 100644 --- a/backend/test/modules/party/party_router_test.py +++ b/backend/test/modules/party/party_router_test.py @@ -390,31 +390,22 @@ async def test_get_parties_nearby_with_date_range(self): assert data[0].id == party_valid.id @pytest.mark.asyncio - async def test_get_parties_nearby_validation_errors(self): + @pytest.mark.parametrize( + "params", + [ + {"start_date": "2024-01-01", "end_date": "2024-01-02"}, # missing place_id + {"place_id": "ChIJtest123", "end_date": "2024-01-02"}, # missing start_date + {"place_id": "ChIJtest123", "start_date": "2024-01-01"}, # missing end_date + {"place_id": "ChIJtest123", "start_date": "01-01-2024", "end_date": "2024-01-02"}, # MM-DD-YYYY + {"place_id": "ChIJtest123", "start_date": "2024-01-01", "end_date": "01/02/2024"}, # MM/DD/YYYY + {"place_id": "ChIJtest123", "start_date": "2024/01/01", "end_date": "2024-01-02"}, # slashes + {"place_id": "ChIJtest123", "start_date": "2024-1-1", "end_date": "2024-01-02"}, # no leading zeros + {"place_id": "ChIJtest123", "start_date": "2024-01-01", "end_date": "24-01-02"}, # 2-digit year + {"place_id": "ChIJtest123", "start_date": "not-a-date", "end_date": "2024-01-02"}, # invalid string + ], + ) + async def test_get_parties_nearby_validation_errors(self, params: dict[str, str]): """Test validation errors for nearby search.""" - now = datetime.now(timezone.utc) - - # Missing place_id - params = { - "start_date": now.strftime("%Y-%m-%d"), - "end_date": (now + timedelta(days=1)).strftime("%Y-%m-%d"), - } - response = await self.admin_client.get("/api/parties/nearby", params=params) - assert_res_validation_error(response) - - # Missing start_date - params = { - "place_id": "ChIJtest123", - "end_date": (now + timedelta(days=1)).strftime("%Y-%m-%d"), - } - response = await self.admin_client.get("/api/parties/nearby", params=params) - assert_res_validation_error(response) - - # Missing end_date - params = { - "place_id": "ChIJtest123", - "start_date": now.strftime("%Y-%m-%d"), - } response = await self.admin_client.get("/api/parties/nearby", params=params) assert_res_validation_error(response) @@ -470,3 +461,22 @@ async def test_get_parties_csv_with_data(self): # Verify party IDs are in CSV for party in parties: assert str(party.id) in csv_content + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "params", + [ + {"end_date": "2024-01-02"}, # missing start_date + {"start_date": "2024-01-01"}, # missing end_date + {"start_date": "01-01-2024", "end_date": "2024-01-02"}, # MM-DD-YYYY + {"start_date": "2024-01-01", "end_date": "01/02/2024"}, # MM/DD/YYYY + {"start_date": "2024/01/01", "end_date": "2024-01-02"}, # slashes + {"start_date": "2024-1-1", "end_date": "2024-01-02"}, # no leading zeros + {"start_date": "2024-01-01", "end_date": "24-01-02"}, # 2-digit year + {"start_date": "not-a-date", "end_date": "2024-01-02"}, # invalid string + ], + ) + async def test_get_parties_csv_validation_errors(self, params: dict[str, str]): + """Test validation errors for CSV export.""" + response = await self.admin_client.get("/api/parties/csv", params=params) + assert_res_validation_error(response)