Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 16 additions & 39 deletions app/api/v1/endpoints/auth.py
Original file line number Diff line number Diff line change
@@ -1,60 +1,33 @@
from datetime import timedelta
from typing import Any

from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordRequestForm
from sqlalchemy.orm import Session

from app.core.config import settings
from app.core.security import create_access_token, get_password_hash, verify_password
from app.db.database import get_db
from app.models.user import User
from app.schemas.token import Token
from app.schemas.user import UserCreate
from app.services.auth_service import auth_service
from app.services.user_service import user_service

router = APIRouter()


def get_access_token(id: int) -> Any:
access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
access_token = create_access_token(subject=id, expires_delta=access_token_expires)

return access_token


@router.post("/register", response_model=Token)
async def register_user(user_in: UserCreate, db: Session = Depends(get_db)) -> Any:
"""
Register a new user.
"""
user = db.query(User).filter(User.username == user_in.username).first()
if user:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="A user with this username already exists",
)
# Hash the password
hashed_password = auth_service.get_password_hash(user_in.password)

if len(user_in.password) < 4:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Password must be at least 4 characters",
)
# Create the user (this includes validation)
user = user_service.create_user(db, user_in, hashed_password)

if len(user_in.username) < 3:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Username must be at least 3 characters",
)
# Generate access token
access_token = auth_service.generate_access_token(int(user.id))

user = User(
username=user_in.username,
hashed_password=get_password_hash(user_in.password),
)
db.add(user)
db.commit()
db.refresh(user)

return {"access_token": get_access_token(int(user.id)), "token_type": "bearer"}
return {"access_token": access_token, "token_type": "bearer"}


@router.post("/login", response_model=Token)
Expand All @@ -64,13 +37,17 @@ async def login_for_access_token(
"""
OAuth2 compatible token login, get an access token for future requests.
"""
user = db.query(User).filter(User.username == form_data.username).first()
# Authenticate user
user = auth_service.authenticate_user(db, form_data.username, form_data.password)

if not user or not verify_password(form_data.password, str(user.hashed_password)):
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect username or password",
headers={"WWW-Authenticate": "Bearer"},
)

return {"access_token": get_access_token(int(user.id)), "token_type": "bearer"}
# Generate access token
access_token = auth_service.generate_access_token(int(user.id))

return {"access_token": access_token, "token_type": "bearer"}
4 changes: 2 additions & 2 deletions app/api/v1/endpoints/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,16 @@

from fastapi import APIRouter, Depends

from app.core.security import get_current_user
from app.models.user import User
from app.schemas.user import UserResponse
from app.services.auth_service import auth_service

router = APIRouter()


@router.get("/me", response_model=UserResponse)
async def read_current_user(
current_user: User = Depends(get_current_user),
current_user: User = Depends(auth_service.get_current_user),
) -> Any:
"""
Get current user.
Expand Down
55 changes: 0 additions & 55 deletions app/core/security.py
Original file line number Diff line number Diff line change
@@ -1,50 +1,14 @@
from datetime import UTC, datetime, timedelta
from typing import Optional, Union

from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from jose import jwt
from passlib.context import CryptContext
from pydantic import ValidationError
from sqlalchemy.orm import Session

from app.core.config import settings
from app.db.database import get_db
from app.models.user import User
from app.schemas.token import TokenPayload

pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")

oauth2_scheme = OAuth2PasswordBearer(tokenUrl=f"{settings.API_V1_STR}/auth/login")


def get_password_hash(password: str) -> str:
"""
Hash a password using bcrypt.

Args:
password: Plain text password to hash

Returns:
Hashed password string
"""
return pwd_context.hash(password)


def verify_password(plain_password: str, hashed_password: str) -> bool:
"""
Verify a password against its hash.

Args:
plain_password: Plain text password to verify
hashed_password: Previously hashed password to check against

Returns:
True if password matches, False otherwise
"""
return pwd_context.verify(plain_password, hashed_password)


def create_access_token(subject: Union[str, int], expires_delta: Optional[timedelta] = None) -> str:
"""
Create a JWT token with the provided subject (typically user ID)
Expand All @@ -57,22 +21,3 @@ def create_access_token(subject: Union[str, int], expires_delta: Optional[timede
to_encode = {"exp": expire, "sub": str(subject)}
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
return encoded_jwt


def get_current_user(db: Session = Depends(get_db), token: str = Depends(oauth2_scheme)) -> User:
"""
Decode JWT token and return the current user.
"""
try:
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
token_data = TokenPayload(**payload)
except (jwt.JWTError, ValidationError):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Could not validate credentials",
)

user = db.query(User).filter(User.id == int(token_data.sub)).first() # type: ignore[arg-type]
if not user:
raise HTTPException(status_code=404, detail="User not found")
return user
2 changes: 2 additions & 0 deletions app/services/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
from .ai_agents_service import ai_agents_service
from .auth_service import auth_service
from .routing_service import routing_service
from .user_service import user_service
121 changes: 121 additions & 0 deletions app/services/auth_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
"""
Authentication service for handling password hashing, verification, and token generation.
"""

from datetime import timedelta
from typing import Optional

from fastapi import Depends, HTTPException, status
from jose import jwt
from passlib.context import CryptContext
from pydantic import ValidationError
from sqlalchemy.orm import Session

from app.core.config import settings
from app.core.security import create_access_token, oauth2_scheme
from app.db.database import get_db
from app.models.user import User
from app.schemas.token import TokenPayload
from app.services.user_service import user_service

pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")


class AuthService:
"""Service for handling authentication operations."""

@staticmethod
def get_password_hash(password: str) -> str:
"""
Hash a password using bcrypt.

Args:
password: Plain text password to hash

Returns:
Hashed password string
"""
return pwd_context.hash(password)

@staticmethod
def verify_password(plain_password: str, hashed_password: str) -> bool:
"""
Verify a password against its hash.

Args:
plain_password: Plain text password to verify
hashed_password: Previously hashed password to check against

Returns:
True if password matches, False otherwise
"""
return pwd_context.verify(plain_password, hashed_password)

@staticmethod
def generate_access_token(user_id: int) -> str:
"""
Generate an access token for a user.

Args:
user_id: The ID of the user to generate token for

Returns:
JWT access token string
"""
access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
return create_access_token(subject=user_id, expires_delta=access_token_expires)

@staticmethod
def authenticate_user(db: Session, username: str, password: str) -> Optional[User]:
"""
Authenticate a user by username and password.

Args:
db: Database session
username: Username to authenticate
password: Plain text password to verify

Returns:
User object if authentication successful, None otherwise
"""
user = user_service.get_user_by_username(db, username)
if not user:
return None
if not AuthService.verify_password(password, str(user.hashed_password)):
return None
return user

@staticmethod
def get_current_user(
db: Session = Depends(get_db), token: str = Depends(oauth2_scheme)
) -> User:
"""
Decode JWT token and return the current user.

Args:
db: Database session
token: JWT token from request

Returns:
User object for the authenticated user

Raises:
HTTPException: If token is invalid or user not found
"""
try:
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
token_data = TokenPayload(**payload)
except (jwt.JWTError, ValidationError):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Could not validate credentials",
)

user = db.query(User).filter(User.id == int(token_data.sub)).first() # type: ignore
if not user:
raise HTTPException(status_code=404, detail="User not found")
return user


# Create a singleton instance
auth_service = AuthService()
Loading