Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 83 additions & 7 deletions backend/api/server_fastapi_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import uuid

import modal
from fastapi import APIRouter, File, Form, HTTPException, UploadFile
from fastapi import APIRouter, Body, File, Form, HTTPException, UploadFile

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -93,6 +93,7 @@ def _register_routes(self):
self.router.add_api_route("/cache/clear", self.clear_cache, methods=["POST"])
self.router.add_api_route("/auth/device/code", self.request_device_code, methods=["POST"])
self.router.add_api_route("/auth/device/poll", self.poll_device_code, methods=["POST"])
self.router.add_api_route("/auth/device/authorize", self.authorize_device, methods=["POST"])

async def health(self):
"""
Expand Down Expand Up @@ -310,28 +311,28 @@ async def request_device_code(self):
logger.error(f"[Device Code] Error generating device code: {e}")
raise HTTPException(status_code=500, detail=str(e))

async def poll_device_code(self, device_code: str):
async def poll_device_code(self, device_code: str = Body(..., embed=True)):
"""
Poll for device code authorization status.

Request body:
{
"device_code": "a8f3j2k1..."
}

Responses:
- Still waiting: {"status": "pending"}
- User authorized: {"status": "authorized", "user_id": "...", "id_token": "...", "refresh_token": "..."}
- Timed out: {"status": "expired", "error": "device_code_expired"}
- User denied: {"status": "denied", "error": "user_denied_authorization"}

Polling behavior:
- Client should poll every 3 seconds (interval from device/code response)
- Max 200 attempts (10 minutes total)
- Stop immediately if user closes dialog
"""
try:

if not device_code:
raise HTTPException(
status_code=400,
Expand All @@ -349,9 +350,84 @@ async def poll_device_code(self, device_code: str):

logger.info(f"[Device Poll] Device code {device_code} | status: {status.get('status')}")
return status

except HTTPException:
raise
except Exception as e:
logger.error(f"[Device Poll] Error polling device code: {e}")
raise HTTPException(status_code=500, detail=str(e))

async def authorize_device(
self,
user_code: str = Body(...),
firebase_id_token: str = Body(...),
firebase_refresh_token: str = Body("")
):
"""
Authorize a device after user logs in on website.

Request body:
{
"user_code": "ABC-420",
"firebase_id_token": "eyJhbGc...",
"firebase_refresh_token": "AOEOulbB..." (optional for now)
}

Response:
- Success: {"success": true}
- Errors: 400 (missing fields), 401 (invalid token), 404 (code not found), 500 (server error)
"""
try:
# Validate required fields
if not user_code:
raise HTTPException(
status_code=400,
detail="Missing required field: 'user_code'"
)
if not firebase_id_token:
raise HTTPException(
status_code=400,
detail="Missing required field: 'firebase_id_token'"
)

# Verify Firebase token
user_info = self.server_instance.auth_connector.verify_firebase_token(firebase_id_token)
if not user_info:
raise HTTPException(
status_code=401,
detail="Invalid Firebase token"
)

user_id = user_info["user_id"]
logger.info(f"[Device Authorize] Verified token for user: {user_id}")

# Lookup device_code from user_code
device_code = self.server_instance.auth_connector.get_device_code_by_user_code(user_code)
if not device_code:
raise HTTPException(
status_code=404,
detail="User code not found or expired"
)

# Mark device as authorized with tokens
success = self.server_instance.auth_connector.set_device_code_authorized(
device_code,
user_id,
firebase_id_token,
firebase_refresh_token
)

if not success:
raise HTTPException(
status_code=500,
detail="Failed to authorize device"
)

logger.info(f"[Device Authorize] Device authorized for user_code: {user_code}, user: {user_id}")
return {"success": True}

except HTTPException:
raise
except Exception as e:
logger.error(f"[Device Authorize] Error authorizing device: {e}")
raise HTTPException(status_code=500, detail=str(e))
56 changes: 41 additions & 15 deletions backend/auth/auth_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Auth service for device flow authentication.
"""

from firebase_admin import auth
import logging
from typing import Optional, Dict, Any
from datetime import datetime, timezone, timedelta
Expand All @@ -16,7 +17,7 @@
class AuthConnector:
"""
Modal Dict wrapper for device flow authentication.
Stores device codes with expiration (10 minutes) for OAuth device flow.
"""

Expand All @@ -35,12 +36,12 @@ def __init__(self, device_dict_name: str = DEFAULT_DEVICE_DICT, user_dict_name:

def _is_expired(self, entry: Dict[str, Any]) -> bool:
"""Check if a device code entry is expired."""
expires_at= entry.get("expires_at")
expires_at = entry.get("expires_at")
if expires_at is None:
return False
expires_at = datetime.fromisoformat(expires_at.replace('Z', '+00:00'))
return datetime.now(timezone.utc) > expires_at

def _delete_session(self, device_code: str, entry: Optional[Dict[str, Any]]) -> None:
"""
Delete both dicts safely if the entry is expired.
Expand All @@ -65,7 +66,6 @@ def generate_device_code(self) -> str:

def generate_user_code(self) -> str:
"""Generate a user-friendly code in format ABC-420."""

letters = ''.join(secrets.choice(string.ascii_uppercase) for _ in range(3))
digits = ''.join(secrets.choice(string.digits) for _ in range(3))
return f"{letters}-{digits}"
Expand Down Expand Up @@ -96,7 +96,7 @@ def create_device_code_entry(
def get_device_code_entry(self, device_code: str) -> Optional[Dict[str, Any]]:
"""Retrieve device code entry, returns None if not found or expired."""
try:

entry = self.device_store.get(device_code)
if entry is None:
return None
Expand All @@ -111,7 +111,7 @@ def get_device_code_entry(self, device_code: str) -> Optional[Dict[str, Any]]:
def get_device_code_by_user_code(self, user_code: str) -> Optional[str]:
"""
Lookup device_code by user_code.
Returns the device_code if found and not expired, None otherwise.
"""
try:
Expand All @@ -124,7 +124,7 @@ def get_device_code_by_user_code(self, user_code: str) -> Optional[str]:
if user_code in self.user_store:
del self.user_store[user_code]
return None

return device_code
except Exception as e:
logger.error(f"Error looking up device_code by user_code: {e}")
Expand All @@ -136,7 +136,7 @@ def update_device_code_status(self, device_code: str, status: str) -> bool:
entry = self.get_device_code_entry(device_code)
if entry is None:
return False

entry["status"] = status
self.device_store[device_code] = entry
logger.info(f"Updated device code {device_code[:8]}... status to: {status}")
Expand All @@ -157,7 +157,7 @@ def set_device_code_authorized(
entry = self.get_device_code_entry(device_code)
if entry is None:
return False

entry["status"] = "authorized"
entry["user_id"] = user_id
entry["id_token"] = id_token
Expand All @@ -176,7 +176,7 @@ def set_device_code_denied(self, device_code: str) -> bool:
entry = self.get_device_code_entry(device_code)
if entry is None:
return False

entry["status"] = "denied"
entry["denied_at"] = datetime.now(timezone.utc).isoformat()
self.device_store[device_code] = entry
Expand All @@ -189,31 +189,35 @@ def set_device_code_denied(self, device_code: str) -> bool:
def get_device_code_poll_status(self, device_code: str) -> Optional[Dict[str, Any]]:
"""
Get device code status for polling endpoint.
Returns status dict with appropriate fields based on state:
- pending: {"status": "pending"}
- authorized: {"status": "authorized", "user_id": ..., "id_token": ..., "refresh_token": ...}
- expired: {"status": "expired", "error": "device_code_expired"}
- denied: {"status": "denied", "error": "user_denied_authorization"}
- not_found: None (treat as expired)
Tokens are deleted after retrieval (one-time use).
"""
entry = self.get_device_code_entry(device_code)

if entry is None:
return {
"status": "expired",
"error": "device_code_expired"
}

status = entry.get("status", "pending")

if status == "authorized":
return {
result = {
"status": "authorized",
"user_id": entry.get("user_id"),
"id_token": entry.get("id_token"),
"refresh_token": entry.get("refresh_token")
}
self._delete_session(device_code, entry)
return result
elif status == "denied":
return {
"status": "denied",
Expand Down Expand Up @@ -241,3 +245,25 @@ def delete_device_code(self, device_code: str) -> bool:
except Exception as e:
logger.error(f"Error deleting device code: {e}")
return False

def verify_firebase_token(self, id_token: str) -> Optional[Dict[str, Any]]:
"""Verify Firebase ID token from website/plugin."""
try:
decoded_token = auth.verify_id_token(id_token)
return {
"user_id": decoded_token['uid'],
"email": decoded_token.get('email'),
"email_verified": decoded_token.get('email_verified', False)
}
except auth.InvalidIdTokenError as e:
logger.error(f"Invalid Firebase token: {e}")
return None
except auth.ExpiredIdTokenError as e:
logger.error(f"Expired Firebase token: {e}")
return None
except auth.RevokedIdTokenError as e:
logger.error(f"Revoked Firebase token: {e}")
return None
except auth.CertificateFetchError as e:
logger.error(f"Firebase certificate fetch error: {e}")
return None
4 changes: 2 additions & 2 deletions backend/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ dependencies = [
"transformers",
"scenedetect",
"boto3",
"torchvision"
"torchvision",
"firebase-admin>=7.1.0",
]

[project.scripts]
Expand All @@ -34,7 +35,6 @@ packages = ["cli.py"]

[dependency-groups]
dev = [
"opencv-python",
"pytest",
"pytest-asyncio>=1.3.0",
"pytest-cov",
Expand Down
15 changes: 15 additions & 0 deletions backend/services/http_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,21 @@ def _initialize_connectors(self):
logger.info(f"[{self.__class__.__name__}] Starting up in '{env}' environment")
self.start_time = datetime.now(timezone.utc)

# Initialize Firebase Admin SDK (required for token verification)
try:
import firebase_admin
import json
firebase_credentials = json.loads(get_env_var("FIREBASE_ADMIN_KEY"))
from firebase_admin import credentials
cred = credentials.Certificate(firebase_credentials)
firebase_admin.initialize_app(cred)
logger.info(f"[{self.__class__.__name__}] Firebase Admin SDK initialized")
except ValueError:
# Already initialized, which is fine
pass
except Exception as e:
logger.warning(f"[{self.__class__.__name__}] Firebase initialization failed: {e}")

# Get environment variables
PINECONE_API_KEY = get_env_var("PINECONE_API_KEY")
R2_ACCOUNT_ID = get_env_var("R2_ACCOUNT_ID")
Expand Down
Loading
Loading