-
Notifications
You must be signed in to change notification settings - Fork 8
Open
Description
Hi,
Thank you for the amazing work!
I found that BioCLIP2 has a higher similarity between embeddings than original CLIP.
Do you have any insight about this observation? Based on my understanding, BioCLIP2 should discern different species more easily than CLIP, while my code suprisingly leads to a higher similarity.
#!/usr/bin/env python
"""
Test if BioCLIP2 embeddings are too similar (causing guidance to have no effect).
Compare CLIP ViT-L/14 vs BioCLIP-2 embedding similarities.
"""
import torch
import sys
from pathlib import Path
from transformers import CLIPTextModel, CLIPTokenizer
import open_clip
repo_root = Path(__file__).resolve().parents[2]
sys.path.insert(0, str(repo_root))
def test_embedding_similarity():
print("="*80)
print("Testing Embedding Similarity: CLIP ViT-L/14 vs BioCLIP-2")
print("="*80)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Load CLIP ViT-L/14 (standard CLIP used in SD)
print("\n[INFO] Loading CLIP ViT-L/14...")
clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
clip_text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(device).eval()
# Load BioCLIP-2
print("[INFO] Loading BioCLIP-2...")
bioclip2_model, _, _ = open_clip.create_model_and_transforms('hf-hub:imageomics/bioclip-2')
bioclip2_tokenizer = open_clip.get_tokenizer('hf-hub:imageomics/bioclip-2')
bioclip2_model = bioclip2_model.to(device).eval()
# Test prompts - different species from different families
prompts = [
"accipitridae, buteo, jamaicensis", # Red-tailed hawk
"accipitridae, Haliaeetus, leucocephalus", # Bald eagle (same family, different genus)
"Icteridae, Icterus, pustulatus", # Streak-backed oriole (completely different family)
]
print(f"\n[INFO] Testing {len(prompts)} prompts:")
for i, p in enumerate(prompts):
print(f" {i}: {p}")
# Helper function to encode with BioCLIP-2 (get sequence output for SD)
def encode_bioclip2_sequence(model, tokens):
"""Get full sequence embeddings [B, 77, 768] for SD cross-attention"""
with torch.no_grad():
x = model.token_embedding(tokens)
x = x + model.positional_embedding
x = x.permute(1, 0, 2) # NLD -> LND
x = model.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD
x = model.ln_final(x)
return x
# Test both models
results = {}
for model_name, tokenizer, encoder in [
('CLIP ViT-L/14', clip_tokenizer, clip_text_encoder),
('BioCLIP-2', bioclip2_tokenizer, bioclip2_model)
]:
print(f"\n{'='*80}")
print(f"Testing: {model_name}")
print(f"{'='*80}")
embeddings = []
for i, prompt in enumerate(prompts):
if model_name == 'CLIP ViT-L/14':
# CLIP from transformers
tokens = tokenizer(
prompt,
padding="max_length",
max_length=77,
truncation=True,
return_tensors="pt"
).to(device)
with torch.no_grad():
# Get sequence output [1, 77, 768]
emb = encoder(tokens.input_ids)[0]
else:
# BioCLIP-2 from open_clip
tokens = tokenizer([prompt]).to(device)
# Get sequence output [1, 77, 768]
emb = encode_bioclip2_sequence(encoder, tokens)
embeddings.append(emb)
print(f"Prompt {i}: shape={emb.shape}, mean={emb.mean():.4f}, std={emb.std():.4f}")
# Calculate pairwise cosine similarities
from torch.nn.functional import cosine_similarity
sim_01 = cosine_similarity(embeddings[0].flatten(), embeddings[1].flatten(), dim=0).item()
sim_12 = cosine_similarity(embeddings[1].flatten(), embeddings[2].flatten(), dim=0).item()
sim_02 = cosine_similarity(embeddings[0].flatten(), embeddings[2].flatten(), dim=0).item()
print(f"\nPairwise Cosine Similarities:")
print(f" Prompt 0 vs 1 (same family, diff genus): {sim_01:.6f}")
print(f" Prompt 1 vs 2 (diff families): {sim_12:.6f}")
print(f" Prompt 0 vs 2 (diff families): {sim_02:.6f}")
avg_sim = (sim_01 + sim_12 + sim_02) / 3
print(f" Average similarity: {avg_sim:.6f}")
# Store results
results[model_name] = {
'sim_01': sim_01,
'sim_12': sim_12,
'sim_02': sim_02,
'avg': avg_sim
}
if __name__ == "__main__":
test_embedding_similarity()
Thank you!
Metadata
Metadata
Assignees
Labels
No labels