Skip to content

Commit a0d5c48

Browse files
committed
Merge branch 'main' into joer/static-temporal-only
2 parents 174963a + 231384a commit a0d5c48

15 files changed

Lines changed: 720 additions & 117 deletions

olmoearth_pretrain/evals/datasets/__init__.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""OlmoEarth Pretrain eval datasets."""
22

33
import logging
4+
from typing import Any
45

56
from olmo_core.config import StrEnum
67
from torch.utils.data import Dataset
@@ -14,6 +15,7 @@
1415
from .mados_dataset import MADOSDataset
1516
from .normalize import NormMethod
1617
from .pastis_dataset import PASTISRDataset
18+
from .pretrain_subset import PretrainSubsetDataset
1719
from .rslearn_dataset import from_registry_entry
1820

1921
logger = logging.getLogger(__name__)
@@ -40,9 +42,19 @@ def get_eval_dataset(
4042
# Default to 2std no clip - this matches what our model sees in pretraining,
4143
# so when using dataset stats (e.g. for MADOS) consistency is important.
4244
norm_method: str = NormMethod.NORM_NO_CLIP_2_STD,
45+
**kwargs: Any,
4346
) -> Dataset:
4447
"""Retrieve an eval dataset from the dataset name."""
45-
if eval_dataset.startswith("m-"):
48+
if eval_dataset == "pretrain_subset":
49+
return PretrainSubsetDataset(
50+
h5py_dir=kwargs["h5py_dir"],
51+
training_modalities=kwargs.get("training_modalities", input_modalities),
52+
max_samples=kwargs.get("max_samples", 512),
53+
patch_size=kwargs.get("pretrain_patch_size", 4),
54+
hw_p=kwargs.get("pretrain_hw_p", 8),
55+
seed=kwargs.get("pretrain_seed", 42),
56+
)
57+
elif eval_dataset.startswith("m-"):
4658
# m- == "modified for geobench"
4759
return GeobenchDataset(
4860
geobench_dir=paths.GEOBENCH_DIR,

olmoearth_pretrain/evals/datasets/configs.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,18 @@ def from_dict(cls, d: dict[str, Any]) -> "EvalDatasetConfig":
4949

5050

5151
DATASET_TO_CONFIG = {
52+
# Dummy config — only used for embedding diagnostics, not actual classification.
53+
"pretrain_subset": EvalDatasetConfig(
54+
task_type=TaskType.CLASSIFICATION,
55+
imputes=[],
56+
num_classes=1,
57+
is_multilabel=False,
58+
supported_modalities=[
59+
Modality.SENTINEL2_L2A.name,
60+
Modality.SENTINEL1.name,
61+
Modality.LANDSAT.name,
62+
],
63+
),
5264
"m-eurosat": EvalDatasetConfig(
5365
task_type=TaskType.CLASSIFICATION,
5466
imputes=[],
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
"""Eval dataset adapter that loads a subset of pretraining data.
2+
3+
Wraps OlmoEarthDataset to expose the eval dataset interface
4+
(returns MaskedOlmoEarthSample, dummy_label) so it can be used
5+
with the downstream evaluator callback for embedding diagnostics.
6+
"""
7+
8+
from __future__ import annotations
9+
10+
import logging
11+
12+
import numpy as np
13+
import torch
14+
from torch.utils.data import Dataset
15+
from upath import UPath
16+
17+
from olmoearth_pretrain.data.dataset import GetItemArgs, OlmoEarthDataset
18+
from olmoearth_pretrain.datatypes import MaskedOlmoEarthSample
19+
20+
logger = logging.getLogger(__name__)
21+
22+
DEFAULT_PATCH_SIZE = 4
23+
DEFAULT_HW_P = 8
24+
DEFAULT_MAX_SAMPLES = 512
25+
26+
27+
class PretrainSubsetDataset(Dataset):
28+
"""Wraps OlmoEarthDataset for use as an eval dataset.
29+
30+
Returns (MaskedOlmoEarthSample, dummy_label) to match the eval
31+
dataset interface. Uses a fixed subset of indices for reproducibility.
32+
"""
33+
34+
def __init__(
35+
self,
36+
h5py_dir: str,
37+
training_modalities: list[str],
38+
max_samples: int = DEFAULT_MAX_SAMPLES,
39+
patch_size: int = DEFAULT_PATCH_SIZE,
40+
hw_p: int = DEFAULT_HW_P,
41+
seed: int = 42,
42+
) -> None:
43+
"""Initialize with a fixed reproducible subset of training indices."""
44+
self.patch_size = patch_size
45+
self.hw_p = hw_p
46+
self.max_samples = max_samples
47+
48+
self._dataset = OlmoEarthDataset(
49+
h5py_dir=UPath(h5py_dir),
50+
training_modalities=training_modalities,
51+
dtype=np.float32,
52+
normalize=True,
53+
)
54+
self._dataset.prepare()
55+
56+
total = len(self._dataset)
57+
n = min(max_samples, total)
58+
rng = np.random.RandomState(seed)
59+
self._indices = rng.choice(total, size=n, replace=False).tolist()
60+
61+
def __len__(self) -> int:
62+
"""Return number of samples in the subset."""
63+
return len(self._indices)
64+
65+
def __getitem__(self, idx: int) -> tuple[MaskedOlmoEarthSample, torch.Tensor]:
66+
"""Return (MaskedOlmoEarthSample, dummy_label) for the given index."""
67+
real_idx = self._indices[idx]
68+
args = GetItemArgs(
69+
idx=real_idx,
70+
patch_size=self.patch_size,
71+
sampled_hw_p=self.hw_p,
72+
)
73+
_, sample = self._dataset[args]
74+
masked = MaskedOlmoEarthSample.from_olmoearthsample(sample)
75+
dummy_label = torch.tensor(0, dtype=torch.long)
76+
return masked, dummy_label
Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
"""Embedding quality diagnostics for detecting representation collapse.
2+
3+
Computes geometry metrics on embedding matrices to diagnose failure modes
4+
in self-supervised pretraining (dimensional collapse, crowding, etc.).
5+
6+
Supports two embedding shapes:
7+
- [N, D]: image-level (classification). One embedding per sample.
8+
- [N, P, D] or [N, H, W, D]: patch-level (segmentation). Multiple patches per sample.
9+
Computes global, inter-sample, and intra-sample diagnostics.
10+
11+
Can be used standalone on any embedding tensor, or integrated
12+
into the eval pipeline via the evaluator callback.
13+
"""
14+
15+
from __future__ import annotations
16+
17+
import logging
18+
19+
import torch
20+
from torch import Tensor
21+
22+
logger = logging.getLogger(__name__)
23+
24+
MAX_PAIRWISE_SAMPLES = 2048
25+
MAX_SVD_SAMPLES = 4096
26+
MAX_INTRA_SAMPLE_IMAGES = 256
27+
28+
29+
def effective_rank(embeddings: Tensor) -> float:
30+
"""Effective rank via Shannon entropy of singular values.
31+
32+
Returns a value between 1 (full collapse) and min(N, D) (maximally spread).
33+
Roy & Bhattacharyya (2007).
34+
"""
35+
n = embeddings.shape[0]
36+
if n > MAX_SVD_SAMPLES:
37+
idx = torch.randperm(n, device=embeddings.device)[:MAX_SVD_SAMPLES]
38+
embeddings = embeddings[idx]
39+
S = torch.linalg.svdvals(embeddings.float())
40+
S = S[S > 0]
41+
if S.numel() == 0:
42+
return 0.0
43+
p = S / S.sum()
44+
entropy = -(p * p.log()).sum()
45+
return entropy.exp().item()
46+
47+
48+
def uniformity(embeddings: Tensor, t: float = 2.0) -> float:
49+
"""Uniformity metric (Wang & Isola 2020). More negative = more uniform."""
50+
z = torch.nn.functional.normalize(embeddings.float(), dim=-1)
51+
n = z.shape[0]
52+
if n > MAX_PAIRWISE_SAMPLES:
53+
idx = torch.randperm(n, device=z.device)[:MAX_PAIRWISE_SAMPLES]
54+
z = z[idx]
55+
n = MAX_PAIRWISE_SAMPLES
56+
sq_dists = torch.cdist(z, z, p=2).pow(2)
57+
mask = torch.triu(torch.ones(n, n, device=z.device, dtype=torch.bool), diagonal=1)
58+
sq_dists_upper = sq_dists[mask]
59+
return torch.log(torch.exp(-t * sq_dists_upper).mean()).item()
60+
61+
62+
def pairwise_cosine_stats(embeddings: Tensor) -> dict[str, float]:
63+
"""Pairwise cosine similarity stats. High mean + low std = crowding."""
64+
z = torch.nn.functional.normalize(embeddings.float(), dim=-1)
65+
n = z.shape[0]
66+
if n > MAX_PAIRWISE_SAMPLES:
67+
idx = torch.randperm(n, device=z.device)[:MAX_PAIRWISE_SAMPLES]
68+
z = z[idx]
69+
n = MAX_PAIRWISE_SAMPLES
70+
sim = z @ z.T
71+
mask = torch.triu(torch.ones(n, n, device=z.device, dtype=torch.bool), diagonal=1)
72+
sims = sim[mask]
73+
return {
74+
"cosine_sim_mean": sims.mean().item(),
75+
"cosine_sim_std": sims.std().item(),
76+
"cosine_sim_min": sims.min().item(),
77+
"cosine_sim_max": sims.max().item(),
78+
}
79+
80+
81+
def embedding_norm_stats(embeddings: Tensor) -> dict[str, float]:
82+
"""L2 norm statistics across samples."""
83+
norms = embeddings.float().norm(dim=-1)
84+
return {
85+
"norm_mean": norms.mean().item(),
86+
"norm_std": norms.std().item(),
87+
"norm_min": norms.min().item(),
88+
"norm_max": norms.max().item(),
89+
}
90+
91+
92+
def compute_embedding_diagnostics(embeddings: Tensor) -> dict[str, float]:
93+
"""Compute all embedding quality diagnostics on [N, D] embeddings."""
94+
if embeddings.ndim != 2:
95+
raise ValueError(f"Expected 2D embeddings [N, D], got shape {embeddings.shape}")
96+
n, d = embeddings.shape
97+
if n < 2:
98+
logger.warning("Need at least 2 samples for embedding diagnostics")
99+
return {}
100+
101+
metrics: dict[str, float] = {}
102+
metrics["effective_rank"] = effective_rank(embeddings)
103+
metrics["embedding_dim"] = float(d)
104+
metrics["num_samples"] = float(n)
105+
metrics.update(embedding_norm_stats(embeddings))
106+
107+
if n >= 4:
108+
metrics["uniformity"] = uniformity(embeddings)
109+
metrics.update(pairwise_cosine_stats(embeddings))
110+
111+
return metrics
112+
113+
114+
def _compute_intra_sample_diagnostics(embeddings: Tensor) -> dict[str, float]:
115+
"""Compute per-image patch diagnostics, averaged across images.
116+
117+
Args:
118+
embeddings: [N, P, D] tensor where P is patches per image.
119+
120+
Measures whether patches within an image are diverse (good for segmentation)
121+
or collapsed (all patches identical = segmentation impossible).
122+
"""
123+
n, p, d = embeddings.shape
124+
if p < 2:
125+
logger.warning("Need at least 2 patches per image for intra-sample diagnostics")
126+
return {}
127+
128+
num_images = min(n, MAX_INTRA_SAMPLE_IMAGES)
129+
if num_images < n:
130+
idx = torch.randperm(n, device=embeddings.device)[:num_images]
131+
embeddings = embeddings[idx]
132+
133+
# Batch cosine sim: normalize then bmm → [num_images, P, P]
134+
z = torch.nn.functional.normalize(embeddings.float(), dim=-1)
135+
sim_matrices = torch.bmm(z, z.transpose(1, 2))
136+
tri_mask = torch.triu(
137+
torch.ones(p, p, device=z.device, dtype=torch.bool), diagonal=1
138+
)
139+
140+
cosine_means = []
141+
cosine_stds = []
142+
for i in range(num_images):
143+
sims = sim_matrices[i][tri_mask]
144+
cosine_means.append(sims.mean().item())
145+
cosine_stds.append(sims.std().item())
146+
147+
# Batch norm std
148+
norms = embeddings.float().norm(dim=-1) # [num_images, P]
149+
norm_stds = norms.std(dim=1) # [num_images]
150+
151+
metrics: dict[str, float] = {
152+
"norm_std": norm_stds.mean().item(),
153+
"num_patches": float(p),
154+
"num_images_sampled": float(num_images),
155+
}
156+
if cosine_means:
157+
metrics["cosine_sim_mean"] = sum(cosine_means) / len(cosine_means)
158+
metrics["cosine_sim_std"] = sum(cosine_stds) / len(cosine_stds)
159+
return metrics
160+
161+
162+
def compute_spatial_embedding_diagnostics(embeddings: Tensor) -> dict[str, float]:
163+
"""Compute diagnostics for spatial (patch-level) embeddings.
164+
165+
Accepts [N, *, D] where * is one or more spatial dims (e.g. [N, H, W, D]
166+
or [N, P, D]). Returns metrics with flat prefixes (global_, inter_, intra_)
167+
to avoid deep nesting in wandb.
168+
"""
169+
if embeddings.ndim < 3:
170+
raise ValueError(
171+
f"Expected 3+ dim embeddings [N, *, D], got shape {embeddings.shape}"
172+
)
173+
174+
n = embeddings.shape[0]
175+
d = embeddings.shape[-1]
176+
patches = embeddings.reshape(n, -1, d)
177+
p = patches.shape[1]
178+
179+
if n < 2:
180+
logger.warning("Need at least 2 samples for spatial embedding diagnostics")
181+
return {}
182+
183+
metrics: dict[str, float] = {}
184+
185+
# Global: flatten all patches, subsample if huge
186+
flat = patches.reshape(-1, d)
187+
if flat.shape[0] > MAX_SVD_SAMPLES:
188+
idx = torch.randperm(flat.shape[0], device=flat.device)[:MAX_SVD_SAMPLES]
189+
flat = flat[idx]
190+
for k, v in compute_embedding_diagnostics(flat).items():
191+
metrics[f"global_{k}"] = v
192+
193+
# Inter-sample: mean pool patches per image -> [N, D]
194+
pooled = patches.float().mean(dim=1)
195+
for k, v in compute_embedding_diagnostics(pooled).items():
196+
metrics[f"inter_{k}"] = v
197+
198+
# Intra-sample: per-image patch diversity
199+
if p >= 2:
200+
for k, v in _compute_intra_sample_diagnostics(patches).items():
201+
metrics[f"intra_{k}"] = v
202+
203+
return metrics

0 commit comments

Comments
 (0)