|
| 1 | +"""Embedding quality diagnostics for detecting representation collapse. |
| 2 | +
|
| 3 | +Computes geometry metrics on embedding matrices to diagnose failure modes |
| 4 | +in self-supervised pretraining (dimensional collapse, crowding, etc.). |
| 5 | +
|
| 6 | +Supports two embedding shapes: |
| 7 | +- [N, D]: image-level (classification). One embedding per sample. |
| 8 | +- [N, P, D] or [N, H, W, D]: patch-level (segmentation). Multiple patches per sample. |
| 9 | + Computes global, inter-sample, and intra-sample diagnostics. |
| 10 | +
|
| 11 | +Can be used standalone on any embedding tensor, or integrated |
| 12 | +into the eval pipeline via the evaluator callback. |
| 13 | +""" |
| 14 | + |
| 15 | +from __future__ import annotations |
| 16 | + |
| 17 | +import logging |
| 18 | + |
| 19 | +import torch |
| 20 | +from torch import Tensor |
| 21 | + |
| 22 | +logger = logging.getLogger(__name__) |
| 23 | + |
| 24 | +MAX_PAIRWISE_SAMPLES = 2048 |
| 25 | +MAX_SVD_SAMPLES = 4096 |
| 26 | +MAX_INTRA_SAMPLE_IMAGES = 256 |
| 27 | + |
| 28 | + |
| 29 | +def effective_rank(embeddings: Tensor) -> float: |
| 30 | + """Effective rank via Shannon entropy of singular values. |
| 31 | +
|
| 32 | + Returns a value between 1 (full collapse) and min(N, D) (maximally spread). |
| 33 | + Roy & Bhattacharyya (2007). |
| 34 | + """ |
| 35 | + n = embeddings.shape[0] |
| 36 | + if n > MAX_SVD_SAMPLES: |
| 37 | + idx = torch.randperm(n, device=embeddings.device)[:MAX_SVD_SAMPLES] |
| 38 | + embeddings = embeddings[idx] |
| 39 | + S = torch.linalg.svdvals(embeddings.float()) |
| 40 | + S = S[S > 0] |
| 41 | + if S.numel() == 0: |
| 42 | + return 0.0 |
| 43 | + p = S / S.sum() |
| 44 | + entropy = -(p * p.log()).sum() |
| 45 | + return entropy.exp().item() |
| 46 | + |
| 47 | + |
| 48 | +def uniformity(embeddings: Tensor, t: float = 2.0) -> float: |
| 49 | + """Uniformity metric (Wang & Isola 2020). More negative = more uniform.""" |
| 50 | + z = torch.nn.functional.normalize(embeddings.float(), dim=-1) |
| 51 | + n = z.shape[0] |
| 52 | + if n > MAX_PAIRWISE_SAMPLES: |
| 53 | + idx = torch.randperm(n, device=z.device)[:MAX_PAIRWISE_SAMPLES] |
| 54 | + z = z[idx] |
| 55 | + n = MAX_PAIRWISE_SAMPLES |
| 56 | + sq_dists = torch.cdist(z, z, p=2).pow(2) |
| 57 | + mask = torch.triu(torch.ones(n, n, device=z.device, dtype=torch.bool), diagonal=1) |
| 58 | + sq_dists_upper = sq_dists[mask] |
| 59 | + return torch.log(torch.exp(-t * sq_dists_upper).mean()).item() |
| 60 | + |
| 61 | + |
| 62 | +def pairwise_cosine_stats(embeddings: Tensor) -> dict[str, float]: |
| 63 | + """Pairwise cosine similarity stats. High mean + low std = crowding.""" |
| 64 | + z = torch.nn.functional.normalize(embeddings.float(), dim=-1) |
| 65 | + n = z.shape[0] |
| 66 | + if n > MAX_PAIRWISE_SAMPLES: |
| 67 | + idx = torch.randperm(n, device=z.device)[:MAX_PAIRWISE_SAMPLES] |
| 68 | + z = z[idx] |
| 69 | + n = MAX_PAIRWISE_SAMPLES |
| 70 | + sim = z @ z.T |
| 71 | + mask = torch.triu(torch.ones(n, n, device=z.device, dtype=torch.bool), diagonal=1) |
| 72 | + sims = sim[mask] |
| 73 | + return { |
| 74 | + "cosine_sim_mean": sims.mean().item(), |
| 75 | + "cosine_sim_std": sims.std().item(), |
| 76 | + "cosine_sim_min": sims.min().item(), |
| 77 | + "cosine_sim_max": sims.max().item(), |
| 78 | + } |
| 79 | + |
| 80 | + |
| 81 | +def embedding_norm_stats(embeddings: Tensor) -> dict[str, float]: |
| 82 | + """L2 norm statistics across samples.""" |
| 83 | + norms = embeddings.float().norm(dim=-1) |
| 84 | + return { |
| 85 | + "norm_mean": norms.mean().item(), |
| 86 | + "norm_std": norms.std().item(), |
| 87 | + "norm_min": norms.min().item(), |
| 88 | + "norm_max": norms.max().item(), |
| 89 | + } |
| 90 | + |
| 91 | + |
| 92 | +def compute_embedding_diagnostics(embeddings: Tensor) -> dict[str, float]: |
| 93 | + """Compute all embedding quality diagnostics on [N, D] embeddings.""" |
| 94 | + if embeddings.ndim != 2: |
| 95 | + raise ValueError(f"Expected 2D embeddings [N, D], got shape {embeddings.shape}") |
| 96 | + n, d = embeddings.shape |
| 97 | + if n < 2: |
| 98 | + logger.warning("Need at least 2 samples for embedding diagnostics") |
| 99 | + return {} |
| 100 | + |
| 101 | + metrics: dict[str, float] = {} |
| 102 | + metrics["effective_rank"] = effective_rank(embeddings) |
| 103 | + metrics["embedding_dim"] = float(d) |
| 104 | + metrics["num_samples"] = float(n) |
| 105 | + metrics.update(embedding_norm_stats(embeddings)) |
| 106 | + |
| 107 | + if n >= 4: |
| 108 | + metrics["uniformity"] = uniformity(embeddings) |
| 109 | + metrics.update(pairwise_cosine_stats(embeddings)) |
| 110 | + |
| 111 | + return metrics |
| 112 | + |
| 113 | + |
| 114 | +def _compute_intra_sample_diagnostics(embeddings: Tensor) -> dict[str, float]: |
| 115 | + """Compute per-image patch diagnostics, averaged across images. |
| 116 | +
|
| 117 | + Args: |
| 118 | + embeddings: [N, P, D] tensor where P is patches per image. |
| 119 | +
|
| 120 | + Measures whether patches within an image are diverse (good for segmentation) |
| 121 | + or collapsed (all patches identical = segmentation impossible). |
| 122 | + """ |
| 123 | + n, p, d = embeddings.shape |
| 124 | + if p < 2: |
| 125 | + logger.warning("Need at least 2 patches per image for intra-sample diagnostics") |
| 126 | + return {} |
| 127 | + |
| 128 | + num_images = min(n, MAX_INTRA_SAMPLE_IMAGES) |
| 129 | + if num_images < n: |
| 130 | + idx = torch.randperm(n, device=embeddings.device)[:num_images] |
| 131 | + embeddings = embeddings[idx] |
| 132 | + |
| 133 | + # Batch cosine sim: normalize then bmm → [num_images, P, P] |
| 134 | + z = torch.nn.functional.normalize(embeddings.float(), dim=-1) |
| 135 | + sim_matrices = torch.bmm(z, z.transpose(1, 2)) |
| 136 | + tri_mask = torch.triu( |
| 137 | + torch.ones(p, p, device=z.device, dtype=torch.bool), diagonal=1 |
| 138 | + ) |
| 139 | + |
| 140 | + cosine_means = [] |
| 141 | + cosine_stds = [] |
| 142 | + for i in range(num_images): |
| 143 | + sims = sim_matrices[i][tri_mask] |
| 144 | + cosine_means.append(sims.mean().item()) |
| 145 | + cosine_stds.append(sims.std().item()) |
| 146 | + |
| 147 | + # Batch norm std |
| 148 | + norms = embeddings.float().norm(dim=-1) # [num_images, P] |
| 149 | + norm_stds = norms.std(dim=1) # [num_images] |
| 150 | + |
| 151 | + metrics: dict[str, float] = { |
| 152 | + "norm_std": norm_stds.mean().item(), |
| 153 | + "num_patches": float(p), |
| 154 | + "num_images_sampled": float(num_images), |
| 155 | + } |
| 156 | + if cosine_means: |
| 157 | + metrics["cosine_sim_mean"] = sum(cosine_means) / len(cosine_means) |
| 158 | + metrics["cosine_sim_std"] = sum(cosine_stds) / len(cosine_stds) |
| 159 | + return metrics |
| 160 | + |
| 161 | + |
| 162 | +def compute_spatial_embedding_diagnostics(embeddings: Tensor) -> dict[str, float]: |
| 163 | + """Compute diagnostics for spatial (patch-level) embeddings. |
| 164 | +
|
| 165 | + Accepts [N, *, D] where * is one or more spatial dims (e.g. [N, H, W, D] |
| 166 | + or [N, P, D]). Returns metrics with flat prefixes (global_, inter_, intra_) |
| 167 | + to avoid deep nesting in wandb. |
| 168 | + """ |
| 169 | + if embeddings.ndim < 3: |
| 170 | + raise ValueError( |
| 171 | + f"Expected 3+ dim embeddings [N, *, D], got shape {embeddings.shape}" |
| 172 | + ) |
| 173 | + |
| 174 | + n = embeddings.shape[0] |
| 175 | + d = embeddings.shape[-1] |
| 176 | + patches = embeddings.reshape(n, -1, d) |
| 177 | + p = patches.shape[1] |
| 178 | + |
| 179 | + if n < 2: |
| 180 | + logger.warning("Need at least 2 samples for spatial embedding diagnostics") |
| 181 | + return {} |
| 182 | + |
| 183 | + metrics: dict[str, float] = {} |
| 184 | + |
| 185 | + # Global: flatten all patches, subsample if huge |
| 186 | + flat = patches.reshape(-1, d) |
| 187 | + if flat.shape[0] > MAX_SVD_SAMPLES: |
| 188 | + idx = torch.randperm(flat.shape[0], device=flat.device)[:MAX_SVD_SAMPLES] |
| 189 | + flat = flat[idx] |
| 190 | + for k, v in compute_embedding_diagnostics(flat).items(): |
| 191 | + metrics[f"global_{k}"] = v |
| 192 | + |
| 193 | + # Inter-sample: mean pool patches per image -> [N, D] |
| 194 | + pooled = patches.float().mean(dim=1) |
| 195 | + for k, v in compute_embedding_diagnostics(pooled).items(): |
| 196 | + metrics[f"inter_{k}"] = v |
| 197 | + |
| 198 | + # Intra-sample: per-image patch diversity |
| 199 | + if p >= 2: |
| 200 | + for k, v in _compute_intra_sample_diagnostics(patches).items(): |
| 201 | + metrics[f"intra_{k}"] = v |
| 202 | + |
| 203 | + return metrics |
0 commit comments