Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 81 additions & 55 deletions backend/app/routers/uploads.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,66 @@ async def _get_upload_record(db: AsyncSession, upload_id: str) -> models.UploadR
return record


async def _finalize_upload(
db: AsyncSession,
record: models.UploadRecord,
queue: ProcessingQueue | None,
) -> models.UploadRecord:
"""Validate and finalize an uploaded file after all bytes have been received."""
if record.upload_length is None:
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="Upload length unknown")

if record.upload_offset < record.upload_length:
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="Upload not finished")

if record.status in {"completed", "postprocessing"}:
return record

path = Path(record.storage_path)
if not path.exists():
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Uploaded file not found")

try:
actual_mimetype: str = detect_mimetype(path)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to detect file type: {e}",
) from e

stmt: Select[tuple[models.UploadToken]] = select(models.UploadToken).where(models.UploadToken.id == record.token_id)
res: Result[tuple[models.UploadToken]] = await db.execute(stmt)

if not (token := res.scalar_one_or_none()):
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Token not found")

if not mime_allowed(actual_mimetype, token.allowed_mime):
path.unlink(missing_ok=True)
await db.delete(record)
await db.commit()
raise HTTPException(
status_code=status.HTTP_415_UNSUPPORTED_MEDIA_TYPE,
detail=f"Actual file type '{actual_mimetype}' does not match allowed types",
)

record.mimetype = actual_mimetype

if is_multimedia(actual_mimetype):
record.status = "postprocessing"
record.completed_at = None
await db.commit()
await db.refresh(record)
if queue:
await queue.enqueue(record.public_id)
return record

record.status = "completed"
record.completed_at = datetime.now(UTC)
await db.commit()
await db.refresh(record)
return record


@router.post("/initiate", response_model=schemas.InitiateUploadResponse, status_code=status.HTTP_201_CREATED, name="initiate_upload")
async def initiate_upload(
request: Request,
Expand Down Expand Up @@ -207,7 +267,6 @@ async def tus_patch(
upload_id: str,
request: Request,
db: Annotated[AsyncSession, Depends(get_db)],
queue: Annotated[ProcessingQueue | None, Depends(get_processing_queue)],
upload_offset: Annotated[int, Header(convert_underscores=False, alias="Upload-Offset")] = ...,
content_length: Annotated[int | None, Header()] = None,
content_type: Annotated[str, Header(convert_underscores=False, alias="Content-Type")] = ...,
Expand All @@ -219,7 +278,6 @@ async def tus_patch(
upload_id (str): The public ID of the upload.
request (Request): The incoming HTTP request.
db (AsyncSession): Database session.
queue (ProcessingQueue | None): The processing queue for post-processing.
upload_offset (int): The current upload offset from the client.
content_length (int | None): The Content-Length header value.
content_type (str): The Content-Type header value.
Expand Down Expand Up @@ -268,52 +326,14 @@ async def tus_patch(
if record.upload_offset > record.upload_length:
raise HTTPException(status_code=status.HTTP_413_CONTENT_TOO_LARGE, detail="Upload exceeds declared length")

if record.upload_offset == record.upload_length:
try:
actual_mimetype: str = detect_mimetype(path)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to detect file type: {e}",
)

stmt: Select[tuple[models.UploadToken]] = select(models.UploadToken).where(models.UploadToken.id == record.token_id)
res: Result[tuple[models.UploadToken]] = await db.execute(stmt)

if not (token := res.scalar_one_or_none()):
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Token not found")

if not mime_allowed(actual_mimetype, token.allowed_mime):
path.unlink(missing_ok=True)
await db.delete(record)
await db.commit()
raise HTTPException(
status_code=status.HTTP_415_UNSUPPORTED_MEDIA_TYPE,
detail=f"Actual file type '{actual_mimetype}' does not match allowed types",
)

record.mimetype = actual_mimetype

if is_multimedia(actual_mimetype):
record.status = "postprocessing"
await db.commit()
await db.refresh(record)
if queue:
await queue.enqueue(record.public_id)
else:
record.status = "completed"
record.completed_at = datetime.now(UTC)
await db.commit()
await db.refresh(record)
else:
record.status = "in_progress"

try:
await db.commit()
await db.refresh(record)
except Exception:
await db.rollback()
await db.refresh(record)
record.status = "in_progress"

try:
await db.commit()
await db.refresh(record)
except Exception:
await db.rollback()
await db.refresh(record)

return Response(
status_code=status.HTTP_204_NO_CONTENT,
Expand Down Expand Up @@ -366,26 +386,32 @@ async def tus_delete(upload_id: str, db: Annotated[AsyncSession, Depends(get_db)


@router.post("/{upload_id}/complete", response_model=schemas.UploadRecordResponse, name="mark_complete")
async def mark_complete(upload_id: str, db: Annotated[AsyncSession, Depends(get_db)]) -> models.UploadRecord:
async def mark_complete(
upload_id: str,
db: Annotated[AsyncSession, Depends(get_db)],
queue: Annotated[ProcessingQueue | None, Depends(get_processing_queue)],
token: Annotated[str, Query(description="Upload token")] = ...,
) -> models.UploadRecord:
"""
Mark an upload as complete.

Args:
upload_id (str): The public ID of the upload.
db (AsyncSession): Database session.
queue (ProcessingQueue | None): The processing queue for post-processing.
token (str): The upload token string.

Returns:
UploadRecord: The updated upload record.

"""
record: models.UploadRecord = await _get_upload_record(db, upload_id)
await _ensure_token(db, token_id=record.token_id, check_remaining=False)
token_row: models.UploadToken = await _ensure_token(db, token_value=token, check_remaining=False)

record.status = "completed"
record.completed_at = datetime.now(UTC)
await db.commit()
await db.refresh(record)
return record
if record.token_id != token_row.id:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Upload does not belong to this token")

return await _finalize_upload(db, record, queue)


@router.delete("/{upload_id}/cancel", response_model=dict, name="cancel_upload")
Expand Down
10 changes: 5 additions & 5 deletions backend/tests/test_download_restrictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ async def test_download_blocked_for_disabled_token(client):

upload_data = await initiate_upload(client, upload_token, "test.txt", 12)
upload_id = upload_data["upload_id"]
await upload_file_via_tus(client, upload_id, b"test content")
await upload_file_via_tus(client, upload_id, b"test content", upload_token)

await client.patch(
app.url_path_for("update_token", token_value=upload_token),
Expand All @@ -45,7 +45,7 @@ async def test_download_blocked_for_expired_token(client):

upload_data = await initiate_upload(client, upload_token, "test.txt", 12)
upload_id = upload_data["upload_id"]
await upload_file_via_tus(client, upload_id, b"test content")
await upload_file_via_tus(client, upload_id, b"test content", upload_token)

expired_time = datetime.now(UTC) - timedelta(hours=1)
await client.patch(
Expand All @@ -70,7 +70,7 @@ async def test_download_allowed_for_disabled_token_with_admin_key(client):

upload_data = await initiate_upload(client, upload_token, "test.txt", 12)
upload_id = upload_data["upload_id"]
await upload_file_via_tus(client, upload_id, b"test content")
await upload_file_via_tus(client, upload_id, b"test content", upload_token)

await client.patch(
app.url_path_for("update_token", token_value=upload_token),
Expand All @@ -94,7 +94,7 @@ async def test_get_file_info_blocked_for_disabled_token(client):

upload_data = await initiate_upload(client, upload_token, "test.txt", 12)
upload_id = upload_data["upload_id"]
await upload_file_via_tus(client, upload_id, b"test content")
await upload_file_via_tus(client, upload_id, b"test content", upload_token)

await client.patch(
app.url_path_for("update_token", token_value=upload_token),
Expand All @@ -118,7 +118,7 @@ async def test_get_file_info_allowed_for_disabled_token_with_admin_key(client):

upload_data = await initiate_upload(client, upload_token, "test.txt", 12)
upload_id = upload_data["upload_id"]
await upload_file_via_tus(client, upload_id, b"test content")
await upload_file_via_tus(client, upload_id, b"test content", upload_token)

await client.patch(
app.url_path_for("update_token", token_value=upload_token),
Expand Down
4 changes: 2 additions & 2 deletions backend/tests/test_download_url_security.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ async def test_list_token_uploads_does_not_expose_api_key():
upload_data = await initiate_upload(
client, token_data["token"], filename="test.txt", size_bytes=11, filetype="text/plain", meta_data={}
)
await upload_file_via_tus(client, upload_data["upload_id"], b"hello world")
await upload_file_via_tus(client, upload_data["upload_id"], b"hello world", token_data["token"])

# Get uploads list as admin
response = await client.get(
Expand All @@ -48,7 +48,7 @@ async def test_get_file_info_does_not_expose_api_key():
upload_data = await initiate_upload(
client, token_data["token"], filename="test.txt", size_bytes=11, filetype="text/plain", meta_data={}
)
await upload_file_via_tus(client, upload_data["upload_id"], b"hello world")
await upload_file_via_tus(client, upload_data["upload_id"], b"hello world", token_data["token"])

response = await client.get(
app.url_path_for(
Expand Down
23 changes: 21 additions & 2 deletions backend/tests/test_mimetype_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from backend.app.config import settings
from backend.app.db import SessionLocal
from backend.app.main import app
from backend.tests.utils import complete_upload


@pytest.mark.asyncio
Expand Down Expand Up @@ -56,8 +57,11 @@ async def test_mimetype_spoofing_rejected(client):
},
)

assert patch_resp.status_code == status.HTTP_415_UNSUPPORTED_MEDIA_TYPE, "Fake video file should be rejected with 415"
assert "does not match allowed types" in patch_resp.json()["detail"], "Error should indicate type mismatch"
assert patch_resp.status_code == status.HTTP_204_NO_CONTENT, "TUS PATCH should accept bytes before explicit completion"

complete_status, complete_data = await complete_upload(client, upload_id, token_value)
assert complete_status == status.HTTP_415_UNSUPPORTED_MEDIA_TYPE, "Fake video file should be rejected during explicit completion"
assert "does not match allowed types" in complete_data["detail"], "Error should indicate type mismatch"

head_resp = await client.head(app.url_path_for("tus_head", upload_id=upload_id))
assert head_resp.status_code == status.HTTP_404_NOT_FOUND, "Rejected upload should be removed"
Expand Down Expand Up @@ -111,6 +115,10 @@ async def test_valid_mimetype_accepted(client):

assert patch_resp.status_code == status.HTTP_204_NO_CONTENT, "Valid text file should be accepted"

complete_status, complete_data = await complete_upload(client, upload_id, token_value)
assert complete_status == status.HTTP_200_OK, "Completion endpoint should finalize valid text uploads"
assert complete_data["status"] == "completed", "Text upload should be marked completed after explicit completion"

head_resp = await client.head(app.url_path_for("tus_head", upload_id=upload_id))
assert head_resp.status_code == status.HTTP_200_OK, "Upload should still exist after completion"

Expand Down Expand Up @@ -158,6 +166,9 @@ async def test_mimetype_updated_on_completion(client):
)
assert patch_resp.status_code == status.HTTP_204_NO_CONTENT, "Upload completion should return 204"

complete_status, _ = await complete_upload(client, upload_id, token_value)
assert complete_status == status.HTTP_200_OK, "Completion endpoint should succeed for uploaded text files"

async with SessionLocal() as session:
stmt = select(models.UploadRecord).where(models.UploadRecord.public_id == upload_id)
res = await session.execute(stmt)
Expand Down Expand Up @@ -211,6 +222,10 @@ async def test_ffprobe_extracts_metadata_for_video(client):
)
assert patch_resp.status_code == status.HTTP_204_NO_CONTENT, "Video upload should complete successfully"

complete_status, complete_data = await complete_upload(client, upload_id, token_value)
assert complete_status == status.HTTP_200_OK, "Completion endpoint should accept uploaded video files"
assert complete_data["status"] == "postprocessing", "Video upload should enter postprocessing after explicit completion"

from backend.tests.test_postprocessing import wait_for_processing

await wait_for_processing([upload_id], timeout=10.0)
Expand Down Expand Up @@ -272,6 +287,10 @@ async def test_ffprobe_not_run_for_non_multimedia(client):
)
assert patch_resp.status_code == status.HTTP_204_NO_CONTENT, "Text upload should complete successfully"

complete_status, complete_data = await complete_upload(client, upload_id, token_value)
assert complete_status == status.HTTP_200_OK, "Completion endpoint should succeed for text uploads"
assert complete_data["status"] == "completed", "Text upload should complete immediately after explicit completion"

async with SessionLocal() as session:
stmt = select(models.UploadRecord).where(models.UploadRecord.public_id == upload_id)
res = await session.execute(stmt)
Expand Down
47 changes: 2 additions & 45 deletions backend/tests/test_postprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ async def test_multimedia_upload_enters_postprocessing(client):
)
upload_id = upload_data["upload_id"]

await upload_file_via_tus(client, upload_id, video_content)
await upload_file_via_tus(client, upload_id, video_content, token_value)

async with SessionLocal() as session:
stmt = select(models.UploadRecord).where(models.UploadRecord.public_id == upload_id)
Expand All @@ -79,7 +79,7 @@ async def test_non_multimedia_upload_completes_immediately(client):
)
upload_id = upload_data["upload_id"]

await upload_file_via_tus(client, upload_id, pdf_content)
await upload_file_via_tus(client, upload_id, pdf_content, token_value)

async with SessionLocal() as session:
stmt = select(models.UploadRecord).where(models.UploadRecord.public_id == upload_id)
Expand All @@ -90,49 +90,6 @@ async def test_non_multimedia_upload_completes_immediately(client):
assert record.completed_at is not None, "Upload should be marked complete"


@pytest.mark.asyncio
async def test_postprocessing_worker_processes_queue(client):
"""Test that the post-processing worker processes pending uploads."""
token_data = await create_token(client, max_uploads=2)
token_value = token_data["token"]

video_file = Path(__file__).parent / "fixtures" / "sample.mp4"
video_content = video_file.read_bytes()

upload1_data = await initiate_upload(
client, token_value, filename="video1.mp4", size_bytes=len(video_content), filetype="video/mp4", meta_data={"title": "Video 1"}
)
upload1_id = upload1_data["upload_id"]

upload2_data = await initiate_upload(
client, token_value, filename="video2.mp4", size_bytes=len(video_content), filetype="video/mp4", meta_data={"title": "Video 2"}
)
upload2_id = upload2_data["upload_id"]

await upload_file_via_tus(client, upload1_id, video_content)
await upload_file_via_tus(client, upload2_id, video_content)

async with SessionLocal() as session:
stmt = select(models.UploadRecord).where(models.UploadRecord.public_id.in_([upload1_id, upload2_id]))
result = await session.execute(stmt)
records = result.scalars().all()

for record in records:
assert record.status in ("postprocessing", "completed"), "Upload should be in postprocessing or already completed"

completed = await wait_for_processing([upload1_id, upload2_id])
assert completed, "Processing should complete within timeout"

async with SessionLocal() as session:
stmt = select(models.UploadRecord).where(models.UploadRecord.public_id.in_([upload1_id, upload2_id]))
result = await session.execute(stmt)
records = result.scalars().all()

for record in records:
assert record.status == "completed", "Both uploads should be completed after processing"
assert record.completed_at is not None, "Both uploads should have completion time"


@pytest.mark.asyncio
async def test_postprocessing_handles_missing_file():
"""Test that post-processing handles missing files gracefully."""
Expand Down
Loading
Loading