Skip to content

High Similarity between Embeddings #22

@LeoHsuProgrammingLab

Description

@LeoHsuProgrammingLab

Hi,

Thank you for the amazing work!

I found that BioCLIP2 has a higher similarity between embeddings than original CLIP.
Do you have any insight about this observation? Based on my understanding, BioCLIP2 should discern different species more easily than CLIP, while my code suprisingly leads to a higher similarity.

#!/usr/bin/env python
"""
Test if BioCLIP2 embeddings are too similar (causing guidance to have no effect).
Compare CLIP ViT-L/14 vs BioCLIP-2 embedding similarities.
"""

import torch
import sys
from pathlib import Path
from transformers import CLIPTextModel, CLIPTokenizer
import open_clip

repo_root = Path(__file__).resolve().parents[2]
sys.path.insert(0, str(repo_root))


def test_embedding_similarity():
    print("="*80)
    print("Testing Embedding Similarity: CLIP ViT-L/14 vs BioCLIP-2")
    print("="*80)
    
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    # Load CLIP ViT-L/14 (standard CLIP used in SD)
    print("\n[INFO] Loading CLIP ViT-L/14...")
    clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
    clip_text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(device).eval()
    
    # Load BioCLIP-2
    print("[INFO] Loading BioCLIP-2...")
    bioclip2_model, _, _ = open_clip.create_model_and_transforms('hf-hub:imageomics/bioclip-2')
    bioclip2_tokenizer = open_clip.get_tokenizer('hf-hub:imageomics/bioclip-2')
    bioclip2_model = bioclip2_model.to(device).eval()
    
    # Test prompts - different species from different families
    prompts = [
        "accipitridae, buteo, jamaicensis",        # Red-tailed hawk
        "accipitridae, Haliaeetus, leucocephalus", # Bald eagle (same family, different genus)
        "Icteridae, Icterus, pustulatus",          # Streak-backed oriole (completely different family)
    ]
    
    print(f"\n[INFO] Testing {len(prompts)} prompts:")
    for i, p in enumerate(prompts):
        print(f"  {i}: {p}")
    
    # Helper function to encode with BioCLIP-2 (get sequence output for SD)
    def encode_bioclip2_sequence(model, tokens):
        """Get full sequence embeddings [B, 77, 768] for SD cross-attention"""
        with torch.no_grad():
            x = model.token_embedding(tokens)
            x = x + model.positional_embedding
            x = x.permute(1, 0, 2)  # NLD -> LND
            x = model.transformer(x)
            x = x.permute(1, 0, 2)  # LND -> NLD
            x = model.ln_final(x)
        return x
    
    # Test both models
    results = {}
    
    for model_name, tokenizer, encoder in [
        ('CLIP ViT-L/14', clip_tokenizer, clip_text_encoder),
        ('BioCLIP-2', bioclip2_tokenizer, bioclip2_model)
    ]:
        print(f"\n{'='*80}")
        print(f"Testing: {model_name}")
        print(f"{'='*80}")
        
        embeddings = []
        
        for i, prompt in enumerate(prompts):
            if model_name == 'CLIP ViT-L/14':
                # CLIP from transformers
                tokens = tokenizer(
                    prompt,
                    padding="max_length",
                    max_length=77,
                    truncation=True,
                    return_tensors="pt"
                ).to(device)
                
                with torch.no_grad():
                    # Get sequence output [1, 77, 768]
                    emb = encoder(tokens.input_ids)[0]
            else:
                # BioCLIP-2 from open_clip
                tokens = tokenizer([prompt]).to(device)
                # Get sequence output [1, 77, 768]
                emb = encode_bioclip2_sequence(encoder, tokens)
            
            embeddings.append(emb)
            print(f"Prompt {i}: shape={emb.shape}, mean={emb.mean():.4f}, std={emb.std():.4f}")
        
        # Calculate pairwise cosine similarities
        from torch.nn.functional import cosine_similarity
        
        sim_01 = cosine_similarity(embeddings[0].flatten(), embeddings[1].flatten(), dim=0).item()
        sim_12 = cosine_similarity(embeddings[1].flatten(), embeddings[2].flatten(), dim=0).item()
        sim_02 = cosine_similarity(embeddings[0].flatten(), embeddings[2].flatten(), dim=0).item()
        
        print(f"\nPairwise Cosine Similarities:")
        print(f"  Prompt 0 vs 1 (same family, diff genus): {sim_01:.6f}")
        print(f"  Prompt 1 vs 2 (diff families):           {sim_12:.6f}")
        print(f"  Prompt 0 vs 2 (diff families):           {sim_02:.6f}")
        
        avg_sim = (sim_01 + sim_12 + sim_02) / 3
        print(f"  Average similarity: {avg_sim:.6f}")
        
        # Store results
        results[model_name] = {
            'sim_01': sim_01,
            'sim_12': sim_12,
            'sim_02': sim_02,
            'avg': avg_sim
        }
    
if __name__ == "__main__":
    test_embedding_similarity()

Thank you!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions