Skip to content

Commit d5b7dc2

Browse files
authored
Merge pull request #5 from cleong110/signclip_metric
Signclip metric
2 parents 1ca9565 + 73ebd75 commit d5b7dc2

7 files changed

+1063
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,318 @@
1+
import argparse
2+
from pathlib import Path
3+
import time
4+
import json
5+
import random
6+
import pandas as pd
7+
import numpy as np
8+
import torch
9+
from tqdm import tqdm
10+
from pose_evaluation.metrics.embedding_distance_metric import EmbeddingDistanceMetric
11+
12+
def load_embedding(file_path: Path) -> np.ndarray:
13+
"""
14+
Load a SignCLIP embedding from a .npy file, ensuring it has the correct shape.
15+
16+
Args:
17+
file_path (Path): Path to the .npy file.
18+
19+
Returns:
20+
np.ndarray: The embedding with shape (768,).
21+
"""
22+
embedding = np.load(file_path)
23+
if embedding.ndim == 2 and embedding.shape[0] == 1:
24+
embedding = embedding[0] # Reduce shape from (1, 768) to (768,)
25+
return embedding
26+
27+
28+
def match_embeddings_to_glosses(emb_dir: Path, split_df: pd.DataFrame) -> pd.DataFrame:
29+
"""
30+
Match .npy embeddings to the corresponding glosses based on the numerical ID.
31+
32+
Args:
33+
emb_dir (Path): Directory containing the .npy files.
34+
split_df (pd.DataFrame): DataFrame containing the split file with the "Video file" column.
35+
36+
Returns:
37+
pd.DataFrame: Updated DataFrame with an additional column for embeddings.
38+
"""
39+
40+
# Step 1: Create a mapping of numerical IDs to .npy files
41+
map_start = time.perf_counter()
42+
embeddings_map = {npy_file.stem.split("-")[0]: npy_file for npy_file in emb_dir.glob("*.npy")}
43+
map_end = time.perf_counter()
44+
print(f"Creating embeddings map took {map_end - map_start:.4f} seconds")
45+
46+
# Step 2: Vectorized matching of embeddings
47+
match_start = time.perf_counter()
48+
49+
def get_embedding(video_file):
50+
numerical_id = video_file.split("-")[0]
51+
npy_file = embeddings_map.get(numerical_id)
52+
if npy_file is not None:
53+
return load_embedding(npy_file)
54+
return None
55+
56+
split_df["embedding"] = split_df["Video file"].apply(get_embedding)
57+
match_end = time.perf_counter()
58+
print(f"Matching embeddings to glosses took {match_end - match_start:.4f} seconds")
59+
60+
return split_df
61+
62+
63+
def calculate_mean_distances(
64+
distance_matrix: torch.Tensor, indices_a: torch.Tensor, indices_b: torch.Tensor, exclude_self: bool = False
65+
) -> float:
66+
"""
67+
Calculate the mean of distances between two sets of indices in a 2D distance matrix.
68+
69+
Args:
70+
distance_matrix (torch.Tensor): A 2D tensor representing pairwise distances.
71+
indices_a (torch.Tensor): A tensor of row indices.
72+
indices_b (torch.Tensor): A tensor of column indices.
73+
exclude_self (bool): Whether to exclude distances where indices_a == indices_b.
74+
75+
Returns:
76+
float: The mean distance between all pairs of (indices_a, indices_b).
77+
"""
78+
# Create all pair combinations
79+
row_indices, col_indices = torch.meshgrid(indices_a, indices_b, indexing="ij")
80+
81+
if exclude_self:
82+
# Apply a mask to remove self-distances
83+
mask = row_indices != col_indices
84+
row_indices = row_indices[mask]
85+
col_indices = col_indices[mask]
86+
87+
# Gather distances
88+
selected_distances = distance_matrix[row_indices.flatten(), col_indices.flatten()]
89+
90+
# Return the mean
91+
return selected_distances.mean().item()
92+
93+
94+
def generate_synthetic_data(num_items, num_classes, num_items_per_class=4):
95+
96+
torch.manual_seed(42)
97+
random.seed(42)
98+
# distance_matrix = torch.rand((num_items, num_items)) * 100
99+
distance_matrix = torch.full((num_items, num_items), 10.0)
100+
distance_matrix.fill_diagonal_(0)
101+
indices = list(range(num_items))
102+
random.shuffle(indices)
103+
104+
classes = {
105+
f"CLASS_{i}": torch.tensor([indices.pop() for _ in range(num_items_per_class)]) for i in range(num_classes)
106+
}
107+
# Assign intra-class distances
108+
mean_values_by_class = {}
109+
for i, class_name in enumerate(classes.keys()):
110+
mean_value = i + 1
111+
mean_values_by_class[class_name] = mean_value
112+
for class_name, indices in classes.items():
113+
mean_value = mean_values_by_class[class_name]
114+
for i in indices:
115+
for j in indices:
116+
if i != j: # Exclude self-distances
117+
distance_matrix[i, j] = mean_value
118+
return classes, distance_matrix
119+
120+
121+
def calculate_class_means(gloss_indices, scores):
122+
class_means_by_gloss = {}
123+
all_indices = torch.arange(scores.size(0), dtype=int)
124+
125+
for gloss, indices in tqdm(gloss_indices.items(), desc="Finding mean values by gloss"):
126+
indices = torch.LongTensor(indices)
127+
class_means_by_gloss[gloss] = {}
128+
within_class_mean = calculate_mean_distances(scores, indices, indices, exclude_self=True)
129+
130+
class_means_by_gloss[gloss]["in_class"] = within_class_mean
131+
132+
complement_indices = all_indices[~torch.isin(all_indices, indices)]
133+
without_class_mean = calculate_mean_distances(scores, indices, complement_indices)
134+
class_means_by_gloss[gloss]["out_of_class"] = without_class_mean
135+
136+
return class_means_by_gloss
137+
138+
139+
# def calculate_class_means(gloss_indices, scores):
140+
# all_within_class_distances = np.array([]) # Initialize as empty NumPy array
141+
# all_between_class_distances = np.array([]) # Initialize as empty NumPy array
142+
# within_class_means_by_gloss = {}
143+
# for gloss, indices in tqdm(gloss_indices.items(), desc="Finding mean values by gloss"):
144+
# # Within-class distances
145+
# within_class_distances = scores[np.ix_(indices, indices)]
146+
# within_class_mean = torch.mean(within_class_distances)
147+
# within_class_means_by_gloss[gloss] = within_class_mean
148+
# within_class_distances = within_class_distances[np.triu_indices(len(indices), k=1)]
149+
# all_within_class_distances = np.concatenate([all_within_class_distances, within_class_distances.ravel()])
150+
#
151+
# # Between-class distances
152+
# other_indices = np.setdiff1d(np.arange(len(scores)), indices)
153+
# between_class_distances = scores[np.ix_(indices, other_indices)]
154+
# all_between_class_distances = np.concatenate([all_between_class_distances, between_class_distances.ravel()])
155+
#
156+
# for gloss, mean in within_class_means_by_gloss.items():
157+
# print(f"Within {gloss}: {within_class_means_by_gloss[gloss]}")
158+
#
159+
# print(f"Mean within classes: {np.mean(all_within_class_distances)}")
160+
# print(f"Mean between classes: {np.mean(all_between_class_distances)}")
161+
# return within_class_means_by_gloss
162+
163+
164+
def evaluate_signclip(emb_dir: Path, split_file: Path, out_path: Path, kind: str = "cosine"):
165+
"""
166+
Evaluate SignCLIP embeddings using score_all.
167+
168+
Args:
169+
emb_dir (Path): Directory containing .npy embeddings.
170+
split_file (Path): Path to the split CSV file.
171+
kind (str): Metric type ("cosine" or "l2"). Default is "cosine".
172+
"""
173+
overall_start = time.perf_counter() # Start overall benchmarking
174+
175+
# Step 1: Load split file
176+
split_load_start = time.perf_counter()
177+
split_df = pd.read_csv(split_file)
178+
split_load_end = time.perf_counter()
179+
print(f"Loading split file took {split_load_end - split_load_start:.4f} seconds")
180+
# print(f"{split_df.info()}")
181+
182+
# Step 2: Match embeddings to glosses
183+
match_start = time.perf_counter()
184+
split_df = match_embeddings_to_glosses(emb_dir, split_df)
185+
match_end = time.perf_counter()
186+
print(f"Matching embeddings to glosses took {match_end - match_start:.4f} seconds")
187+
# print(split_df.info())
188+
189+
# Step 3: Filter out rows without embeddings
190+
filter_start = time.perf_counter()
191+
items_with_embeddings_df = split_df.dropna(subset=["embedding"]).reset_index(drop=True)
192+
embeddings = items_with_embeddings_df["embedding"].tolist()
193+
filter_end = time.perf_counter()
194+
print(f"Filtering embeddings took {filter_end - filter_start:.4f} seconds")
195+
print(items_with_embeddings_df.info())
196+
197+
# Step 4: Initialize the distance metric
198+
metric_start = time.perf_counter()
199+
# metric = EmbeddingDistanceMetric(kind=kind, device="cpu")
200+
metric = EmbeddingDistanceMetric(kind=kind)
201+
metric_end = time.perf_counter()
202+
print(f"Initializing metric took {metric_end - metric_start:.4f} seconds")
203+
204+
# Step 5: Compute all pairwise scores
205+
score_start = time.perf_counter()
206+
print(f"Computing {kind} distances for {len(embeddings)} embeddings...")
207+
scores = metric.score_all(embeddings, embeddings)
208+
score_end = time.perf_counter()
209+
print(f"Score_all took {score_end - score_start:.3f} seconds")
210+
211+
# Step 7: Extract file list from DataFrame
212+
files_start = time.perf_counter()
213+
files = items_with_embeddings_df["Video file"].tolist()
214+
files_end = time.perf_counter()
215+
print(f"Extracting file list took {files_end - files_start:.4f} seconds")
216+
217+
analysis_start = time.perf_counter()
218+
index_to_check = 0
219+
number_to_check = 10
220+
print(f"The first {number_to_check} scores for {files[index_to_check]} to...")
221+
for ref, score in list(zip(files, scores[index_to_check]))[:number_to_check]:
222+
print("\t*------------->", f"{ref}".ljust(35), "\t", score.item())
223+
224+
unique_glosses = items_with_embeddings_df["Gloss"].unique()
225+
print(f"We have a vocabulary of {len(unique_glosses)} glosses")
226+
gloss_indices = {}
227+
for gloss in items_with_embeddings_df["Gloss"].unique():
228+
gloss_indices[gloss] = items_with_embeddings_df.index[items_with_embeddings_df["Gloss"] == gloss].tolist()
229+
230+
for gloss, indices in list(gloss_indices.items())[:10]:
231+
print(f"Here are the {len(indices)} indices for {gloss}:{indices}")
232+
233+
find_class_distances_start = time.perf_counter()
234+
235+
# synthetic_classes, synthetic_distances = generate_synthetic_data(30000, 2700, 8)
236+
# class_means = calculate_class_means(synthetic_classes, synthetic_distances)
237+
class_means = calculate_class_means(gloss_indices, scores)
238+
239+
find_class_distances_end = time.perf_counter()
240+
241+
print(f"Finding within and without took {find_class_distances_end-find_class_distances_start}")
242+
243+
analysis_end = time.perf_counter()
244+
analysis_duration = analysis_end - analysis_start
245+
246+
in_class_means = [mean_dict["in_class"] for mean_dict in class_means.values()]
247+
out_class_means = [mean_dict["out_of_class"] for mean_dict in class_means.values()]
248+
249+
for gloss, means in list(class_means.items())[:10]:
250+
print(gloss, means)
251+
252+
print(f"Mean of in-class means: {np.mean(in_class_means)}")
253+
print(f"Mean of out-of-class means: {np.mean(out_class_means)}")
254+
255+
print(f"Analysis took {analysis_duration} seconds")
256+
257+
# Step 8: Save the scores and files to a compressed file
258+
259+
save_start = time.perf_counter()
260+
class_means_json = out_path.with_name(f"{out_path.stem}_class_means").with_suffix(".json")
261+
with open(class_means_json, "w") as f:
262+
print(f"Writing class means to {f}")
263+
json.dump(class_means, f)
264+
np.savez(out_path, scores=scores, files=files)
265+
save_end = time.perf_counter()
266+
print(f"Saving scores and files took {save_end - save_start:.4f} seconds")
267+
print(f"Scores of shape {scores.shape} with files list of length {len(files)} saved to {out_path}")
268+
269+
# Step 9: Read back the saved scores
270+
read_start = time.perf_counter()
271+
read_back_in = np.load(f"{out_path}")
272+
read_end = time.perf_counter()
273+
print(f"Reading back the file took {read_end - read_start:.4f} seconds")
274+
275+
# Step 10: Verify if the read data matches the original scores
276+
verify_start = time.perf_counter()
277+
if np.allclose(read_back_in["scores"], scores):
278+
print("Yay! All the same!")
279+
else:
280+
print("Mismatch found!")
281+
verify_end = time.perf_counter()
282+
print(f"Verification step took {verify_end - verify_start:.4f} seconds")
283+
284+
# Overall time
285+
overall_end = time.perf_counter()
286+
print(f"Total script runtime: {overall_end - overall_start:.4f} seconds")
287+
288+
289+
def main():
290+
parser = argparse.ArgumentParser(description="Evaluate SignCLIP embeddings with score_all.")
291+
parser.add_argument("emb_dir", type=Path, help="Path to the directory containing SignCLIP .npy files")
292+
parser.add_argument("--split_file", type=Path, required=True, help="Path to the split CSV file (e.g., test.csv)")
293+
parser.add_argument(
294+
"--kind",
295+
type=str,
296+
choices=["cosine", "l2"],
297+
default="cosine",
298+
help="Type of distance metric to use (default: cosine)",
299+
)
300+
301+
parser.add_argument("--out_path", type=Path, help="Where to save output distance npz matrix+file list")
302+
303+
args = parser.parse_args()
304+
305+
output_file = args.out_path
306+
if output_file is None:
307+
output_file = Path(f"signclip_scores_{args.split_file.name}").with_suffix(".npz")
308+
309+
if output_file.suffix != ".npz":
310+
output_file = Path(f"{output_file}.npz")
311+
312+
print(f"Scores will be saved to {output_file}")
313+
314+
evaluate_signclip(emb_dir=args.emb_dir, split_file=args.split_file, out_path=output_file, kind=args.kind)
315+
316+
317+
if __name__ == "__main__":
318+
main()

pose_evaluation/metrics/.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
temp/
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from typing import TypeVar
2+
import torch
3+
from pose_evaluation.metrics.base import BaseMetric
4+
5+
6+
# Define a type alias for embeddings (e.g., torch.Tensor)
7+
Embedding = TypeVar("Embedding", bound=torch.Tensor)
8+
9+
EmbeddingMetric = BaseMetric[Embedding]

pose_evaluation/metrics/conftest.py

+50
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import shutil
2+
from pathlib import Path
3+
from typing import Callable, Union
4+
import torch
5+
import numpy as np
6+
import pytest
7+
8+
9+
@pytest.fixture(scope="session", autouse=True)
10+
def clean_test_artifacts():
11+
"""Fixture to clean up test artifacts before each test session."""
12+
test_artifacts_dir = Path(__file__).parent / "tests" # Using Path
13+
if test_artifacts_dir.exists():
14+
shutil.rmtree(test_artifacts_dir) # shutil.rmtree still works with Path
15+
test_artifacts_dir.mkdir(parents=True, exist_ok=True) # Using Path.mkdir
16+
yield # This allows the test session to run
17+
# (Optional) You can add cleanup logic here to run after the session if needed
18+
19+
20+
@pytest.fixture(name="distance_matrix_shape_checker")
21+
def fixture_distance_matrix_shape_checker() -> Callable[[torch.Tensor, torch.Tensor], None]:
22+
def _check_shape(hyp_count: int, ref_count: int, distance_matrix: torch.Tensor):
23+
24+
expected_shape = torch.Size([hyp_count, ref_count])
25+
assert (
26+
distance_matrix.shape == expected_shape
27+
), f"For M={hyp_count} hypotheses, N={ref_count} references, Distance Matrix should be MxN={expected_shape}. Instead, received {distance_matrix.shape}"
28+
29+
return _check_shape
30+
31+
32+
@pytest.fixture(name="distance_range_checker")
33+
def fixture_distance_range_checker() -> Callable[[Union[torch.Tensor, np.ndarray], float, float], None]:
34+
def _check_range(
35+
distances: Union[torch.Tensor, np.ndarray],
36+
min_val: float = 0,
37+
max_val: float = 2,
38+
) -> None:
39+
max_distance = distances.max().item()
40+
min_distance = distances.min().item()
41+
42+
# Use np.isclose for comparisons with tolerance
43+
assert (
44+
np.isclose(min_distance, min_val, atol=1e-6) or min_val <= min_distance <= max_val
45+
), f"Minimum distance ({min_distance}) is outside the expected range [{min_val}, {max_val}]"
46+
assert (
47+
np.isclose(max_distance, max_val, atol=1e-6) or min_val <= max_distance <= max_val
48+
), f"Maximum distance ({max_distance}) is outside the expected range [{min_val}, {max_val}]"
49+
50+
return _check_range

0 commit comments

Comments
 (0)