Skip to content

Commit b7d305a

Browse files
authored
Merge pull request #542 from allenai/henryh/vectorized-masked-neg-loss
Add vectorized masked-negatives patch discrimination loss
2 parents dc49639 + 2b7e341 commit b7d305a

2 files changed

Lines changed: 366 additions & 0 deletions

File tree

olmoearth_pretrain/train/loss.py

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -510,6 +510,162 @@ def compute(
510510
return self.weight * total_loss
511511

512512

513+
@LOSS_REGISTRY.register("modality_patch_discrimination_masked_negatives_vec")
514+
class ModalityPatchDiscriminationMaskedNegativesVec(Loss):
515+
"""Vectorized patch discrimination with same-target negative masking.
516+
517+
Equivalent to ModalityPatchDiscriminationMaskedNegatives but fully batched:
518+
no per-sample Python loops, no .item() syncs, no repeated torch.eye allocations.
519+
"""
520+
521+
name = "ModalityPatchDiscMaskedVec"
522+
523+
def __init__(
524+
self,
525+
tau: float = 0.1,
526+
pred2unit: bool = False,
527+
weight: float = 1.0,
528+
modality_weights: dict[str, float] | None = None,
529+
same_target_threshold: float = 0.999,
530+
mask_negatives_for_modalities: list[str] | None = None,
531+
) -> None:
532+
"""Initialize with same params as ModalityPatchDiscriminationMaskedNegatives."""
533+
self.tau = tau
534+
self.pred2unit = pred2unit
535+
self.weight = weight
536+
self.modality_weights = modality_weights
537+
self.same_target_threshold = same_target_threshold
538+
self.mask_negatives_for_modalities = mask_negatives_for_modalities
539+
540+
def _compute_modality_loss_parallel(
541+
self,
542+
all_preds: Tensor,
543+
all_masks: Tensor,
544+
all_targets: Tensor,
545+
modality: str,
546+
) -> Tensor:
547+
batch_size, num_tokens, dim = all_preds.shape
548+
decoder_mask = all_masks == MaskValue.DECODER.value
549+
count = decoder_mask.sum(dim=-1) # (batch,)
550+
551+
# Sort so decoder tokens come first per sample
552+
_, sort_indices = decoder_mask.long().sort(dim=1, descending=True, stable=True)
553+
sort_expanded = sort_indices.unsqueeze(-1).expand(-1, -1, dim)
554+
sorted_preds = all_preds.gather(1, sort_expanded).float()
555+
sorted_targets = all_targets.gather(1, sort_expanded).float()
556+
557+
# valid_mask[b, i] = True iff position i is a decoder token for sample b
558+
range_tensor = torch.arange(num_tokens, device=count.device)
559+
valid_mask = range_tensor.unsqueeze(0) < count.unsqueeze(1) # (batch, T)
560+
561+
if self.pred2unit:
562+
mask_float = valid_mask.unsqueeze(-1).float()
563+
total_decoder = mask_float.sum().clamp(min=1)
564+
pred_mu = (sorted_preds * mask_float).sum(
565+
dim=(0, 1), keepdim=True
566+
) / total_decoder
567+
centered = sorted_preds - pred_mu
568+
pred_var = (centered**2 * mask_float).sum(dim=(0, 1), keepdim=True) / (
569+
total_decoder - 1
570+
).clamp(min=1)
571+
sorted_preds = (sorted_preds - pred_mu) / (pred_var.sqrt() + 1e-4)
572+
573+
sorted_preds = F.normalize(sorted_preds, p=2, dim=-1)
574+
sorted_targets = F.normalize(sorted_targets, p=2, dim=-1)
575+
576+
# Score matrix: (batch, T, T) — each sample independent
577+
scores = torch.bmm(sorted_preds, sorted_targets.transpose(1, 2)) / self.tau
578+
579+
should_mask = (
580+
self.mask_negatives_for_modalities is None
581+
or modality in self.mask_negatives_for_modalities
582+
)
583+
584+
# Track which samples to skip (default: none)
585+
sample_skip = torch.zeros(batch_size, dtype=torch.bool, device=scores.device)
586+
587+
if should_mask:
588+
# Target self-similarity per sample: (batch, T, T)
589+
target_sim = torch.bmm(sorted_targets, sorted_targets.transpose(1, 2))
590+
same_target = target_sim > self.same_target_threshold
591+
592+
# Only consider valid token pairs
593+
valid_2d = valid_mask.unsqueeze(1) & valid_mask.unsqueeze(
594+
2
595+
) # (batch, T, T)
596+
597+
# Diagonal (self) is never an invalid negative
598+
diag = torch.eye(num_tokens, dtype=torch.bool, device=scores.device)
599+
invalid_negatives = same_target & ~diag.unsqueeze(0) & valid_2d
600+
601+
# The original only applies masking when c_val > 1, so restrict
602+
# invalid_negatives and skip-detection to samples with count > 1.
603+
multi_token = (count > 1).unsqueeze(1).unsqueeze(2) # (batch, 1, 1)
604+
invalid_negatives = invalid_negatives & multi_token
605+
606+
# Skip samples where any valid token has zero valid negatives
607+
valid_neg_count = (~same_target & valid_2d).sum(dim=-1) # (batch, T)
608+
has_zero_neg = (
609+
(valid_neg_count == 0) & valid_mask & (count > 1).unsqueeze(1)
610+
)
611+
sample_skip = has_zero_neg.any(dim=1)
612+
613+
scores = scores.masked_fill(invalid_negatives, float("-inf"))
614+
615+
# Mask out non-decoder columns
616+
col_mask = valid_mask.unsqueeze(1).expand_as(scores)
617+
scores = scores.masked_fill(~col_mask, -torch.finfo(scores.dtype).max)
618+
619+
# Mask rows for zero-count samples to prevent NaN
620+
row_mask = valid_mask.unsqueeze(2).expand_as(scores)
621+
scores = scores.masked_fill(~row_mask, 0.0)
622+
623+
# Labels: diagonal (token i matches target i)
624+
labels = range_tensor.unsqueeze(0).expand(batch_size, -1)
625+
626+
loss_per_pos = F.cross_entropy(
627+
scores.reshape(-1, num_tokens),
628+
labels.reshape(-1),
629+
reduction="none",
630+
) * (self.tau * 2)
631+
loss_per_pos = loss_per_pos.reshape(batch_size, num_tokens)
632+
633+
# Zero out invalid positions and skipped samples
634+
sample_contributes = (count > 0) & ~sample_skip
635+
effective_valid = valid_mask.float() * sample_contributes.unsqueeze(1).float()
636+
effective_count = count.float() * sample_contributes.float()
637+
num_contributing = sample_contributes.sum()
638+
639+
loss_per_sample = (loss_per_pos * effective_valid).sum(
640+
dim=1
641+
) / effective_count.clamp(min=1)
642+
loss = loss_per_sample.sum() / num_contributing.float().clamp(min=1)
643+
644+
return loss
645+
646+
def compute(
647+
self, predictions: TokensAndMasks, targets: TokensAndMasks, **kwargs: Any
648+
) -> Tensor:
649+
"""Compute patch discrimination loss with masked same-target negatives (vectorized)."""
650+
modality_preds, modality_masks = (
651+
predictions.flatten_tokens_and_masks_per_modality()
652+
)
653+
modality_targets = targets.flatten_tokens_and_masks_per_modality()[0]
654+
655+
total_loss = 0
656+
for all_preds, all_masks, all_targets, modality in zip(
657+
modality_preds, modality_masks, modality_targets, targets.modalities
658+
):
659+
loss = self._compute_modality_loss_parallel(
660+
all_preds, all_masks, all_targets, modality
661+
)
662+
if self.modality_weights is not None:
663+
loss = loss * self.modality_weights.get(modality, 1.0)
664+
total_loss += loss
665+
666+
return self.weight * total_loss
667+
668+
513669
@LOSS_REGISTRY.register("modality_patch_discrimination_vec")
514670
class ModalityPatchDiscriminationLossVec(Loss):
515671
"""Loss function for per-modality patch discrimination task.

tests/unit/train/test_loss.py

Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,17 @@
1515
ModalityPatchDiscriminationLossNew,
1616
ModalityPatchDiscriminationLossVec,
1717
ModalityPatchDiscriminationMaskedNegatives,
18+
ModalityPatchDiscriminationMaskedNegativesVec,
1819
PatchDiscriminationLoss,
1920
PatchDiscriminationLossNew,
2021
)
2122
from olmoearth_pretrain.train.masking import MaskValue
2223

2324
logger = logging.getLogger(__name__)
2425

26+
RTOL = 1e-4
27+
ATOL = 1e-6
28+
2529

2630
def test_patch_disc_loss() -> None:
2731
"""Just test that it runs as expected."""
@@ -1142,3 +1146,209 @@ def test_modality_patch_discrimination_masked_negatives() -> None:
11421146

11431147
# Masking removes false negatives from denominator, so loss should be lower
11441148
assert loss_value < loss_no_mask_value
1149+
1150+
1151+
# ---------------------------------------------------------------------------
1152+
# ModalityPatchDiscriminationMaskedNegativesVec vs sequential
1153+
# ---------------------------------------------------------------------------
1154+
1155+
1156+
def _make_masked_neg_pair(
1157+
tau: float = 0.1, threshold: float = 0.999, mask_modalities: list[str] | None = None
1158+
) -> tuple:
1159+
"""Return (sequential, vec) loss instances with matching params."""
1160+
seq = ModalityPatchDiscriminationMaskedNegatives(
1161+
tau=tau,
1162+
same_target_threshold=threshold,
1163+
mask_negatives_for_modalities=mask_modalities,
1164+
)
1165+
vec = ModalityPatchDiscriminationMaskedNegativesVec(
1166+
tau=tau,
1167+
same_target_threshold=threshold,
1168+
mask_negatives_for_modalities=mask_modalities,
1169+
)
1170+
return seq, vec
1171+
1172+
1173+
def test_masked_neg_vec_matches_sequential_uniform() -> None:
1174+
"""Vec matches sequential when all tokens are decoder tokens."""
1175+
b, t_h, t_w, t, d = 4, 3, 3, 2, 16
1176+
torch.manual_seed(42)
1177+
1178+
preds = TokensAndMasks(
1179+
sentinel2_l2a=torch.randn((b, t_h, t_w, t, d)),
1180+
sentinel2_l2a_mask=torch.ones((b, t_h, t_w, t)) * MaskValue.DECODER.value,
1181+
)
1182+
targets = TokensAndMasks(
1183+
sentinel2_l2a=torch.randn((b, t_h, t_w, t, d)),
1184+
sentinel2_l2a_mask=torch.ones((b, t_h, t_w, t)) * MaskValue.DECODER.value,
1185+
)
1186+
1187+
seq, vec = _make_masked_neg_pair()
1188+
loss_seq = seq.compute(preds, targets)
1189+
loss_vec = vec.compute(preds, targets)
1190+
assert torch.isclose(loss_seq, loss_vec, rtol=RTOL, atol=ATOL), (
1191+
f"seq={loss_seq.item()}, vec={loss_vec.item()}"
1192+
)
1193+
1194+
1195+
def test_masked_neg_vec_matches_sequential_uneven() -> None:
1196+
"""Vec matches sequential with uneven decoder token counts."""
1197+
b, t_h, t_w, t, d = 6, 4, 4, 2, 8
1198+
1199+
for seed in range(20):
1200+
torch.manual_seed(seed)
1201+
s2_mask = torch.randint(0, 4, (b, t_h, t_w, t))
1202+
preds = TokensAndMasks(
1203+
sentinel2_l2a=torch.randn((b, t_h, t_w, t, d)),
1204+
sentinel2_l2a_mask=s2_mask,
1205+
)
1206+
targets = TokensAndMasks(
1207+
sentinel2_l2a=torch.randn((b, t_h, t_w, t, d)),
1208+
sentinel2_l2a_mask=s2_mask,
1209+
)
1210+
seq, vec = _make_masked_neg_pair()
1211+
loss_seq = seq.compute(preds, targets)
1212+
loss_vec = vec.compute(preds, targets)
1213+
assert torch.isclose(loss_seq, loss_vec, rtol=RTOL, atol=ATOL), (
1214+
f"seed={seed}: seq={loss_seq.item()}, vec={loss_vec.item()}"
1215+
)
1216+
1217+
1218+
def test_masked_neg_vec_with_identical_targets() -> None:
1219+
"""Test masking behavior when some targets are identical (triggers skip)."""
1220+
b, t_h, t_w, t, d = 4, 2, 2, 2, 8
1221+
torch.manual_seed(7)
1222+
1223+
target_s2 = torch.randn((b, t_h, t_w, t, d))
1224+
# Make ALL tokens in sample 0 identical → should be skipped
1225+
target_s2[0] = target_s2[0, 0, 0, 0].expand_as(target_s2[0])
1226+
1227+
preds = TokensAndMasks(
1228+
sentinel2_l2a=torch.randn((b, t_h, t_w, t, d)),
1229+
sentinel2_l2a_mask=torch.ones((b, t_h, t_w, t)) * MaskValue.DECODER.value,
1230+
)
1231+
targets = TokensAndMasks(
1232+
sentinel2_l2a=target_s2,
1233+
sentinel2_l2a_mask=torch.ones((b, t_h, t_w, t)) * MaskValue.DECODER.value,
1234+
)
1235+
1236+
seq, vec = _make_masked_neg_pair()
1237+
loss_seq = seq.compute(preds, targets)
1238+
loss_vec = vec.compute(preds, targets)
1239+
assert torch.isclose(loss_seq, loss_vec, rtol=RTOL, atol=ATOL), (
1240+
f"identical targets: seq={loss_seq.item()}, vec={loss_vec.item()}"
1241+
)
1242+
1243+
1244+
def test_masked_neg_vec_gradients() -> None:
1245+
"""Gradients match between sequential and vec."""
1246+
b, t_h, t_w, t, d = 4, 3, 3, 2, 16
1247+
1248+
for seed in [0, 7, 42, 999]:
1249+
torch.manual_seed(seed)
1250+
s2_mask = torch.randint(0, 4, (b, t_h, t_w, t))
1251+
s2_data = torch.randn((b, t_h, t_w, t, d))
1252+
s2_tgt = torch.randn((b, t_h, t_w, t, d))
1253+
1254+
# Sequential
1255+
s2_seq = s2_data.clone().requires_grad_(True)
1256+
preds_s = TokensAndMasks(sentinel2_l2a=s2_seq, sentinel2_l2a_mask=s2_mask)
1257+
targets_s = TokensAndMasks(
1258+
sentinel2_l2a=s2_tgt.clone(), sentinel2_l2a_mask=s2_mask
1259+
)
1260+
seq, vec = _make_masked_neg_pair()
1261+
loss_s = seq.compute(preds_s, targets_s)
1262+
loss_s.backward()
1263+
1264+
# Vec
1265+
s2_vec = s2_data.clone().requires_grad_(True)
1266+
preds_v = TokensAndMasks(sentinel2_l2a=s2_vec, sentinel2_l2a_mask=s2_mask)
1267+
targets_v = TokensAndMasks(
1268+
sentinel2_l2a=s2_tgt.clone(), sentinel2_l2a_mask=s2_mask
1269+
)
1270+
loss_v = vec.compute(preds_v, targets_v)
1271+
loss_v.backward()
1272+
1273+
assert torch.isclose(loss_s, loss_v, rtol=RTOL, atol=ATOL), (
1274+
f"seed={seed}: loss seq={loss_s.item()}, vec={loss_v.item()}"
1275+
)
1276+
assert torch.allclose(s2_seq.grad, s2_vec.grad, rtol=RTOL, atol=ATOL), (
1277+
f"seed={seed}: grad max diff="
1278+
f"{(s2_seq.grad - s2_vec.grad).abs().max().item()}"
1279+
)
1280+
1281+
1282+
def test_masked_neg_vec_missing_samples() -> None:
1283+
"""Vec matches sequential when some samples have no decoder tokens."""
1284+
b, t_h, t_w, t, d = 5, 4, 4, 2, 8
1285+
torch.manual_seed(456)
1286+
1287+
s2_mask = torch.randint(0, 3, (b, t_h, t_w, t))
1288+
s2_mask[0] = MaskValue.ONLINE_ENCODER.value
1289+
s2_mask[2] = MaskValue.MISSING.value
1290+
1291+
preds = TokensAndMasks(
1292+
sentinel2_l2a=torch.randn((b, t_h, t_w, t, d)),
1293+
sentinel2_l2a_mask=s2_mask,
1294+
)
1295+
targets = TokensAndMasks(
1296+
sentinel2_l2a=torch.randn((b, t_h, t_w, t, d)),
1297+
sentinel2_l2a_mask=s2_mask,
1298+
)
1299+
1300+
seq, vec = _make_masked_neg_pair()
1301+
loss_seq = seq.compute(preds, targets)
1302+
loss_vec = vec.compute(preds, targets)
1303+
assert torch.isclose(loss_seq, loss_vec, rtol=RTOL, atol=ATOL), (
1304+
f"seq={loss_seq.item()}, vec={loss_vec.item()}"
1305+
)
1306+
1307+
1308+
def test_masked_neg_vec_selective_modality_masking() -> None:
1309+
"""Masking only applied to specified modalities."""
1310+
b, t_h, t_w, t, d = 4, 3, 3, 2, 16
1311+
torch.manual_seed(99)
1312+
1313+
preds = TokensAndMasks(
1314+
sentinel2_l2a=torch.randn((b, t_h, t_w, t, d)),
1315+
sentinel2_l2a_mask=torch.ones((b, t_h, t_w, t)) * MaskValue.DECODER.value,
1316+
worldcover=torch.randn((b, t_h, t_w, 1, d)),
1317+
worldcover_mask=torch.ones((b, t_h, t_w, 1)) * MaskValue.DECODER.value,
1318+
)
1319+
targets = TokensAndMasks(
1320+
sentinel2_l2a=torch.randn((b, t_h, t_w, t, d)),
1321+
sentinel2_l2a_mask=torch.ones((b, t_h, t_w, t)) * MaskValue.DECODER.value,
1322+
worldcover=torch.randn((b, t_h, t_w, 1, d)),
1323+
worldcover_mask=torch.ones((b, t_h, t_w, 1)) * MaskValue.DECODER.value,
1324+
)
1325+
1326+
seq, vec = _make_masked_neg_pair(mask_modalities=["worldcover"])
1327+
loss_seq = seq.compute(preds, targets)
1328+
loss_vec = vec.compute(preds, targets)
1329+
assert torch.isclose(loss_seq, loss_vec, rtol=RTOL, atol=ATOL), (
1330+
f"selective: seq={loss_seq.item()}, vec={loss_vec.item()}"
1331+
)
1332+
1333+
1334+
def test_masked_neg_vec_large_batch() -> None:
1335+
"""Equivalence at training-like batch size."""
1336+
b, t_h, t_w, t, d = 32, 4, 4, 2, 64
1337+
torch.manual_seed(2024)
1338+
s2_mask = torch.randint(0, 4, (b, t_h, t_w, t))
1339+
1340+
preds = TokensAndMasks(
1341+
sentinel2_l2a=torch.randn((b, t_h, t_w, t, d)),
1342+
sentinel2_l2a_mask=s2_mask,
1343+
)
1344+
targets = TokensAndMasks(
1345+
sentinel2_l2a=torch.randn((b, t_h, t_w, t, d)),
1346+
sentinel2_l2a_mask=s2_mask,
1347+
)
1348+
1349+
seq, vec = _make_masked_neg_pair()
1350+
loss_seq = seq.compute(preds, targets)
1351+
loss_vec = vec.compute(preds, targets)
1352+
assert torch.isclose(loss_seq, loss_vec, rtol=RTOL, atol=ATOL), (
1353+
f"large batch: seq={loss_seq.item()}, vec={loss_vec.item()}"
1354+
)

0 commit comments

Comments
 (0)