Skip to content

Commit 1e309ac

Browse files
authored
Merge pull request #11 from ClipABit/videoembed
Integrated embedding logic to the process/upload video pipeline
2 parents 267b60b + 6f095f1 commit 1e309ac

File tree

3 files changed

+107
-4
lines changed

3 files changed

+107
-4
lines changed

backend/embeddings/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,5 @@
1-
# Make embeddings a proper Python package
1+
# Make embeddings a proper Python package
2+
3+
from .embedder import VideoEmbedder
4+
5+
__all__ = ["VideoEmbedder"]

backend/embeddings/embedder.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import torch
2+
import numpy as np
3+
from PIL import Image
4+
from transformers import (
5+
CLIPModel,
6+
CLIPProcessor
7+
)
8+
9+
10+
class VideoEmbedder:
11+
"""
12+
A class to handle video embedding generation using various models.
13+
"""
14+
def __init__(self):
15+
self._device = "cuda" if torch.cuda.is_available() else "cpu"
16+
self._clip_model = None
17+
self._clip_processor = None
18+
self._get_clip_model()
19+
20+
def _get_clip_model(self):
21+
"""Lazily load and return CLIP model + processor."""
22+
if self._clip_model is None or self._clip_processor is None:
23+
print("Loading CLIP model into memory...")
24+
self._clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(
25+
self._device
26+
)
27+
self._clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
28+
return self._clip_model, self._clip_processor
29+
30+
def _generate_clip_embedding(self, frames, num_frames: int = 8) -> torch.Tensor:
31+
"""
32+
Generate a single embedding for a video chunk by averaging the normalized
33+
embeddings of sampled frames using the Open AI CLIP Model.
34+
Args:
35+
processed_chunk (Dict[str, Any]): The processed video chunk object.
36+
num_frames (int): Number of frames to sample evenly across the video.
37+
38+
Returns:
39+
torch.Tensor: A single, normalized embedding tensor for the video chunk.
40+
"""
41+
42+
# Fetch the preloaded model and processor
43+
model, processor = self._get_clip_model()
44+
45+
# Sample frames evenly across the video if the num frames is greater than available frames
46+
num_frames = min(num_frames, frames.shape[0])
47+
frame_indices = np.linspace(0, frames.shape[0] - 1, num_frames).astype(int)
48+
sampled_frames = [Image.fromarray(frames[idx]) for idx in frame_indices]
49+
50+
# Transform the frame data to match the standard dimensions and normalization of the pixel values to the ranges
51+
# of the data the model was trained on.
52+
inputs = processor(images=sampled_frames, return_tensors="pt", size=224).to(self._device)
53+
54+
with torch.no_grad():
55+
frame_features = model.get_image_features(**inputs)
56+
frame_features = frame_features / frame_features.norm(p=2, dim=-1, keepdim=True)
57+
58+
video_embedding = frame_features.mean(dim=0)
59+
video_embedding = video_embedding / video_embedding.norm(p=2, dim=-1, keepdim=True)
60+
61+
62+
return video_embedding.cpu()
63+

backend/main.py

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ def startup(self):
4848

4949
# Import classes here
5050
from preprocessing.preprocessor import Preprocessor
51+
from embeddings.embedder import VideoEmbedder
5152
from database.pinecone_connector import PineconeConnector
5253
from database.job_store_connector import JobStoreConnector
5354
from database.r2_connector import R2Connector
@@ -81,6 +82,7 @@ def startup(self):
8182
# Instantiate classes
8283

8384
self.preprocessor = Preprocessor(min_chunk_duration=1.0, max_chunk_duration=10.0, scene_threshold=13.0)
85+
self.video_embedder = VideoEmbedder()
8486
self.pinecone_connector = PineconeConnector(api_key=PINECONE_API_KEY, index_name=PINECONE_CHUNKS_INDEX)
8587
self.job_store = JobStoreConnector(dict_name="clipabit-jobs")
8688
self.r2_connector = R2Connector(account_id=R2_ACCOUNT_ID,
@@ -128,11 +130,45 @@ async def process_video(self, video_bytes: bytes, filename: str, job_id: str):
128130
# Prepare chunk details for response (without frame arrays)
129131
chunk_details = []
130132
for chunk in processed_chunks:
133+
embedding = self.video_embedder._generate_clip_embedding(chunk["frames"], num_frames=8)
134+
135+
logger.info(f"[Job {job_id}] Generated CLIP embedding for chunk {chunk['chunk_id']}")
136+
logger.info(f"[Job {job_id}] Upserting embedding for chunk {chunk['chunk_id']} to Pinecone...")
137+
138+
139+
# 1. Handle timestamp_range (List of Numbers -> Two Numbers)
140+
if 'timestamp_range' in chunk['metadata']:
141+
start_time, end_time = chunk['metadata'].pop('timestamp_range')
142+
chunk['metadata']['start_time_s'] = start_time
143+
chunk['metadata']['end_time_s'] = end_time
144+
145+
# 2. Handle file_info (Nested Dict -> Flat Keys)
146+
if 'file_info' in chunk['metadata']:
147+
file_info = chunk['metadata'].pop('file_info')
148+
for key, value in file_info.items():
149+
chunk['metadata'][f'file_{key}'] = value
150+
151+
# 3. Final Check: Remove Nulls (Optional but good practice)
152+
# Pinecone rejects keys with null values.
153+
keys_to_delete = [k for k, v in chunk['metadata'].items() if v is None]
154+
for k in keys_to_delete:
155+
del chunk['metadata'][k]
156+
157+
158+
self.pinecone_connector.upsert_chunk(
159+
chunk_id=chunk['chunk_id'],
160+
chunk_embedding=embedding.numpy(),
161+
namespace="test",
162+
metadata=chunk['metadata']
163+
)
164+
131165
chunk_details.append({
132166
"chunk_id": chunk['chunk_id'],
133167
"metadata": chunk['metadata'],
134-
"memory_mb": chunk['memory_mb']
168+
"memory_mb": chunk['memory_mb'],
135169
})
170+
171+
# TODO: Upload processed data to S3
136172

137173
result = {
138174
"job_id": job_id,
@@ -143,7 +179,7 @@ async def process_video(self, video_bytes: bytes, filename: str, job_id: str):
143179
"total_frames": total_frames,
144180
"total_memory_mb": total_memory,
145181
"avg_complexity": avg_complexity,
146-
"chunk_details": chunk_details
182+
"chunk_details": chunk_details,
147183
}
148184

149185
logger.info(f"[Job {job_id}] Finished processing {filename}")
@@ -219,7 +255,7 @@ async def upload(self, file: UploadFile = None):
219255
"message": "Video uploaded successfully, processing in background"
220256
}
221257

222-
@modal.fastapi_endpoint(method="POST")
258+
@modal.fastapi_endpoint(method="GET")
223259
async def search(self, query: str):
224260
"""Search endpoint - accepts a text query and returns semantic search results."""
225261
logger.info(f"[Search] Query: {query}")

0 commit comments

Comments
 (0)