Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 45 additions & 1 deletion src/routes/accounts.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from datetime import datetime, timezone
from typing import cast

from fastapi import APIRouter, Depends, status, HTTPException
from fastapi import APIRouter, Depends, status, HTTPException, BackgroundTasks
from sqlalchemy import select, delete
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession
Expand Down Expand Up @@ -67,7 +67,9 @@
)
async def register_user(
user_data: UserRegistrationRequestSchema,
background_tasks: BackgroundTasks,
db: AsyncSession = Depends(get_db),
email_sender: EmailSenderInterface = Depends(get_accounts_email_notificator),
) -> UserRegistrationResponseSchema:
"""
Endpoint for user registration.
Expand All @@ -78,7 +80,9 @@ async def register_user(

Args:
user_data (UserRegistrationRequestSchema): The registration details including email and password.
background_tasks (BackgroundTasks): The background tasks manager.
db (AsyncSession): The asynchronous database session.
email_sender (EmailSenderInterface): The email notification service.

Returns:
UserRegistrationResponseSchema: The newly created user's details.
Expand Down Expand Up @@ -120,6 +124,13 @@ async def register_user(

await db.commit()
await db.refresh(new_user)

activation_link = f"http://127.0.0.1/accounts/activate/?email={new_user.email}&token={activation_token.token}"
background_tasks.add_task(
email_sender.send_activation_email,
str(new_user.email),
activation_link
)
except SQLAlchemyError as e:
await db.rollback()
raise HTTPException(
Expand Down Expand Up @@ -163,7 +174,9 @@ async def register_user(
)
async def activate_account(
activation_data: UserActivationRequestSchema,
background_tasks: BackgroundTasks,
db: AsyncSession = Depends(get_db),
email_sender: EmailSenderInterface = Depends(get_accounts_email_notificator),
) -> MessageResponseSchema:
"""
Endpoint to activate a user's account.
Expand All @@ -175,7 +188,9 @@ async def activate_account(

Args:
activation_data (UserActivationRequestSchema): Contains the user's email and activation token.
background_tasks (BackgroundTasks): The background tasks manager.
db (AsyncSession): The asynchronous database session.
email_sender (EmailSenderInterface): The email notification service.

Returns:
MessageResponseSchema: A response message confirming successful activation.
Expand Down Expand Up @@ -218,6 +233,13 @@ async def activate_account(
await db.delete(token_record)
await db.commit()

login_link = "http://127.0.0.1/accounts/login/"
background_tasks.add_task(
email_sender.send_activation_complete_email,
str(user.email),
login_link
)

return MessageResponseSchema(message="User account activated successfully.")


Expand All @@ -233,7 +255,9 @@ async def activate_account(
)
async def request_password_reset_token(
data: PasswordResetRequestSchema,
background_tasks: BackgroundTasks,
db: AsyncSession = Depends(get_db),
email_sender: EmailSenderInterface = Depends(get_accounts_email_notificator),
) -> MessageResponseSchema:
"""
Endpoint to request a password reset token.
Expand All @@ -243,7 +267,9 @@ async def request_password_reset_token(

Args:
data (PasswordResetRequestSchema): The request data containing the user's email.
background_tasks (BackgroundTasks): The background tasks manager.
db (AsyncSession): The asynchronous database session.
email_sender (EmailSenderInterface): The email notification service.

Returns:
MessageResponseSchema: A success message indicating that instructions will be sent.
Expand All @@ -263,6 +289,13 @@ async def request_password_reset_token(
db.add(reset_token)
await db.commit()

reset_link = f"http://127.0.0.1/accounts/password-reset/?email={user.email}&token={reset_token.token}"
background_tasks.add_task(
email_sender.send_password_reset_email,
str(user.email),
reset_link
)

return MessageResponseSchema(
message="If you are registered, you will receive an email with instructions."
)
Expand Down Expand Up @@ -313,7 +346,9 @@ async def request_password_reset_token(
)
async def reset_password(
data: PasswordResetCompleteRequestSchema,
background_tasks: BackgroundTasks,
db: AsyncSession = Depends(get_db),
email_sender: EmailSenderInterface = Depends(get_accounts_email_notificator),
) -> MessageResponseSchema:
"""
Endpoint for resetting a user's password.
Expand All @@ -324,7 +359,9 @@ async def reset_password(
Args:
data (PasswordResetCompleteRequestSchema): The request data containing the user's email,
token, and new password.
background_tasks (BackgroundTasks): The background tasks manager.
db (AsyncSession): The asynchronous database session.
email_sender (EmailSenderInterface): The email notification service.

Returns:
MessageResponseSchema: A response message indicating successful password reset.
Expand Down Expand Up @@ -369,6 +406,13 @@ async def reset_password(
user.password = data.password
await db.run_sync(lambda s: s.delete(token_record))
await db.commit()

login_link = "http://127.0.0.1/accounts/login/"
background_tasks.add_task(
email_sender.send_password_reset_complete_email,
str(user.email),
login_link
)
except SQLAlchemyError:
await db.rollback()
raise HTTPException(
Expand Down
131 changes: 129 additions & 2 deletions src/routes/profiles.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,132 @@
from fastapi import APIRouter
from fastapi import APIRouter, Depends, status, HTTPException, UploadFile, File
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from sqlalchemy.orm import joinedload
from sqlalchemy.exc import SQLAlchemyError

from database import get_db, UserModel, UserProfileModel, UserGroupEnum
from database.models.accounts import GenderEnum
from config import get_jwt_auth_manager, get_s3_storage_client
from security.interfaces import JWTAuthManagerInterface
from storages.interfaces import S3StorageInterface
from security.http import get_token
from exceptions import BaseSecurityError
from schemas.profiles import ProfileResponseSchema, ProfileCreateSchema
from validation import validate_image

router = APIRouter()

# Write your code here

@router.post(
"/users/{user_id}/profile/",
response_model=ProfileResponseSchema,
summary="Create User Profile",
description="Create a new profile for a user, including an avatar upload.",
status_code=status.HTTP_201_CREATED,
)
async def create_profile(
user_id: int,
profile_data: ProfileCreateSchema = Depends(ProfileCreateSchema.as_form),
avatar: UploadFile = File(...),
db: AsyncSession = Depends(get_db),
token: str = Depends(get_token),
jwt_manager: JWTAuthManagerInterface = Depends(get_jwt_auth_manager),
s3_client: S3StorageInterface = Depends(get_s3_storage_client),
) -> ProfileResponseSchema:
"""
Endpoint for creating a user profile.

This endpoint validates the user's token, checks for authorization,
ensures the user doesn't already have a profile, uploads an avatar to S3,
and saves the profile details to the database.
"""
# 1. Token validation
try:
payload = jwt_manager.decode_access_token(token)
current_user_id = payload.get("user_id")
except BaseSecurityError as e:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Token has expired." if "expired" in str(e).lower() else str(e)
)

# 2. Authorization rules
stmt = select(UserModel).options(joinedload(UserModel.group)).where(UserModel.id == current_user_id)
result = await db.execute(stmt)
current_user = result.scalars().first()

if not current_user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="User not found or not active."
)

if current_user_id != user_id and current_user.group.name != UserGroupEnum.ADMIN:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="You don't have permission to edit this profile."
)

# 3. User existence and status
stmt = select(UserModel).where(UserModel.id == user_id)
result = await db.execute(stmt)
target_user = result.scalars().first()
if not target_user or not target_user.is_active:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="User not found or not active."
)

# 4. Check for existing profile
stmt = select(UserProfileModel).where(UserProfileModel.user_id == user_id)
result = await db.execute(stmt)
if result.scalars().first():
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="User already has a profile."
)

# 5. Avatar upload to s3 storage
try:
validate_image(avatar)
ext = avatar.filename.split(".")[-1]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method of getting the file extension might not work as expected for all filenames. For example, if a filename has no extension (e.g., myfile), this will use the entire filename as the extension. For more robust parsing, consider using Python's os.path.splitext or pathlib module.

object_key = f"avatars/{user_id}_avatar.{ext}"

content = await avatar.read()
await s3_client.upload_file(object_key, content)
avatar_url = await s3_client.get_file_url(object_key)
except ValueError as e:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=str(e)
)
except Exception:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to upload avatar. Please try again later."
)

# 6. Profile creation and storage
try:
new_profile = UserProfileModel(
user_id=user_id,
first_name=profile_data.first_name.lower(),
last_name=profile_data.last_name.lower(),
gender=GenderEnum(profile_data.gender),
date_of_birth=profile_data.date_of_birth,
info=profile_data.info,
avatar=object_key
)
db.add(new_profile)
await db.commit()
await db.refresh(new_profile)
except SQLAlchemyError:
await db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="An error occurred while creating the profile."
)

response_data = ProfileResponseSchema.model_validate(new_profile)
response_data.avatar = avatar_url
return response_data
79 changes: 75 additions & 4 deletions src/schemas/profiles.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,84 @@
import json
from datetime import date
from typing import Optional

from fastapi import UploadFile, Form, File, HTTPException
from pydantic import BaseModel, field_validator, HttpUrl
from fastapi import Form, HTTPException, status
from pydantic import BaseModel, field_validator, ValidationError

from validation import (
validate_name,
validate_image,
validate_gender,
validate_birth_date
)
Comment on lines 8 to 13

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to the task description, you are required to import validate_image from the validation package in this file. Please add it to the import list.

from database.models.accounts import GenderEnum

# Write your code here

class ProfileResponseSchema(BaseModel):
id: int
user_id: int
first_name: str
last_name: str
gender: GenderEnum
date_of_birth: date
info: str
avatar: Optional[str] = None

model_config = {
"from_attributes": True
}


class ProfileCreateSchema(BaseModel):
first_name: str
last_name: str
gender: str
date_of_birth: date
info: str

@field_validator("first_name", "last_name")
@classmethod
def validate_names(cls, v: str) -> str:
validate_name(v)
return v

@field_validator("gender")
@classmethod
def validate_genders(cls, v: str) -> str:
validate_gender(v)
return v

@field_validator("date_of_birth")
@classmethod
def validate_birth_dates(cls, v: date) -> date:
validate_birth_date(v)
return v

@field_validator("info")
@classmethod
def validate_infos(cls, v: str) -> str:
if not v or v.isspace():
raise ValueError("Info field cannot be empty or contain only spaces.")
return v

@classmethod
def as_form(
cls,
first_name: str = Form(...),
last_name: str = Form(...),
gender: str = Form(...),
date_of_birth: date = Form(...),
info: str = Form(...)
):
try:
return cls(
first_name=first_name,
last_name=last_name,
gender=gender,
date_of_birth=date_of_birth,
info=info
)
except ValidationError as e:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=json.loads(e.json())
)