diff --git a/.github/workflows/cd.yml b/.github/workflows/cd.yml index 46e2b99..a1111c6 100644 --- a/.github/workflows/cd.yml +++ b/.github/workflows/cd.yml @@ -12,6 +12,10 @@ concurrency: permissions: contents: write +# Note: We deploy the individual apps (server.py, search_app.py, processing_app.py) +# for optimal cold start performance. The dev_combined.py app is ONLY for local +# development and should never be deployed to staging/prod. + jobs: # ------------------------------------------------------------------ # STAGING DEPLOYMENT (Runs on push to 'staging') @@ -38,15 +42,29 @@ jobs: - name: Install dependencies run: uv sync --frozen - - name: Deploy to Modal (Staging) + - name: Deploy Processing App (Staging) + env: + MODAL_TOKEN_ID: ${{ secrets.MODAL_TOKEN_ID }} + MODAL_TOKEN_SECRET: ${{ secrets.MODAL_TOKEN_SECRET }} + ENVIRONMENT: staging + run: uv run modal deploy apps/processing_app.py + + - name: Deploy Search App (Staging) env: MODAL_TOKEN_ID: ${{ secrets.MODAL_TOKEN_ID }} MODAL_TOKEN_SECRET: ${{ secrets.MODAL_TOKEN_SECRET }} - ENVIRONMENT: staging # for OS - run: uv run modal deploy main.py + ENVIRONMENT: staging + run: uv run modal deploy apps/search_app.py + + - name: Deploy Server (Staging) + env: + MODAL_TOKEN_ID: ${{ secrets.MODAL_TOKEN_ID }} + MODAL_TOKEN_SECRET: ${{ secrets.MODAL_TOKEN_SECRET }} + ENVIRONMENT: staging + run: uv run modal deploy apps/server.py # ------------------------------------------------------------------ - # PRODUCTION DEPLOYMENT (Runs only when manually triggered) + # PRODUCTION DEPLOYMENT (Runs on push to 'main') # ------------------------------------------------------------------ deploy-prod: name: Deploy Production @@ -69,9 +87,23 @@ jobs: - name: Install dependencies run: uv sync --frozen - - name: Deploy to Modal (Prod) + - name: Deploy Processing App (Prod) + env: + MODAL_TOKEN_ID: ${{ secrets.MODAL_TOKEN_ID }} + MODAL_TOKEN_SECRET: ${{ secrets.MODAL_TOKEN_SECRET }} + ENVIRONMENT: prod + run: uv run modal deploy apps/processing_app.py + + - name: Deploy Search App (Prod) + env: + MODAL_TOKEN_ID: ${{ secrets.MODAL_TOKEN_ID }} + MODAL_TOKEN_SECRET: ${{ secrets.MODAL_TOKEN_SECRET }} + ENVIRONMENT: prod + run: uv run modal deploy apps/search_app.py + + - name: Deploy Server (Prod) env: MODAL_TOKEN_ID: ${{ secrets.MODAL_TOKEN_ID }} MODAL_TOKEN_SECRET: ${{ secrets.MODAL_TOKEN_SECRET }} ENVIRONMENT: prod - run: uv run modal deploy main.py \ No newline at end of file + run: uv run modal deploy apps/server.py \ No newline at end of file diff --git a/backend/README.md b/backend/README.md index 6a06389..2203340 100644 --- a/backend/README.md +++ b/backend/README.md @@ -1,6 +1,31 @@ # ClipABit Backend -Video processing backend that runs on Modal. Accepts video uploads via FastAPI and processes them in serverless containers. +Video processing backend that runs on Modal. Built as a microservices architecture with three specialized apps for optimized cold start times. + +## Architecture + +The backend uses different architectures for development vs production: + +### Production (staging/prod) + +Split into three Modal apps for optimal cold start times: + +| App | Purpose | Dependencies | +|-----|---------|--------------| +| **server** | API gateway, handles requests, lightweight operations | Minimal (FastAPI, boto3) | +| **search** | Semantic search with CLIP text encoder | Medium (torch, transformers) | +| **processing** | Video processing, embedding generation | Heavy (ffmpeg, opencv, torch, transformers) | + +This architecture ensures that lightweight API calls (health, status, list) don't need to load heavy ML models. + +### Local Development (dev) + +Single combined Modal app (`dev_combined.py`) with all three services: + +| App | Purpose | +|-----|---------| +| **dev-combined** | Server + Search + Processing in one app | +This allows hot-reload on all services without cross-app lookup issues. Cold start time is acceptable for local development where iteration speed matters more than cold start performance. ## Quick Start @@ -11,34 +36,58 @@ uv sync # 2. Authenticate with Modal (first time only - opens browser) uv run modal token new -# 3. Configure Modal secrets (dev and prod) -modal secret create dev \ - ENVIRONMENT=dev \ - PINECONE_API_KEY=your_pinecone_api_key \ - R2_ACCOUNT_ID=your_r2_account_id \ - R2_ACCESS_KEY_ID=your_r2_access_key_id \ - R2_SECRET_ACCESS_KEY=your_r2_secret_access_key - -modal secret create prod \ - ENVIRONMENT=prod \ - PINECONE_API_KEY=your_pinecone_api_key \ - R2_ACCOUNT_ID=your_r2_account_id \ - R2_ACCESS_KEY_ID=your_r2_access_key_id \ - R2_SECRET_ACCESS_KEY=your_r2_secret_access_key - -# 4. Start dev server (hot-reloads on file changes, uses "dev" secret) +# 3. Start dev server uv run dev ``` Note: `uv run` automatically uses the virtual environment - no need to activate it manually. +## Development CLI + +### Local Development (Combined App) + +For local development, use the combined app that includes all services in one: + +| Command | Description | +|---------|-------------| +| `uv run dev` | Run combined dev app (server + search + processing in one) | + +This uses `apps/dev_combined.py` which combines all three services into a single Modal app. Benefits: +- Hot-reload works on all services (server, search, processing) +- No cross-app lookup issues +- Easy to iterate on any part of the system + +**Note:** Cold starts will be slower since all dependencies load together, but this is acceptable for local development where iteration speed matters more than cold start performance. + +### Individual Apps (For Testing/Debugging) + +You can also run individual apps if needed: + +| Command | Description | +|---------|-------------| +| `uv run server` | Run just the API server | +| `uv run search` | Run just the search app | +| `uv run processing` | Run just the processing app | + + +**Note:** Cross-app communication only works between deployed apps, not ephemeral serve apps. For full system testing, use `uv run dev` (combined app) or deploy the apps. + ## How It Works -- `main.py` defines a Modal App with a `Server` class -- `/upload` endpoint accepts video files and spawns background processing jobs +### Production Architecture (staging/prod) + +- `apps/server.py` - API gateway, delegates heavy work to other apps via `modal.Cls.from_name()` +- `apps/search_app.py` - Handles semantic search queries with CLIP text encoder +- `apps/processing_app.py` - Processes video uploads (chunking, embedding, storage) +- Cross-app communication uses Modal's `Cls.from_name()` for lookups and `spawn()`/`remote()` for calls - Environment variables stored in Modal secrets (no .env files needed) -- `uv run dev` automatically uses "dev" secret for development -- Production deployment handled via CI/CD or direct Modal CLI + +### Local Development Architecture (dev) + +- `apps/dev_combined.py` - All three services in one app for easy iteration +- Uses `api/fastapi_router.py` (configured for dev combined mode) which accepts worker class references instead of doing lookups +- No cross-app lookups needed - services call each other directly within the same app +- Hot-reload works on all services simultaneously ## Managing Dependencies @@ -56,3 +105,13 @@ uv run pytest # Run all tests uv run pytest -v # Verbose output uv run pytest --cov # With coverage ``` + +Note: Some integration tests require `ffmpeg` to be installed locally. + +## Deployment + +Deployment is handled via GitHub Actions CI/CD. **Only the individual apps are deployed** (not dev_combined.py): + +1. `processing_app.py` (heavy dependencies) +2. `search_app.py` (medium dependencies) +3. `server.py` (API gateway - depends on the other two) diff --git a/backend/api/__init__.py b/backend/api/__init__.py index fcb15b2..9b1c473 100644 --- a/backend/api/__init__.py +++ b/backend/api/__init__.py @@ -1,3 +1,4 @@ -from .fastapi_router import FastAPIRouter +from .server_fastapi_router import ServerFastAPIRouter +from .search_fastapi_router import SearchFastAPIRouter -__all__ = ["FastAPIRouter"] +__all__ = ["ServerFastAPIRouter", "SearchFastAPIRouter"] diff --git a/backend/api/fastapi_router.py b/backend/api/fastapi_router.py deleted file mode 100644 index bc42805..0000000 --- a/backend/api/fastapi_router.py +++ /dev/null @@ -1,224 +0,0 @@ -__all__ = ["FastAPIRouter"] - -import logging -import time -import uuid -from fastapi import APIRouter, Form, HTTPException, UploadFile - -logger = logging.getLogger(__name__) - -class FastAPIRouter: - def __init__(self, server_instance, is_internal_env): - """ - Initializes the API routes, giving them access to the server instance - for calling background tasks and accessing shared state. - """ - self.server_instance = server_instance - self.is_internal_env = is_internal_env - self.router = APIRouter() - self._register_routes() - - def _register_routes(self): - """Registers all the FastAPI routes.""" - self.router.add_api_route("/health", self.health, methods=["GET"]) - self.router.add_api_route("/status", self.status, methods=["GET"]) - self.router.add_api_route("/upload", self.upload, methods=["POST"]) - self.router.add_api_route("/search", self.search, methods=["GET"]) - self.router.add_api_route("/videos", self.list_videos, methods=["GET"]) - self.router.add_api_route("/videos/{hashed_identifier}", self.delete_video, methods=["DELETE"]) - - async def health(self): - """ - Health check endpoint. - Returns a simple status message indicating the service is running. - """ - return {"status": "ok"} - - async def status(self, job_id: str): - """ - Check the status of a video processing job. - - Args: - job_id (str): The unique identifier for the video processing job. - - Returns: - dict: Contains: - - job_id (str): The job identifier - - status (str): 'processing', 'completed', or 'failed' - - message (str, optional): If still processing or not found - - result (dict, optional): Full job result if completed - - This endpoint allows clients (e.g., frontend) to poll for job progress and retrieve results when ready. - """ - job_data = self.server_instance.job_store.get_job(job_id) - if job_data is None: - return { - "job_id": job_id, - "status": "processing", - "message": "Job is still processing or not found" - } - return job_data - - async def upload(self, file: UploadFile, namespace: str = Form("")): - """ - Handle video file upload and start background processing. - - Args: - file (UploadFile): The uploaded video file. - namespace (str, optional): Namespace for Pinecone and R2 storage (default: "") - - Returns: - dict: Contains job_id, filename, content_type, size_bytes, status, and message. - """ - contents = await file.read() - file_size = len(contents) - job_id = str(uuid.uuid4()) - - self.server_instance.job_store.create_job(job_id, { - "job_id": job_id, - "filename": file.filename, - "status": "processing", - "size_bytes": file_size, - "content_type": file.content_type, - "namespace": namespace - }) - - self.server_instance.process_video_background.spawn(contents, file.filename, job_id, namespace) - - return { - "job_id": job_id, - "filename": file.filename, - "content_type": file.content_type, - "size_bytes": file_size, - "status": "processing", - "message": "Video uploaded successfully, processing in background" - } - - async def search(self, query: str, namespace: str = "", top_k: int = 10): - """ - Search endpoint - accepts a text query and returns semantic search results. - - Args: - query (str): The search query string (required) - namespace (str, optional): Namespace for Pinecone search (default: "") - top_k (int, optional): Number of top results to return (default: 10) - - Returns: - json: dict with 'query', 'results', and 'timing' - - Raises: - HTTPException: If search fails (500 Internal Server Error) - """ - try: - t_start = time.perf_counter() - logger.info(f"[Search] Query: '{query}' | namespace='{namespace}' | top_k={top_k}") - - results = self.server_instance.searcher.search( - query=query, - top_k=top_k, - namespace=namespace - ) - - t_done = time.perf_counter() - logger.info(f"[Search] Found {len(results)} chunk-level results in {t_done - t_start:.3f}s") - - return { - "query": query, - "results": results, - "timing": { - "total_s": round(t_done - t_start, 3) - } - } - except Exception as e: - logger.error(f"[Search] Error: {e}") - raise HTTPException(status_code=500, detail=str(e)) - - async def list_videos( - self, - namespace: str = "__default__", - page_size: int = 20, - page_token: str | None = None, - ): - """ - List all videos stored in R2 for the given namespace. - - Args: - namespace (str, optional): Namespace for R2 storage (default: "__default__") - - Returns: - json: dict with 'status', 'namespace', and 'videos' list - - Raises: - HTTPException: If fetching videos fails (500 Internal Server Error) - """ - if page_size <= 0: - raise HTTPException(status_code=400, detail="page_size must be positive") - - logger.info( - "[List Videos] Fetching videos for namespace: %s (page_size=%s, page_token=%s)", - namespace, - page_size, - page_token, - ) - try: - videos, next_token, total_videos, total_pages = ( - self.server_instance.r2_connector.list_videos_page( - namespace=namespace, - page_size=page_size, - continuation_token=page_token, - ) - ) - return { - "status": "success", - "namespace": namespace, - "videos": videos, - "next_page_token": next_token, - "total_videos": total_videos, - "total_pages": total_pages, - } - except Exception as e: - logger.error(f"[List Videos] Error fetching videos: {e}") - raise HTTPException(status_code=500, detail=str(e)) - - async def delete_video(self, hashed_identifier: str, filename: str, namespace: str = ""): - """ - Delete a video and its associated chunks from storage and database. - - Args: - hashed_identifier (str): The unique identifier of the video in R2 storage. - filename (str): The original filename of the video. - namespace (str, optional): Namespace for Pinecone and R2 storage (default: "") - - Returns: - dict: Contains status and message about deletion result. - - Raises: - HTTPException: If deletion fails at any step. - - 500 Internal Server Error with details. - - 400 Bad Request if parameters are missing. - - 404 Not Found if video does not exist. - - 403 Forbidden if deletion is not allowed. - """ - logger.info(f"[Delete Video] Request to delete video: {filename} ({hashed_identifier}) | namespace='{namespace}'") - if not self.is_internal_env: - raise HTTPException(status_code=403, detail="Video deletion is not allowed in the current environment.") - - job_id = str(uuid.uuid4()) - self.server_instance.job_store.create_job(job_id, { - "job_id": job_id, - "hashed_identifier": hashed_identifier, - "namespace": namespace, - "status": "processing", - "operation": "delete" - }) - - self.server_instance.delete_video_background.spawn(job_id, hashed_identifier, namespace) - - return { - "job_id": job_id, - "hashed_identifier": hashed_identifier, - "namespace": namespace, - "status": "processing", - "message": "Video deletion started, processing in background" - } - diff --git a/backend/api/search_fastapi_router.py b/backend/api/search_fastapi_router.py new file mode 100644 index 0000000..0f4248b --- /dev/null +++ b/backend/api/search_fastapi_router.py @@ -0,0 +1,80 @@ +""" +FastAPI router for the Search service. + +Exposes the search endpoint directly from the SearchService, +eliminating the need to go through the Server gateway. +""" + +__all__ = ["SearchFastAPIRouter"] + +import logging +import time + +from fastapi import APIRouter, HTTPException + +logger = logging.getLogger(__name__) + + +class SearchFastAPIRouter: + """ + FastAPI router for the Search service. + + Handles: semantic search queries. + Exposed directly by SearchService for lower latency (no server hop). + """ + + def __init__(self, search_service_instance): + """ + Initialize the search router. + + Args: + search_service_instance: The SearchService instance with embedder and connectors + """ + self.search_service = search_service_instance + self.router = APIRouter() + self._register_routes() + + def _register_routes(self): + """Register all search routes.""" + self.router.add_api_route("/health", self.health, methods=["GET"]) + self.router.add_api_route("/search", self.search, methods=["GET"]) + + async def health(self): + """Health check endpoint.""" + return {"status": "ok", "service": "search"} + + async def search(self, query: str, namespace: str = "", top_k: int = 10): + """ + Search endpoint - accepts a text query and returns semantic search results. + + Args: + query (str): The search query string (required) + namespace (str, optional): Namespace for Pinecone search (default: "") + top_k (int, optional): Number of top results to return (default: 10) + + Returns: + json: dict with 'query', 'results', and 'timing' + + Raises: + HTTPException: If search fails (500 Internal Server Error) + """ + try: + t_start = time.perf_counter() + logger.info(f"[Search] Query: '{query}' | namespace='{namespace}' | top_k={top_k}") + + # Call search directly on the service instance (no RPC, no cross-app call) + results = self.search_service._search_internal(query, namespace, top_k) + + t_done = time.perf_counter() + logger.info(f"[Search] Found {len(results)} results in {t_done - t_start:.3f}s") + + return { + "query": query, + "results": results, + "timing": { + "total_s": round(t_done - t_start, 3) + } + } + except Exception as e: + logger.error(f"[Search] Error: {e}") + raise HTTPException(status_code=500, detail=str(e)) diff --git a/backend/api/server_fastapi_router.py b/backend/api/server_fastapi_router.py new file mode 100644 index 0000000..417a6cc --- /dev/null +++ b/backend/api/server_fastapi_router.py @@ -0,0 +1,357 @@ +__all__ = ["ServerFastAPIRouter"] + +import logging +import uuid + +import modal +from fastapi import APIRouter, File, Form, HTTPException, UploadFile + +logger = logging.getLogger(__name__) + + +class ServerFastAPIRouter: + """ + FastAPI router for the Server service. + + Handles: health, status, upload, list_videos, delete, cache operations. + Search is handled separately by SearchService with its own ASGI app. + """ + + def __init__( + self, + server_instance, + is_internal_env: bool, + environment: str = "dev", + processing_service_cls=None + ): + """ + Initializes the API routes, giving them access to the server instance + for calling background tasks and accessing shared state. + + Args: + server_instance: The Modal server instance for accessing connectors and spawning local methods + is_internal_env: Whether this is an internal (dev/staging) environment + environment: Environment name (dev, staging, prod) for cross-app lookups + processing_service_cls: Optional ProcessingService class for dev combined mode (direct access) + """ + self.server_instance = server_instance + self.is_internal_env = is_internal_env + self.environment = environment + self.processing_service_cls = processing_service_cls + self.router = APIRouter() + + # Initialize UploadHandler with process_video spawn function + from services.upload_handler import UploadHandler + self.upload_handler = UploadHandler( + job_store=server_instance.job_store, + process_video_spawn_fn=self._get_process_video_spawn_fn() + ) + + self._register_routes() + + def _get_process_video_spawn_fn(self): + """ + Create a spawn function that works in both dev combined and production modes. + + Returns: + Callable that spawns process_video_background + """ + def spawn_process_video(video_bytes: bytes, filename: str, job_id: str, namespace: str, parent_batch_id: str): + try: + if self.processing_service_cls: + # Dev combined mode - direct access + self.processing_service_cls().process_video_background.spawn( + video_bytes, filename, job_id, namespace, parent_batch_id + ) + logger.info(f"[Upload] Spawned processing job {job_id} (dev combined mode)") + else: + # Production mode - cross-app call + from shared.config import get_modal_environment + processing_app_name = f"{self.environment}-processing" + ProcessingService = modal.Cls.from_name( + processing_app_name, + "ProcessingService", + environment_name=get_modal_environment() + ) + ProcessingService().process_video_background.spawn( + video_bytes, filename, job_id, namespace, parent_batch_id + ) + logger.info(f"[Upload] Spawned processing job {job_id} to {processing_app_name}") + except Exception as e: + logger.error(f"[Upload] Failed to spawn processing job {job_id}: {e}") + raise + + return spawn_process_video + + def _register_routes(self): + """Registers all the FastAPI routes.""" + self.router.add_api_route("/health", self.health, methods=["GET"]) + self.router.add_api_route("/status", self.status, methods=["GET"]) + self.router.add_api_route("/upload", self.upload, methods=["POST"]) + self.router.add_api_route("/videos", self.list_videos, methods=["GET"]) + self.router.add_api_route("/videos/{hashed_identifier}", self.delete_video, methods=["DELETE"]) + self.router.add_api_route("/cache/clear", self.clear_cache, methods=["POST"]) + self.router.add_api_route("/auth/device/code", self.request_device_code, methods=["POST"]) + self.router.add_api_route("/auth/device/poll", self.poll_device_code, methods=["POST"]) + + async def health(self): + """ + Health check endpoint. + Returns a simple status message indicating the service is running. + """ + return {"status": "ok"} + + async def status(self, job_id: str): + """ + Check the status of a video processing job. + + Args: + job_id (str): The unique identifier for the video processing job. + + Returns: + dict: Contains: + - job_id (str): The job identifier + - status (str): 'processing', 'completed', or 'failed' + - message (str, optional): If still processing or not found + - result (dict, optional): Full job result if completed + + This endpoint allows clients (e.g., frontend) to poll for job progress and retrieve results when ready. + """ + job_data = self.server_instance.job_store.get_job(job_id) + if job_data is None: + return { + "job_id": job_id, + "status": "processing", + "message": "Job is still processing or not found" + } + return job_data + + async def upload(self, files: list[UploadFile] = File(default=[]), namespace: str = Form("")): + """ + Handle video file upload and start background processing. + Supports both single and batch uploads. + + Args: + files (list[UploadFile]): List of uploaded video file(s). Client sends files with repeated 'files' field names, which FastAPI collects into a list. + namespace (str, optional): Namespace for Pinecone and R2 storage (default: "") + + Returns: + dict: For single upload: job_id, filename, etc. + For batch upload: batch_job_id, total_videos, child_jobs, etc. + + Raises: + HTTPException: 400 if validation fails, 500 if processing errors + """ + return await self.upload_handler.handle_upload(files, namespace) + + async def list_videos( + self, + namespace: str = "__default__", + page_size: int = 20, + page_token: str | None = None, + ): + """ + List all videos stored in R2 for the given namespace. + + Args: + namespace (str, optional): Namespace for R2 storage (default: "__default__") + + Returns: + json: dict with 'status', 'namespace', and 'videos' list + + Raises: + HTTPException: If fetching videos fails (500 Internal Server Error) + """ + if page_size <= 0: + raise HTTPException(status_code=400, detail="page_size must be positive") + + logger.info( + "[List Videos] Fetching videos for namespace: %s (page_size=%s, page_token=%s)", + namespace, + page_size, + page_token, + ) + try: + videos, next_token, total_videos, total_pages = ( + self.server_instance.r2_connector.list_videos_page( + namespace=namespace, + page_size=page_size, + continuation_token=page_token, + ) + ) + return { + "status": "success", + "namespace": namespace, + "videos": videos, + "next_page_token": next_token, + "total_videos": total_videos, + "total_pages": total_pages, + } + except Exception as e: + logger.error(f"[List Videos] Error fetching videos: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + async def delete_video(self, hashed_identifier: str, filename: str, namespace: str = ""): + """ + Delete a video and its associated chunks from storage and database. + + Args: + hashed_identifier (str): The unique identifier of the video in R2 storage. + filename (str): The original filename of the video. + namespace (str, optional): Namespace for Pinecone and R2 storage (default: "") + + Returns: + dict: Contains status and message about deletion result. + + Raises: + HTTPException: If deletion fails at any step. + - 500 Internal Server Error with details. + - 400 Bad Request if parameters are missing. + - 404 Not Found if video does not exist. + - 403 Forbidden if deletion is not allowed. + """ + logger.info(f"[Delete Video] Request to delete video: {filename} ({hashed_identifier}) | namespace='{namespace}'") + if not self.is_internal_env: + raise HTTPException(status_code=403, detail="Video deletion is not allowed in the current environment.") + + job_id = str(uuid.uuid4()) + self.server_instance.job_store.create_job(job_id, { + "job_id": job_id, + "hashed_identifier": hashed_identifier, + "namespace": namespace, + "status": "processing", + "operation": "delete" + }) + + self.server_instance.delete_video_background.spawn(job_id, hashed_identifier, namespace) + + return { + "job_id": job_id, + "hashed_identifier": hashed_identifier, + "namespace": namespace, + "status": "processing", + "message": "Video deletion started, processing in background" + } + + async def clear_cache(self, namespace: str = "__default__"): + """ + Clear the URL cache for a given namespace. + + Args: + namespace (str, optional): Namespace to clear cache for (default: "__default__") + + Returns: + dict: Contains status and number of entries cleared + + Raises: + HTTPException: If cache clearing is not allowed (403 Forbidden) + """ + logger.info(f"[Clear Cache] Request to clear cache for namespace: {namespace}") + if not self.is_internal_env: + raise HTTPException(status_code=403, detail="Cache clearing is not allowed in the current environment.") + + try: + cleared_count = self.server_instance.r2_connector.clear_cache(namespace) + logger.info(f"[Clear Cache] Cleared {cleared_count} cache entries for namespace: {namespace}") + return { + "status": "success", + "namespace": namespace, + "cleared_entries": cleared_count, + "message": f"Successfully cleared {cleared_count} cache entries" + } + except Exception as e: + logger.error(f"[Clear Cache] Error clearing cache: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + async def request_device_code(self): + """ + Request device codes for OAuth device flow. + + No request body needed. + + Returns: + dict: { + "device_code": str, # Long secret code + "user_code": str, # User-friendly code (e.g., "ABC-420") + "verification_url": str, # URL where user enters code + "expires_in": int, # Expiration time in seconds (600 = 10 minutes) + "interval": int # Polling interval in seconds (3) + } + """ + try: + device_code = self.server_instance.auth_connector.generate_device_code() + user_code = self.server_instance.auth_connector.generate_user_code() + + expires_in = 600 + success = self.server_instance.auth_connector.create_device_code_entry( + device_code=device_code, + user_code=user_code, + expires_in=expires_in + ) + + if not success: + raise HTTPException( + status_code=500, + detail="Failed to create device code entry" + ) + + logger.info(f"[Device Code] Generated device code for user_code: {user_code}") + + return { + "device_code": device_code, + "user_code": user_code, + "verification_url": "clipabit.web.app/auth/device", + "expires_in": expires_in, + "interval": 3 + } + except HTTPException: + raise + except Exception as e: + logger.error(f"[Device Code] Error generating device code: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + async def poll_device_code(self, device_code: str): + """ + Poll for device code authorization status. + + Request body: + { + "device_code": "a8f3j2k1..." + } + + Responses: + - Still waiting: {"status": "pending"} + - User authorized: {"status": "authorized", "user_id": "...", "id_token": "...", "refresh_token": "..."} + - Timed out: {"status": "expired", "error": "device_code_expired"} + - User denied: {"status": "denied", "error": "user_denied_authorization"} + + Polling behavior: + - Client should poll every 3 seconds (interval from device/code response) + - Max 200 attempts (10 minutes total) + - Stop immediately if user closes dialog + """ + try: + + if not device_code: + raise HTTPException( + status_code=400, + detail="Missing required field: 'device_code'" + ) + + + status = self.server_instance.auth_connector.get_device_code_poll_status(device_code) + + if status is None: + return { + "status": "expired", + "error": "device_code_expired" + } + + logger.info(f"[Device Poll] Device code {device_code} | status: {status.get('status')}") + return status + + except HTTPException: + raise + except Exception as e: + logger.error(f"[Device Poll] Error polling device code: {e}") + raise HTTPException(status_code=500, detail=str(e)) diff --git a/backend/apps/__init__.py b/backend/apps/__init__.py new file mode 100644 index 0000000..efea4c1 --- /dev/null +++ b/backend/apps/__init__.py @@ -0,0 +1,13 @@ +""" +Modal apps package. + +Contains separate Modal app entry points: +- dev_combined.py: Combined app for local development +- server.py: Server app (upload, status, videos, delete) +- search_app.py: Search app (semantic search) +- processing_app.py: Processing app (video processing) + +These are entry points, not modules to be imported. +""" + +__all__ = [] diff --git a/backend/apps/dev_combined.py b/backend/apps/dev_combined.py new file mode 100644 index 0000000..d66a57c --- /dev/null +++ b/backend/apps/dev_combined.py @@ -0,0 +1,76 @@ +""" +Dev Combined Modal App + +ONLY FOR LOCAL DEVELOPMENT (ENVIRONMENT=dev) + +Combines Server, Search, and Processing into one app for easy local iteration. +All three services run in the same Modal app, so no cross-app lookups needed. + +For staging/prod deployments, use the separate apps: +- apps/server.py +- apps/search_app.py +- apps/processing_app.py + +Endpoints: +- DevServer: /health, /status, /upload, /videos, /videos/{id}, /cache/clear +- DevSearchService: /health, /search (separate ASGI app for lower latency) +""" + +import logging +import modal + +from shared.config import get_environment, get_secrets +from shared.images import get_dev_image +from services.search_service import SearchService +from services.processing_service import ProcessingService +from services.http_server import ServerService + +logger = logging.getLogger(__name__) + +# Environment setup +env = get_environment() + + +if env != "dev": + raise ValueError( + f"dev_combined.py is ONLY for local development (ENVIRONMENT=dev). " + f"Current environment: {env}. Use separate apps for staging/prod." + ) + +logger.info("Starting Combined Dev App - all services in one app for local iteration") + +app = modal.App( + name=f"{env}-server", + image=get_dev_image(), + secrets=[get_secrets()] +) + +# SearchService exposes its own ASGI app for direct HTTP access (no server hop) +DevSearchService = app.cls( + cpu=2.0, + memory=2048, + timeout=60, + scaledown_window=120, + enable_memory_snapshot=True, # Snapshot after @enter() for faster subsequent cold starts +)(SearchService) + +DevProcessingService = app.cls(cpu=4.0, memory=4096, timeout=600)(ProcessingService) + + +# Define DevServer to add the asgi_app method and pass service classes +@app.cls(cpu=2.0, memory=2048, timeout=120, scaledown_window=120) +class DevServer(ServerService): + """Server with ASGI app for dev combined mode (excludes search).""" + + @modal.enter() + def startup(self): + """Initialize connectors and create FastAPI app with service classes.""" + self._initialize_connectors() + # Create FastAPI app (search is handled by DevSearchService's own ASGI app) + self.fastapi_app = self.create_fastapi_app( + processing_service_cls=DevProcessingService + ) + + @modal.asgi_app() + def asgi_app(self): + return self.fastapi_app diff --git a/backend/apps/processing_app.py b/backend/apps/processing_app.py new file mode 100644 index 0000000..7fee60a --- /dev/null +++ b/backend/apps/processing_app.py @@ -0,0 +1,31 @@ +""" +Processing Modal App + +Handles video processing with full CLIP image encoder and preprocessing pipeline. +Heavy dependencies (~15-20s cold start) - acceptable for background jobs. + +This app is spawned by the Server for video uploads. +""" + +import logging +import modal + +from shared.config import get_environment, get_secrets +from shared.images import get_processing_image +from services.processing_service import ProcessingService + +logger = logging.getLogger(__name__) + +# Environment setup +env = get_environment() +logger.info(f"Starting Processing App in '{env}' environment") + +# Create Modal app with processing-specific image +app = modal.App( + name=f"{env}-processing", + image=get_processing_image(), + secrets=[get_secrets()] +) + +# Register ProcessingService with this app +app.cls(cpu=4.0, memory=4096, timeout=600)(ProcessingService) diff --git a/backend/apps/search_app.py b/backend/apps/search_app.py new file mode 100644 index 0000000..8343caf --- /dev/null +++ b/backend/apps/search_app.py @@ -0,0 +1,31 @@ +""" +Search Modal App + +Handles semantic search with CLIP text encoder. +Medium-weight dependencies (~8-10s cold start) - lighter than full video processing. + +Uses CLIPTextModelWithProjection (~150MB) instead of full CLIPModel (~350MB). +""" + +import logging +import modal + +from shared.config import get_environment, get_secrets +from shared.images import get_search_image +from services.search_service import SearchService + +logger = logging.getLogger(__name__) + +# Environment setup +env = get_environment() +logger.info(f"Starting Search App in '{env}' environment") + +# Create Modal app with search-specific image +app = modal.App( + name=f"{env}-search", + image=get_search_image(), + secrets=[get_secrets()] +) + +# Register SearchService with this app +app.cls(cpu=2.0, memory=2048, timeout=60, scaledown_window=120)(SearchService) diff --git a/backend/apps/server.py b/backend/apps/server.py new file mode 100644 index 0000000..4c863cf --- /dev/null +++ b/backend/apps/server.py @@ -0,0 +1,40 @@ +""" +Server Modal App + +Handles all HTTP endpoints with minimal dependencies for fast cold starts (~3-5s). +""" + +import logging +import modal + +from shared.config import get_environment, get_secrets +from shared.images import get_server_image +from services.http_server import ServerService + +logger = logging.getLogger(__name__) + +# Environment setup +env = get_environment() + +# Create Modal app with minimal image +app = modal.App( + name=f"{env}-server", + image=get_server_image(), + secrets=[get_secrets()] +) + + +@app.cls(cpu=2.0, memory=2048, timeout=120) +class Server(ServerService): + """Server with ASGI app for production deployment.""" + + @modal.enter() + def startup(self): + """Initialize connectors and create FastAPI app.""" + self._initialize_connectors() + # Create FastAPI app (no service classes for production - uses from_name) + self.fastapi_app = self.create_fastapi_app() + + @modal.asgi_app() + def asgi_app(self): + return self.fastapi_app diff --git a/backend/auth/__init__.py b/backend/auth/__init__.py new file mode 100644 index 0000000..a6e8307 --- /dev/null +++ b/backend/auth/__init__.py @@ -0,0 +1,7 @@ +""" +Auth module for device flow authentication. +""" + +from auth.auth_connector import AuthConnector + +__all__ = ["AuthConnector"] diff --git a/backend/auth/auth_connector.py b/backend/auth/auth_connector.py new file mode 100644 index 0000000..5129cc0 --- /dev/null +++ b/backend/auth/auth_connector.py @@ -0,0 +1,243 @@ +""" +Auth service for device flow authentication. +""" + +import logging +from typing import Optional, Dict, Any +from datetime import datetime, timezone, timedelta +import secrets +import string +import modal + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +class AuthConnector: + """ + Modal Dict wrapper for device flow authentication. + + Stores device codes with expiration (10 minutes) for OAuth device flow. + """ + + DEFAULT_DEVICE_DICT = "clipabit-auth-device-codes" + DEFAULT_USER_DICT = "clipabit-auth-user-codes" + + TOKEN_EXPIRY_TIME = 600 + + def __init__(self, device_dict_name: str = DEFAULT_DEVICE_DICT, user_dict_name: str = DEFAULT_USER_DICT): + self.device_dict_name = device_dict_name + self.user_dict_name = user_dict_name + + self.device_store = modal.Dict.from_name(device_dict_name, create_if_missing=True) + self.user_store = modal.Dict.from_name(user_dict_name, create_if_missing=True) + logger.info(f"Initialized AuthConnector with device dict: {device_dict_name} and user dict: {user_dict_name}") + + def _is_expired(self, entry: Dict[str, Any]) -> bool: + """Check if a device code entry is expired.""" + expires_at= entry.get("expires_at") + if expires_at is None: + return False + expires_at = datetime.fromisoformat(expires_at.replace('Z', '+00:00')) + return datetime.now(timezone.utc) > expires_at + + def _delete_session(self, device_code: str, entry: Optional[Dict[str, Any]]) -> None: + """ + Delete both dicts safely if the entry is expired. + device_code -> entry + user_code -> device_code + """ + if entry is None: + entry = self.device_store.get(device_code) + if entry: + user_code = entry.get("user_code") + if user_code: + if user_code in self.user_store: + del self.user_store[user_code] + logger.info(f"Deleted expired user code entry for user_code: {user_code}") + if device_code in self.device_store: + del self.device_store[device_code] + logger.info(f"Deleted expired device code entry for device_code: {device_code[:8]}...") + + def generate_device_code(self) -> str: + """Generate a secure random device code.""" + return secrets.token_urlsafe(48) + + def generate_user_code(self) -> str: + """Generate a user-friendly code in format ABC-420.""" + + letters = ''.join(secrets.choice(string.ascii_uppercase) for _ in range(3)) + digits = ''.join(secrets.choice(string.digits) for _ in range(3)) + return f"{letters}-{digits}" + + def create_device_code_entry( + self, + device_code: str, + user_code: str, + expires_in: int = TOKEN_EXPIRY_TIME + ) -> bool: + """Create a new device code entry with expiration.""" + try: + expires_at = datetime.now(timezone.utc) + timedelta(seconds=expires_in) + entry = { + "user_code": user_code, + "status": "pending", + "created_at": datetime.now(timezone.utc).isoformat(), + "expires_at": expires_at.isoformat() + } + self.device_store[device_code] = entry + self.user_store[user_code] = device_code + logger.info(f"Created device code entry for user_code: {user_code}") + return True + except Exception as e: + logger.error(f"Error creating device code entry: {e}") + return False + + def get_device_code_entry(self, device_code: str) -> Optional[Dict[str, Any]]: + """Retrieve device code entry, returns None if not found or expired.""" + try: + + entry = self.device_store.get(device_code) + if entry is None: + return None + if self._is_expired(entry): + self._delete_session(device_code, entry) + return None + return entry + except Exception as e: + logger.error(f"Error retrieving device code entry: {e}") + return None + + def get_device_code_by_user_code(self, user_code: str) -> Optional[str]: + """ + Lookup device_code by user_code. + + Returns the device_code if found and not expired, None otherwise. + """ + try: + device_code = self.user_store.get(user_code) + if not device_code: + return None + + entry = self.get_device_code_entry(device_code) + if entry is None: + if user_code in self.user_store: + del self.user_store[user_code] + return None + + return device_code + except Exception as e: + logger.error(f"Error looking up device_code by user_code: {e}") + return None + + def update_device_code_status(self, device_code: str, status: str) -> bool: + """Update the status of a device code (e.g., 'pending' -> 'authorized').""" + try: + entry = self.get_device_code_entry(device_code) + if entry is None: + return False + + entry["status"] = status + self.device_store[device_code] = entry + logger.info(f"Updated device code {device_code[:8]}... status to: {status}") + return True + except Exception as e: + logger.error(f"Error updating device code status: {e}") + return False + + def set_device_code_authorized( + self, + device_code: str, + user_id: str, + id_token: str, + refresh_token: str + ) -> bool: + """Mark device code as authorized and store user tokens.""" + try: + entry = self.get_device_code_entry(device_code) + if entry is None: + return False + + entry["status"] = "authorized" + entry["user_id"] = user_id + entry["id_token"] = id_token + entry["refresh_token"] = refresh_token + entry["authorized_at"] = datetime.now(timezone.utc).isoformat() + self.device_store[device_code] = entry + logger.info(f"Device code {device_code[:8]}... authorized for user {user_id}") + return True + except Exception as e: + logger.error(f"Error setting device code as authorized: {e}") + return False + + def set_device_code_denied(self, device_code: str) -> bool: + """Mark device code as denied by user.""" + try: + entry = self.get_device_code_entry(device_code) + if entry is None: + return False + + entry["status"] = "denied" + entry["denied_at"] = datetime.now(timezone.utc).isoformat() + self.device_store[device_code] = entry + logger.info(f"Device code {device_code[:8]}... denied by user") + return True + except Exception as e: + logger.error(f"Error setting device code as denied: {e}") + return False + + def get_device_code_poll_status(self, device_code: str) -> Optional[Dict[str, Any]]: + """ + Get device code status for polling endpoint. + + Returns status dict with appropriate fields based on state: + - pending: {"status": "pending"} + - authorized: {"status": "authorized", "user_id": ..., "id_token": ..., "refresh_token": ...} + - expired: {"status": "expired", "error": "device_code_expired"} + - denied: {"status": "denied", "error": "user_denied_authorization"} + - not_found: None (treat as expired) + """ + entry = self.get_device_code_entry(device_code) + + if entry is None: + return { + "status": "expired", + "error": "device_code_expired" + } + + status = entry.get("status", "pending") + + if status == "authorized": + return { + "status": "authorized", + "user_id": entry.get("user_id"), + "id_token": entry.get("id_token"), + "refresh_token": entry.get("refresh_token") + } + elif status == "denied": + return { + "status": "denied", + "error": "user_denied_authorization" + } + elif status == "pending": + return { + "status": "pending" + } + else: + return { + "status": "expired", + "error": "device_code_expired" + } + + def delete_device_code(self, device_code: str) -> bool: + """Remove device code entry.""" + try: + entry = self.get_device_code_entry(device_code) + if entry is None: + return False + self._delete_session(device_code, entry) + logger.info(f"Deleted device code entry for device_code: {device_code[:8]}...") + return True + except Exception as e: + logger.error(f"Error deleting device code: {e}") + return False diff --git a/backend/cli.py b/backend/cli.py index cd959a7..7227512 100644 --- a/backend/cli.py +++ b/backend/cli.py @@ -1,4 +1,104 @@ +"""CLI for serving Modal apps locally.""" + +import signal import subprocess +import sys + +# Dev combined app - all services in one for local iteration +DEV_COMBINED_APP = "apps/dev_combined.py" + +# Individual apps for staging/prod deployment +APPS = { + "server": ("services/http_server.py", "\033[36m"), # Cyan + "search": ("services/search_service.py", "\033[33m"), # Yellow + "processing": ("services/processing_service.py", "\033[35m"), # Magenta +} +RESET = "\033[0m" + + +def _prefix_output(process, name, color): + """Read process output and prefix each line with the app name.""" + prefix = f"{color}[{name:^10}]{RESET} " + try: + for line in iter(process.stdout.readline, ""): + if line: + print(f"{prefix}{line}", end="", flush=True) + except (ValueError, OSError): + # Process closed, ignore + pass + + +def serve_all(): + """ + Serve the combined dev app (all services in one). + + For local development, we use dev_combined.py which includes + Server, Search, and Processing in a single Modal app. + This allows hot-reload on all services without cross-app lookup issues. + """ + print("Starting combined dev app (all services in one)...\n") + print(f" \033[32m●{RESET} dev-combined (server + search + processing)\n") + print("Note: For staging/prod, deploy individual apps separately.\n") + print("-" * 60 + "\n") + + # Run with color-coded output prefixing + color = "\033[32m" # Green for combined dev app + process = subprocess.Popen( + ["modal", "serve", DEV_COMBINED_APP], + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + bufsize=1, + ) + + # Handle graceful shutdown + def signal_handler(sig, frame): + process.terminate() + sys.exit(0) + + signal.signal(signal.SIGINT, signal_handler) + signal.signal(signal.SIGTERM, signal_handler) + + # Stream output with color prefix + _prefix_output(process, "dev", color) + process.wait() + + +def _serve_single_app(name: str): + """Serve a single app with color-coded output.""" + path, color = APPS[name] + + process = subprocess.Popen( + ["modal", "serve", path], + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + bufsize=1, + ) + + # Handle graceful shutdown + def signal_handler(sig, frame): + process.terminate() + sys.exit(0) + + signal.signal(signal.SIGINT, signal_handler) + signal.signal(signal.SIGTERM, signal_handler) + + # Stream output with color prefix + _prefix_output(process, name, color) + process.wait() + + +def serve_server(): + """Serve the API server app.""" + _serve_single_app("server") + + +def serve_search(): + """Serve the search app.""" + _serve_single_app("search") + -def serve(): - subprocess.run(["modal", "serve", "main.py"]) +def serve_processing(): + """Serve the processing app.""" + _serve_single_app("processing") diff --git a/backend/database/r2_connector.py b/backend/database/r2_connector.py index 771b273..9bbdc97 100644 --- a/backend/database/r2_connector.py +++ b/backend/database/r2_connector.py @@ -364,15 +364,12 @@ def fetch_video_page( params = { "Bucket": self.bucket_name, "Prefix": prefix, - "MaxKeys": min(page_size + 1, 1000), + "MaxKeys": page_size, # Request exactly page_size items } if continuation_token: - cursor_key = self._decode_cursor_token(continuation_token) - if cursor_key: - params["StartAfter"] = cursor_key - else: - params["ContinuationToken"] = continuation_token + # Use S3's native ContinuationToken + params["ContinuationToken"] = continuation_token response = self.s3_client.list_objects_v2(**params) @@ -385,15 +382,19 @@ def fetch_video_page( continue filtered.append(obj) - has_more_flag = response.get("IsTruncated", False) - items = filtered[:page_size] - has_more = has_more_flag or len(filtered) > page_size + # Use S3's IsTruncated flag to determine if there are more pages + has_more = response.get("IsTruncated", False) + items = filtered # Use all filtered items since we requested exactly page_size videos: List[dict] = [] for obj in items: object_key = obj.get("Key") try: - filename = object_key.split('/', 1)[1] + parts = object_key.split('/', 1) + if len(parts) != 2: + logger.warning(f"Skipping object with unexpected key format: {object_key}") + continue + filename = parts[1] identifier = self._encode_path(self.bucket_name, namespace, filename) url = self.s3_client.generate_presigned_url( 'get_object', @@ -406,12 +407,12 @@ def fetch_video_page( "hashed_identifier": identifier, "presigned_url": url, }) - except ClientError as e: - logger.error(f"Error processing video {object_key}: {e}") + else: + logger.warning(f"Generated empty presigned URL for {object_key}") + except Exception as e: + logger.error(f"Error processing video {object_key}: {e}", exc_info=True) - next_token = response.get("NextContinuationToken") if has_more_flag else None - if not next_token and has_more and videos: - next_token = self._encode_cursor_token(items[-1].get("Key")) + next_token = response.get("NextContinuationToken") if has_more else None logger.info( "Fetched %s video objects for namespace %s (has_more=%s)", diff --git a/backend/embeddings/__init__.py b/backend/embeddings/__init__.py index ced4c89..21c34ef 100644 --- a/backend/embeddings/__init__.py +++ b/backend/embeddings/__init__.py @@ -1,5 +1,5 @@ # Make embeddings a proper Python package -from .embedder import VideoEmbedder +from .video_embedder import VideoEmbedder __all__ = ["VideoEmbedder"] \ No newline at end of file diff --git a/backend/embeddings/embedder.py b/backend/embeddings/video_embedder.py similarity index 100% rename from backend/embeddings/embedder.py rename to backend/embeddings/video_embedder.py diff --git a/backend/main.py b/backend/main.py deleted file mode 100644 index d4f7473..0000000 --- a/backend/main.py +++ /dev/null @@ -1,417 +0,0 @@ -import os -import logging -import modal - -# Configure logging -logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' -) -logger = logging.getLogger(__name__) - -# Configure Modal app and image -# dependencies found in pyproject.toml -image = ( - modal.Image.debian_slim(python_version="3.12") - .apt_install("ffmpeg", "libsm6", "libxext6") # for video processing - .uv_sync(extra_options="--no-dev") # exclude dev dependencies to avoid package conflicts - .add_local_python_source( # add all local modules here - "api", - "preprocessing", - "embeddings", - "models", - "database", - "search", - "services", - ) - ) - -# Environment: "dev" (default) or "prod" (set via ENVIRONMENT variable) -env = os.environ.get("ENVIRONMENT", "dev") -if env not in ["dev", "prod", "staging"]: - raise ValueError(f"Invalid ENVIRONMENT value: {env}. Must be one of: dev, prod, staging") -logger.info(f"Starting Modal app in '{env}' environment") - -IS_INTERNAL_ENV = env in ["dev", "staging"] - -# Create Modal app -app = modal.App( - name=env, - image=image, - secrets=[modal.Secret.from_name(env)] -) - - -@app.cls(cpu=4.0, memory=4096, timeout=600) -class Server: - - @modal.enter() - def startup(self): - """ - Startup logic. This runs once when the container starts. - Here is where you would instantiate classes and load models that are - reused across multiple requests to avoid reloading them each time. - """ - # Import local module inside class - import os - from datetime import datetime, timezone - - # Import classes here - from preprocessing.preprocessor import Preprocessor - from embeddings.embedder import VideoEmbedder - from database.pinecone_connector import PineconeConnector - from database.job_store_connector import JobStoreConnector - from search.searcher import Searcher - from database.r2_connector import R2Connector - from api import FastAPIRouter - from services.upload import UploadHandler - from fastapi import FastAPI - - logger.info(f"Container starting up! Environment = {env}") - self.start_time = datetime.now(timezone.utc) - - # Get environment variables (TODO: abstract to config module) - PINECONE_API_KEY = os.getenv("PINECONE_API_KEY") - if not PINECONE_API_KEY: - raise ValueError("PINECONE_API_KEY not found in environment variables") - - R2_ACCOUNT_ID = os.getenv("R2_ACCOUNT_ID") - if not R2_ACCOUNT_ID: - raise ValueError("R2_ACCOUNT_ID not found in environment variables") - - R2_ACCESS_KEY_ID = os.getenv("R2_ACCESS_KEY_ID") - if not R2_ACCESS_KEY_ID: - raise ValueError("R2_ACCESS_KEY_ID not found in environment variables") - - R2_SECRET_ACCESS_KEY = os.getenv("R2_SECRET_ACCESS_KEY") - if not R2_SECRET_ACCESS_KEY: - raise ValueError("R2_SECRET_ACCESS_KEY not found in environment variables") - - ENVIRONMENT = os.getenv("ENVIRONMENT", "dev") - if ENVIRONMENT not in ["dev", "prod", "staging"]: - raise ValueError(f"Invalid ENVIRONMENT value: {ENVIRONMENT}. Must be one of: dev, prod, staging") - logger.info(f"Running in environment: {ENVIRONMENT}") - - # Select Pinecone index based on environment - pinecone_index = f"{ENVIRONMENT}-chunks" - logger.info(f"Using Pinecone index: {pinecone_index}") - - # Instantiate classes - self.preprocessor = Preprocessor(min_chunk_duration=1.0, max_chunk_duration=10.0, scene_threshold=13.0) - self.video_embedder = VideoEmbedder() - self.pinecone_connector = PineconeConnector(api_key=PINECONE_API_KEY, index_name=pinecone_index) - self.job_store = JobStoreConnector(dict_name="clipabit-jobs") - - self.r2_connector = R2Connector( - account_id=R2_ACCOUNT_ID, - access_key_id=R2_ACCESS_KEY_ID, - secret_access_key=R2_SECRET_ACCESS_KEY, - environment=ENVIRONMENT - ) - - self.searcher = Searcher( - api_key=PINECONE_API_KEY, - index_name=pinecone_index, - r2_connector=self.r2_connector - ) - - self.upload_handler = UploadHandler( - job_store=self.job_store, - process_video_method=self.process_video - ) - - #FastAPI app - self.fastapi_app = FastAPI() - self.api = FastAPIRouter(self, IS_INTERNAL_ENV) - self.fastapi_app.include_router(self.api.router) - - logger.info("Container modules initialized and ready!") - - print(f"[Container] Started at {self.start_time.isoformat()}") - - - @modal.asgi_app() - def asgi_app(self): - return self.fastapi_app - - @modal.method() - async def process_video(self, video_bytes: bytes, filename: str, job_id: str, namespace: str = "", parent_batch_id: str = None): - logger.info(f"[Job {job_id}] Processing started: {filename} ({len(video_bytes)} bytes) | namespace='{namespace}' | batch={parent_batch_id or 'None'}") - - hashed_identifier = None - upserted_chunk_ids = [] - - try: - # Upload original video to R2 bucket - # TODO: do this in parallel with processing - success, hashed_identifier = self.r2_connector.upload_video( - video_data=video_bytes, - filename=filename, - namespace=namespace - ) - if not success: - # Capture error details returned in hashed_identifier before resetting it - upload_error_details = hashed_identifier - # Reset hashed_identifier if upload failed to avoid rollback attempting to delete it - hashed_identifier = None - raise Exception(f"Failed to upload video to R2 storage: {upload_error_details}") - - # Process video through preprocessing pipeline - processed_chunks = self.preprocessor.process_video_from_bytes( - video_bytes=video_bytes, - video_id=job_id, - filename=filename, - hashed_identifier=hashed_identifier - ) - - # Calculate summary statistics - total_frames = sum(chunk['metadata']['frame_count'] for chunk in processed_chunks) - total_memory = sum(chunk['memory_mb'] for chunk in processed_chunks) - avg_complexity = sum(chunk['metadata']['complexity_score'] for chunk in processed_chunks) / len(processed_chunks) if processed_chunks else 0 - - logger.info(f"[Job {job_id}] Complete: {len(processed_chunks)} chunks, {total_frames} frames, {total_memory:.2f} MB, avg_complexity={avg_complexity:.3f}") - - # Embed frames and store in Pinecone - logger.info(f"[Job {job_id}] Embedding and upserting {len(processed_chunks)} chunks") - - # Prepare chunk details for response (without frame arrays) - chunk_details = [] - for chunk in processed_chunks: - embedding = self.video_embedder._generate_clip_embedding(chunk["frames"], num_frames=8) - - logger.info(f"[Job {job_id}] Generated CLIP embedding for chunk {chunk['chunk_id']}") - logger.info(f"[Job {job_id}] Upserting embedding for chunk {chunk['chunk_id']} to Pinecone...") - - - # 1. Handle timestamp_range (List of Numbers -> Two Numbers) - if 'timestamp_range' in chunk['metadata']: - start_time, end_time = chunk['metadata'].pop('timestamp_range') - chunk['metadata']['start_time_s'] = start_time - chunk['metadata']['end_time_s'] = end_time - - # 2. Handle file_info (Nested Dict -> Flat Keys) - if 'file_info' in chunk['metadata']: - file_info = chunk['metadata'].pop('file_info') - for key, value in file_info.items(): - chunk['metadata'][f'file_{key}'] = value - - # 3. Final Check: Remove Nulls (Optional but good practice) - # Pinecone rejects keys with null values. - keys_to_delete = [k for k, v in chunk['metadata'].items() if v is None] - for k in keys_to_delete: - del chunk['metadata'][k] - - - success = self.pinecone_connector.upsert_chunk( - chunk_id=chunk['chunk_id'], - chunk_embedding=embedding.numpy(), - namespace=namespace, - metadata=chunk['metadata'] - ) - - if success: - upserted_chunk_ids.append(chunk['chunk_id']) - else: - raise Exception(f"Failed to upsert chunk {chunk['chunk_id']} to Pinecone") - - chunk_details.append({ - "chunk_id": chunk['chunk_id'], - "metadata": chunk['metadata'], - "memory_mb": chunk['memory_mb'], - }) - - result = { - "job_id": job_id, - "status": "completed", - "hashed_identifier": hashed_identifier, - "filename": filename, - "chunks": len(processed_chunks), - "total_frames": total_frames, - "total_memory_mb": total_memory, - "avg_complexity": avg_complexity, - "chunk_details": chunk_details, - } - - logger.info(f"[Job {job_id}] Finished processing {filename}") - - # Store result for polling endpoint in shared storage - self.job_store.set_job_completed(job_id, result) - - # Update parent batch if exists - if parent_batch_id: - update_success = self.job_store.update_batch_on_child_completion( - parent_batch_id, - job_id, - result - ) - if update_success: - logger.info(f"[Job {job_id}] Updated parent batch {parent_batch_id}") - else: - logger.error( - f"[Job {job_id}] CRITICAL: Failed to update parent batch {parent_batch_id} " - f"after max retries. Batch state may be inconsistent." - ) - - # Invalidate cached pages for namespace after successful processing - try: - self.r2_connector.clear_cache(namespace or "__default__") - except Exception as cache_exc: - logger.error(f"[Job {job_id}] Failed to clear cache for namespace {namespace}: {cache_exc}") - - return result - - except Exception as e: - logger.error(f"[Job {job_id}] Processing failed: {e}") - - # --- ROLLBACK LOGIC --- - logger.warning(f"[Job {job_id}] Initiating rollback due to failure...") - - # 1. Delete video from R2 - if hashed_identifier: - logger.info(f"[Job {job_id}] Rolling back: Deleting video from R2 ({hashed_identifier})") - success = self.r2_connector.delete_video(hashed_identifier) - if not success: - logger.error(f"[Job {job_id}] Rollback failed for R2 video deletion: {hashed_identifier}") - - # 2. Delete chunks from Pinecone - if upserted_chunk_ids: - logger.info(f"[Job {job_id}] Rolling back: Deleting {len(upserted_chunk_ids)} chunks from Pinecone") - success = self.pinecone_connector.delete_chunks(upserted_chunk_ids, namespace=namespace) - if not success: - logger.error(f"[Job {job_id}] Rollback failed for Pinecone chunks deletion: {len(upserted_chunk_ids)} chunks") - - logger.info(f"[Job {job_id}] Rollback complete.") - # ---------------------- - - import traceback - traceback.print_exc() # Print full stack trace for debugging - - # Store error result for polling endpoint in shared storage - self.job_store.set_job_failed(job_id, str(e)) - - # Update parent batch if exists - if parent_batch_id: - error_result = { - "job_id": job_id, - "status": "failed", - "filename": filename, - "error": str(e) - } - update_success = self.job_store.update_batch_on_child_completion( - parent_batch_id, - job_id, - error_result - ) - if update_success: - logger.info(f"[Job {job_id}] Updated parent batch {parent_batch_id} with failure status") - else: - logger.error( - f"[Job {job_id}] CRITICAL: Failed to update parent batch {parent_batch_id} " - f"after max retries. Batch state may be inconsistent." - ) - - return {"job_id": job_id, "status": "failed", "error": str(e)} - - - @modal.method() - async def delete_video_background(self, job_id: str, hashed_identifier: str, namespace: str = ""): - """ - Background job that deletes a video and all associated chunks from R2 and Pinecone. - - This method is intended to run asynchronously as part of a job lifecycle. It: - - 1. Attempts to delete all chunks in Pinecone associated with ``hashed_identifier`` and - the given ``namespace`` using ``pinecone_connector.delete_by_identifier``. - 2. If Pinecone deletion is successful, attempts to delete the corresponding video - object from R2 via ``r2_connector.delete_video``. - 3. On full success (both deletions succeed), builds a result payload and records - the job as completed in ``self.job_store`` by calling - ``self.job_store.set_job_completed(job_id, result)``. - 4. On any failure (including partial failures where Pinecone succeeds but R2 fails, - or Pinecone deletion itself fails), logs the error, records the job as failed in - ``self.job_store`` via ``self.job_store.set_job_failed(job_id, error_message)``, - and returns a failure payload. - - The return value is the same object stored in ``job_store`` and has the following - general shape: - - - On success:: - - { - "job_id": "", - "status": "completed", - "hashed_identifier": "", - "namespace": "", - "r2": {"deleted": true}, - "pinecone": {"deleted": true} - } - - - On failure (including partial deletion failures):: - - { - "job_id": "", - "status": "failed", - "error": "" - } - - In particular, if Pinecone deletion succeeds but R2 deletion fails, the method logs - a critical inconsistency, raises an exception internally, and ultimately marks the - job as failed in ``job_store`` with an appropriate error message, indicating that - chunks may have been removed while the video object remains in R2. - """ - logger.info(f"[Job {job_id}] Deletion started: {hashed_identifier} | namespace='{namespace}'") - - try: - # Delete chunks from Pinecone - pinecone_success = self.pinecone_connector.delete_by_identifier( - hashed_identifier=hashed_identifier, - namespace=namespace - ) - - # NOTE: idk if we acc need to raise exception here because this isn't a critical failure - if not pinecone_success: - raise Exception("Failed to delete chunks from Pinecone") - - # Delete from R2. If this fails, chunks are gone but video remains - notify client. - r2_success = self.r2_connector.delete_video(hashed_identifier) - if not r2_success: - logger.critical(f"[Job {job_id}] INCONSISTENCY: Chunks deleted but R2 deletion failed for {hashed_identifier}") - raise Exception("Failed to delete video from R2 after deleting chunks. System may be inconsistent.") - - # Build success response - result = { - "job_id": job_id, - "status": "completed", - "hashed_identifier": hashed_identifier, - "namespace": namespace, - "r2": { - "deleted": r2_success - }, - "pinecone": { - "deleted": pinecone_success - } - } - - logger.info(f"[Job {job_id}] Deletion completed: R2={r2_success}, Pinecone chunks={pinecone_success}") - - # Store result for polling endpoint - self.job_store.set_job_completed(job_id, result) - - try: - self.r2_connector.clear_cache(namespace or "__default__") - except Exception as cache_exc: - logger.error( - f"[Job {job_id}] Failed to clear cache after deletion for namespace {namespace}: {cache_exc}" - ) - return result - - except Exception as e: - error_msg = str(e) - logger.error(f"[Job {job_id}] Deletion failed: {error_msg}") - - import traceback - traceback.print_exc() - - # Store error result - self.job_store.set_job_failed(job_id, error_msg) - return {"job_id": job_id, "status": "failed", "error": error_msg} \ No newline at end of file diff --git a/backend/pyproject.toml b/backend/pyproject.toml index c170f46..d7650b0 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -20,7 +20,10 @@ dependencies = [ ] [project.scripts] -dev = "cli:serve" +dev = "cli:serve_all" +server = "cli:serve_server" +search = "cli:serve_search" +processing = "cli:serve_processing" [build-system] requires = ["hatchling"] diff --git a/backend/search/__init__.py b/backend/search/__init__.py index 7f7813b..e764137 100644 --- a/backend/search/__init__.py +++ b/backend/search/__init__.py @@ -2,7 +2,7 @@ Search module for semantic search using CLIP embeddings and Pinecone. """ -from search.embedder import TextEmbedder +from search.text_embedder import TextEmbedder from search.searcher import Searcher __all__ = ["TextEmbedder", "Searcher"] diff --git a/backend/search/embedder.py b/backend/search/embedder.py deleted file mode 100644 index 541a162..0000000 --- a/backend/search/embedder.py +++ /dev/null @@ -1,97 +0,0 @@ -import logging -from typing import Union, List -import numpy as np -import torch -from transformers import CLIPTextModelWithProjection, CLIPTokenizer - -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - -""" -Text Embedding module using CLIP model. - -Provides text-to-vector conversion using OpenAI's CLIP model -for semantic search capabilities. -""" - - -class TextEmbedder: - """ - CLIP-based text embedder for semantic search. - - Converts text queries into 512-dimensional embeddings using - OpenAI's CLIP text model (ViT-B/32 variant). - - Uses CLIPTextModelWithProjection for efficiency (loads only text encoder, - not the full CLIP model with vision encoder). - - Usage: - embedder = TextEmbedder() - vector = embedder.embed_text("woman on a train") - """ - - def __init__(self, model_name: str = "openai/clip-vit-base-patch32"): - """ - Initialize the text embedder. - - Args: - model_name: HuggingFace model identifier for CLIP. - Defaults to "openai/clip-vit-base-patch32". - """ - self.model_name = model_name - self.device = "cuda" if torch.cuda.is_available() else "cpu" - self.model = None - self.tokenizer = None - - logger.info(f"TextEmbedder initialized (device: {self.device})") - - def _load_model(self): - """Lazy load the CLIP text model on first use.""" - if self.model is None: - logger.info(f"Loading CLIP text model: {self.model_name}") - self.tokenizer = CLIPTokenizer.from_pretrained(self.model_name) - self.model = CLIPTextModelWithProjection.from_pretrained(self.model_name).to(self.device) - self.model.eval() - logger.info("CLIP text model loaded successfully") - - def embed_text(self, text: Union[str, List[str]]) -> np.ndarray: - """ - Generate embeddings for text input(s). - - Args: - text: Single text string or list of text strings - - Returns: - numpy array of embeddings (512-d, L2-normalized) - Shape: (512,) for single text, (N, 512) for batch - """ - self._load_model() - - # Handle single string - if isinstance(text, str): - text = [text] - - # Tokenize inputs - inputs = self.tokenizer( - text, - return_tensors="pt", - padding=True, - truncation=True, - max_length=77 # CLIP's max sequence length - ).to(self.device) - - # Generate embeddings - with torch.no_grad(): - # CLIPTextModelWithProjection outputs already-projected features - text_features = self.model(**inputs).text_embeds - # L2 normalize (essential for cosine similarity search) - text_features = text_features / text_features.norm(p=2, dim=-1, keepdim=True) - - # Convert to numpy - embeddings = text_features.cpu().numpy() - - # Return single vector if single input - if len(embeddings) == 1: - return embeddings[0] - - return embeddings \ No newline at end of file diff --git a/backend/search/searcher.py b/backend/search/searcher.py index b9f26e8..0a37765 100644 --- a/backend/search/searcher.py +++ b/backend/search/searcher.py @@ -10,7 +10,7 @@ from database.pinecone_connector import PineconeConnector from database.r2_connector import R2Connector -from search.embedder import TextEmbedder +from search.text_embedder import TextEmbedder logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) diff --git a/backend/search/text_embedder.py b/backend/search/text_embedder.py new file mode 100644 index 0000000..1e6f7cc --- /dev/null +++ b/backend/search/text_embedder.py @@ -0,0 +1,137 @@ +""" +ONNX-based Text Embedding module using CLIP model. + +Provides text-to-vector conversion using OpenAI's CLIP model +with ONNX Runtime for fast inference and minimal dependencies. + +This eliminates the need for PyTorch (~2GB) and transformers library, +reducing cold start times significantly. +""" + +import logging +from typing import Union, List +import numpy as np + +logger = logging.getLogger(__name__) + +# Default path where ONNX model is stored in the Modal image +DEFAULT_ONNX_MODEL_PATH = "/root/models/clip_text_encoder.onnx" +DEFAULT_TOKENIZER_PATH = "/root/models/clip_tokenizer/tokenizer.json" + + +class TextEmbedder: + """ + ONNX-based CLIP text embedder for semantic search. + + Converts text queries into 512-dimensional embeddings using + OpenAI's CLIP text model (ViT-B/32 variant) via ONNX Runtime. + + Uses raw tokenizers library instead of transformers for faster imports. + + Usage: + embedder = TextEmbedder() + vector = embedder.embed_text("woman on a train") + """ + + def __init__( + self, + model_path: str = DEFAULT_ONNX_MODEL_PATH, + tokenizer_path: str = DEFAULT_TOKENIZER_PATH + ): + """ + Initialize the ONNX text embedder. + + Args: + model_path: Path to the ONNX model file. + tokenizer_path: Path to the tokenizer.json file. + """ + self.model_path = model_path + self.tokenizer_path = tokenizer_path + self.session = None + self.tokenizer = None + + logger.info(f"TextEmbedder initialized (model: {model_path})") + + def _load_model(self): + """Lazy load the ONNX model and tokenizer on first use.""" + if self.session is None: + import onnxruntime as ort + from tokenizers import Tokenizer + + logger.info(f"Loading ONNX model from: {self.model_path}") + + # Configure ONNX Runtime for CPU (no CUDA dependency) + sess_options = ort.SessionOptions() + sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL + + self.session = ort.InferenceSession( + self.model_path, + sess_options, + providers=['CPUExecutionProvider'] + ) + + # Load tokenizer directly from tokenizers library (no transformers import) + logger.info(f"Loading tokenizer from: {self.tokenizer_path}") + self.tokenizer = Tokenizer.from_file(self.tokenizer_path) + + # Configure padding and truncation for CLIP (max 77 tokens) + self.tokenizer.enable_padding(length=77, pad_id=0) + self.tokenizer.enable_truncation(max_length=77) + + logger.info("ONNX model and tokenizer loaded successfully") + + def embed_text(self, text: Union[str, List[str]]) -> np.ndarray: + """ + Generate embeddings for text input(s). + + Args: + text: Single text string or list of text strings + + Returns: + numpy array of embeddings (512-d, L2-normalized) + Shape: (512,) for single text, (N, 512) for batch + """ + self._load_model() + + # Handle single string + single_input = isinstance(text, str) + if single_input: + text = [text] + + # Tokenize inputs using raw tokenizers library + encoded = self.tokenizer.encode_batch(text) + + # Extract input_ids and attention_mask + input_ids = np.array([e.ids for e in encoded], dtype=np.int64) + attention_mask = np.array([e.attention_mask for e in encoded], dtype=np.int64) + + # Run ONNX inference + onnx_inputs = { + "input_ids": input_ids, + "attention_mask": attention_mask, + } + + outputs = self.session.run(None, onnx_inputs) + + # Handle different ONNX model output formats + # The model might output multiple tensors, we need the one with shape (batch_size, 512) + text_embeds = None + for output in outputs: + if len(output.shape) == 2 and output.shape[1] == 512: + text_embeds = output + break + + if text_embeds is None: + # Fallback: take first output if none match expected shape + logger.warning(f"Could not find 512-d output, using first output with shape {outputs[0].shape}") + text_embeds = outputs[0] + + # L2 normalize (essential for cosine similarity search) + norms = np.linalg.norm(text_embeds, axis=-1, keepdims=True) + text_embeds = text_embeds / norms + + # Return single vector if single input + if single_input: + return text_embeds[0] + + return text_embeds diff --git a/backend/services/__init__.py b/backend/services/__init__.py index 50e916f..56f5df1 100644 --- a/backend/services/__init__.py +++ b/backend/services/__init__.py @@ -1,3 +1,3 @@ -from .upload import UploadHandler +from .upload_handler import UploadHandler __all__ = ["UploadHandler"] diff --git a/backend/services/http_server.py b/backend/services/http_server.py new file mode 100644 index 0000000..595462b --- /dev/null +++ b/backend/services/http_server.py @@ -0,0 +1,137 @@ +""" +ServerService class - base class shared between server.py and dev_combined.py +""" + +import logging +import modal + +from shared.config import get_environment, get_env_var, get_pinecone_index + +logger = logging.getLogger(__name__) + + +class ServerService: + """ + Server service base class - handles HTTP endpoints and background deletion. + """ + + def _initialize_connectors(self): + """Initialize connectors (non-decorated, can be called from subclasses).""" + from datetime import datetime, timezone + from database.pinecone_connector import PineconeConnector + from database.job_store_connector import JobStoreConnector + from database.r2_connector import R2Connector + from auth.auth_connector import AuthConnector + + env = get_environment() + logger.info(f"[{self.__class__.__name__}] Starting up in '{env}' environment") + self.start_time = datetime.now(timezone.utc) + + # Get environment variables + PINECONE_API_KEY = get_env_var("PINECONE_API_KEY") + R2_ACCOUNT_ID = get_env_var("R2_ACCOUNT_ID") + R2_ACCESS_KEY_ID = get_env_var("R2_ACCESS_KEY_ID") + R2_SECRET_ACCESS_KEY = get_env_var("R2_SECRET_ACCESS_KEY") + + pinecone_index = get_pinecone_index() + logger.info(f"[{self.__class__.__name__}] Using Pinecone index: {pinecone_index}") + + # Initialize lightweight connectors + self.pinecone_connector = PineconeConnector( + api_key=PINECONE_API_KEY, + index_name=pinecone_index + ) + self.job_store = JobStoreConnector(dict_name="clipabit-jobs") + self.r2_connector = R2Connector( + account_id=R2_ACCOUNT_ID, + access_key_id=R2_ACCESS_KEY_ID, + secret_access_key=R2_SECRET_ACCESS_KEY, + environment=env + ) + self.auth_connector = AuthConnector() + + # Store config for router + self.env = env + self.is_internal = env in ["dev", "staging"] + + logger.info(f"[{self.__class__.__name__}] Initialized and ready!") + + @modal.enter() + def startup(self): + """Initialize connectors and FastAPI app.""" + self._initialize_connectors() + + def create_fastapi_app(self, processing_service_cls=None): + """ + Create FastAPI app with routes. + + Note: Search is now handled by SearchService with its own ASGI app. + + Args: + processing_service_cls: Optional ProcessingService class for dev combined mode + """ + from api import ServerFastAPIRouter + from fastapi import FastAPI + + self.fastapi_app = FastAPI(title="Clipabit Server") + api_router = ServerFastAPIRouter( + server_instance=self, + is_internal_env=self.is_internal, + environment=self.env, + processing_service_cls=processing_service_cls + ) + self.fastapi_app.include_router(api_router.router) + return self.fastapi_app + + @modal.method() + def delete_video_background(self, job_id: str, hashed_identifier: str, namespace: str = ""): + """Background job that deletes a video and all associated chunks.""" + logger.info(f"[{self.__class__.__name__}][Job {job_id}] Deletion started: {hashed_identifier} | namespace='{namespace}'") + + try: + # Delete chunks from Pinecone + pinecone_success = self.pinecone_connector.delete_by_identifier( + hashed_identifier=hashed_identifier, + namespace=namespace + ) + + if not pinecone_success: + raise Exception("Failed to delete chunks from Pinecone") + + # Delete from R2 + r2_success = self.r2_connector.delete_video(hashed_identifier) + if not r2_success: + logger.critical( + f"[{self.__class__.__name__}][Job {job_id}] INCONSISTENCY: Chunks deleted but R2 deletion failed" + ) + raise Exception("Failed to delete video from R2 after deleting chunks. System may be inconsistent.") + + result = { + "job_id": job_id, + "status": "completed", + "hashed_identifier": hashed_identifier, + "namespace": namespace, + "r2": {"deleted": r2_success}, + "pinecone": {"deleted": pinecone_success} + } + + logger.info(f"[{self.__class__.__name__}][Job {job_id}] Deletion completed") + self.job_store.set_job_completed(job_id, result) + + # Clear cache + try: + self.r2_connector.clear_cache(namespace or "__default__") + except Exception as cache_exc: + logger.error(f"[{self.__class__.__name__}][Job {job_id}] Failed to clear cache: {cache_exc}") + + return result + + except Exception as e: + error_msg = str(e) + logger.error(f"[{self.__class__.__name__}][Job {job_id}] Deletion failed: {error_msg}") + + import traceback + traceback.print_exc() + + self.job_store.set_job_failed(job_id, error_msg) + return {"job_id": job_id, "status": "failed", "error": error_msg} diff --git a/backend/services/processing_service.py b/backend/services/processing_service.py new file mode 100644 index 0000000..ac62442 --- /dev/null +++ b/backend/services/processing_service.py @@ -0,0 +1,227 @@ +""" +ProcessingService class - base class shared between processing_app.py and dev_combined.py +""" + +import logging +import modal + +from shared.config import get_environment, get_env_var, get_pinecone_index + +logger = logging.getLogger(__name__) + + +class ProcessingService: + """ + Processing service base class. + + Loads full CLIP model and preprocessing pipeline on startup. + """ + + @modal.enter() + def startup(self): + """Load CLIP model and initialize all connectors.""" + from preprocessing.preprocessor import Preprocessor + from embeddings.video_embedder import VideoEmbedder + from database.pinecone_connector import PineconeConnector + from database.job_store_connector import JobStoreConnector + from database.r2_connector import R2Connector + + env = get_environment() + logger.info(f"[{self.__class__.__name__}] Starting up in '{env}' environment") + + # Get environment variables + PINECONE_API_KEY = get_env_var("PINECONE_API_KEY") + R2_ACCOUNT_ID = get_env_var("R2_ACCOUNT_ID") + R2_ACCESS_KEY_ID = get_env_var("R2_ACCESS_KEY_ID") + R2_SECRET_ACCESS_KEY = get_env_var("R2_SECRET_ACCESS_KEY") + + pinecone_index = get_pinecone_index() + logger.info(f"[{self.__class__.__name__}] Using Pinecone index: {pinecone_index}") + + # Initialize preprocessor and embedder + self.preprocessor = Preprocessor( + min_chunk_duration=1.0, + max_chunk_duration=10.0, + scene_threshold=13.0 + ) + self.video_embedder = VideoEmbedder() + logger.info(f"[{self.__class__.__name__}] CLIP image encoder and preprocessor loaded") + + # Initialize connectors + self.pinecone_connector = PineconeConnector( + api_key=PINECONE_API_KEY, + index_name=pinecone_index + ) + self.job_store = JobStoreConnector(dict_name="clipabit-jobs") + self.r2_connector = R2Connector( + account_id=R2_ACCOUNT_ID, + access_key_id=R2_ACCESS_KEY_ID, + secret_access_key=R2_SECRET_ACCESS_KEY, + environment=env + ) + + logger.info(f"[{self.__class__.__name__}] Initialized and ready!") + + @modal.method() + def process_video_background( + self, + video_bytes: bytes, + filename: str, + job_id: str, + namespace: str = "", + parent_batch_id: str = None + ): + """Process an uploaded video through the full pipeline.""" + logger.info( + f"[{self.__class__.__name__}][Job {job_id}] Processing started: {filename} ({len(video_bytes)} bytes) " + f"| namespace='{namespace}' | batch={parent_batch_id or 'None'}" + ) + + hashed_identifier = None + upserted_chunk_ids = [] + + try: + # Stage 1: Upload original video to R2 bucket + success, hashed_identifier = self.r2_connector.upload_video( + video_data=video_bytes, + filename=filename, + namespace=namespace + ) + if not success: + upload_error_details = hashed_identifier + hashed_identifier = None + raise Exception(f"Failed to upload video to R2 storage: {upload_error_details}") + + # Stage 2: Process video through preprocessing pipeline + processed_chunks = self.preprocessor.process_video_from_bytes( + video_bytes=video_bytes, + video_id=job_id, + filename=filename, + hashed_identifier=hashed_identifier + ) + + # Calculate summary statistics + total_frames = sum(chunk['metadata']['frame_count'] for chunk in processed_chunks) + total_memory = sum(chunk['memory_mb'] for chunk in processed_chunks) + avg_complexity = ( + sum(chunk['metadata']['complexity_score'] for chunk in processed_chunks) + / len(processed_chunks) + if processed_chunks else 0 + ) + + logger.info( + f"[{self.__class__.__name__}][Job {job_id}] Complete: {len(processed_chunks)} chunks, " + f"{total_frames} frames, {total_memory:.2f} MB, avg_complexity={avg_complexity:.3f}" + ) + + # Stage 3-4: Embed frames and store in Pinecone + logger.info(f"[{self.__class__.__name__}][Job {job_id}] Embedding and upserting {len(processed_chunks)} chunks") + + chunk_details = [] + for chunk in processed_chunks: + embedding = self.video_embedder._generate_clip_embedding( + chunk["frames"], + num_frames=8 + ) + + logger.info(f"[{self.__class__.__name__}][Job {job_id}] Generated CLIP embedding for chunk {chunk['chunk_id']}") + + # Transform metadata for Pinecone compatibility + if 'timestamp_range' in chunk['metadata']: + start_time, end_time = chunk['metadata'].pop('timestamp_range') + chunk['metadata']['start_time_s'] = start_time + chunk['metadata']['end_time_s'] = end_time + + if 'file_info' in chunk['metadata']: + file_info = chunk['metadata'].pop('file_info') + for key, value in file_info.items(): + chunk['metadata'][f'file_{key}'] = value + + keys_to_delete = [k for k, v in chunk['metadata'].items() if v is None] + for k in keys_to_delete: + del chunk['metadata'][k] + + success = self.pinecone_connector.upsert_chunk( + chunk_id=chunk['chunk_id'], + chunk_embedding=embedding.numpy(), + namespace=namespace, + metadata=chunk['metadata'] + ) + + if success: + upserted_chunk_ids.append(chunk['chunk_id']) + else: + raise Exception(f"Failed to upsert chunk {chunk['chunk_id']} to Pinecone") + + chunk_details.append({ + "chunk_id": chunk['chunk_id'], + "metadata": chunk['metadata'], + "memory_mb": chunk['memory_mb'], + }) + + result = { + "job_id": job_id, + "status": "completed", + "hashed_identifier": hashed_identifier, + "filename": filename, + "chunks": len(processed_chunks), + "total_frames": total_frames, + "total_memory_mb": total_memory, + "avg_complexity": avg_complexity, + "chunk_details": chunk_details, + } + + logger.info(f"[{self.__class__.__name__}][Job {job_id}] Finished processing {filename}") + + # Stage 5: Store result + self.job_store.set_job_completed(job_id, result) + + # Update parent batch if exists + if parent_batch_id: + update_success = self.job_store.update_batch_on_child_completion( + parent_batch_id, + job_id, + result + ) + if update_success: + logger.info(f"[{self.__class__.__name__}][Job {job_id}] Updated parent batch {parent_batch_id}") + else: + logger.error( + f"[{self.__class__.__name__}][Job {job_id}] CRITICAL: Failed to update parent batch {parent_batch_id}" + ) + + # Invalidate cache + try: + self.r2_connector.clear_cache(namespace or "__default__") + except Exception as cache_exc: + logger.error(f"[{self.__class__.__name__}][Job {job_id}] Failed to clear cache: {cache_exc}") + + return result + + except Exception as e: + logger.error(f"[{self.__class__.__name__}][Job {job_id}] Processing failed: {e}") + + # Rollback + if hashed_identifier: + logger.info(f"[{self.__class__.__name__}][Job {job_id}] Rolling back: Deleting video from R2") + self.r2_connector.delete_video(hashed_identifier) + + if upserted_chunk_ids: + logger.info(f"[{self.__class__.__name__}][Job {job_id}] Rolling back: Deleting chunks from Pinecone") + self.pinecone_connector.delete_chunks(upserted_chunk_ids, namespace=namespace) + + import traceback + traceback.print_exc() + + self.job_store.set_job_failed(job_id, str(e)) + + if parent_batch_id: + error_result = { + "job_id": job_id, + "status": "failed", + "filename": filename, + "error": str(e) + } + self.job_store.update_batch_on_child_completion(parent_batch_id, job_id, error_result) + + return {"job_id": job_id, "status": "failed", "error": str(e)} diff --git a/backend/services/search_service.py b/backend/services/search_service.py new file mode 100644 index 0000000..27fcddc --- /dev/null +++ b/backend/services/search_service.py @@ -0,0 +1,146 @@ +""" +SearchService class - base class shared between search_app.py and dev_combined.py + +Exposes its own ASGI app for direct HTTP access (no server gateway hop). +""" + +import logging +import modal + +from shared.config import get_environment, get_env_var, get_pinecone_index + +logger = logging.getLogger(__name__) + + +class SearchService: + """ + Search service with direct HTTP endpoint. + + Loads CLIP text encoder on startup and handles semantic search queries. + Exposes its own ASGI app for lower latency (bypasses server gateway). + """ + + @modal.enter() + def startup(self): + """Load CLIP text encoder and initialize connectors.""" + from database.pinecone_connector import PineconeConnector + from database.r2_connector import R2Connector + from search.text_embedder import TextEmbedder + + env = get_environment() + logger.info(f"[{self.__class__.__name__}] Starting up in '{env}' environment") + + # Get environment variables + PINECONE_API_KEY = get_env_var("PINECONE_API_KEY") + R2_ACCOUNT_ID = get_env_var("R2_ACCOUNT_ID") + R2_ACCESS_KEY_ID = get_env_var("R2_ACCESS_KEY_ID") + R2_SECRET_ACCESS_KEY = get_env_var("R2_SECRET_ACCESS_KEY") + + pinecone_index = get_pinecone_index() + logger.info(f"[{self.__class__.__name__}] Using Pinecone index: {pinecone_index}") + + # Initialize text embedder (loads CLIP text encoder) + self.embedder = TextEmbedder() + self.embedder._load_model() + logger.info(f"[{self.__class__.__name__}] CLIP text encoder (ONNX) loaded on CPU") + + # Initialize connectors + self.pinecone_connector = PineconeConnector( + api_key=PINECONE_API_KEY, + index_name=pinecone_index + ) + self.r2_connector = R2Connector( + account_id=R2_ACCOUNT_ID, + access_key_id=R2_ACCESS_KEY_ID, + secret_access_key=R2_SECRET_ACCESS_KEY, + environment=env + ) + + # Create FastAPI app for direct HTTP access + self.fastapi_app = self._create_fastapi_app() + + logger.info(f"[{self.__class__.__name__}] Initialized and ready!") + + def _create_fastapi_app(self): + """Create FastAPI app with search routes.""" + from fastapi import FastAPI + from fastapi.middleware.cors import CORSMiddleware + from api.search_fastapi_router import SearchFastAPIRouter + + app = FastAPI(title="ClipABit Search API") + + # Add CORS middleware + app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + + # Add search routes + router = SearchFastAPIRouter(search_service_instance=self) + app.include_router(router.router) + + return app + + @modal.asgi_app() + def asgi_app(self): + """Expose FastAPI app as ASGI endpoint.""" + return self.fastapi_app + + def _search_internal(self, query: str, namespace: str = "", top_k: int = 10) -> list: + """ + Internal search implementation. + + Called directly by the FastAPI router (no RPC overhead). + """ + logger.info(f"[{self.__class__.__name__}] Query: '{query}' | namespace='{namespace}' | top_k={top_k}") + + # Generate query embedding + query_embedding = self.embedder.embed_text(query) + + # Search Pinecone + matches = self.pinecone_connector.query_chunks( + query_embedding=query_embedding, + namespace=namespace, + top_k=top_k + ) + + # Format results with presigned URLs + results = [] + for match in matches: + metadata = match.get('metadata', {}) + + if 'file_hashed_identifier' not in metadata: + logger.warning( + "Skipping result %s: missing file_hashed_identifier", + match.get('id'), + ) + continue + + presigned_url = None + if self.r2_connector: + presigned_url = self.r2_connector.generate_presigned_url( + identifier=metadata['file_hashed_identifier'], + validate_exists=True, + ) + + if not presigned_url: + logger.warning( + "Skipping result %s: unable to generate presigned URL", + match.get('id'), + ) + continue + + metadata['presigned_url'] = presigned_url + + result = { + 'id': match.get('id'), + 'score': match.get('score', 0.0), + 'metadata': metadata + } + results.append(result) + + logger.info(f"[{self.__class__.__name__}] Found {len(results)} results") + return results diff --git a/backend/services/upload.py b/backend/services/upload_handler.py similarity index 95% rename from backend/services/upload.py rename to backend/services/upload_handler.py index a407c14..7161d93 100644 --- a/backend/services/upload.py +++ b/backend/services/upload_handler.py @@ -1,3 +1,5 @@ +"""Upload validation and orchestration service.""" + import logging import uuid from typing import Tuple @@ -24,16 +26,18 @@ class UploadHandler: MAX_FILE_SIZE = 2 * 1024 * 1024 * 1024 # 2GB in bytes MAX_BATCH_SIZE = 200 - def __init__(self, job_store, process_video_method): + def __init__(self, job_store, process_video_spawn_fn): """ Initialize upload handler. Args: job_store: JobStoreConnector instance for job tracking and status updates - process_video_method: Modal method reference for spawning async video processing + process_video_spawn_fn: Callable that spawns async video processing + For dev mode: ProcessingService().process_video_background.spawn + For prod mode: modal.Cls.from_name(...).process_video_background.spawn """ self.job_store = job_store - self.process_video = process_video_method + self.process_video_spawn = process_video_spawn_fn def validate_file(self, file: UploadFile, file_contents: bytes = None) -> Tuple[bool, str]: """ @@ -112,7 +116,7 @@ async def handle_single_upload(self, file: UploadFile, namespace: str) -> dict: "namespace": namespace }) - self.process_video.spawn(contents, file.filename, job_id, namespace, None) + self.process_video_spawn(contents, file.filename, job_id, namespace, None) return { "job_id": job_id, @@ -193,7 +197,7 @@ async def handle_batch_upload(self, files: list[UploadFile], namespace: str) -> "namespace": namespace }) - self.process_video.spawn( + self.process_video_spawn( contents, meta["file"].filename, meta["job_id"], namespace, batch_job_id ) spawned.append(meta["job_id"]) diff --git a/backend/shared/__init__.py b/backend/shared/__init__.py new file mode 100644 index 0000000..69b1c4f --- /dev/null +++ b/backend/shared/__init__.py @@ -0,0 +1,15 @@ +"""Shared utilities package for config and image definitions.""" + +from .config import get_environment, get_env_var, get_pinecone_index, get_secrets +from .images import get_dev_image, get_server_image, get_search_image, get_processing_image + +__all__ = [ + "get_environment", + "get_env_var", + "get_pinecone_index", + "get_secrets", + "get_dev_image", + "get_server_image", + "get_search_image", + "get_processing_image", +] diff --git a/backend/shared/config.py b/backend/shared/config.py new file mode 100644 index 0000000..06fc008 --- /dev/null +++ b/backend/shared/config.py @@ -0,0 +1,127 @@ +""" +Shared configuration module for Modal apps. + +Provides environment handling, app naming, and secrets management +shared across all Modal apps (Server, Search, Processing). +""" + +import os +import sys +import logging +import modal + + +def configure_logging(level: int = logging.INFO) -> None: + """ + Configure logging to send INFO and DEBUG to stdout, WARNING+ to stderr. + + This ensures info logs appear in stdout for proper log routing in production. + + Args: + level: Minimum logging level (default: logging.INFO) + """ + root_logger = logging.getLogger() + + # Clear any existing handlers to avoid duplicates + root_logger.handlers.clear() + root_logger.setLevel(level) + + formatter = logging.Formatter( + '%(asctime)s - %(name)s - %(levelname)s - %(message)s' + ) + + # Handler for INFO and DEBUG -> stdout + stdout_handler = logging.StreamHandler(sys.stdout) + stdout_handler.setLevel(level) + stdout_handler.addFilter(lambda record: record.levelno < logging.WARNING) + stdout_handler.setFormatter(formatter) + + # Handler for WARNING and above -> stderr + stderr_handler = logging.StreamHandler(sys.stderr) + stderr_handler.setLevel(logging.WARNING) + stderr_handler.setFormatter(formatter) + + root_logger.addHandler(stdout_handler) + root_logger.addHandler(stderr_handler) + + +# Configure logging on module import +configure_logging() +logger = logging.getLogger(__name__) + +# Valid environments +VALID_ENVIRONMENTS = ["dev", "prod", "staging"] + + +def get_environment() -> str: + """ + Get the current environment from ENVIRONMENT variable. + + Returns: + str: Environment name (dev, prod, or staging) + + Raises: + ValueError: If ENVIRONMENT is not a valid value + """ + env = os.environ.get("ENVIRONMENT", "dev") + if env not in VALID_ENVIRONMENTS: + raise ValueError(f"Invalid ENVIRONMENT: {env}. Must be one of: {VALID_ENVIRONMENTS}") + return env + +def get_modal_environment() -> str: + """Get the modal environment name.""" + return 'main' + +def get_secrets() -> modal.Secret: + """ + Get Modal secrets for the current environment. + + Returns: + modal.Secret: Secret object containing environment variables + """ + env = get_environment() + return modal.Secret.from_name(env) + + +def is_internal_env() -> bool: + """ + Check if running in an internal (non-production) environment. + + Returns: + bool: True if dev or staging, False if prod + """ + env = get_environment() + return env in ["dev", "staging"] + + +def get_pinecone_index() -> str: + """ + Get the Pinecone index name for the current environment. + + Returns: + str: Pinecone index name (e.g., "dev-chunks") + """ + env = get_environment() + return f"{env}-chunks" + + +# Environment variable helpers for connectors +def get_env_var(key: str) -> str: + """ + Get a required environment variable or raise an error. + + Args: + key: Environment variable name + + Returns: + str: Environment variable value + + Raises: + ValueError: If the environment variable is not set + """ + value = os.getenv(key) + if not value: + raise ValueError(f"{key} not found in environment variables") + return value + + diff --git a/backend/shared/images.py b/backend/shared/images.py new file mode 100644 index 0000000..93f9e2f --- /dev/null +++ b/backend/shared/images.py @@ -0,0 +1,243 @@ +""" +Modal Image definitions for each app. + +Separates dependencies to minimize cold start times: +- Server: Minimal deps (~3-5s cold start) +- Search: Medium deps with CLIP text encoder (~8-10s cold start) +- Processing: Heavy deps with full video pipeline (~15-20s cold start) +""" + +import modal + +def _download_clip_full_model_for_dev(): + """Pre-download full CLIP model for video processing at image build time.""" + from transformers import CLIPModel, CLIPProcessor + model_name = "openai/clip-vit-base-patch32" + # Full model for video processing + CLIPModel.from_pretrained(model_name) + CLIPProcessor.from_pretrained(model_name, use_fast=True) + + +def get_dev_image() -> modal.Image: + """ + Create the Modal image for the dev app. + + Pre-downloads all models at build time to eliminate cold start downloads. + Uses ONNX for text embedding (search) and PyTorch for video processing. + """ + return ( + modal.Image.debian_slim(python_version="3.12") + .apt_install("ffmpeg", "libsm6", "libxext6") + .pip_install( + "fastapi[standard]", + "python-multipart", + "boto3", + "pinecone", + "numpy", + "torch", + "torchvision", + "transformers", + "opencv-python-headless", + "scenedetect", + "pillow", + "onnxruntime", + "onnxscript", + "tokenizers", # For text embedder (faster than transformers import) + ) + .run_function(_download_clip_full_model_for_dev) + .run_function(_export_clip_text_to_onnx) + .add_local_python_source( + "api", + "auth", + "database", + "embeddings", + "models", + "shared", + "preprocessing", + "search", + "services" + ) + ) + +def get_server_image() -> modal.Image: + """ + Create the Modal image for the Server app. + + Minimal dependencies for fast cold starts. + Handles: health, status, upload, search, list_videos, delete operations. + """ + return ( + modal.Image.debian_slim(python_version="3.12") + .pip_install( + "fastapi[standard]", + "python-multipart", + "boto3", + "pinecone", + "numpy", + ) + .add_local_python_source( + "database", + "models", + "api", + "auth", + "shared", + "services", + ) + ) + +def _export_clip_text_to_onnx(): + """ + Export CLIP text encoder to ONNX format at image build time. + + This runs ONCE during image build (not at cold start). + Uses PyTorch to load and export the model, then saves both: + - The ONNX model file (~150MB) + - The tokenizer + + At runtime, only onnxruntime is imported (no torch), so cold starts are fast. + """ + import os + import torch + from transformers import CLIPTextModelWithProjection, CLIPTokenizerFast + + model_name = "openai/clip-vit-base-patch32" + model_dir = "/root/models" + onnx_path = f"{model_dir}/clip_text_encoder.onnx" + tokenizer_path = f"{model_dir}/clip_tokenizer" + + os.makedirs(model_dir, exist_ok=True) + + print(f"[BUILD TIME] Loading PyTorch CLIP text model: {model_name}") + model = CLIPTextModelWithProjection.from_pretrained(model_name) + model.eval() + + # Create dummy inputs for ONNX export + dummy_input_ids = torch.randint(0, 49408, (1, 77)) # vocab_size=49408, seq_len=77 + dummy_attention_mask = torch.ones(1, 77, dtype=torch.long) + + print(f"[BUILD TIME] Exporting to ONNX format: {onnx_path}") + torch.onnx.export( + model, + (dummy_input_ids, dummy_attention_mask), + onnx_path, + input_names=["input_ids", "attention_mask"], + output_names=["text_embeds"], + dynamic_axes={ + "input_ids": {0: "batch_size", 1: "sequence_length"}, + "attention_mask": {0: "batch_size", 1: "sequence_length"}, + "text_embeds": {0: "batch_size"} + }, + opset_version=14, + do_constant_folding=True, + ) + + print(f"[BUILD TIME] ONNX model saved: {os.path.getsize(onnx_path) / 1024 / 1024:.1f} MB") + + # Save the fast tokenizer + print(f"[BUILD TIME] Saving tokenizer to: {tokenizer_path}") + tokenizer = CLIPTokenizerFast.from_pretrained(model_name) + tokenizer.save_pretrained(tokenizer_path) + + # Verify the ONNX model works + print("[BUILD TIME] Verifying ONNX model...") + import onnxruntime as ort + import numpy as np + + session = ort.InferenceSession(onnx_path, providers=['CPUExecutionProvider']) + test_input_ids = np.random.randint(0, 49408, (1, 77), dtype=np.int64) + test_attention_mask = np.ones((1, 77), dtype=np.int64) + outputs = session.run(None, { + "input_ids": test_input_ids, + "attention_mask": test_attention_mask + }) + + print(f"[BUILD TIME] ✓ ONNX model verified! Output shape: {outputs[0].shape}") + print("[BUILD TIME] ✓ Export complete!") + + +def get_search_image() -> modal.Image: + """ + Create the Modal image for the Search app. + + Uses ONNX Runtime and raw tokenizers for minimal cold starts. + - No PyTorch at runtime (~2GB saved) + - No transformers at runtime (~5-8s import time saved) + - Only onnxruntime + tokenizers (~100MB total) + + Build strategy: + 1. Install torch + transformers temporarily (build time only) + 2. Export CLIP model to ONNX format and save tokenizer + 3. Uninstall torch + transformers to reduce image size + 4. Install lightweight runtime deps (onnxruntime, tokenizers) + """ + return ( + modal.Image.debian_slim(python_version="3.12") + # Step 1: Install torch for model export (build time) + .pip_install( + "torch", + "transformers", + "onnxruntime", + "onnxscript", + ) + # Step 2: Export model to ONNX (build time) + .run_function(_export_clip_text_to_onnx) + # Step 3: Remove torch and transformers to save space and import time + .run_commands("pip uninstall -y torch transformers") + # Step 4: Install lightweight runtime dependencies + .pip_install( + "pinecone", + "boto3", + "numpy", + "tokenizers", + "fastapi[standard]", + ) + .add_local_python_source( + "api", + "database", + "search", + "shared", + "services", + ) + ) + +def _download_clip_full_model(): + """Pre-download full CLIP model (vision + text) at image build time.""" + from transformers import CLIPModel, CLIPProcessor + model_name = "openai/clip-vit-base-patch32" + CLIPModel.from_pretrained(model_name) + CLIPProcessor.from_pretrained(model_name, use_fast=True) + + +def get_processing_image() -> modal.Image: + """ + Create the Modal image for the Processing app. + + Heavy dependencies for video processing pipeline. + Includes: ffmpeg, opencv, scenedetect, full CLIP model, etc. + + Pre-downloads the model at build time to eliminate cold start downloads. + """ + return ( + modal.Image.debian_slim(python_version="3.12") + .apt_install("ffmpeg", "libsm6", "libxext6") + .pip_install( + "torch", + "torchvision", + "transformers", + "opencv-python-headless", + "scenedetect", + "pillow", + "numpy", + "pinecone", + "boto3", + ) + .run_function(_download_clip_full_model) + .add_local_python_source( + "database", + "preprocessing", + "embeddings", + "models", + "shared", + "services" + ) + ) diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py index 9018e22..6550087 100644 --- a/backend/tests/conftest.py +++ b/backend/tests/conftest.py @@ -331,11 +331,62 @@ def mock_r2_connector(mocker, mock_modal_dict): return connector, mock_client, mock_boto3 +@pytest.fixture +def processing_service(mocker): + """ + Creates a ProcessingService instance with all dependencies mocked. + We bypass the actual startup() logic and manually inject mocks. + Used for testing video processing pipeline. + """ + # Create a mock for the modal module + mock_modal = MagicMock() + + # Configure the mock decorators to just return the original class/function + def identity_decorator(*args, **kwargs): + def wrapper(obj): + return obj + return wrapper + + # Handle @app.cls() -> returns decorator -> returns class + mock_modal.App.return_value.cls.side_effect = identity_decorator + + # Handle @modal.method() -> returns decorator -> returns function + mock_modal.method.side_effect = identity_decorator + + # Handle @modal.method_background() -> returns decorator -> returns function + mock_modal.method_background.side_effect = identity_decorator + + # Handle @modal.enter() -> returns decorator -> returns function + mock_modal.enter.side_effect = identity_decorator + + # Patch sys.modules to use our mock_modal + with patch.dict(sys.modules, {'modal': mock_modal}): + # Import the processing app module + if 'apps.processing_app' in sys.modules: + import apps.processing_app as processing_app + importlib.reload(processing_app) + else: + import apps.processing_app as processing_app + + # Now ProcessingService is a regular Python class + service = processing_app.ProcessingService() + + # Mock all the components that would be set in startup() + service.r2_connector = mocker.MagicMock() + service.pinecone_connector = mocker.MagicMock() + service.preprocessor = mocker.MagicMock() + service.video_embedder = mocker.MagicMock() + service.job_store = mocker.MagicMock() + + yield service + + @pytest.fixture def server_instance(mocker): """ Creates a Server instance with all dependencies mocked. We bypass the actual startup() logic and manually inject mocks. + Used for testing delete operations. """ # Create a mock for the modal module mock_modal = MagicMock() @@ -355,29 +406,24 @@ def wrapper(obj): # Handle @modal.enter() -> returns decorator -> returns function mock_modal.enter.side_effect = identity_decorator - # Handle @modal.fastapi_endpoint() -> returns decorator -> returns function - mock_modal.fastapi_endpoint.side_effect = identity_decorator + # Handle @modal.asgi_app() -> returns decorator -> returns function + mock_modal.asgi_app.side_effect = identity_decorator # Patch sys.modules to use our mock_modal with patch.dict(sys.modules, {'modal': mock_modal}): - # Now import main. It will use the mocked modal. - # We need to reload it if it was already imported - if 'main' in sys.modules: - import main - importlib.reload(main) + # Import the server service module + if 'services.http_server' in sys.modules: + import services.http_server as server_module + importlib.reload(server_module) else: - import main + import services.http_server as server_module - # Now Server is a regular Python class, not a Modal wrapped one - server = main.Server() + server = server_module.ServerService() # Mock all the components that would be set in startup() server.r2_connector = mocker.MagicMock() server.pinecone_connector = mocker.MagicMock() - server.preprocessor = mocker.MagicMock() - server.video_embedder = mocker.MagicMock() server.job_store = mocker.MagicMock() - server.searcher = mocker.MagicMock() yield server diff --git a/backend/tests/integration/test_api_endpoints.py b/backend/tests/integration/test_api_endpoints.py index 994ea12..8f7e208 100644 --- a/backend/tests/integration/test_api_endpoints.py +++ b/backend/tests/integration/test_api_endpoints.py @@ -1,11 +1,13 @@ import io from typing import Any, Dict, List, Tuple +from unittest.mock import patch +import modal import pytest from fastapi import FastAPI from fastapi.testclient import TestClient -from api.fastapi_router import FastAPIRouter +from api.server_fastapi_router import ServerFastAPIRouter class FakeJobStore: @@ -13,6 +15,11 @@ def __init__(self) -> None: self._jobs: Dict[str, Dict[str, Any]] = {} def create_job(self, job_id: str, data: Dict[str, Any]) -> None: + # Add backward compatible fields if not present + if "job_type" not in data: + data["job_type"] = "video" + if "parent_batch_id" not in data: + data["parent_batch_id"] = None self._jobs[job_id] = data def get_job(self, job_id: str) -> Dict[str, Any] | None: @@ -24,6 +31,63 @@ def set_job_completed(self, job_id: str, result: Dict[str, Any]) -> None: def set_job_failed(self, job_id: str, error: str) -> None: self._jobs[job_id] = {"job_id": job_id, "status": "failed", "error": error} + def create_batch_job( + self, batch_job_id: str, child_job_ids: List[str], namespace: str + ) -> bool: + """Create a new batch job entry.""" + batch_data = { + "batch_job_id": batch_job_id, + "job_type": "batch", + "status": "processing", + "namespace": namespace, + "total_videos": len(child_job_ids), + "child_job_ids": child_job_ids, + "completed_count": 0, + "failed_count": 0, + "processing_count": len(child_job_ids), + } + self._jobs[batch_job_id] = batch_data + return True + + def update_batch_on_child_completion( + self, batch_job_id: str, child_job_id: str, child_result: Dict[str, Any] + ) -> bool: + """Update batch job when a child completes.""" + if batch_job_id not in self._jobs: + return False + + batch_job = self._jobs[batch_job_id] + child_status = child_result.get("status") + + if child_status == "completed": + batch_job["completed_count"] += 1 + batch_job["processing_count"] -= 1 + elif child_status == "failed": + batch_job["failed_count"] += 1 + batch_job["processing_count"] -= 1 + + # Update batch status + total = batch_job["total_videos"] + completed = batch_job["completed_count"] + failed = batch_job["failed_count"] + + if completed + failed == total: + if failed == 0: + batch_job["status"] = "completed" + elif completed == 0: + batch_job["status"] = "failed" + else: + batch_job["status"] = "partial" + + return True + + def delete_job(self, job_id: str) -> bool: + """Delete a job from the store.""" + if job_id in self._jobs: + del self._jobs[job_id] + return True + return False + class FakeSpawner: def __init__(self) -> None: @@ -33,14 +97,19 @@ def spawn(self, *args: Any) -> None: self.calls.append(args) -class FakeSearcher: - def __init__(self, results: List[Dict[str, Any]] | None = None) -> None: - self.results = results or [] - self.last_call: Tuple[str, int, str] | None = None +class FakeModalFunction: + """Fake Modal function that tracks spawn/remote calls.""" + def __init__(self) -> None: + self.spawn_calls: List[Tuple[Any, ...]] = [] + self.remote_calls: List[Tuple[Any, ...]] = [] + self.remote_return_value: Any = [] - def search(self, query: str, top_k: int, namespace: str) -> List[Dict[str, Any]]: - self.last_call = (query, top_k, namespace) - return self.results + def spawn(self, *args: Any) -> None: + self.spawn_calls.append(args) + + def remote(self, *args: Any) -> Any: + self.remote_calls.append(args) + return self.remote_return_value class FakeR2Connector: @@ -66,26 +135,12 @@ def list_videos_page( class ServerStub: """ - Minimal server stub providing the attributes FastAPIRouter uses. + Minimal server stub providing the attributes ServerFastAPIRouter uses. """ def __init__(self) -> None: self.job_store = FakeJobStore() - self.process_video_background = FakeSpawner() self.delete_video_background = FakeSpawner() - self.searcher = FakeSearcher( - results=[ - { - "score": 0.99, - "metadata": { - "presigned_url": "https://example.com/video.mp4", - "start_time_s": 0, - "file_filename": "video.mp4", - "hashed_identifier": "abc123", - }, - } - ] - ) self.r2_connector = FakeR2Connector( videos=[ { @@ -98,38 +153,68 @@ def __init__(self) -> None: @pytest.fixture() -def test_client_internal() -> Tuple[TestClient, ServerStub]: +def mock_modal_lookup(): + """Mock modal.Cls.from_name to return fake service classes.""" + fake_process_fn = FakeModalFunction() + + class FakeServiceClass: + """Fake service class that returns fake modal functions.""" + def __init__(self, func): + self.func = func + + def __call__(self): + """Return self to allow chaining like ServiceClass().method""" + return self + + @property + def process_video_background(self): + return self.func + + def lookup_side_effect(app_name: str, class_name: str, **kwargs): + if "processing" in app_name or class_name == "ProcessingService": + return FakeServiceClass(fake_process_fn) + raise ValueError(f"Unknown app: {app_name}, class: {class_name}") + + # Mock modal.Cls.from_name + with patch.object(modal.Cls, "from_name", side_effect=lookup_side_effect): + yield { + "process_fn": fake_process_fn, + } + + +@pytest.fixture() +def test_client_internal(mock_modal_lookup) -> Tuple[TestClient, ServerStub, dict]: """ FastAPI TestClient with is_internal_env=True, so delete is allowed. """ server = ServerStub() app = FastAPI() - router = FastAPIRouter(server, is_internal_env=True) + router = ServerFastAPIRouter(server, is_internal_env=True, environment="dev") app.include_router(router.router) - return TestClient(app), server + return TestClient(app), server, mock_modal_lookup @pytest.fixture() -def test_client_external() -> Tuple[TestClient, ServerStub]: +def test_client_external(mock_modal_lookup) -> Tuple[TestClient, ServerStub, dict]: """ FastAPI TestClient with is_internal_env=False, so delete is forbidden. """ server = ServerStub() app = FastAPI() - router = FastAPIRouter(server, is_internal_env=False) + router = ServerFastAPIRouter(server, is_internal_env=False, environment="prod") app.include_router(router.router) - return TestClient(app), server + return TestClient(app), server, mock_modal_lookup -def test_health_ok(test_client_internal: Tuple[TestClient, ServerStub]) -> None: - client, _ = test_client_internal +def test_health_ok(test_client_internal: Tuple[TestClient, ServerStub, dict]) -> None: + client, _, _ = test_client_internal resp = client.get("/health") assert resp.status_code == 200 assert resp.json() == {"status": "ok"} -def test_list_videos_returns_data(test_client_internal: Tuple[TestClient, ServerStub]) -> None: - client, server = test_client_internal +def test_list_videos_returns_data(test_client_internal: Tuple[TestClient, ServerStub, dict]) -> None: + client, server, _ = test_client_internal resp = client.get("/videos", params={"namespace": "ns1"}) assert resp.status_code == 200 data = resp.json() @@ -143,30 +228,19 @@ def test_list_videos_returns_data(test_client_internal: Tuple[TestClient, Server assert server.r2_connector.last_namespace == "ns1" -def test_search_invokes_searcher_and_returns_results( - test_client_internal: Tuple[TestClient, ServerStub] -) -> None: - client, server = test_client_internal - resp = client.get("/search", params={"query": "hello world", "namespace": "web-demo", "top_k": 5}) - assert resp.status_code == 200 - data = resp.json() - assert "results" in data and isinstance(data["results"], list) - assert server.searcher.last_call == ("hello world", 5, "web-demo") - - -def test_status_processing_when_unknown_job(test_client_internal: Tuple[TestClient, ServerStub]) -> None: - client, _ = test_client_internal +def test_status_processing_when_unknown_job(test_client_internal: Tuple[TestClient, ServerStub, dict]) -> None: + client, _, _ = test_client_internal resp = client.get("/status", params={"job_id": "does-not-exist"}) assert resp.status_code == 200 data = resp.json() assert data["status"] == "processing" -def test_upload_creates_job_and_spawns_background( - test_client_internal: Tuple[TestClient, ServerStub] +def test_upload_creates_job_and_spawns_processing_app( + test_client_internal: Tuple[TestClient, ServerStub, dict] ) -> None: - client, server = test_client_internal - files = {"file": ("clip.mp4", io.BytesIO(b"fake-bytes"), "video/mp4")} + client, server, mock_fns = test_client_internal + files = [("files", ("clip.mp4", io.BytesIO(b"fake-bytes"), "video/mp4"))] resp = client.post("/upload", files=files, data={"namespace": "ns1"}) assert resp.status_code == 200 data = resp.json() @@ -174,19 +248,103 @@ def test_upload_creates_job_and_spawns_background( job_id = data["job_id"] # Job created assert server.job_store.get_job(job_id) is not None - # Background process triggered - assert len(server.process_video_background.calls) == 1 - call_args = server.process_video_background.calls[0] - # args: (contents, filename, job_id, namespace) + # Processing app spawn triggered + assert len(mock_fns["process_fn"].spawn_calls) == 1 + call_args = mock_fns["process_fn"].spawn_calls[0] + # args: (contents, filename, job_id, namespace, parent_batch_id) assert call_args[1] == "clip.mp4" assert call_args[2] == job_id assert call_args[3] == "ns1" + assert call_args[4] is None # No parent batch + + +def test_batch_upload_creates_batch_job_and_spawns_children( + test_client_internal: Tuple[TestClient, ServerStub, dict] +) -> None: + client, server, mock_fns = test_client_internal + # Upload 3 videos + files = [ + ("files", ("video1.mp4", io.BytesIO(b"fake-bytes-1"), "video/mp4")), + ("files", ("video2.mp4", io.BytesIO(b"fake-bytes-2"), "video/mp4")), + ("files", ("video3.mp4", io.BytesIO(b"fake-bytes-3"), "video/mp4")), + ] + resp = client.post("/upload", files=files, data={"namespace": "batch-ns"}) + assert resp.status_code == 200 + data = resp.json() + + # Check batch response + assert data["status"] == "processing" + assert "batch_job_id" in data + assert data["total_videos"] == 3 + assert data["successfully_spawned"] == 3 + assert data["failed_validation"] == 0 + + batch_job_id = data["batch_job_id"] + assert batch_job_id.startswith("batch-") + + # Batch job created + batch_job = server.job_store.get_job(batch_job_id) + assert batch_job is not None + assert batch_job["job_type"] == "batch" + assert len(batch_job["child_job_ids"]) == 3 + + # All child jobs spawned + assert len(mock_fns["process_fn"].spawn_calls) == 3 + + # Check each child job was created and linked to batch + for i, call_args in enumerate(mock_fns["process_fn"].spawn_calls): + filename = call_args[1] + job_id = call_args[2] + namespace = call_args[3] + parent_batch_id = call_args[4] + + assert filename in ["video1.mp4", "video2.mp4", "video3.mp4"] + assert namespace == "batch-ns" + assert parent_batch_id == batch_job_id + + # Child job exists and is linked to batch + child_job = server.job_store.get_job(job_id) + assert child_job is not None + assert child_job["parent_batch_id"] == batch_job_id + + +def test_batch_upload_with_validation_failures( + test_client_internal: Tuple[TestClient, ServerStub, dict] +) -> None: + client, server, mock_fns = test_client_internal + # Upload mix of valid and invalid files + files = [ + ("files", ("video1.mp4", io.BytesIO(b"fake-bytes-1"), "video/mp4")), + ("files", ("bad.txt", io.BytesIO(b"not-a-video"), "text/plain")), # Invalid extension + ("files", ("video2.mp4", io.BytesIO(b"fake-bytes-2"), "video/mp4")), + ] + resp = client.post("/upload", files=files, data={"namespace": "ns1"}) + assert resp.status_code == 200 + data = resp.json() + + # Check that only valid files were processed + assert data["total_submitted"] == 3 + assert data["failed_validation"] == 1 + assert data["total_videos"] == 2 + assert data["successfully_spawned"] == 2 + + # Only 2 spawns for valid files + assert len(mock_fns["process_fn"].spawn_calls) == 2 + + +def test_batch_upload_rejects_empty_list( + test_client_internal: Tuple[TestClient, ServerStub, dict] +) -> None: + client, _, _ = test_client_internal + resp = client.post("/upload", files=[], data={"namespace": "ns1"}) + assert resp.status_code == 400 + assert "No files provided" in resp.json()["detail"] def test_status_completed_after_job_store_update( - test_client_internal: Tuple[TestClient, ServerStub] + test_client_internal: Tuple[TestClient, ServerStub, dict] ) -> None: - client, server = test_client_internal + client, server, _ = test_client_internal # Create job and mark complete server.job_store.create_job("j1", {"job_id": "j1", "status": "processing"}) server.job_store.set_job_completed("j1", {"job_id": "j1", "status": "completed"}) @@ -195,16 +353,16 @@ def test_status_completed_after_job_store_update( assert resp.json()["status"] == "completed" -def test_delete_video_forbidden_when_external(test_client_external: Tuple[TestClient, ServerStub]) -> None: - client, _ = test_client_external +def test_delete_video_forbidden_when_external(test_client_external: Tuple[TestClient, ServerStub, dict]) -> None: + client, _, _ = test_client_external resp = client.delete("/videos/abc123", params={"filename": "clip.mp4", "namespace": "ns1"}) assert resp.status_code == 403 def test_delete_video_triggers_background_when_internal( - test_client_internal: Tuple[TestClient, ServerStub] + test_client_internal: Tuple[TestClient, ServerStub, dict] ) -> None: - client, server = test_client_internal + client, server, _ = test_client_internal resp = client.delete("/videos/abc123", params={"filename": "clip.mp4", "namespace": "ns1"}) assert resp.status_code == 200 data = resp.json() @@ -214,4 +372,3 @@ def test_delete_video_triggers_background_when_internal( # args: (job_id, hashed_identifier, namespace) assert call_args[1] == "abc123" assert call_args[2] == "ns1" - diff --git a/backend/tests/integration/test_pipeline.py b/backend/tests/integration/test_pipeline.py index f4b6198..5ee2f33 100644 --- a/backend/tests/integration/test_pipeline.py +++ b/backend/tests/integration/test_pipeline.py @@ -1,9 +1,10 @@ import pytest from unittest.mock import MagicMock -class TestPipeline: + +class TestProcessingPipeline: """ - Integration tests for the video processing pipeline. + Integration tests for the video processing pipeline (ProcessingService). Covers success paths, failure/rollback scenarios, and edge cases. """ @@ -12,7 +13,7 @@ class TestPipeline: # ========================================================================== @pytest.mark.asyncio - async def test_process_video_success(self, server_instance, sample_video_bytes): + async def test_process_video_success(self, processing_service, sample_video_bytes): """ Scenario: Happy path - everything succeeds. Expectation: @@ -25,7 +26,7 @@ async def test_process_video_success(self, server_instance, sample_video_bytes): """ # Setup hashed_id = "hash-success" - server_instance.r2_connector.upload_video.return_value = (True, hashed_id) + processing_service.r2_connector.upload_video.return_value = (True, hashed_id) # Mock Preprocessor output chunks = [ @@ -42,18 +43,18 @@ async def test_process_video_success(self, server_instance, sample_video_bytes): "memory_mb": 1.5 } ] - server_instance.preprocessor.process_video_from_bytes.return_value = chunks + processing_service.preprocessor.process_video_from_bytes.return_value = chunks # Mock Embedder mock_embedding = MagicMock() mock_embedding.numpy.return_value = [0.1, 0.2] - server_instance.video_embedder._generate_clip_embedding.return_value = mock_embedding + processing_service.video_embedder._generate_clip_embedding.return_value = mock_embedding # Mock Pinecone success - server_instance.pinecone_connector.upsert_chunk.return_value = True + processing_service.pinecone_connector.upsert_chunk.return_value = True # Execute - result = await server_instance.process_video( + result = processing_service.process_video_background( video_bytes=sample_video_bytes, filename="success.mp4", job_id="job-success", @@ -68,16 +69,16 @@ async def test_process_video_success(self, server_instance, sample_video_bytes): assert result["total_memory_mb"] == 2.5 # Verify Interactions - server_instance.r2_connector.upload_video.assert_called_once() - server_instance.preprocessor.process_video_from_bytes.assert_called_once() - assert server_instance.video_embedder._generate_clip_embedding.call_count == 2 - assert server_instance.pinecone_connector.upsert_chunk.call_count == 2 + processing_service.r2_connector.upload_video.assert_called_once() + processing_service.preprocessor.process_video_from_bytes.assert_called_once() + assert processing_service.video_embedder._generate_clip_embedding.call_count == 2 + assert processing_service.pinecone_connector.upsert_chunk.call_count == 2 # Verify Job Store Update - server_instance.job_store.set_job_completed.assert_called_once_with("job-success", result) + processing_service.job_store.set_job_completed.assert_called_once_with("job-success", result) @pytest.mark.asyncio - async def test_process_video_empty_result(self, server_instance): + async def test_process_video_empty_result(self, processing_service): """ Scenario: Video processed but resulted in 0 chunks (e.g. too short). Expectation: @@ -86,11 +87,11 @@ async def test_process_video_empty_result(self, server_instance): - No Pinecone upserts. """ # Setup - server_instance.r2_connector.upload_video.return_value = (True, "hash-empty") - server_instance.preprocessor.process_video_from_bytes.return_value = [] # No chunks + processing_service.r2_connector.upload_video.return_value = (True, "hash-empty") + processing_service.preprocessor.process_video_from_bytes.return_value = [] # No chunks # Execute - result = await server_instance.process_video( + result = processing_service.process_video_background( video_bytes=b"short-video", filename="short.mp4", job_id="job-empty", @@ -101,15 +102,15 @@ async def test_process_video_empty_result(self, server_instance): assert result["status"] == "completed" assert result["chunks"] == 0 - server_instance.video_embedder._generate_clip_embedding.assert_not_called() - server_instance.pinecone_connector.upsert_chunk.assert_not_called() + processing_service.video_embedder._generate_clip_embedding.assert_not_called() + processing_service.pinecone_connector.upsert_chunk.assert_not_called() # ========================================================================== # ROLLBACK / FAILURE SCENARIOS # ========================================================================== @pytest.mark.asyncio - async def test_rollback_on_r2_upload_failure(self, server_instance): + async def test_rollback_on_r2_upload_failure(self, processing_service): """ Scenario: R2 upload fails immediately. Expectation: @@ -117,11 +118,10 @@ async def test_rollback_on_r2_upload_failure(self, server_instance): - No rollback actions (delete_video, delete_chunks) because nothing was created. """ # Setup - # Raise exception - server_instance.r2_connector.upload_video.side_effect = Exception("R2 Upload Error") + processing_service.r2_connector.upload_video.side_effect = Exception("R2 Upload Error") # Execute - result = await server_instance.process_video( + result = processing_service.process_video_background( video_bytes=b"fake-video-data", filename="test.mp4", job_id="job-1", @@ -133,11 +133,11 @@ async def test_rollback_on_r2_upload_failure(self, server_instance): assert "R2 Upload Error" in result["error"] # Rollback checks - server_instance.r2_connector.delete_video.assert_not_called() - server_instance.pinecone_connector.delete_chunks.assert_not_called() + processing_service.r2_connector.delete_video.assert_not_called() + processing_service.pinecone_connector.delete_chunks.assert_not_called() @pytest.mark.asyncio - async def test_rollback_on_preprocessing_failure(self, server_instance): + async def test_rollback_on_preprocessing_failure(self, processing_service): """ Scenario: R2 upload succeeds, but Preprocessing fails. Expectation: @@ -147,12 +147,11 @@ async def test_rollback_on_preprocessing_failure(self, server_instance): """ # Setup hashed_id = "hash-123" - server_instance.r2_connector.upload_video.return_value = (True, hashed_id) - - server_instance.preprocessor.process_video_from_bytes.side_effect = Exception("Preprocessing Failed") + processing_service.r2_connector.upload_video.return_value = (True, hashed_id) + processing_service.preprocessor.process_video_from_bytes.side_effect = Exception("Preprocessing Failed") # Execute - result = await server_instance.process_video( + result = processing_service.process_video_background( video_bytes=b"fake-video-data", filename="test.mp4", job_id="job-2", @@ -164,11 +163,11 @@ async def test_rollback_on_preprocessing_failure(self, server_instance): assert "Preprocessing Failed" in result["error"] # Rollback checks - server_instance.r2_connector.delete_video.assert_called_once_with(hashed_id) - server_instance.pinecone_connector.delete_chunks.assert_not_called() + processing_service.r2_connector.delete_video.assert_called_once_with(hashed_id) + processing_service.pinecone_connector.delete_chunks.assert_not_called() @pytest.mark.asyncio - async def test_rollback_on_embedding_failure(self, server_instance): + async def test_rollback_on_embedding_failure(self, processing_service): """ Scenario: R2 upload & Preprocessing succeed, but Embedding generation fails. Expectation: @@ -178,7 +177,7 @@ async def test_rollback_on_embedding_failure(self, server_instance): """ # Setup hashed_id = "hash-456" - server_instance.r2_connector.upload_video.return_value = (True, hashed_id) + processing_service.r2_connector.upload_video.return_value = (True, hashed_id) # Mock preprocessor to return one chunk chunk = { @@ -187,13 +186,13 @@ async def test_rollback_on_embedding_failure(self, server_instance): "metadata": {"frame_count": 10, "complexity_score": 0.5}, "memory_mb": 1.0 } - server_instance.preprocessor.process_video_from_bytes.return_value = [chunk] + processing_service.preprocessor.process_video_from_bytes.return_value = [chunk] # Fail embedding - server_instance.video_embedder._generate_clip_embedding.side_effect = Exception("Embedding Model Error") + processing_service.video_embedder._generate_clip_embedding.side_effect = Exception("Embedding Model Error") # Execute - result = await server_instance.process_video( + result = processing_service.process_video_background( video_bytes=b"fake-video-data", filename="test.mp4", job_id="job-3", @@ -205,11 +204,11 @@ async def test_rollback_on_embedding_failure(self, server_instance): assert "Embedding Model Error" in result["error"] # Rollback checks - server_instance.r2_connector.delete_video.assert_called_once_with(hashed_id) - server_instance.pinecone_connector.delete_chunks.assert_not_called() + processing_service.r2_connector.delete_video.assert_called_once_with(hashed_id) + processing_service.pinecone_connector.delete_chunks.assert_not_called() @pytest.mark.asyncio - async def test_rollback_on_partial_pinecone_failure(self, server_instance): + async def test_rollback_on_partial_pinecone_failure(self, processing_service): """ Scenario: - R2 upload succeeds. @@ -223,7 +222,7 @@ async def test_rollback_on_partial_pinecone_failure(self, server_instance): """ # Setup hashed_id = "hash-789" - server_instance.r2_connector.upload_video.return_value = (True, hashed_id) + processing_service.r2_connector.upload_video.return_value = (True, hashed_id) chunks = [ { @@ -239,24 +238,18 @@ async def test_rollback_on_partial_pinecone_failure(self, server_instance): "memory_mb": 1.0 } ] - server_instance.preprocessor.process_video_from_bytes.return_value = chunks + processing_service.preprocessor.process_video_from_bytes.return_value = chunks # Mock embedding to succeed mock_embedding = MagicMock() mock_embedding.numpy.return_value = [0.1, 0.2] - server_instance.video_embedder._generate_clip_embedding.return_value = mock_embedding + processing_service.video_embedder._generate_clip_embedding.return_value = mock_embedding # Mock Pinecone upsert: First succeeds, Second fails - # side_effect can be an iterable of return values or exceptions - # Call 1: True (Success) - # Call 2: False (Failure) OR Raise Exception - - # The code checks `if success:` then `else: raise Exception` - # So we can just return [True, False] - server_instance.pinecone_connector.upsert_chunk.side_effect = [True, False] + processing_service.pinecone_connector.upsert_chunk.side_effect = [True, False] # Execute - result = await server_instance.process_video( + result = processing_service.process_video_background( video_bytes=b"fake-video-data", filename="test.mp4", job_id="job-4", @@ -268,16 +261,16 @@ async def test_rollback_on_partial_pinecone_failure(self, server_instance): assert "Failed to upsert chunk chunk-2" in result["error"] # Rollback checks - server_instance.r2_connector.delete_video.assert_called_once_with(hashed_id) + processing_service.r2_connector.delete_video.assert_called_once_with(hashed_id) # Should delete the one that succeeded (chunk-1) - server_instance.pinecone_connector.delete_chunks.assert_called_once_with( + processing_service.pinecone_connector.delete_chunks.assert_called_once_with( ["chunk-1"], namespace="test-ns" ) @pytest.mark.asyncio - async def test_rollback_best_effort_when_cleanup_fails(self, server_instance): + async def test_rollback_best_effort_when_cleanup_fails(self, processing_service): """ Scenario: - Pipeline fails (triggering rollback). @@ -287,31 +280,31 @@ async def test_rollback_best_effort_when_cleanup_fails(self, server_instance): """ # Setup failure in pipeline (Partial Pinecone failure to ensure we have chunks to delete) hashed_id = "hash-fail-cleanup" - server_instance.r2_connector.upload_video.return_value = (True, hashed_id) + processing_service.r2_connector.upload_video.return_value = (True, hashed_id) chunks = [ {"chunk_id": "c1", "frames": [], "metadata": {"frame_count": 10, "complexity_score": 0.5}, "memory_mb": 1}, {"chunk_id": "c2", "frames": [], "metadata": {"frame_count": 10, "complexity_score": 0.5}, "memory_mb": 1} ] - server_instance.preprocessor.process_video_from_bytes.return_value = chunks - server_instance.video_embedder._generate_clip_embedding.return_value = MagicMock(numpy=lambda: [0.1]) + processing_service.preprocessor.process_video_from_bytes.return_value = chunks + processing_service.video_embedder._generate_clip_embedding.return_value = MagicMock(numpy=lambda: [0.1]) # Upsert: True, False (Trigger rollback) - server_instance.pinecone_connector.upsert_chunk.side_effect = [True, False] + processing_service.pinecone_connector.upsert_chunk.side_effect = [True, False] # Setup failure in R2 cleanup - server_instance.r2_connector.delete_video.return_value = False + processing_service.r2_connector.delete_video.return_value = False # Execute - result = await server_instance.process_video( + result = processing_service.process_video_background( video_bytes=b"data", filename="test.mp4", job_id="job-5", namespace="ns" ) # Verify R2 delete was called (and failed) - server_instance.r2_connector.delete_video.assert_called_once_with(hashed_id) + processing_service.r2_connector.delete_video.assert_called_once_with(hashed_id) # Verify Pinecone delete was called DESPITE R2 delete failure - server_instance.pinecone_connector.delete_chunks.assert_called_once_with(["c1"], namespace="ns") + processing_service.pinecone_connector.delete_chunks.assert_called_once_with(["c1"], namespace="ns") # Verify result is still failed assert result["status"] == "failed" @@ -321,7 +314,7 @@ async def test_rollback_best_effort_when_cleanup_fails(self, server_instance): # ========================================================================== @pytest.mark.asyncio - async def test_metadata_transformation(self, server_instance): + async def test_metadata_transformation(self, processing_service): """ Scenario: Metadata contains complex types (timestamp_range, file_info) that need flattening. Expectation: @@ -330,7 +323,7 @@ async def test_metadata_transformation(self, server_instance): - Null values are removed. """ # Setup - server_instance.r2_connector.upload_video.return_value = (True, "hash-meta") + processing_service.r2_connector.upload_video.return_value = (True, "hash-meta") raw_metadata = { "frame_count": 10, @@ -346,17 +339,17 @@ async def test_metadata_transformation(self, server_instance): "metadata": raw_metadata, "memory_mb": 1.0 }] - server_instance.preprocessor.process_video_from_bytes.return_value = chunks + processing_service.preprocessor.process_video_from_bytes.return_value = chunks # Mock embedding - server_instance.video_embedder._generate_clip_embedding.return_value = MagicMock(numpy=lambda: [0.1]) - server_instance.pinecone_connector.upsert_chunk.return_value = True + processing_service.video_embedder._generate_clip_embedding.return_value = MagicMock(numpy=lambda: [0.1]) + processing_service.pinecone_connector.upsert_chunk.return_value = True # Execute - await server_instance.process_video(b"data", "test.mp4", "job-meta", "ns") + processing_service.process_video_background(b"data", "test.mp4", "job-meta", "ns") # Verify Upsert Call Arguments - call_args = server_instance.pinecone_connector.upsert_chunk.call_args + call_args = processing_service.pinecone_connector.upsert_chunk.call_args upserted_metadata = call_args.kwargs['metadata'] # Check transformations @@ -370,12 +363,17 @@ async def test_metadata_transformation(self, server_instance): assert 'file_info' not in upserted_metadata assert 'optional_field' not in upserted_metadata + +class TestDeletionPipeline: + """ + Integration tests for the video deletion pipeline (ServerService.delete_video_background). + """ + # ========================================================================== # DELETION SCENARIOS # ========================================================================== - @pytest.mark.asyncio - async def test_delete_video_success(self, server_instance): + def test_delete_video_success(self, server_instance): """ Scenario: Happy path for deletion - everything succeeds. Expectation: @@ -392,7 +390,7 @@ async def test_delete_video_success(self, server_instance): server_instance.r2_connector.delete_video.return_value = True # Execute - result = await server_instance.delete_video_background( + result = server_instance.delete_video_background( job_id=job_id, hashed_identifier=hashed_id, namespace=namespace @@ -414,8 +412,7 @@ async def test_delete_video_success(self, server_instance): # Verify Job Store Update server_instance.job_store.set_job_completed.assert_called_once_with(job_id, result) - @pytest.mark.asyncio - async def test_delete_video_pinecone_failure(self, server_instance): + def test_delete_video_pinecone_failure(self, server_instance): """ Scenario: Pinecone deletion fails. Expectation: @@ -430,7 +427,7 @@ async def test_delete_video_pinecone_failure(self, server_instance): server_instance.pinecone_connector.delete_by_identifier.return_value = False # Execute - result = await server_instance.delete_video_background( + result = server_instance.delete_video_background( job_id=job_id, hashed_identifier=hashed_id, namespace=namespace @@ -447,8 +444,7 @@ async def test_delete_video_pinecone_failure(self, server_instance): # Verify Job Store Update server_instance.job_store.set_job_failed.assert_called_once_with(job_id, "Failed to delete chunks from Pinecone") - @pytest.mark.asyncio - async def test_delete_video_r2_failure(self, server_instance): + def test_delete_video_r2_failure(self, server_instance): """ Scenario: Pinecone deletion succeeds, but R2 deletion fails. Expectation: @@ -464,7 +460,7 @@ async def test_delete_video_r2_failure(self, server_instance): server_instance.r2_connector.delete_video.return_value = False # Execute - result = await server_instance.delete_video_background( + result = server_instance.delete_video_background( job_id=job_id, hashed_identifier=hashed_id, namespace=namespace @@ -481,4 +477,3 @@ async def test_delete_video_r2_failure(self, server_instance): # Verify Job Store Update error_msg = "Failed to delete video from R2 after deleting chunks. System may be inconsistent." server_instance.job_store.set_job_failed.assert_called_once_with(job_id, error_msg) - diff --git a/backend/tests/integration/test_preprocessor.py b/backend/tests/integration/test_preprocessor.py index 5e94d3f..f49ecf4 100644 --- a/backend/tests/integration/test_preprocessor.py +++ b/backend/tests/integration/test_preprocessor.py @@ -1,12 +1,29 @@ import pytest import numpy as np import os +import shutil from preprocessing.preprocessor import Preprocessor +# Check if ffmpeg is available on the system +FFMPEG_AVAILABLE = shutil.which("ffmpeg") is not None +FFPROBE_AVAILABLE = shutil.which("ffprobe") is not None + +requires_ffmpeg = pytest.mark.skipif( + not FFMPEG_AVAILABLE, + reason="ffmpeg not installed on this system" +) + +requires_ffprobe = pytest.mark.skipif( + not FFPROBE_AVAILABLE, + reason="ffprobe not installed on this system" +) + + class TestEndToEndProcessing: """Test complete preprocessing pipeline.""" + @requires_ffmpeg def test_process_video_from_bytes(self, preprocessor, sample_video_bytes): """Verify processing from video bytes works.""" result = preprocessor.process_video_from_bytes( @@ -206,6 +223,7 @@ def test_metadata_cache_improves_performance(self, preprocessor, sample_video_5s class TestErrorHandling: """Test error handling and edge cases.""" + @requires_ffmpeg def test_invalid_video_bytes_raises_error(self, preprocessor): """Verify invalid video data is handled gracefully.""" with pytest.raises(RuntimeError): @@ -235,6 +253,7 @@ def test_corrupted_video_handles_gracefully(self, preprocessor, temp_dir): class TestCodecSupport: """Test codec detection and transcoding capabilities.""" + @requires_ffprobe def test_detects_av1_codec(self, preprocessor, sample_video_av1): """Verify AV1 codec is correctly identified.""" # Ensure the file exists @@ -244,18 +263,21 @@ def test_detects_av1_codec(self, preprocessor, sample_video_av1): codec = preprocessor._get_video_codec(str(sample_video_av1)) assert codec == "av1" + @requires_ffprobe def test_detects_h264_codec(self, preprocessor, sample_video_h264): """Verify H.264 codec is correctly identified.""" assert os.path.exists(sample_video_h264) codec = preprocessor._get_video_codec(str(sample_video_h264)) assert codec == "h264" + @requires_ffprobe def test_detects_vp9_codec(self, preprocessor, sample_video_vp9): """Verify VP9 codec is correctly identified.""" assert os.path.exists(sample_video_vp9) codec = preprocessor._get_video_codec(str(sample_video_vp9)) assert codec == "vp9" + @requires_ffmpeg def test_transcodes_av1_to_h264(self, preprocessor, sample_video_av1): """Verify AV1 video is transcoded and processed.""" with open(sample_video_av1, "rb") as f: @@ -275,6 +297,7 @@ def test_transcodes_av1_to_h264(self, preprocessor, sample_video_av1): # Verify chunks have frames assert all(len(chunk['frames']) > 0 for chunk in result) + @requires_ffmpeg def test_transcodes_vp9_to_h264(self, preprocessor, sample_video_vp9): """Verify VP9 video is transcoded and processed.""" with open(sample_video_vp9, "rb") as f: @@ -292,6 +315,7 @@ def test_transcodes_vp9_to_h264(self, preprocessor, sample_video_vp9): assert len(result) > 0 assert all(len(chunk['frames']) > 0 for chunk in result) + @requires_ffmpeg def test_h264_codec_skips_transcoding(self, preprocessor, sample_video_5s): """Verify non-AV1 video is processed without transcoding.""" with open(sample_video_5s, "rb") as f: diff --git a/backend/tests/integration/test_search_service.py b/backend/tests/integration/test_search_service.py new file mode 100644 index 0000000..368a77e --- /dev/null +++ b/backend/tests/integration/test_search_service.py @@ -0,0 +1,378 @@ +""" +Integration tests for SearchService. + +Tests the full search flow with mocked external dependencies. +""" + +import sys +import importlib +from typing import Any, Dict, List +from unittest.mock import MagicMock, patch + +import numpy as np +import pytest +from fastapi.testclient import TestClient + + +class FakePineconeConnector: + """Fake PineconeConnector for testing.""" + + def __init__(self, matches: List[Dict[str, Any]] | None = None): + self.matches = matches or [] + self.last_query_embedding: np.ndarray | None = None + self.last_namespace: str | None = None + self.last_top_k: int | None = None + + def query_chunks( + self, + query_embedding: np.ndarray, + namespace: str = "", + top_k: int = 5 + ) -> List[Dict[str, Any]]: + self.last_query_embedding = query_embedding + self.last_namespace = namespace + self.last_top_k = top_k + return self.matches + + +class FakeR2Connector: + """Fake R2Connector for testing.""" + + def __init__(self, presigned_urls: Dict[str, str] | None = None): + self.presigned_urls = presigned_urls or {} + self.generate_calls: List[tuple] = [] + + def generate_presigned_url( + self, + identifier: str, + validate_exists: bool = False + ) -> str | None: + self.generate_calls.append((identifier, validate_exists)) + return self.presigned_urls.get(identifier) + + +class FakeTextEmbedder: + """Fake TextEmbedder for testing.""" + + def __init__(self): + self.embed_calls: List[str] = [] + self._loaded = False + + def _load_model(self): + self._loaded = True + + def embed_text(self, text: str) -> np.ndarray: + self.embed_calls.append(text) + # Return a fake 512-d normalized embedding + embedding = np.random.randn(512).astype(np.float32) + return embedding / np.linalg.norm(embedding) + + +@pytest.fixture +def search_service_instance(mocker): + """ + Create a SearchService instance with mocked dependencies. + + Bypasses Modal decorators and manually injects mock components. + """ + # Create a mock for the modal module + mock_modal = MagicMock() + + # Configure the mock decorators to just return the original class/function + def identity_decorator(*args, **kwargs): + def wrapper(obj): + return obj + return wrapper + + mock_modal.enter.side_effect = identity_decorator + mock_modal.asgi_app.side_effect = identity_decorator + + # Mock shared.config functions + with patch.dict(sys.modules, {'modal': mock_modal}), \ + patch('shared.config.get_environment', return_value='test'), \ + patch('shared.config.get_env_var', return_value='test-value'), \ + patch('shared.config.get_pinecone_index', return_value='test-index'): + + # Import/reload the module + if 'services.search_service' in sys.modules: + import services.search_service as search_module + importlib.reload(search_module) + else: + import services.search_service as search_module + + # Create instance + service = search_module.SearchService() + + # Inject mock components + service.embedder = FakeTextEmbedder() + service.embedder._load_model() + + service.pinecone_connector = FakePineconeConnector( + matches=[ + { + 'id': 'chunk-1', + 'score': 0.95, + 'metadata': { + 'file_hashed_identifier': 'hash-abc', + 'file_name': 'video1.mp4', + 'start_time': 0.0, + 'end_time': 5.0 + } + }, + { + 'id': 'chunk-2', + 'score': 0.87, + 'metadata': { + 'file_hashed_identifier': 'hash-def', + 'file_name': 'video2.mp4', + 'start_time': 10.0, + 'end_time': 15.0 + } + }, + { + 'id': 'chunk-3', + 'score': 0.75, + 'metadata': { + # Missing file_hashed_identifier - should be skipped + 'file_name': 'video3.mp4' + } + } + ] + ) + + service.r2_connector = FakeR2Connector( + presigned_urls={ + 'hash-abc': 'https://r2.example.com/hash-abc/video1.mp4', + 'hash-def': 'https://r2.example.com/hash-def/video2.mp4', + # hash-ghi not present - will return None + } + ) + + yield service + + +@pytest.fixture +def search_test_client(search_service_instance): + """Create FastAPI test client for SearchService.""" + # Manually create the FastAPI app + service = search_service_instance + service.fastapi_app = service._create_fastapi_app() + + return TestClient(service.fastapi_app), service + + +class TestSearchServiceInternal: + """Test _search_internal method.""" + + def test_search_returns_results(self, search_service_instance): + """Verify search returns formatted results.""" + service = search_service_instance + + results = service._search_internal("woman on a train", namespace="test-ns", top_k=10) + + # Should return 2 results (chunk-3 skipped due to missing identifier) + assert len(results) == 2 + assert results[0]['id'] == 'chunk-1' + assert results[0]['score'] == 0.95 + assert results[1]['id'] == 'chunk-2' + + def test_search_generates_embeddings(self, search_service_instance): + """Verify embedder is called with query text.""" + service = search_service_instance + + service._search_internal("my search query") + + assert len(service.embedder.embed_calls) == 1 + assert service.embedder.embed_calls[0] == "my search query" + + def test_search_queries_pinecone(self, search_service_instance): + """Verify Pinecone is queried with correct parameters.""" + service = search_service_instance + + service._search_internal("test", namespace="my-namespace", top_k=20) + + assert service.pinecone_connector.last_namespace == "my-namespace" + assert service.pinecone_connector.last_top_k == 20 + assert service.pinecone_connector.last_query_embedding is not None + + def test_search_adds_presigned_urls(self, search_service_instance): + """Verify presigned URLs are added to results.""" + service = search_service_instance + + results = service._search_internal("test") + + assert results[0]['metadata']['presigned_url'] == 'https://r2.example.com/hash-abc/video1.mp4' + assert results[1]['metadata']['presigned_url'] == 'https://r2.example.com/hash-def/video2.mp4' + + def test_search_skips_missing_identifier(self, search_service_instance): + """Verify results without file_hashed_identifier are skipped.""" + service = search_service_instance + + # chunk-3 has no file_hashed_identifier + results = service._search_internal("test") + + ids = [r['id'] for r in results] + assert 'chunk-3' not in ids + + def test_search_skips_missing_presigned_url(self, search_service_instance): + """Verify results without presigned URL are skipped.""" + service = search_service_instance + + # Add a match with an identifier that has no presigned URL + service.pinecone_connector.matches.append({ + 'id': 'chunk-4', + 'score': 0.6, + 'metadata': { + 'file_hashed_identifier': 'hash-nonexistent', + 'file_name': 'video4.mp4' + } + }) + + results = service._search_internal("test") + + ids = [r['id'] for r in results] + assert 'chunk-4' not in ids + + def test_search_empty_results(self, search_service_instance): + """Verify empty results are handled.""" + service = search_service_instance + service.pinecone_connector.matches = [] + + results = service._search_internal("nonexistent query") + + assert results == [] + + def test_search_default_parameters(self, search_service_instance): + """Verify default namespace and top_k.""" + service = search_service_instance + + service._search_internal("test") + + # Default namespace should be "" + assert service.pinecone_connector.last_namespace == "" + # Default top_k should be 10 + assert service.pinecone_connector.last_top_k == 10 + + +class TestSearchServiceFastAPIApp: + """Test SearchService FastAPI app creation.""" + + def test_creates_fastapi_app(self, search_service_instance): + """Verify FastAPI app is created.""" + service = search_service_instance + app = service._create_fastapi_app() + + assert app is not None + assert app.title == "ClipABit Search API" + + def test_app_has_cors_middleware(self, search_service_instance): + """Verify CORS middleware is added.""" + service = search_service_instance + app = service._create_fastapi_app() + + # Check middleware is present + middleware_classes = [type(m).__name__ for m in app.user_middleware] + # Note: CORS middleware may appear differently in the list + # The important thing is the app is configured + assert app is not None + assert "Middleware" in middleware_classes + + +class TestSearchServiceHTTPEndpoints: + """Test SearchService HTTP endpoints via TestClient.""" + + def test_health_endpoint(self, search_test_client): + """Verify /health endpoint works.""" + client, _ = search_test_client + + resp = client.get("/health") + + assert resp.status_code == 200 + assert resp.json()["status"] == "ok" + assert resp.json()["service"] == "search" + + def test_search_endpoint(self, search_test_client): + """Verify /search endpoint returns results.""" + client, service = search_test_client + + resp = client.get("/search", params={"query": "test query"}) + + assert resp.status_code == 200 + data = resp.json() + assert data["query"] == "test query" + assert len(data["results"]) == 2 + assert "timing" in data + + def test_search_with_namespace(self, search_test_client): + """Verify namespace is passed through.""" + client, service = search_test_client + + client.get("/search", params={"query": "test", "namespace": "custom-ns"}) + + assert service.pinecone_connector.last_namespace == "custom-ns" + + def test_search_with_top_k(self, search_test_client): + """Verify top_k is passed through.""" + client, service = search_test_client + + client.get("/search", params={"query": "test", "top_k": 25}) + + assert service.pinecone_connector.last_top_k == 25 + + def test_search_missing_query(self, search_test_client): + """Verify missing query returns 422.""" + client, _ = search_test_client + + resp = client.get("/search") + + assert resp.status_code == 422 + + +class TestSearchServiceResultFormatting: + """Test result formatting in SearchService.""" + + def test_result_structure(self, search_service_instance): + """Verify result dictionary structure.""" + service = search_service_instance + + results = service._search_internal("test") + + result = results[0] + assert 'id' in result + assert 'score' in result + assert 'metadata' in result + assert 'presigned_url' in result['metadata'] + + def test_preserves_original_metadata(self, search_service_instance): + """Verify original metadata is preserved.""" + service = search_service_instance + + results = service._search_internal("test") + + metadata = results[0]['metadata'] + assert metadata['file_name'] == 'video1.mp4' + assert metadata['start_time'] == 0.0 + assert metadata['end_time'] == 5.0 + assert metadata['file_hashed_identifier'] == 'hash-abc' + + def test_score_is_float(self, search_service_instance): + """Verify score is a float.""" + service = search_service_instance + + results = service._search_internal("test") + + assert isinstance(results[0]['score'], float) + + +class TestSearchServiceWithNoR2Connector: + """Test SearchService behavior when R2 connector is not available.""" + + def test_handles_none_r2_connector(self, search_service_instance): + """Verify graceful handling when r2_connector is None.""" + service = search_service_instance + service.r2_connector = None + + results = service._search_internal("test") + + # All results should be skipped since no presigned URLs can be generated + assert results == [] diff --git a/backend/tests/unit/test_r2_connector.py b/backend/tests/unit/test_r2_connector.py index 7066971..6633bdd 100644 --- a/backend/tests/unit/test_r2_connector.py +++ b/backend/tests/unit/test_r2_connector.py @@ -142,7 +142,7 @@ def test_fetch_page_success(self, mock_r2_connector): mock_client.list_objects_v2.assert_called_once_with( Bucket="test", Prefix="ns/", - MaxKeys=3, + MaxKeys=2, # Uses exact page_size ) for call in mock_client.generate_presigned_url.call_args_list: kwargs = call.kwargs @@ -157,33 +157,50 @@ def test_fetch_page_handles_error(self, mock_r2_connector): assert videos == [] assert token is None - def test_fetch_page_cursor_fallback(self, mock_r2_connector): + def test_fetch_page_no_more_pages(self, mock_r2_connector): + """Test that no next_token is returned when S3 indicates no more pages.""" connector, mock_client, _ = mock_r2_connector mock_client.list_objects_v2.return_value = { 'Contents': [ {'Key': 'ns/vid1.mp4'}, {'Key': 'ns/vid2.mp4'}, - {'Key': 'ns/vid3.mp4'}, ], - 'IsTruncated': False, + 'IsTruncated': False, # S3 says no more pages } mock_client.generate_presigned_url.return_value = "http://url" videos, token = connector.fetch_video_page(namespace="ns", page_size=2) assert len(videos) == 2 - assert token is not None - assert token.startswith("cursor:") + assert token is None # No more pages + + def test_fetch_page_with_continuation_token(self, mock_r2_connector): + """Test that continuation tokens are passed to S3 correctly.""" + connector, mock_client, _ = mock_r2_connector - mock_client.list_objects_v2.reset_mock() - connector.fetch_video_page(namespace="ns", page_size=2, continuation_token=token) + mock_client.list_objects_v2.return_value = { + 'Contents': [ + {'Key': 'ns/vid3.mp4'}, + {'Key': 'ns/vid4.mp4'}, + ], + 'IsTruncated': False, + } + mock_client.generate_presigned_url.return_value = "http://url" + + videos, token = connector.fetch_video_page( + namespace="ns", + page_size=2, + continuation_token="s3-native-token" + ) + mock_client.list_objects_v2.assert_called_once_with( Bucket="test", Prefix="ns/", - MaxKeys=3, - StartAfter='ns/vid2.mp4', + MaxKeys=2, # Uses exact page_size + ContinuationToken="s3-native-token", ) + assert len(videos) == 2 for call in mock_client.generate_presigned_url.call_args_list: kwargs = call.kwargs assert kwargs['ExpiresIn'] == DEFAULT_PRESIGNED_URL_TTL diff --git a/backend/tests/unit/test_search_fastapi_router.py b/backend/tests/unit/test_search_fastapi_router.py new file mode 100644 index 0000000..415c742 --- /dev/null +++ b/backend/tests/unit/test_search_fastapi_router.py @@ -0,0 +1,212 @@ +""" +Unit tests for SearchFastAPIRouter. + +Tests the search API endpoint with mocked SearchService. +""" + +from typing import Any, List, Dict + +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient + +from api.search_fastapi_router import SearchFastAPIRouter + + +class FakeSearchService: + """Fake SearchService for testing the router.""" + + def __init__(self, results: List[Dict[str, Any]] | None = None): + self.results = results or [] + self.last_query: str | None = None + self.last_namespace: str | None = None + self.last_top_k: int | None = None + self.should_raise: Exception | None = None + + def _search_internal(self, query: str, namespace: str = "", top_k: int = 10) -> List[Dict[str, Any]]: + """Mock search implementation.""" + self.last_query = query + self.last_namespace = namespace + self.last_top_k = top_k + + if self.should_raise: + raise self.should_raise + + return self.results + + +@pytest.fixture() +def search_service() -> FakeSearchService: + """Create a fake search service with sample results.""" + return FakeSearchService( + results=[ + { + "id": "chunk-1", + "score": 0.95, + "metadata": { + "file_hashed_identifier": "abc123", + "file_name": "video1.mp4", + "start_time": 0.0, + "end_time": 5.0, + "presigned_url": "https://example.com/video1.mp4" + } + }, + { + "id": "chunk-2", + "score": 0.87, + "metadata": { + "file_hashed_identifier": "def456", + "file_name": "video2.mp4", + "start_time": 10.0, + "end_time": 15.0, + "presigned_url": "https://example.com/video2.mp4" + } + } + ] + ) + + +@pytest.fixture() +def test_client(search_service: FakeSearchService) -> tuple[TestClient, FakeSearchService]: + """Create FastAPI test client with search router.""" + app = FastAPI() + router = SearchFastAPIRouter(search_service_instance=search_service) + app.include_router(router.router) + return TestClient(app), search_service + + +class TestHealthEndpoint: + """Test /health endpoint.""" + + def test_health_returns_ok(self, test_client: tuple[TestClient, FakeSearchService]) -> None: + """Verify health check returns ok status.""" + client, _ = test_client + resp = client.get("/health") + + assert resp.status_code == 200 + data = resp.json() + assert data["status"] == "ok" + assert data["service"] == "search" + + +class TestSearchEndpoint: + """Test /search endpoint.""" + + def test_search_with_query_returns_results(self, test_client: tuple[TestClient, FakeSearchService]) -> None: + """Verify search returns results for a query.""" + client, service = test_client + resp = client.get("/search", params={"query": "woman on a train"}) + + assert resp.status_code == 200 + data = resp.json() + + assert data["query"] == "woman on a train" + assert len(data["results"]) == 2 + assert data["results"][0]["id"] == "chunk-1" + assert data["results"][0]["score"] == 0.95 + assert "timing" in data + assert "total_s" in data["timing"] + + def test_search_passes_query_to_service(self, test_client: tuple[TestClient, FakeSearchService]) -> None: + """Verify query is passed correctly to service.""" + client, service = test_client + client.get("/search", params={"query": "test query"}) + + assert service.last_query == "test query" + + def test_search_with_namespace(self, test_client: tuple[TestClient, FakeSearchService]) -> None: + """Verify namespace parameter is passed to service.""" + client, service = test_client + client.get("/search", params={"query": "test", "namespace": "my-namespace"}) + + assert service.last_namespace == "my-namespace" + + def test_search_with_default_namespace(self, test_client: tuple[TestClient, FakeSearchService]) -> None: + """Verify default namespace is empty string.""" + client, service = test_client + client.get("/search", params={"query": "test"}) + + assert service.last_namespace == "" + + def test_search_with_top_k(self, test_client: tuple[TestClient, FakeSearchService]) -> None: + """Verify top_k parameter is passed to service.""" + client, service = test_client + client.get("/search", params={"query": "test", "top_k": 20}) + + assert service.last_top_k == 20 + + def test_search_with_default_top_k(self, test_client: tuple[TestClient, FakeSearchService]) -> None: + """Verify default top_k is 10.""" + client, service = test_client + client.get("/search", params={"query": "test"}) + + assert service.last_top_k == 10 + + def test_search_with_all_parameters(self, test_client: tuple[TestClient, FakeSearchService]) -> None: + """Verify all parameters are passed correctly.""" + client, service = test_client + client.get("/search", params={ + "query": "my search", + "namespace": "custom-ns", + "top_k": 5 + }) + + assert service.last_query == "my search" + assert service.last_namespace == "custom-ns" + assert service.last_top_k == 5 + + def test_search_missing_query_returns_error(self, test_client: tuple[TestClient, FakeSearchService]) -> None: + """Verify missing query parameter returns 422.""" + client, _ = test_client + resp = client.get("/search") + + assert resp.status_code == 422 # FastAPI validation error + + def test_search_empty_results(self, test_client: tuple[TestClient, FakeSearchService]) -> None: + """Verify empty results are handled correctly.""" + client, service = test_client + service.results = [] + + resp = client.get("/search", params={"query": "nonexistent"}) + + assert resp.status_code == 200 + data = resp.json() + assert data["results"] == [] + + def test_search_service_error_returns_500(self, test_client: tuple[TestClient, FakeSearchService]) -> None: + """Verify service errors return 500.""" + client, service = test_client + service.should_raise = Exception("Database connection failed") + + resp = client.get("/search", params={"query": "test"}) + + assert resp.status_code == 500 + assert "Database connection failed" in resp.json()["detail"] + + def test_search_timing_is_positive(self, test_client: tuple[TestClient, FakeSearchService]) -> None: + """Verify timing is a positive number.""" + client, _ = test_client + resp = client.get("/search", params={"query": "test"}) + + data = resp.json() + assert data["timing"]["total_s"] >= 0 + + +class TestRouterInitialization: + """Test SearchFastAPIRouter initialization.""" + + def test_router_stores_service_instance(self) -> None: + """Verify router stores the service instance.""" + service = FakeSearchService() + router = SearchFastAPIRouter(search_service_instance=service) + + assert router.search_service is service + + def test_router_registers_routes(self) -> None: + """Verify router registers expected routes.""" + service = FakeSearchService() + router = SearchFastAPIRouter(search_service_instance=service) + + routes = [route.path for route in router.router.routes] + assert "/health" in routes + assert "/search" in routes diff --git a/backend/tests/unit/test_text_embedder.py b/backend/tests/unit/test_text_embedder.py new file mode 100644 index 0000000..6aead60 --- /dev/null +++ b/backend/tests/unit/test_text_embedder.py @@ -0,0 +1,343 @@ +""" +Unit tests for TextEmbedder (ONNX-based). + +Tests the text embedding functionality with mocked ONNX runtime and tokenizer. +""" + +import sys +from unittest.mock import MagicMock, patch +import numpy as np +import pytest + +from search.text_embedder import TextEmbedder, DEFAULT_ONNX_MODEL_PATH, DEFAULT_TOKENIZER_PATH + + +class FakeEncoding: + """Fake tokenizer encoding result.""" + + def __init__(self, ids: list[int], attention_mask: list[int]): + self.ids = ids + self.attention_mask = attention_mask + + +class FakeTokenizer: + """Fake tokenizer for testing.""" + + def __init__(self): + self.padding_enabled = False + self.truncation_enabled = False + self.padding_length = None + self.max_length = None + + def enable_padding(self, length: int, pad_id: int): + self.padding_enabled = True + self.padding_length = length + + def enable_truncation(self, max_length: int): + self.truncation_enabled = True + self.max_length = max_length + + def encode_batch(self, texts: list[str]) -> list[FakeEncoding]: + """Return fake encodings for batch.""" + return [ + FakeEncoding( + ids=[101] + [1000 + i for i in range(min(len(text.split()), 75))] + [102] + [0] * max(0, 77 - min(len(text.split()), 75) - 2), + attention_mask=[1] * (min(len(text.split()), 75) + 2) + [0] * max(0, 77 - min(len(text.split()), 75) - 2) + ) + for text in texts + ] + + +class FakeOnnxSession: + """Fake ONNX inference session.""" + + def __init__(self): + self.run_calls = [] + + def run(self, output_names, inputs): + self.run_calls.append((output_names, inputs)) + batch_size = inputs["input_ids"].shape[0] + # Return random 512-d embeddings + embeddings = np.random.randn(batch_size, 512).astype(np.float32) + return [embeddings] + + +@pytest.fixture +def mock_onnx_and_tokenizer(): + """Mock onnxruntime and tokenizers modules.""" + mock_ort = MagicMock() + mock_tokenizers = MagicMock() + + fake_session = FakeOnnxSession() + fake_tokenizer = FakeTokenizer() + + mock_ort.SessionOptions.return_value = MagicMock() + mock_ort.GraphOptimizationLevel.ORT_ENABLE_ALL = "ORT_ENABLE_ALL" + mock_ort.InferenceSession.return_value = fake_session + + mock_tokenizers.Tokenizer.from_file.return_value = fake_tokenizer + + with patch.dict(sys.modules, { + 'onnxruntime': mock_ort, + 'tokenizers': mock_tokenizers, + }): + yield mock_ort, mock_tokenizers, fake_session, fake_tokenizer + + +@pytest.fixture +def embedder_with_mocks(mock_onnx_and_tokenizer): + """Create TextEmbedder with mocked dependencies.""" + mock_ort, mock_tokenizers, fake_session, fake_tokenizer = mock_onnx_and_tokenizer + + embedder = TextEmbedder( + model_path="/test/model.onnx", + tokenizer_path="/test/tokenizer.json" + ) + embedder._load_model() + + return embedder, fake_session, fake_tokenizer, mock_ort, mock_tokenizers + + +class TestTextEmbedderInitialization: + """Test TextEmbedder initialization.""" + + def test_uses_default_paths(self): + """Verify default paths are set correctly.""" + embedder = TextEmbedder() + + assert embedder.model_path == DEFAULT_ONNX_MODEL_PATH + assert embedder.tokenizer_path == DEFAULT_TOKENIZER_PATH + + def test_accepts_custom_paths(self): + """Verify custom paths are stored.""" + embedder = TextEmbedder( + model_path="/custom/model.onnx", + tokenizer_path="/custom/tokenizer.json" + ) + + assert embedder.model_path == "/custom/model.onnx" + assert embedder.tokenizer_path == "/custom/tokenizer.json" + + def test_session_and_tokenizer_not_loaded_on_init(self): + """Verify lazy loading - session/tokenizer not loaded until needed.""" + embedder = TextEmbedder() + + assert embedder.session is None + assert embedder.tokenizer is None + + +class TestModelLoading: + """Test model loading behavior.""" + + def test_load_model_sets_session(self, mock_onnx_and_tokenizer): + """Verify _load_model sets the session.""" + mock_ort, mock_tokenizers, fake_session, fake_tokenizer = mock_onnx_and_tokenizer + + embedder = TextEmbedder( + model_path="/test/model.onnx", + tokenizer_path="/test/tokenizer.json" + ) + embedder._load_model() + + assert embedder.session is fake_session + assert embedder.tokenizer is fake_tokenizer + + def test_load_model_configures_tokenizer_padding(self, mock_onnx_and_tokenizer): + """Verify tokenizer padding is configured for CLIP (77 tokens).""" + _, _, _, fake_tokenizer = mock_onnx_and_tokenizer + + embedder = TextEmbedder() + embedder._load_model() + + assert fake_tokenizer.padding_enabled is True + assert fake_tokenizer.padding_length == 77 + + def test_load_model_configures_tokenizer_truncation(self, mock_onnx_and_tokenizer): + """Verify tokenizer truncation is configured for CLIP (77 tokens).""" + _, _, _, fake_tokenizer = mock_onnx_and_tokenizer + + embedder = TextEmbedder() + embedder._load_model() + + assert fake_tokenizer.truncation_enabled is True + assert fake_tokenizer.max_length == 77 + + def test_load_model_only_loads_once(self, mock_onnx_and_tokenizer): + """Verify model is only loaded once even if _load_model called multiple times.""" + mock_ort, mock_tokenizers, _, _ = mock_onnx_and_tokenizer + + embedder = TextEmbedder() + embedder._load_model() + embedder._load_model() + embedder._load_model() + + # Should only be called once + assert mock_ort.InferenceSession.call_count == 1 + assert mock_tokenizers.Tokenizer.from_file.call_count == 1 + + +class TestEmbedText: + """Test text embedding functionality.""" + + def test_embed_single_text_returns_1d_array(self, embedder_with_mocks): + """Verify single text returns 1D array.""" + embedder, _, _, _, _ = embedder_with_mocks + + result = embedder.embed_text("hello world") + + assert isinstance(result, np.ndarray) + assert result.ndim == 1 + assert result.shape == (512,) + + def test_embed_batch_returns_2d_array(self, embedder_with_mocks): + """Verify batch input returns 2D array.""" + embedder, _, _, _, _ = embedder_with_mocks + + result = embedder.embed_text(["hello", "world", "test"]) + + assert isinstance(result, np.ndarray) + assert result.ndim == 2 + assert result.shape == (3, 512) + + def test_embed_text_is_normalized(self, embedder_with_mocks): + """Verify output is L2 normalized.""" + embedder, _, _, _, _ = embedder_with_mocks + + result = embedder.embed_text("hello world") + norm = np.linalg.norm(result) + + assert np.isclose(norm, 1.0, atol=1e-5) + + def test_embed_batch_all_normalized(self, embedder_with_mocks): + """Verify all batch embeddings are L2 normalized.""" + embedder, _, _, _, _ = embedder_with_mocks + + result = embedder.embed_text(["hello", "world", "test"]) + + for i in range(3): + norm = np.linalg.norm(result[i]) + assert np.isclose(norm, 1.0, atol=1e-5) + + def test_embed_text_calls_tokenizer(self, embedder_with_mocks): + """Verify tokenizer is called with input text.""" + embedder, _, fake_tokenizer, _, _ = embedder_with_mocks + + # Spy on encode_batch + original_encode = fake_tokenizer.encode_batch + call_args = [] + + def spy_encode(texts): + call_args.append(texts) + return original_encode(texts) + + fake_tokenizer.encode_batch = spy_encode + + embedder.embed_text("test query") + + assert len(call_args) == 1 + assert call_args[0] == ["test query"] + + def test_embed_text_calls_session_run(self, embedder_with_mocks): + """Verify ONNX session.run is called.""" + embedder, fake_session, _, _, _ = embedder_with_mocks + + embedder.embed_text("test") + + assert len(fake_session.run_calls) == 1 + _, inputs = fake_session.run_calls[0] + assert "input_ids" in inputs + assert "attention_mask" in inputs + + def test_embed_text_input_ids_shape(self, embedder_with_mocks): + """Verify input_ids has correct shape (batch, 77).""" + embedder, fake_session, _, _, _ = embedder_with_mocks + + embedder.embed_text("test") + + _, inputs = fake_session.run_calls[0] + assert inputs["input_ids"].shape[0] == 1 # batch size + assert inputs["attention_mask"].shape[0] == 1 + + +class TestEdgeCases: + """Test edge cases and error handling.""" + + def test_embed_empty_string(self, embedder_with_mocks): + """Verify empty string can be embedded.""" + embedder, _, _, _, _ = embedder_with_mocks + + result = embedder.embed_text("") + + assert result.shape == (512,) + assert np.isclose(np.linalg.norm(result), 1.0, atol=1e-5) + + def test_embed_single_item_list(self, embedder_with_mocks): + """Verify single-item list returns 2D array.""" + embedder, _, _, _, _ = embedder_with_mocks + + result = embedder.embed_text(["single"]) + + assert result.ndim == 2 + assert result.shape == (1, 512) + + def test_embed_long_text_handled(self, embedder_with_mocks): + """Verify long text is handled (truncation configured).""" + embedder, _, _, _, _ = embedder_with_mocks + + # Very long text that would exceed 77 tokens + long_text = " ".join(["word"] * 200) + result = embedder.embed_text(long_text) + + assert result.shape == (512,) + + def test_embed_special_characters(self, embedder_with_mocks): + """Verify special characters are handled.""" + embedder, _, _, _, _ = embedder_with_mocks + + result = embedder.embed_text("hello! @#$%^&*() world?") + + assert result.shape == (512,) + + +class TestMultipleOutputFormats: + """Test handling of different ONNX output formats.""" + + def test_handles_multiple_outputs_finds_512d(self, mock_onnx_and_tokenizer): + """Verify correct output is selected when model returns multiple outputs.""" + mock_ort, mock_tokenizers, _, fake_tokenizer = mock_onnx_and_tokenizer + + # Create session that returns multiple outputs + class MultiOutputSession: + def run(self, output_names, inputs): + return [ + np.random.randn(1, 77, 768).astype(np.float32), # Hidden states + np.random.randn(1, 512).astype(np.float32), # Text embeds (correct) + np.random.randn(1, 768).astype(np.float32), # Other output + ] + + mock_ort.InferenceSession.return_value = MultiOutputSession() + + embedder = TextEmbedder() + result = embedder.embed_text("test") + + # Should find the 512-d output + assert result.shape == (512,) + + def test_fallback_to_first_output(self, mock_onnx_and_tokenizer): + """Verify fallback to first output if no 512-d output found.""" + mock_ort, mock_tokenizers, _, fake_tokenizer = mock_onnx_and_tokenizer + + # Create session that returns unexpected shape + class UnexpectedOutputSession: + def run(self, output_names, inputs): + return [ + np.random.randn(1, 768).astype(np.float32), # Wrong dimension + ] + + mock_ort.InferenceSession.return_value = UnexpectedOutputSession() + + embedder = TextEmbedder() + result = embedder.embed_text("test") + + # Falls back to first output (768-d) + assert result.shape == (768,) diff --git a/frontend/plugin/README.md b/frontend/plugin/README.md deleted file mode 100644 index e69de29..0000000 diff --git a/frontend/streamlit/config.py b/frontend/streamlit/config.py index 7678207..0bf7d9d 100644 --- a/frontend/streamlit/config.py +++ b/frontend/streamlit/config.py @@ -17,36 +17,39 @@ class Config: # Environment (defaults to "dev") ENVIRONMENT = os.environ.get("ENVIRONMENT", "dev") + IS_FILE_CHANGE_ENABLED = os.environ.get("IS_FILE_CHANGE_ENABLED", "true").lower() == "true" # Validate environment if ENVIRONMENT not in ["dev", "prod", "staging"]: raise ValueError(f"Invalid ENVIRONMENT value: {ENVIRONMENT}. Must be one of: dev, prod, staging") - + print(f"Running in {ENVIRONMENT} environment") # Modal app name (matches backend app name) APP_NAME = f"clipabit-{ENVIRONMENT}" - + # Determine url portion based on environment - url_portion = "" if ENVIRONMENT in ["prod", "staging"] else f"-{ENVIRONMENT}" + url_portion = "dev" if ENVIRONMENT == "dev" else "" + url_portion2 = "-dev" if ENVIRONMENT == "dev" else "" + - # Base URL for single ASGI app exposed via Modal - # Pattern: https://clipabit01--{env}-server-asgi-app-{env}.modal.run (dev/staging) - BASE_API_URL = f"https://clipabit01--{ENVIRONMENT}-server-asgi-app{url_portion}.modal.run" + # Server API URL (handles upload, status, videos, delete, cache) + SERVER_BASE_URL = f"https://clipabit01--{ENVIRONMENT}-server-{url_portion}server-asgi-app{url_portion2}.modal.run" - # API Endpoints routed through the single FastAPI app - SEARCH_API_URL = f"{BASE_API_URL}/search" - UPLOAD_API_URL = f"{BASE_API_URL}/upload" - STATUS_API_URL = f"{BASE_API_URL}/status" - LIST_VIDEOS_API_URL = f"{BASE_API_URL}/videos" - DELETE_VIDEO_API_URL = f"{BASE_API_URL}/videos/{{hashed_identifier}}" # with path param on call + # Search API URL (in dev its server-searchservice-asgi-app, else its search-searchservice-asgi-app) + SEARCH_BASE_URL = f"https://clipabit01--{ENVIRONMENT}-{"server" if ENVIRONMENT == "dev" else "search"}-searchservice-asgi-app{url_portion2}.modal.run" + + # API Endpoints + SERVER_UPLOAD_URL = f"{SERVER_BASE_URL}/upload" + SERVER_STATUS_URL = f"{SERVER_BASE_URL}/status" + SEARCH_STATUS_URL = f"{SEARCH_BASE_URL}/status" + SEARCH_SEARCH_URL = f"{SEARCH_BASE_URL}/search" + SERVER_LIST_VIDEOS_URL = f"{SERVER_BASE_URL}/videos" + SERVER_DELETE_VIDEO_URL = f"{SERVER_BASE_URL}/videos/{{hashed_identifier}}" # Namespace for Pinecone and R2 (web-demo for public demo) NAMESPACE = "web-demo" - # Flag to indicate if running in internal environment - IS_INTERNAL_ENV = ENVIRONMENT in ["dev", "staging"] - @classmethod def get_config(cls): """Get configuration as a dictionary.""" @@ -57,14 +60,17 @@ def get_config(cls): "namespace": cls.NAMESPACE, # Flags - "is_internal_env": cls.IS_INTERNAL_ENV, + "is_file_change_enable": cls.IS_FILE_CHANGE_ENABLED, # API Endpoints - "search_api_url": cls.SEARCH_API_URL, - "upload_api_url": cls.UPLOAD_API_URL, - "status_api_url": cls.STATUS_API_URL, - "list_videos_api_url": cls.LIST_VIDEOS_API_URL, - "delete_video_api_url": cls.DELETE_VIDEO_API_URL, + "server_base_url": cls.SERVER_BASE_URL, + "search_base_url": cls.SEARCH_BASE_URL, + "server_upload_url": cls.SERVER_UPLOAD_URL, + "server_status_url": cls.SERVER_STATUS_URL, + "search_status_url": cls.SEARCH_STATUS_URL, + "search_search_url": cls.SEARCH_SEARCH_URL, + "server_list_videos_url": cls.SERVER_LIST_VIDEOS_URL, + "server_delete_video_url": cls.SERVER_DELETE_VIDEO_URL, } @classmethod @@ -73,6 +79,7 @@ def print_config_partial(cls): config = cls.get_config() logger.info("Current Configuration:") logger.info(f" Environment: {config['environment']}") + logger.info(f" File Change Enabled: {config['is_file_change_enable']}") logger.info(f" App Name: {config['app_name']}") logger.info(f" Namespace: {config['namespace']}") diff --git a/frontend/streamlit/pages/search_demo.py b/frontend/streamlit/pages/search_demo.py index 1aa7b87..75aab08 100644 --- a/frontend/streamlit/pages/search_demo.py +++ b/frontend/streamlit/pages/search_demo.py @@ -35,14 +35,16 @@ st.session_state.repo_action = None # Configs -SEARCH_API_URL = Config.SEARCH_API_URL -UPLOAD_API_URL = Config.UPLOAD_API_URL -STATUS_API_URL = Config.STATUS_API_URL -LIST_VIDEOS_API_URL = Config.LIST_VIDEOS_API_URL -DELETE_VIDEO_API_URL = Config.DELETE_VIDEO_API_URL NAMESPACE = Config.NAMESPACE ENVIRONMENT = Config.ENVIRONMENT -IS_INTERNAL_ENV = Config.IS_INTERNAL_ENV +IS_FILE_CHANGE_ENABLED = Config.IS_FILE_CHANGE_ENABLED + +# API Endpoints +SEARCH_SEARCH_URL = Config.SEARCH_SEARCH_URL +SERVER_UPLOAD_URL = Config.SERVER_UPLOAD_URL +SERVER_STATUS_URL = Config.SERVER_STATUS_URL +SERVER_LIST_VIDEOS_URL = Config.SERVER_LIST_VIDEOS_URL +SERVER_DELETE_VIDEO_URL = Config.SERVER_DELETE_VIDEO_URL def set_repo_action(action: str) -> None: @@ -53,7 +55,7 @@ def set_repo_action(action: str) -> None: def search_videos(query: str): """Send search query to backend.""" try: - resp = requests.get(SEARCH_API_URL, params={"query": query, "namespace": NAMESPACE}, timeout=30) + resp = requests.get(SEARCH_SEARCH_URL, params={"query": query, "namespace": NAMESPACE}, timeout=30) if resp.status_code == 200: return resp.json() else: @@ -71,7 +73,7 @@ def fetch_videos_page(page_token: str | None = None, page_size: int = REPO_PAGE_ params["page_token"] = page_token try: - resp = requests.get(LIST_VIDEOS_API_URL, params=params, timeout=30) + resp = requests.get(SERVER_LIST_VIDEOS_URL, params=params, timeout=30) if resp.status_code == 200: data = resp.json() return { @@ -142,7 +144,7 @@ def upload_files_to_backend(files_data: list[tuple[bytes, str, str]]): # Handles large batches: 50 files = 1800s (30min), 200 files = 6300s (105min) timeout = max(600, 300 + (len(files) * 30)) - resp = requests.post(UPLOAD_API_URL, files=files, data=data, timeout=timeout) + resp = requests.post(SERVER_UPLOAD_URL, files=files, data=data, timeout=timeout) return resp @@ -151,7 +153,7 @@ def poll_job_status(job_id: str, max_wait: int = 300): start_time = time.time() while time.time() - start_time < max_wait: try: - resp = requests.get(STATUS_API_URL, params={"job_id": job_id}, timeout=30) + resp = requests.get(SERVER_STATUS_URL, params={"job_id": job_id}, timeout=30) if resp.status_code == 200: data = resp.json() status = data.get("status") @@ -168,13 +170,13 @@ def poll_job_status(job_id: str, max_wait: int = 300): def delete_video(hashed_identifier: str, filename: str): """Delete video via API call.""" - if not IS_INTERNAL_ENV: + if not IS_FILE_CHANGE_ENABLED: st.toast(f"Deletion not allowed in {ENVIRONMENT} environment", icon="🚫") return try: resp = requests.delete( - DELETE_VIDEO_API_URL.format(hashed_identifier=hashed_identifier), + SERVER_DELETE_VIDEO_URL.format(hashed_identifier=hashed_identifier), params={ "filename": filename, "namespace": NAMESPACE @@ -343,7 +345,7 @@ def delete_confirmation_dialog(hashed_identifier: str, filename: str): up_col1, up_col2, up_col3 = st.columns([1, 1, 6]) # upload button in internal envs else info text -if IS_INTERNAL_ENV: +if IS_FILE_CHANGE_ENABLED: with up_col1: if st.button("Upload", disabled=(False), use_container_width=True): upload_dialog() @@ -415,7 +417,7 @@ def delete_confirmation_dialog(hashed_identifier: str, filename: str): score = result.get("score", 0) # Video info and delete button row - if IS_INTERNAL_ENV: + if IS_FILE_CHANGE_ENABLED: info_col, delete_col = st.columns([3, 1]) # NOTE: reenable with delete with info_col: @@ -488,10 +490,13 @@ def delete_confirmation_dialog(hashed_identifier: str, filename: str): prev_disabled = st.session_state.repo_loading or current_page_idx <= 0 if st.session_state.repo_loading: next_disabled = True - elif total_pages > 0: - next_disabled = (current_page_idx + 1) >= total_pages else: - next_disabled = True + # Enable next button if: + # 1. The next page is already cached in repo_pages, OR + # 2. There's a repo_next_token indicating more pages can be fetched + has_cached_next_page = (current_page_idx + 1) < total_loaded_pages + has_more_pages_to_fetch = st.session_state.repo_next_token is not None + next_disabled = not (has_cached_next_page or has_more_pages_to_fetch) nav_info_col, nav_prev_col, nav_next_col = st.columns([6, 0.3, 0.3]) @@ -537,7 +542,7 @@ def delete_confirmation_dialog(hashed_identifier: str, filename: str): for idx, video in enumerate(videos): with cols[idx%3]: # Video info and delete button row - if IS_INTERNAL_ENV: + if IS_FILE_CHANGE_ENABLED: info_col, delete_col = st.columns([3, 1]) with info_col: with st.expander("Info"): diff --git a/installer b/installer new file mode 160000 index 0000000..fdbada0 --- /dev/null +++ b/installer @@ -0,0 +1 @@ +Subproject commit fdbada01498b3d2fcd1ffb9989fa63b0738074db