Skip to content

Commit 78319ba

Browse files
committed
chore: merge staging into auth-endpoint (use staging auth implementation)
2 parents c5690b8 + b1b25cc commit 78319ba

File tree

7 files changed

+971
-75
lines changed

7 files changed

+971
-75
lines changed

backend/api/server_fastapi_router.py

Lines changed: 67 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,11 @@
44
import uuid
55

66
import modal
7-
from fastapi import APIRouter, File, Form, HTTPException, UploadFile
8-
from pydantic import BaseModel
7+
from fastapi import APIRouter, Body, File, Form, HTTPException, UploadFile
98

109
logger = logging.getLogger(__name__)
1110

1211

13-
class AuthorizeDeviceRequest(BaseModel):
14-
"""Request body for device code authorization."""
15-
user_code: str
16-
user_id: str
17-
id_token: str
18-
refresh_token: str
19-
20-
2112
class ServerFastAPIRouter:
2213
"""
2314
FastAPI router for the Server service.
@@ -102,7 +93,7 @@ def _register_routes(self):
10293
self.router.add_api_route("/cache/clear", self.clear_cache, methods=["POST"])
10394
self.router.add_api_route("/auth/device/code", self.request_device_code, methods=["POST"])
10495
self.router.add_api_route("/auth/device/poll", self.poll_device_code, methods=["POST"])
105-
self.router.add_api_route("/auth/device/authorize", self.authorize_device_code, methods=["POST"])
96+
self.router.add_api_route("/auth/device/authorize", self.authorize_device, methods=["POST"])
10697

10798
async def health(self):
10899
"""
@@ -320,28 +311,28 @@ async def request_device_code(self):
320311
logger.error(f"[Device Code] Error generating device code: {e}")
321312
raise HTTPException(status_code=500, detail=str(e))
322313

323-
async def poll_device_code(self, device_code: str):
314+
async def poll_device_code(self, device_code: str = Body(..., embed=True)):
324315
"""
325316
Poll for device code authorization status.
326-
317+
327318
Request body:
328319
{
329320
"device_code": "a8f3j2k1..."
330321
}
331-
322+
332323
Responses:
333324
- Still waiting: {"status": "pending"}
334325
- User authorized: {"status": "authorized", "user_id": "...", "id_token": "...", "refresh_token": "..."}
335326
- Timed out: {"status": "expired", "error": "device_code_expired"}
336327
- User denied: {"status": "denied", "error": "user_denied_authorization"}
337-
328+
338329
Polling behavior:
339330
- Client should poll every 3 seconds (interval from device/code response)
340331
- Max 200 attempts (10 minutes total)
341332
- Stop immediately if user closes dialog
342333
"""
343334
try:
344-
335+
345336
if not device_code:
346337
raise HTTPException(
347338
status_code=400,
@@ -359,44 +350,84 @@ async def poll_device_code(self, device_code: str):
359350

360351
logger.info(f"[Device Poll] Device code {device_code} | status: {status.get('status')}")
361352
return status
362-
353+
363354
except HTTPException:
364355
raise
365356
except Exception as e:
366357
logger.error(f"[Device Poll] Error polling device code: {e}")
367358
raise HTTPException(status_code=500, detail=str(e))
368359

369-
async def authorize_device_code(self, request: AuthorizeDeviceRequest):
360+
async def authorize_device(
361+
self,
362+
user_code: str = Body(...),
363+
firebase_id_token: str = Body(...),
364+
firebase_refresh_token: str = Body("")
365+
):
366+
"""
367+
Authorize a device after user logs in on website.
368+
369+
Request body:
370+
{
371+
"user_code": "ABC-420",
372+
"firebase_id_token": "eyJhbGc...",
373+
"firebase_refresh_token": "AOEOulbB..." (optional for now)
374+
}
375+
376+
Response:
377+
- Success: {"success": true}
378+
- Errors: 400 (missing fields), 401 (invalid token), 404 (code not found), 500 (server error)
379+
"""
370380
try:
371-
# Look up device_code by user_code
372-
device_code = self.server_instance.auth_connector.get_device_code_by_user_code(request.user_code)
373-
374-
if device_code is None:
381+
# Validate required fields
382+
if not user_code:
383+
raise HTTPException(
384+
status_code=400,
385+
detail="Missing required field: 'user_code'"
386+
)
387+
if not firebase_id_token:
388+
raise HTTPException(
389+
status_code=400,
390+
detail="Missing required field: 'firebase_id_token'"
391+
)
392+
393+
# Verify Firebase token
394+
user_info = self.server_instance.auth_connector.verify_firebase_token(firebase_id_token)
395+
if not user_info:
396+
raise HTTPException(
397+
status_code=401,
398+
detail="Invalid Firebase token"
399+
)
400+
401+
user_id = user_info["user_id"]
402+
logger.info(f"[Device Authorize] Verified token for user: {user_id}")
403+
404+
# Lookup device_code from user_code
405+
device_code = self.server_instance.auth_connector.get_device_code_by_user_code(user_code)
406+
if not device_code:
375407
raise HTTPException(
376408
status_code=404,
377409
detail="User code not found or expired"
378410
)
379-
380-
# Mark device code as authorized with user tokens
411+
412+
# Mark device as authorized with tokens
381413
success = self.server_instance.auth_connector.set_device_code_authorized(
382-
device_code=device_code,
383-
user_id=request.user_id,
384-
id_token=request.id_token,
385-
refresh_token=request.refresh_token
414+
device_code,
415+
user_id,
416+
firebase_id_token,
417+
firebase_refresh_token
386418
)
387-
419+
388420
if not success:
389421
raise HTTPException(
390422
status_code=500,
391-
detail="Failed to authorize device code"
423+
detail="Failed to authorize device"
392424
)
393-
394-
logger.info(f"[Device Authorize] User code {request.user_code} authorized for user {request.user_id}")
395-
396-
return {"status": "success"}
397-
425+
426+
logger.info(f"[Device Authorize] Device authorized for user_code: {user_code}, user: {user_id}")
427+
return {"success": True}
428+
398429
except HTTPException:
399430
raise
400431
except Exception as e:
401-
logger.error(f"[Device Authorize] Error authorizing device code: {e}")
432+
logger.error(f"[Device Authorize] Error authorizing device: {e}")
402433
raise HTTPException(status_code=500, detail=str(e))

backend/auth/auth_connector.py

Lines changed: 41 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
Auth service for device flow authentication.
33
"""
44

5+
from firebase_admin import auth
56
import logging
67
from typing import Optional, Dict, Any
78
from datetime import datetime, timezone, timedelta
@@ -16,7 +17,7 @@
1617
class AuthConnector:
1718
"""
1819
Modal Dict wrapper for device flow authentication.
19-
20+
2021
Stores device codes with expiration (10 minutes) for OAuth device flow.
2122
"""
2223

@@ -35,12 +36,12 @@ def __init__(self, device_dict_name: str = DEFAULT_DEVICE_DICT, user_dict_name:
3536

3637
def _is_expired(self, entry: Dict[str, Any]) -> bool:
3738
"""Check if a device code entry is expired."""
38-
expires_at= entry.get("expires_at")
39+
expires_at = entry.get("expires_at")
3940
if expires_at is None:
4041
return False
4142
expires_at = datetime.fromisoformat(expires_at.replace('Z', '+00:00'))
4243
return datetime.now(timezone.utc) > expires_at
43-
44+
4445
def _delete_session(self, device_code: str, entry: Optional[Dict[str, Any]]) -> None:
4546
"""
4647
Delete both dicts safely if the entry is expired.
@@ -65,7 +66,6 @@ def generate_device_code(self) -> str:
6566

6667
def generate_user_code(self) -> str:
6768
"""Generate a user-friendly code in format ABC-420."""
68-
6969
letters = ''.join(secrets.choice(string.ascii_uppercase) for _ in range(3))
7070
digits = ''.join(secrets.choice(string.digits) for _ in range(3))
7171
return f"{letters}-{digits}"
@@ -96,7 +96,7 @@ def create_device_code_entry(
9696
def get_device_code_entry(self, device_code: str) -> Optional[Dict[str, Any]]:
9797
"""Retrieve device code entry, returns None if not found or expired."""
9898
try:
99-
99+
100100
entry = self.device_store.get(device_code)
101101
if entry is None:
102102
return None
@@ -111,7 +111,7 @@ def get_device_code_entry(self, device_code: str) -> Optional[Dict[str, Any]]:
111111
def get_device_code_by_user_code(self, user_code: str) -> Optional[str]:
112112
"""
113113
Lookup device_code by user_code.
114-
114+
115115
Returns the device_code if found and not expired, None otherwise.
116116
"""
117117
try:
@@ -124,7 +124,7 @@ def get_device_code_by_user_code(self, user_code: str) -> Optional[str]:
124124
if user_code in self.user_store:
125125
del self.user_store[user_code]
126126
return None
127-
127+
128128
return device_code
129129
except Exception as e:
130130
logger.error(f"Error looking up device_code by user_code: {e}")
@@ -136,7 +136,7 @@ def update_device_code_status(self, device_code: str, status: str) -> bool:
136136
entry = self.get_device_code_entry(device_code)
137137
if entry is None:
138138
return False
139-
139+
140140
entry["status"] = status
141141
self.device_store[device_code] = entry
142142
logger.info(f"Updated device code {device_code[:8]}... status to: {status}")
@@ -157,7 +157,7 @@ def set_device_code_authorized(
157157
entry = self.get_device_code_entry(device_code)
158158
if entry is None:
159159
return False
160-
160+
161161
entry["status"] = "authorized"
162162
entry["user_id"] = user_id
163163
entry["id_token"] = id_token
@@ -176,7 +176,7 @@ def set_device_code_denied(self, device_code: str) -> bool:
176176
entry = self.get_device_code_entry(device_code)
177177
if entry is None:
178178
return False
179-
179+
180180
entry["status"] = "denied"
181181
entry["denied_at"] = datetime.now(timezone.utc).isoformat()
182182
self.device_store[device_code] = entry
@@ -189,31 +189,35 @@ def set_device_code_denied(self, device_code: str) -> bool:
189189
def get_device_code_poll_status(self, device_code: str) -> Optional[Dict[str, Any]]:
190190
"""
191191
Get device code status for polling endpoint.
192-
192+
193193
Returns status dict with appropriate fields based on state:
194194
- pending: {"status": "pending"}
195195
- authorized: {"status": "authorized", "user_id": ..., "id_token": ..., "refresh_token": ...}
196196
- expired: {"status": "expired", "error": "device_code_expired"}
197197
- denied: {"status": "denied", "error": "user_denied_authorization"}
198198
- not_found: None (treat as expired)
199+
200+
Tokens are deleted after retrieval (one-time use).
199201
"""
200202
entry = self.get_device_code_entry(device_code)
201-
203+
202204
if entry is None:
203205
return {
204206
"status": "expired",
205207
"error": "device_code_expired"
206208
}
207-
209+
208210
status = entry.get("status", "pending")
209-
211+
210212
if status == "authorized":
211-
return {
213+
result = {
212214
"status": "authorized",
213215
"user_id": entry.get("user_id"),
214216
"id_token": entry.get("id_token"),
215217
"refresh_token": entry.get("refresh_token")
216218
}
219+
self._delete_session(device_code, entry)
220+
return result
217221
elif status == "denied":
218222
return {
219223
"status": "denied",
@@ -241,3 +245,25 @@ def delete_device_code(self, device_code: str) -> bool:
241245
except Exception as e:
242246
logger.error(f"Error deleting device code: {e}")
243247
return False
248+
249+
def verify_firebase_token(self, id_token: str) -> Optional[Dict[str, Any]]:
250+
"""Verify Firebase ID token from website/plugin."""
251+
try:
252+
decoded_token = auth.verify_id_token(id_token)
253+
return {
254+
"user_id": decoded_token['uid'],
255+
"email": decoded_token.get('email'),
256+
"email_verified": decoded_token.get('email_verified', False)
257+
}
258+
except auth.InvalidIdTokenError as e:
259+
logger.error(f"Invalid Firebase token: {e}")
260+
return None
261+
except auth.ExpiredIdTokenError as e:
262+
logger.error(f"Expired Firebase token: {e}")
263+
return None
264+
except auth.RevokedIdTokenError as e:
265+
logger.error(f"Revoked Firebase token: {e}")
266+
return None
267+
except auth.CertificateFetchError as e:
268+
logger.error(f"Firebase certificate fetch error: {e}")
269+
return None

backend/pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ dependencies = [
1616
"transformers",
1717
"scenedetect",
1818
"boto3",
19-
"torchvision"
19+
"torchvision",
20+
"firebase-admin>=7.1.0",
2021
]
2122

2223
[project.scripts]
@@ -34,7 +35,6 @@ packages = ["cli.py"]
3435

3536
[dependency-groups]
3637
dev = [
37-
"opencv-python",
3838
"pytest",
3939
"pytest-asyncio>=1.3.0",
4040
"pytest-cov",

backend/services/http_server.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,21 @@ def _initialize_connectors(self):
2727
logger.info(f"[{self.__class__.__name__}] Starting up in '{env}' environment")
2828
self.start_time = datetime.now(timezone.utc)
2929

30+
# Initialize Firebase Admin SDK (required for token verification)
31+
try:
32+
import firebase_admin
33+
import json
34+
firebase_credentials = json.loads(get_env_var("FIREBASE_ADMIN_KEY"))
35+
from firebase_admin import credentials
36+
cred = credentials.Certificate(firebase_credentials)
37+
firebase_admin.initialize_app(cred)
38+
logger.info(f"[{self.__class__.__name__}] Firebase Admin SDK initialized")
39+
except ValueError:
40+
# Already initialized, which is fine
41+
pass
42+
except Exception as e:
43+
logger.warning(f"[{self.__class__.__name__}] Firebase initialization failed: {e}")
44+
3045
# Get environment variables
3146
PINECONE_API_KEY = get_env_var("PINECONE_API_KEY")
3247
R2_ACCOUNT_ID = get_env_var("R2_ACCOUNT_ID")

backend/shared/images.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def get_dev_image() -> modal.Image:
4343
"onnxruntime",
4444
"onnxscript",
4545
"tokenizers", # For text embedder (faster than transformers import)
46+
"firebase-admin",
4647
)
4748
.run_function(_download_clip_full_model_for_dev)
4849
.run_function(_export_clip_text_to_onnx)
@@ -74,6 +75,7 @@ def get_server_image() -> modal.Image:
7475
"boto3",
7576
"pinecone",
7677
"numpy",
78+
"firebase-admin",
7779
)
7880
.add_local_python_source(
7981
"database",

0 commit comments

Comments
 (0)