Skip to content

Commit 8f44a43

Browse files
authored
Merge pull request ClipABit#72 from ClipABit/staging
Staging
2 parents 107aeba + 5be9510 commit 8f44a43

File tree

2 files changed

+263
-1
lines changed

2 files changed

+263
-1
lines changed

backend/embeddings/video_embedder.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,8 @@ def _generate_clip_embedding(self, frames, num_frames: int = 8) -> torch.Tensor:
5555
inputs = processor(images=sampled_frames, return_tensors="pt", size=224).to(self._device)
5656

5757
with torch.no_grad():
58-
frame_features = model.get_image_features(**inputs)
58+
output = model.get_image_features(**inputs)
59+
frame_features = output.pooler_output if hasattr(output, 'pooler_output') else output
5960
frame_features = frame_features / frame_features.norm(p=2, dim=-1, keepdim=True)
6061

6162
video_embedding = frame_features.mean(dim=0)
Lines changed: 261 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,261 @@
1+
"""
2+
Unit tests for VideoEmbedder (CLIP-based).
3+
4+
Tests the video embedding functionality with mocked CLIP model and processor.
5+
"""
6+
7+
import sys
8+
from unittest.mock import MagicMock, patch
9+
import numpy as np
10+
import pytest
11+
import torch
12+
13+
14+
class FakeBaseModelOutputWithPooling:
15+
"""Fake HuggingFace output object to simulate newer transformers behavior."""
16+
17+
def __init__(self, pooler_output: torch.Tensor):
18+
self.pooler_output = pooler_output
19+
self.last_hidden_state = torch.randn(1, 50, 768)
20+
21+
22+
class FakeCLIPModel:
23+
"""Fake CLIP model for testing."""
24+
25+
def __init__(self, return_output_object: bool = False):
26+
self.return_output_object = return_output_object
27+
self.get_image_features_calls = []
28+
29+
def to(self, device):
30+
return self
31+
32+
def get_image_features(self, **inputs):
33+
self.get_image_features_calls.append(inputs)
34+
batch_size = inputs["pixel_values"].shape[0]
35+
embeddings = torch.randn(batch_size, 512)
36+
37+
if self.return_output_object:
38+
return FakeBaseModelOutputWithPooling(embeddings)
39+
return embeddings
40+
41+
42+
class FakeProcessorOutput:
43+
"""Fake processor output that supports .to() method."""
44+
45+
def __init__(self, pixel_values: torch.Tensor):
46+
self.pixel_values = pixel_values
47+
self._data = {"pixel_values": pixel_values}
48+
49+
def to(self, device):
50+
return self
51+
52+
def keys(self):
53+
return self._data.keys()
54+
55+
def __getitem__(self, key):
56+
return self._data[key]
57+
58+
def __iter__(self):
59+
return iter(self._data)
60+
61+
62+
class FakeCLIPProcessor:
63+
"""Fake CLIP processor for testing."""
64+
65+
def __init__(self):
66+
self.call_args = []
67+
68+
def __call__(self, images, return_tensors, size):
69+
self.call_args.append((images, return_tensors, size))
70+
batch_size = len(images)
71+
pixel_values = torch.randn(batch_size, 3, 224, 224)
72+
return FakeProcessorOutput(pixel_values)
73+
74+
75+
@pytest.fixture
76+
def mock_transformers_tensor_output():
77+
"""Mock transformers module with tensor output (older behavior)."""
78+
mock_transformers = MagicMock()
79+
80+
fake_model = FakeCLIPModel(return_output_object=False)
81+
fake_processor = FakeCLIPProcessor()
82+
83+
mock_transformers.CLIPModel.from_pretrained.return_value = fake_model
84+
mock_transformers.CLIPProcessor.from_pretrained.return_value = fake_processor
85+
86+
with patch.dict(sys.modules, {'transformers': mock_transformers}):
87+
yield mock_transformers, fake_model, fake_processor
88+
89+
90+
@pytest.fixture
91+
def mock_transformers_output_object():
92+
"""Mock transformers module with BaseModelOutputWithPooling output (newer behavior)."""
93+
mock_transformers = MagicMock()
94+
95+
fake_model = FakeCLIPModel(return_output_object=True)
96+
fake_processor = FakeCLIPProcessor()
97+
98+
mock_transformers.CLIPModel.from_pretrained.return_value = fake_model
99+
mock_transformers.CLIPProcessor.from_pretrained.return_value = fake_processor
100+
101+
with patch.dict(sys.modules, {'transformers': mock_transformers}):
102+
yield mock_transformers, fake_model, fake_processor
103+
104+
105+
@pytest.fixture
106+
def embedder_with_tensor_output(mock_transformers_tensor_output):
107+
"""Create VideoEmbedder with mocked dependencies returning tensor."""
108+
mock_transformers, fake_model, fake_processor = mock_transformers_tensor_output
109+
110+
from embeddings.video_embedder import VideoEmbedder
111+
embedder = VideoEmbedder()
112+
113+
return embedder, fake_model, fake_processor
114+
115+
116+
@pytest.fixture
117+
def embedder_with_output_object(mock_transformers_output_object):
118+
"""Create VideoEmbedder with mocked dependencies returning output object."""
119+
mock_transformers, fake_model, fake_processor = mock_transformers_output_object
120+
121+
from embeddings.video_embedder import VideoEmbedder
122+
embedder = VideoEmbedder()
123+
124+
return embedder, fake_model, fake_processor
125+
126+
127+
class TestVideoEmbedderInitialization:
128+
"""Test VideoEmbedder initialization."""
129+
130+
def test_initializes_with_correct_device(self, mock_transformers_tensor_output):
131+
"""Verify device is set based on CUDA availability."""
132+
from embeddings.video_embedder import VideoEmbedder
133+
embedder = VideoEmbedder()
134+
135+
assert embedder._device in ["cuda", "cpu"]
136+
137+
def test_loads_clip_model_on_init(self, mock_transformers_tensor_output):
138+
"""Verify CLIP model is loaded during initialization."""
139+
mock_transformers, _, _ = mock_transformers_tensor_output
140+
141+
from embeddings.video_embedder import VideoEmbedder
142+
# Instance unused - we only care about the side effect of loading the model
143+
_ = VideoEmbedder()
144+
145+
mock_transformers.CLIPModel.from_pretrained.assert_called_once()
146+
mock_transformers.CLIPProcessor.from_pretrained.assert_called_once()
147+
148+
149+
class TestGenerateClipEmbedding:
150+
"""Test _generate_clip_embedding functionality."""
151+
152+
def test_returns_tensor(self, embedder_with_tensor_output, sample_frames):
153+
"""Verify embedding is returned as a tensor."""
154+
embedder, _, _ = embedder_with_tensor_output
155+
156+
result = embedder._generate_clip_embedding(sample_frames)
157+
158+
assert isinstance(result, torch.Tensor)
159+
160+
def test_returns_1d_embedding(self, embedder_with_tensor_output, sample_frames):
161+
"""Verify embedding is 1D (single video embedding)."""
162+
embedder, _, _ = embedder_with_tensor_output
163+
164+
result = embedder._generate_clip_embedding(sample_frames)
165+
166+
assert result.ndim == 1
167+
assert result.shape == (512,)
168+
169+
def test_embedding_is_normalized(self, embedder_with_tensor_output, sample_frames):
170+
"""Verify embedding is L2 normalized."""
171+
embedder, _, _ = embedder_with_tensor_output
172+
173+
result = embedder._generate_clip_embedding(sample_frames)
174+
norm = torch.linalg.norm(result)
175+
176+
assert torch.isclose(norm, torch.tensor(1.0), atol=1e-5)
177+
178+
def test_handles_output_object_from_newer_transformers(self, embedder_with_output_object, sample_frames):
179+
"""
180+
Verify embedding works when model returns BaseModelOutputWithPooling.
181+
182+
This test catches the regression where newer transformers versions
183+
return an output object instead of a raw tensor.
184+
"""
185+
embedder, _, _ = embedder_with_output_object
186+
187+
result = embedder._generate_clip_embedding(sample_frames)
188+
189+
assert isinstance(result, torch.Tensor)
190+
assert result.ndim == 1
191+
assert result.shape == (512,)
192+
193+
def test_samples_frames_evenly(self, embedder_with_tensor_output):
194+
"""Verify frames are sampled evenly across the video."""
195+
embedder, _, fake_processor = embedder_with_tensor_output
196+
197+
frames = np.random.randint(0, 255, (100, 480, 640, 3), dtype=np.uint8)
198+
embedder._generate_clip_embedding(frames, num_frames=8)
199+
200+
assert len(fake_processor.call_args) == 1
201+
images, _, _ = fake_processor.call_args[0]
202+
assert len(images) == 8
203+
204+
def test_handles_fewer_frames_than_requested(self, embedder_with_tensor_output):
205+
"""Verify it handles videos with fewer frames than num_frames."""
206+
embedder, _, fake_processor = embedder_with_tensor_output
207+
208+
frames = np.random.randint(0, 255, (3, 480, 640, 3), dtype=np.uint8)
209+
embedder._generate_clip_embedding(frames, num_frames=8)
210+
211+
images, _, _ = fake_processor.call_args[0]
212+
assert len(images) == 3
213+
214+
def test_calls_model_with_correct_inputs(self, embedder_with_tensor_output, sample_frames):
215+
"""Verify model is called with processed inputs."""
216+
embedder, fake_model, _ = embedder_with_tensor_output
217+
218+
embedder._generate_clip_embedding(sample_frames)
219+
220+
assert len(fake_model.get_image_features_calls) == 1
221+
call_inputs = fake_model.get_image_features_calls[0]
222+
assert "pixel_values" in call_inputs
223+
224+
def test_returns_cpu_tensor(self, embedder_with_tensor_output, sample_frames):
225+
"""Verify result is moved to CPU."""
226+
embedder, _, _ = embedder_with_tensor_output
227+
228+
result = embedder._generate_clip_embedding(sample_frames)
229+
230+
assert result.device.type == "cpu"
231+
232+
233+
class TestEdgeCases:
234+
"""Test edge cases and error handling."""
235+
236+
def test_single_frame_video(self, embedder_with_tensor_output):
237+
"""Verify single frame video can be embedded."""
238+
embedder, _, _ = embedder_with_tensor_output
239+
240+
frames = np.random.randint(0, 255, (1, 480, 640, 3), dtype=np.uint8)
241+
result = embedder._generate_clip_embedding(frames)
242+
243+
assert result.shape == (512,)
244+
assert torch.isclose(torch.linalg.norm(result), torch.tensor(1.0), atol=1e-5)
245+
246+
def test_large_number_of_frames(self, embedder_with_tensor_output):
247+
"""Verify large videos are handled with frame sampling."""
248+
embedder, _, fake_processor = embedder_with_tensor_output
249+
250+
frames = np.random.randint(0, 255, (1000, 480, 640, 3), dtype=np.uint8)
251+
result = embedder._generate_clip_embedding(frames, num_frames=8)
252+
253+
images, _, _ = fake_processor.call_args[0]
254+
assert len(images) == 8
255+
assert result.shape == (512,)
256+
257+
258+
@pytest.fixture
259+
def sample_frames():
260+
"""Array of 10 test frames. Shape: (10, 480, 640, 3)"""
261+
return np.random.randint(0, 255, (10, 480, 640, 3), dtype=np.uint8)

0 commit comments

Comments
 (0)