Skip to content

Commit 4cc45e5

Browse files
committed
Add tests
1 parent 75627ed commit 4cc45e5

File tree

11 files changed

+823
-47
lines changed

11 files changed

+823
-47
lines changed

backend/python/app/routers/driver_routes.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import logging
2+
from typing import Any
23
from uuid import UUID
34

45
from fastapi import APIRouter, Depends, HTTPException, Query, status
@@ -15,6 +16,25 @@
1516
router = APIRouter(prefix="/drivers", tags=["drivers"])
1617

1718

19+
def driver_to_driver_read(driver: Any) -> DriverRead:
20+
"""Convert a Driver model instance to DriverRead."""
21+
return DriverRead(
22+
driver_id=driver.driver_id,
23+
user_id=driver.user_id,
24+
phone=driver.phone,
25+
license_plate=driver.license_plate,
26+
car_make_model=driver.car_make_model,
27+
active=driver.active,
28+
notes=driver.notes,
29+
address=driver.address,
30+
# User fields
31+
auth_id=driver.user.auth_id,
32+
name=driver.user.name,
33+
email=driver.user.email,
34+
role=driver.user.role,
35+
)
36+
37+
1838
@router.get("/", response_model=list[DriverRead])
1939
async def get_drivers(
2040
session: AsyncSession = Depends(get_session),
@@ -39,7 +59,7 @@ async def get_drivers(
3959
status_code=status.HTTP_404_NOT_FOUND,
4060
detail=f"Driver with id {driver_id} not found",
4161
)
42-
return [DriverRead.model_validate(driver)]
62+
return [driver_to_driver_read(driver)]
4363

4464
elif email:
4565
driver = await driver_service.get_driver_by_email(session, email)
@@ -48,11 +68,11 @@ async def get_drivers(
4868
status_code=status.HTTP_404_NOT_FOUND,
4969
detail=f"Driver with email {email} not found",
5070
)
51-
return [DriverRead.model_validate(driver)]
71+
return [driver_to_driver_read(driver)]
5272

5373
else:
5474
drivers = await driver_service.get_drivers(session)
55-
return [DriverRead.model_validate(driver) for driver in drivers]
75+
return [driver_to_driver_read(driver) for driver in drivers]
5676

5777
except HTTPException:
5878
raise
@@ -77,7 +97,7 @@ async def get_driver(
7797
status_code=status.HTTP_404_NOT_FOUND,
7898
detail=f"Driver with id {driver_id} not found",
7999
)
80-
return DriverRead.model_validate(driver)
100+
return driver_to_driver_read(driver)
81101

82102

83103
@router.post("/", response_model=DriverRead, status_code=status.HTTP_201_CREATED)
@@ -91,7 +111,7 @@ async def create_driver(
91111
"""
92112
try:
93113
created_driver = await driver_service.create_driver(session, driver)
94-
return DriverRead.model_validate(created_driver)
114+
return driver_to_driver_read(created_driver)
95115
except Exception as e:
96116
raise HTTPException(
97117
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)
@@ -116,7 +136,7 @@ async def update_driver(
116136
status_code=status.HTTP_404_NOT_FOUND,
117137
detail=f"Driver with id {driver_id} not found",
118138
)
119-
return DriverRead.model_validate(updated_driver)
139+
return driver_to_driver_read(updated_driver)
120140

121141

122142
@router.delete("/{driver_id}", status_code=status.HTTP_204_NO_CONTENT)

backend/python/app/routers/route_group_routes.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,9 @@ async def get_route_groups(
3333
)
3434
result = []
3535
for route_group in route_groups:
36-
data = RouteGroupRead.model_validate(route_group).model_dump()
37-
membership_count = len(route_group.route_group_memberships)
38-
data["num_routes"] = membership_count
36+
data = RouteGroupRead.model_validate(
37+
route_group, from_attributes=True
38+
).model_dump()
3939
if include_routes:
4040
data["routes"] = [
4141
{
@@ -76,15 +76,7 @@ async def create_route_group(
7676
created_route_group = await route_group_service.create_route_group(
7777
session, route_group
7878
)
79-
return RouteGroupRead(
80-
route_group_id=created_route_group.route_group_id,
81-
name=created_route_group.name,
82-
notes=created_route_group.notes,
83-
drive_date=created_route_group.drive_date,
84-
created_at=created_route_group.created_at,
85-
updated_at=created_route_group.updated_at,
86-
num_routes=0,
87-
)
79+
return RouteGroupRead.model_validate(created_route_group, from_attributes=True)
8880
except Exception as e:
8981
raise HTTPException(
9082
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)
@@ -110,7 +102,9 @@ async def update_route_group(
110102
status_code=status.HTTP_404_NOT_FOUND,
111103
detail=f"RouteGroup with id {route_group_id} not found",
112104
)
113-
return RouteGroupRead.model_validate(updated_route_group)
105+
return RouteGroupRead.model_validate(updated_route_group, from_attributes=True)
106+
except HTTPException:
107+
raise
114108
except Exception as e:
115109
raise HTTPException(
116110
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)
@@ -133,6 +127,8 @@ async def delete_route_group(
133127
status_code=status.HTTP_404_NOT_FOUND,
134128
detail=f"RouteGroup with id {route_group_id} not found",
135129
)
130+
except HTTPException:
131+
raise
136132
except Exception as e:
137133
raise HTTPException(
138134
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)

backend/python/app/seed_database.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -405,7 +405,8 @@ def main() -> None:
405405

406406
# Create locations from CSV
407407
print("Creating locations from CSV...")
408-
csv_path = "app/data/locations.csv"
408+
# Allow CSV path to be overridden via environment variable for testing
409+
csv_path = os.getenv("LOCATIONS_CSV_PATH", "app/data/locations.csv")
409410
locations_created = 0
410411

411412
non_school_groups = [
@@ -507,12 +508,22 @@ def main() -> None:
507508
print("Creating drivers...")
508509
num_drivers = max(routes_created, MIN_DRIVERS)
509510
drivers_created = 0
511+
used_emails: set[str] = set()
510512

511513
for _ in range(num_drivers):
512514
# Create a single driver with fake data
515+
# Ensure unique email
516+
email = fake.email()
517+
max_email_attempts = 100
518+
email_attempts = 0
519+
while email in used_emails and email_attempts < max_email_attempts:
520+
email = fake.email()
521+
email_attempts += 1
522+
used_emails.add(email)
523+
513524
user = User(
514525
name=fake.name(),
515-
email=fake.email(),
526+
email=email,
516527
auth_id=f"seed_driver_{uuid.uuid4().hex[:8]}",
517528
)
518529
set_timestamps(user)

backend/python/app/services/implementations/driver_service.py

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from uuid import UUID
33

44
from sqlalchemy.ext.asyncio import AsyncSession
5+
from sqlalchemy.orm import selectinload
56
from sqlmodel import select
67

78
from app.models.driver import Driver, DriverCreate, DriverUpdate
@@ -19,7 +20,11 @@ async def get_driver_by_id(
1920
) -> Driver | None:
2021
"""Get driver by ID - returns SQLModel instance"""
2122
try:
22-
statement = select(Driver).where(Driver.driver_id == driver_id)
23+
statement = (
24+
select(Driver)
25+
.options(selectinload(Driver.user)) # type: ignore[arg-type]
26+
.where(Driver.driver_id == driver_id)
27+
)
2328
result = await session.execute(statement)
2429
driver = result.scalars().first()
2530

@@ -37,7 +42,12 @@ async def get_driver_by_email(
3742
) -> Driver | None:
3843
"""Get driver by email using Firebase"""
3944
try:
40-
statement = select(Driver).join(Driver.user).where(User.email == email) # type: ignore[arg-type]
45+
statement = (
46+
select(Driver)
47+
.options(selectinload(Driver.user)) # type: ignore[arg-type]
48+
.join(Driver.user) # type: ignore[arg-type]
49+
.where(User.email == email)
50+
)
4151
result = await session.execute(statement)
4252
driver = result.scalars().first()
4353

@@ -55,7 +65,11 @@ async def get_driver_by_auth_id(
5565
) -> Driver | None:
5666
"""Get driver by auth_id"""
5767
try:
58-
statement = select(Driver).join(Driver.user).where(User.auth_id == auth_id) # type: ignore[arg-type]
68+
statement = (
69+
select(Driver)
70+
.join(Driver.user) # type: ignore[arg-type]
71+
.where(User.auth_id == auth_id)
72+
)
5973
result = await session.execute(statement)
6074
driver = result.scalars().first()
6175

@@ -71,7 +85,7 @@ async def get_driver_by_auth_id(
7185
async def get_drivers(self, session: AsyncSession) -> list[Driver]:
7286
"""Get all drivers - returns SQLModel instances"""
7387
try:
74-
statement = select(Driver)
88+
statement = select(Driver).options(selectinload(Driver.user)) # type: ignore[arg-type]
7589
result = await session.execute(statement)
7690
return list(result.scalars().all())
7791
except Exception as e:
@@ -98,7 +112,7 @@ async def create_driver(
98112
try:
99113
session.add(driver)
100114
await session.commit()
101-
await session.refresh(driver)
115+
await session.refresh(driver, attribute_names=["user"])
102116
return driver
103117

104118
except Exception as db_error:
@@ -113,7 +127,11 @@ async def update_driver_by_id(
113127
) -> Driver | None:
114128
"""Update driver by ID"""
115129
try:
116-
statement = select(Driver).where(Driver.driver_id == driver_id)
130+
statement = (
131+
select(Driver)
132+
.options(selectinload(Driver.user)) # type: ignore[arg-type]
133+
.where(Driver.driver_id == driver_id)
134+
)
117135
result = await session.execute(statement)
118136
driver = result.scalars().first()
119137

@@ -144,7 +162,7 @@ async def update_driver_by_id(
144162
driver.notes = driver_data.notes
145163

146164
await session.commit()
147-
await session.refresh(driver)
165+
await session.refresh(driver, attribute_names=["user"])
148166
return driver
149167

150168
except Exception as e:
@@ -201,7 +219,11 @@ async def get_driver_id_by_auth_id(
201219
) -> UUID | None:
202220
"""Get driver_id by auth_id"""
203221
try:
204-
statement = select(Driver).join(Driver.user).where(User.auth_id == auth_id) # type: ignore[arg-type]
222+
statement = (
223+
select(Driver)
224+
.join(Driver.user) # type: ignore[arg-type]
225+
.where(User.auth_id == auth_id)
226+
)
205227
result = await session.execute(statement)
206228
driver = result.scalars().first()
207229

@@ -217,7 +239,11 @@ async def get_driver_id_by_auth_id(
217239
async def delete_driver_by_email(self, session: AsyncSession, email: str) -> None:
218240
"""Delete driver by email"""
219241
try:
220-
statement = select(Driver).join(Driver.user).where(User.email == email) # type: ignore[arg-type]
242+
statement = (
243+
select(Driver)
244+
.join(Driver.user) # type: ignore[arg-type]
245+
.where(User.email == email)
246+
)
221247
result = await session.execute(statement)
222248
driver = result.scalars().first()
223249

backend/python/app/services/implementations/route_group_service.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ async def create_route_group(
2121
route_group = RouteGroup.model_validate(route_group_data)
2222
session.add(route_group)
2323
await session.commit()
24-
await session.refresh(route_group)
24+
await session.refresh(route_group, ["route_group_memberships"])
2525
return route_group
2626

2727
async def update_route_group(
@@ -46,7 +46,7 @@ async def update_route_group(
4646
setattr(route_group, field, value)
4747

4848
await session.commit()
49-
await session.refresh(route_group)
49+
await session.refresh(route_group, ["route_group_memberships"])
5050

5151
return route_group
5252

backend/python/app/services/jobs/email_reminder_jobs.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from app.models.driver import Driver
1313
from app.models.driver_assignment import DriverAssignment
1414
from app.models.route import Route
15+
from app.models.user import User
1516
from app.services.implementations.email_service import EmailService
1617

1718

@@ -40,21 +41,21 @@ async def process_daily_reminder_emails() -> None:
4041
# Get all drivers assigned to routes tomorrow
4142
statement = (
4243
select(
43-
Driver.email,
44+
User.email,
4445
DriverAssignment.time,
4546
Route.length,
4647
)
4748
.join(Route, DriverAssignment.route_id == Route.route_id) # type: ignore[arg-type]
4849
.join(Driver, DriverAssignment.driver_id == Driver.driver_id) # type: ignore[arg-type]
50+
.join(User, Driver.user_id == User.user_id) # type: ignore[arg-type]
4951
.where(
5052
and_(
51-
Driver.email is not None, # type: ignore[arg-type]
5253
DriverAssignment.time >= start_of_day, # type: ignore[arg-type]
5354
DriverAssignment.time <= end_of_day, # type: ignore[arg-type]
5455
DriverAssignment.completed.is_(False), # type: ignore[attr-defined]
5556
)
5657
)
57-
.order_by(Driver.email)
58+
.order_by(User.email)
5859
)
5960

6061
result = await session.execute(statement)

backend/python/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ numpy==1.26.4
7474
scikit-learn==1.5.0
7575
scikit-learn-extra==0.2.0
7676
seaborn==0.13.2
77-
matplotlib==3.10.0
77+
matplotlib>=3.10.8
7878
pandas==2.3.3
7979
pandas-stubs
8080
types-seaborn

backend/python/tests/conftest.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ async def test_db_engine() -> AsyncGenerator[Any, None]:
3636

3737
# Use PostgreSQL for testing to support ARRAY types
3838
database_url = os.getenv(
39-
"TEST_DATABASE_URL", "postgresql+asyncpg://postgres:postgres@db:5432/f4k"
39+
"TEST_DATABASE_URL", "postgresql+asyncpg://postgres:postgres@db:5432/f4k_test"
4040
)
4141

4242
engine = create_async_engine(
@@ -105,11 +105,8 @@ def client(test_session: AsyncSession) -> Generator[TestClient, None, None]:
105105
app = create_app()
106106

107107
# Override the database session dependency
108-
def override_get_session() -> AsyncGenerator[AsyncSession, None]:
109-
async def _get_session() -> AsyncGenerator[AsyncSession, None]:
110-
yield test_session
111-
112-
return _get_session()
108+
async def override_get_session() -> AsyncGenerator[AsyncSession, None]:
109+
yield test_session
113110

114111
app.dependency_overrides[get_session] = override_get_session
115112

@@ -125,11 +122,8 @@ async def async_client(
125122
app = create_app()
126123

127124
# Override the database session dependency
128-
def override_get_session() -> AsyncGenerator[AsyncSession, None]:
129-
async def _get_session() -> AsyncGenerator[AsyncSession, None]:
130-
yield test_session
131-
132-
return _get_session()
125+
async def override_get_session() -> AsyncGenerator[AsyncSession, None]:
126+
yield test_session
133127

134128
app.dependency_overrides[get_session] = override_get_session
135129

0 commit comments

Comments
 (0)