Skip to content

Commit b004c2f

Browse files
committed
Add mypy
1 parent 29237d5 commit b004c2f

17 files changed

+89
-61
lines changed

.pre-commit-config.yaml

+6-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
repos:
22
- repo: https://github.com/astral-sh/ruff-pre-commit
3-
rev: v0.7.1
3+
rev: v0.7.2
44
hooks:
55
# Run the linter.
66
- id: ruff
@@ -9,3 +9,8 @@ repos:
99
# Run the formatter.
1010
- id: ruff-format
1111
types_or: [ python, pyi ]
12+
- repo: https://github.com/pre-commit/mirrors-mypy
13+
rev: 'v1.13.0' # Use the sha / tag you want to point at
14+
hooks:
15+
- id: mypy
16+
args: [--strict, --ignore-missing-imports]

backend/app/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
from .app import app
22

3-
__all__ = [app]
3+
__all__ = ["app"]

backend/app/app.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from typing import AsyncGenerator
12
from .tasks.routers import router as task_router
23
from .users.routers import router as user_router
34
from fastapi import FastAPI
@@ -9,7 +10,7 @@
910

1011

1112
@asynccontextmanager
12-
async def lifespan(app: FastAPI):
13+
async def lifespan(app: FastAPI) -> AsyncGenerator[None]:
1314
# Load the ML model
1415
create_db_and_tables()
1516
yield

backend/app/auth.py

+14-8
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
import jwt
2-
from typing import Annotated
2+
from typing import Dict, Annotated, cast, Any
33
from passlib.hash import pbkdf2_sha256
44
from jwt.exceptions import InvalidTokenError
5+
56
from datetime import datetime, timedelta, timezone
67

78
from fastapi import Depends, HTTPException, status
89
from fastapi.security import OAuth2PasswordBearer
10+
from sqlalchemy.sql import ColumnElement
911
from sqlalchemy.orm import Session
1012
from sqlalchemy import select
1113

@@ -20,26 +22,28 @@
2022
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/users/token")
2123

2224

23-
def create_access_token(data: dict, expires_delta: timedelta | None = None):
25+
def create_access_token(
26+
data: Dict[str, str | datetime], expires_delta: timedelta | None = None
27+
) -> str:
2428
to_encode = data.copy()
2529
if expires_delta:
2630
expire = datetime.now(timezone.utc) + expires_delta
2731
else:
2832
expire = datetime.now(timezone.utc) + timedelta(minutes=15)
2933
to_encode.update({"exp": expire})
30-
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
34+
encoded_jwt = str(jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM))
3135
return encoded_jwt
3236

3337

3438
def hash_password(password: str) -> str:
35-
return pbkdf2_sha256.hash(password)
39+
return str(pbkdf2_sha256.hash(password))
3640

3741

3842
async def get_current_user(
3943
*,
4044
session: Session = Depends(get_session),
4145
token: Annotated[str, Depends(oauth2_scheme)],
42-
):
46+
) -> Any:
4347
credentials_exception = HTTPException(
4448
status_code=status.HTTP_401_UNAUTHORIZED,
4549
detail="Could not validate credentials",
@@ -54,16 +58,18 @@ async def get_current_user(
5458
except InvalidTokenError:
5559
raise credentials_exception
5660

57-
stmt = select(User).where(User.username == token_data.username)
58-
db_user = session.exec(stmt).first()[0]
61+
stmt = select(User).where(
62+
cast("ColumnElement[bool]", User.username == token_data.username)
63+
)
64+
db_user = session.execute(stmt).one()
5965
if not db_user:
6066
raise credentials_exception
6167
return db_user
6268

6369

6470
async def get_current_active_user(
6571
current_user: Annotated[User, Depends(get_current_user)],
66-
):
72+
) -> User:
6773
# if current_user.disabled:
6874
# raise HTTPException(status_code=400, detail="Inactive user")
6975
return current_user

backend/app/db.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,23 @@
11
from sqlalchemy_utils import database_exists, create_database
2-
from sqlmodel import Session, create_engine, SQLModel
2+
from sqlmodel import Session, create_engine, SQLModel, Engine
33
from .settings import settings
44

55

6-
def get_engine(url=settings.database_url):
6+
def get_engine(url: str = settings.database_url) -> Engine:
77
engine = create_engine(url)
88
if not database_exists(engine.url):
99
create_database(engine.url)
1010
return engine
1111

1212

13-
def get_session():
13+
def get_session() -> Session:
1414
with Session(get_engine()) as session:
1515
yield session
1616

1717

18-
def create_db_and_tables(engine=get_engine()):
18+
def create_db_and_tables(engine: Engine = get_engine()) -> None:
1919
SQLModel.metadata.create_all(engine)
2020

2121

22-
def drop_db_and_tables(engine=get_engine()):
22+
def drop_db_and_tables(engine: Engine = get_engine()) -> None:
2323
SQLModel.metadata.drop_all(engine)

backend/app/settings.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from pydantic_settings import BaseSettings
22

33

4-
class Settings(BaseSettings):
4+
class Settings(BaseSettings): # type: ignore
55
# TODO: Remove this hardcoded default from here
66
database_url: str = (
77
# "postgresql+psycopg2://myuser:mypassword@db:5432/mydatabase"

backend/app/tasks/models.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@ class TaskStatus(str, enum.Enum):
1111
created = "created"
1212

1313

14-
class TaskBase(SQLModel):
14+
class TaskBase(SQLModel): # type: ignore
1515
title: str
16-
description: str
16+
description: str | None
1717
due_date: datetime | None = Field(nullable=True, default=None)
1818
status: TaskStatus = Field(
1919
sa_column=Column(Enum(TaskStatus)), default=TaskStatus.created
@@ -22,6 +22,6 @@ class TaskBase(SQLModel):
2222

2323

2424
class Task(TaskBase, table=True):
25-
model_config = ConfigDict(validate_assignment=True)
25+
model_config = ConfigDict(validate_assignment=True) # type: ignore
2626

2727
id: int | None = Field(default=None, primary_key=True, index=True)

backend/app/tasks/repository.py

+16-8
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from typing import List, Optional
1+
from typing import List, cast
2+
from sqlalchemy.sql import ColumnElement
23
from sqlalchemy.orm import Session
34
from sqlalchemy import select, update
45
from dataclasses import dataclass
@@ -9,7 +10,7 @@
910
@dataclass
1011
class TaskRepository:
1112
session: Session
12-
current_user_id: int
13+
current_user_id: int | None
1314

1415
def create(self, data: TaskInput) -> TaskOutput:
1516
task = Task(**data.model_dump(exclude_none=True))
@@ -18,25 +19,32 @@ def create(self, data: TaskInput) -> TaskOutput:
1819
self.session.refresh(task)
1920
return TaskOutput(**dict(task))
2021

21-
def list(self, offset: int, limit: int) -> List[Optional[TaskOutput]]:
22+
def list(self, offset: int, limit: int) -> List[TaskOutput]:
23+
# import pdb; pdb.set_trace()
2224
stmt = (
2325
select(Task)
2426
.where(
25-
Task.status != TaskStatus.deleted,
26-
Task.user_id == self.current_user_id,
27+
cast("ColumnElement[bool]", Task.status != TaskStatus.deleted)
28+
)
29+
.where(
30+
cast(
31+
"ColumnElement[bool]", Task.user_id == self.current_user_id
32+
)
2733
)
2834
.offset(offset)
2935
.limit(limit)
3036
)
3137
tasks = self.session.execute(stmt).all()
3238
return [TaskOutput(**dict(task[0])) for task in tasks]
3339

34-
def get_by_id(self, id: int) -> Task:
35-
return self.session.get(Task, id)
40+
def get_by_id(self, id: int) -> Task | None:
41+
return self.session.get(Task, id) # type: ignore
3642

3743
def update(self, task: Task) -> TaskOutput:
3844
self.session.execute(
39-
update(Task).where(Task.id == task.id).values(**dict(task))
45+
update(Task)
46+
.where(cast("ColumnElement[bool]", Task.id == task.id))
47+
.values(**dict(task))
4048
)
4149
self.session.commit()
4250
self.session.refresh(task)

backend/app/tasks/routers.py

+10-10
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Annotated
1+
from typing import List, Annotated
22
from fastapi import Depends, Query, APIRouter
33
from sqlalchemy.orm import Session
44
from ..users.models import User
@@ -11,49 +11,49 @@
1111
router = APIRouter(prefix="/tasks")
1212

1313

14-
@router.get("/")
14+
@router.get("/") # type: ignore
1515
def read_tasks(
1616
*,
1717
current_user: Annotated[User, Depends(get_current_active_user)],
1818
session: Session = Depends(get_session),
1919
offset: int = 0,
2020
limit: int = Query(default=100, le=100),
21-
):
21+
) -> List[TaskOutput]:
2222
return TaskService(session, current_user.id).list(offset, limit)
2323

2424

25-
@router.post("/", response_model=TaskOutput)
25+
@router.post("/", response_model=TaskOutput) # type: ignore
2626
def create_task(
2727
*,
2828
current_user: Annotated[User, Depends(get_current_active_user)],
2929
session: Session = Depends(get_session),
3030
data: TaskInput,
31-
):
31+
) -> TaskOutput:
3232
return TaskService(session, current_user.id).create(data)
3333

3434

35-
@router.get("/{task_id}")
35+
@router.get("/{task_id}") # type: ignore
3636
def read_task(
3737
*,
3838
session: Session = Depends(get_session),
3939
current_user: Annotated[User, Depends(get_current_active_user)],
4040
task_id: int,
41-
):
41+
) -> TaskOutput:
4242
return TaskService(session, current_user.id).read(task_id)
4343

4444

45-
@router.patch("/{task_id}")
45+
@router.patch("/{task_id}") # type: ignore
4646
def update_task(
4747
*,
4848
session: Session = Depends(get_session),
4949
current_user: Annotated[User, Depends(get_current_active_user)],
5050
task_id: int,
5151
task: TaskInput,
52-
):
52+
) -> TaskOutput:
5353
return TaskService(session, current_user.id).update(task_id, task)
5454

5555

56-
@router.delete("/{task_id}")
56+
@router.delete("/{task_id}") # type: ignore
5757
def delete_task(
5858
*,
5959
current_user: Annotated[User, Depends(get_current_active_user)],

backend/app/tasks/schema.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,15 @@
44
from .models import TaskStatus
55

66

7-
class TaskInput(BaseModel):
7+
class TaskInput(BaseModel): # type: ignore
88
title: str = Field(min_length=1, max_length=50)
99
description: str | None = None
1010
user_id: int | None = None
1111
due_date: datetime.datetime | None = None
1212
status: TaskStatus = TaskStatus.created
1313

1414

15-
class TaskOutput(BaseModel):
15+
class TaskOutput(BaseModel): # type: ignore
1616
id: int
1717
title: str
1818
description: Optional[str] = ""

backend/app/tasks/service.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@
1010
@dataclass
1111
class TaskService:
1212
session: Session
13-
current_user_id: int
13+
current_user_id: int | None
1414

15-
def __post_init__(self):
15+
def __post_init__(self) -> None:
1616
self.repository: TaskRepository = TaskRepository(
1717
self.session, self.current_user_id
1818
)
@@ -49,4 +49,4 @@ def delete(self, task_id: int) -> TaskOutput:
4949

5050
task.status = TaskStatus.deleted
5151
self.repository.update(task)
52-
return task
52+
return TaskOutput(**dict(task))

backend/app/users/models.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from pydantic import BaseModel, ConfigDict
44

55

6-
class User(SQLModel, table=True):
6+
class User(SQLModel, table=True): # type: ignore
77
id: int | None = Field(
88
default=None,
99
sa_column=Column(
@@ -14,17 +14,17 @@ class User(SQLModel, table=True):
1414
hashed_password: str
1515

1616

17-
class UserCreate(BaseModel):
17+
class UserCreate(BaseModel): # type: ignore
1818
model_config = ConfigDict(validate_assignment=True)
1919

2020
username: str
2121
password: str
2222

2323

24-
class Token(BaseModel):
24+
class Token(BaseModel): # type: ignore
2525
access_token: str
2626
token_type: str
2727

2828

29-
class TokenData(BaseModel):
29+
class TokenData(BaseModel): # type: ignore
3030
username: str | None = None

backend/app/users/routers.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Annotated
1+
from typing import Annotated, Dict
22
from datetime import timedelta
33

44
from sqlmodel import Session, select
@@ -20,12 +20,12 @@
2020
router = APIRouter(prefix="/users")
2121

2222

23-
@router.post("/token")
23+
@router.post("/token") # type: ignore
2424
async def login(
2525
*,
2626
session: Session = Depends(get_session),
2727
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
28-
):
28+
) -> Dict[str, str]:
2929
stmt = select(User).where(User.username == form_data.username)
3030
db_user = session.exec(stmt).first()
3131
if not db_user or (
@@ -47,14 +47,14 @@ async def login(
4747
}
4848

4949

50-
@router.get("/me")
50+
@router.get("/me") # type: ignore
5151
async def read_users_me(
5252
current_user: Annotated[User, Depends(get_current_active_user)],
53-
):
53+
) -> User:
5454
return current_user
5555

5656

57-
@router.post("/")
57+
@router.post("/") # type: ignore
5858
def create_user(*, session: Session = Depends(get_session), user: UserCreate):
5959
stmt = select(User).where(User.username == user.username)
6060
db_user = session.exec(stmt).first()

0 commit comments

Comments
 (0)