|
15 | 15 | ModalityPatchDiscriminationLossNew, |
16 | 16 | ModalityPatchDiscriminationLossVec, |
17 | 17 | ModalityPatchDiscriminationMaskedNegatives, |
| 18 | + ModalityPatchDiscriminationMaskedNegativesVec, |
18 | 19 | PatchDiscriminationLoss, |
19 | 20 | PatchDiscriminationLossNew, |
20 | 21 | ) |
21 | 22 | from olmoearth_pretrain.train.masking import MaskValue |
22 | 23 |
|
23 | 24 | logger = logging.getLogger(__name__) |
24 | 25 |
|
| 26 | +RTOL = 1e-4 |
| 27 | +ATOL = 1e-6 |
| 28 | + |
25 | 29 |
|
26 | 30 | def test_patch_disc_loss() -> None: |
27 | 31 | """Just test that it runs as expected.""" |
@@ -1142,3 +1146,209 @@ def test_modality_patch_discrimination_masked_negatives() -> None: |
1142 | 1146 |
|
1143 | 1147 | # Masking removes false negatives from denominator, so loss should be lower |
1144 | 1148 | assert loss_value < loss_no_mask_value |
| 1149 | + |
| 1150 | + |
| 1151 | +# --------------------------------------------------------------------------- |
| 1152 | +# ModalityPatchDiscriminationMaskedNegativesVec vs sequential |
| 1153 | +# --------------------------------------------------------------------------- |
| 1154 | + |
| 1155 | + |
| 1156 | +def _make_masked_neg_pair( |
| 1157 | + tau: float = 0.1, threshold: float = 0.999, mask_modalities: list[str] | None = None |
| 1158 | +) -> tuple: |
| 1159 | + """Return (sequential, vec) loss instances with matching params.""" |
| 1160 | + seq = ModalityPatchDiscriminationMaskedNegatives( |
| 1161 | + tau=tau, |
| 1162 | + same_target_threshold=threshold, |
| 1163 | + mask_negatives_for_modalities=mask_modalities, |
| 1164 | + ) |
| 1165 | + vec = ModalityPatchDiscriminationMaskedNegativesVec( |
| 1166 | + tau=tau, |
| 1167 | + same_target_threshold=threshold, |
| 1168 | + mask_negatives_for_modalities=mask_modalities, |
| 1169 | + ) |
| 1170 | + return seq, vec |
| 1171 | + |
| 1172 | + |
| 1173 | +def test_masked_neg_vec_matches_sequential_uniform() -> None: |
| 1174 | + """Vec matches sequential when all tokens are decoder tokens.""" |
| 1175 | + b, t_h, t_w, t, d = 4, 3, 3, 2, 16 |
| 1176 | + torch.manual_seed(42) |
| 1177 | + |
| 1178 | + preds = TokensAndMasks( |
| 1179 | + sentinel2_l2a=torch.randn((b, t_h, t_w, t, d)), |
| 1180 | + sentinel2_l2a_mask=torch.ones((b, t_h, t_w, t)) * MaskValue.DECODER.value, |
| 1181 | + ) |
| 1182 | + targets = TokensAndMasks( |
| 1183 | + sentinel2_l2a=torch.randn((b, t_h, t_w, t, d)), |
| 1184 | + sentinel2_l2a_mask=torch.ones((b, t_h, t_w, t)) * MaskValue.DECODER.value, |
| 1185 | + ) |
| 1186 | + |
| 1187 | + seq, vec = _make_masked_neg_pair() |
| 1188 | + loss_seq = seq.compute(preds, targets) |
| 1189 | + loss_vec = vec.compute(preds, targets) |
| 1190 | + assert torch.isclose(loss_seq, loss_vec, rtol=RTOL, atol=ATOL), ( |
| 1191 | + f"seq={loss_seq.item()}, vec={loss_vec.item()}" |
| 1192 | + ) |
| 1193 | + |
| 1194 | + |
| 1195 | +def test_masked_neg_vec_matches_sequential_uneven() -> None: |
| 1196 | + """Vec matches sequential with uneven decoder token counts.""" |
| 1197 | + b, t_h, t_w, t, d = 6, 4, 4, 2, 8 |
| 1198 | + |
| 1199 | + for seed in range(20): |
| 1200 | + torch.manual_seed(seed) |
| 1201 | + s2_mask = torch.randint(0, 4, (b, t_h, t_w, t)) |
| 1202 | + preds = TokensAndMasks( |
| 1203 | + sentinel2_l2a=torch.randn((b, t_h, t_w, t, d)), |
| 1204 | + sentinel2_l2a_mask=s2_mask, |
| 1205 | + ) |
| 1206 | + targets = TokensAndMasks( |
| 1207 | + sentinel2_l2a=torch.randn((b, t_h, t_w, t, d)), |
| 1208 | + sentinel2_l2a_mask=s2_mask, |
| 1209 | + ) |
| 1210 | + seq, vec = _make_masked_neg_pair() |
| 1211 | + loss_seq = seq.compute(preds, targets) |
| 1212 | + loss_vec = vec.compute(preds, targets) |
| 1213 | + assert torch.isclose(loss_seq, loss_vec, rtol=RTOL, atol=ATOL), ( |
| 1214 | + f"seed={seed}: seq={loss_seq.item()}, vec={loss_vec.item()}" |
| 1215 | + ) |
| 1216 | + |
| 1217 | + |
| 1218 | +def test_masked_neg_vec_with_identical_targets() -> None: |
| 1219 | + """Test masking behavior when some targets are identical (triggers skip).""" |
| 1220 | + b, t_h, t_w, t, d = 4, 2, 2, 2, 8 |
| 1221 | + torch.manual_seed(7) |
| 1222 | + |
| 1223 | + target_s2 = torch.randn((b, t_h, t_w, t, d)) |
| 1224 | + # Make ALL tokens in sample 0 identical → should be skipped |
| 1225 | + target_s2[0] = target_s2[0, 0, 0, 0].expand_as(target_s2[0]) |
| 1226 | + |
| 1227 | + preds = TokensAndMasks( |
| 1228 | + sentinel2_l2a=torch.randn((b, t_h, t_w, t, d)), |
| 1229 | + sentinel2_l2a_mask=torch.ones((b, t_h, t_w, t)) * MaskValue.DECODER.value, |
| 1230 | + ) |
| 1231 | + targets = TokensAndMasks( |
| 1232 | + sentinel2_l2a=target_s2, |
| 1233 | + sentinel2_l2a_mask=torch.ones((b, t_h, t_w, t)) * MaskValue.DECODER.value, |
| 1234 | + ) |
| 1235 | + |
| 1236 | + seq, vec = _make_masked_neg_pair() |
| 1237 | + loss_seq = seq.compute(preds, targets) |
| 1238 | + loss_vec = vec.compute(preds, targets) |
| 1239 | + assert torch.isclose(loss_seq, loss_vec, rtol=RTOL, atol=ATOL), ( |
| 1240 | + f"identical targets: seq={loss_seq.item()}, vec={loss_vec.item()}" |
| 1241 | + ) |
| 1242 | + |
| 1243 | + |
| 1244 | +def test_masked_neg_vec_gradients() -> None: |
| 1245 | + """Gradients match between sequential and vec.""" |
| 1246 | + b, t_h, t_w, t, d = 4, 3, 3, 2, 16 |
| 1247 | + |
| 1248 | + for seed in [0, 7, 42, 999]: |
| 1249 | + torch.manual_seed(seed) |
| 1250 | + s2_mask = torch.randint(0, 4, (b, t_h, t_w, t)) |
| 1251 | + s2_data = torch.randn((b, t_h, t_w, t, d)) |
| 1252 | + s2_tgt = torch.randn((b, t_h, t_w, t, d)) |
| 1253 | + |
| 1254 | + # Sequential |
| 1255 | + s2_seq = s2_data.clone().requires_grad_(True) |
| 1256 | + preds_s = TokensAndMasks(sentinel2_l2a=s2_seq, sentinel2_l2a_mask=s2_mask) |
| 1257 | + targets_s = TokensAndMasks( |
| 1258 | + sentinel2_l2a=s2_tgt.clone(), sentinel2_l2a_mask=s2_mask |
| 1259 | + ) |
| 1260 | + seq, vec = _make_masked_neg_pair() |
| 1261 | + loss_s = seq.compute(preds_s, targets_s) |
| 1262 | + loss_s.backward() |
| 1263 | + |
| 1264 | + # Vec |
| 1265 | + s2_vec = s2_data.clone().requires_grad_(True) |
| 1266 | + preds_v = TokensAndMasks(sentinel2_l2a=s2_vec, sentinel2_l2a_mask=s2_mask) |
| 1267 | + targets_v = TokensAndMasks( |
| 1268 | + sentinel2_l2a=s2_tgt.clone(), sentinel2_l2a_mask=s2_mask |
| 1269 | + ) |
| 1270 | + loss_v = vec.compute(preds_v, targets_v) |
| 1271 | + loss_v.backward() |
| 1272 | + |
| 1273 | + assert torch.isclose(loss_s, loss_v, rtol=RTOL, atol=ATOL), ( |
| 1274 | + f"seed={seed}: loss seq={loss_s.item()}, vec={loss_v.item()}" |
| 1275 | + ) |
| 1276 | + assert torch.allclose(s2_seq.grad, s2_vec.grad, rtol=RTOL, atol=ATOL), ( |
| 1277 | + f"seed={seed}: grad max diff=" |
| 1278 | + f"{(s2_seq.grad - s2_vec.grad).abs().max().item()}" |
| 1279 | + ) |
| 1280 | + |
| 1281 | + |
| 1282 | +def test_masked_neg_vec_missing_samples() -> None: |
| 1283 | + """Vec matches sequential when some samples have no decoder tokens.""" |
| 1284 | + b, t_h, t_w, t, d = 5, 4, 4, 2, 8 |
| 1285 | + torch.manual_seed(456) |
| 1286 | + |
| 1287 | + s2_mask = torch.randint(0, 3, (b, t_h, t_w, t)) |
| 1288 | + s2_mask[0] = MaskValue.ONLINE_ENCODER.value |
| 1289 | + s2_mask[2] = MaskValue.MISSING.value |
| 1290 | + |
| 1291 | + preds = TokensAndMasks( |
| 1292 | + sentinel2_l2a=torch.randn((b, t_h, t_w, t, d)), |
| 1293 | + sentinel2_l2a_mask=s2_mask, |
| 1294 | + ) |
| 1295 | + targets = TokensAndMasks( |
| 1296 | + sentinel2_l2a=torch.randn((b, t_h, t_w, t, d)), |
| 1297 | + sentinel2_l2a_mask=s2_mask, |
| 1298 | + ) |
| 1299 | + |
| 1300 | + seq, vec = _make_masked_neg_pair() |
| 1301 | + loss_seq = seq.compute(preds, targets) |
| 1302 | + loss_vec = vec.compute(preds, targets) |
| 1303 | + assert torch.isclose(loss_seq, loss_vec, rtol=RTOL, atol=ATOL), ( |
| 1304 | + f"seq={loss_seq.item()}, vec={loss_vec.item()}" |
| 1305 | + ) |
| 1306 | + |
| 1307 | + |
| 1308 | +def test_masked_neg_vec_selective_modality_masking() -> None: |
| 1309 | + """Masking only applied to specified modalities.""" |
| 1310 | + b, t_h, t_w, t, d = 4, 3, 3, 2, 16 |
| 1311 | + torch.manual_seed(99) |
| 1312 | + |
| 1313 | + preds = TokensAndMasks( |
| 1314 | + sentinel2_l2a=torch.randn((b, t_h, t_w, t, d)), |
| 1315 | + sentinel2_l2a_mask=torch.ones((b, t_h, t_w, t)) * MaskValue.DECODER.value, |
| 1316 | + worldcover=torch.randn((b, t_h, t_w, 1, d)), |
| 1317 | + worldcover_mask=torch.ones((b, t_h, t_w, 1)) * MaskValue.DECODER.value, |
| 1318 | + ) |
| 1319 | + targets = TokensAndMasks( |
| 1320 | + sentinel2_l2a=torch.randn((b, t_h, t_w, t, d)), |
| 1321 | + sentinel2_l2a_mask=torch.ones((b, t_h, t_w, t)) * MaskValue.DECODER.value, |
| 1322 | + worldcover=torch.randn((b, t_h, t_w, 1, d)), |
| 1323 | + worldcover_mask=torch.ones((b, t_h, t_w, 1)) * MaskValue.DECODER.value, |
| 1324 | + ) |
| 1325 | + |
| 1326 | + seq, vec = _make_masked_neg_pair(mask_modalities=["worldcover"]) |
| 1327 | + loss_seq = seq.compute(preds, targets) |
| 1328 | + loss_vec = vec.compute(preds, targets) |
| 1329 | + assert torch.isclose(loss_seq, loss_vec, rtol=RTOL, atol=ATOL), ( |
| 1330 | + f"selective: seq={loss_seq.item()}, vec={loss_vec.item()}" |
| 1331 | + ) |
| 1332 | + |
| 1333 | + |
| 1334 | +def test_masked_neg_vec_large_batch() -> None: |
| 1335 | + """Equivalence at training-like batch size.""" |
| 1336 | + b, t_h, t_w, t, d = 32, 4, 4, 2, 64 |
| 1337 | + torch.manual_seed(2024) |
| 1338 | + s2_mask = torch.randint(0, 4, (b, t_h, t_w, t)) |
| 1339 | + |
| 1340 | + preds = TokensAndMasks( |
| 1341 | + sentinel2_l2a=torch.randn((b, t_h, t_w, t, d)), |
| 1342 | + sentinel2_l2a_mask=s2_mask, |
| 1343 | + ) |
| 1344 | + targets = TokensAndMasks( |
| 1345 | + sentinel2_l2a=torch.randn((b, t_h, t_w, t, d)), |
| 1346 | + sentinel2_l2a_mask=s2_mask, |
| 1347 | + ) |
| 1348 | + |
| 1349 | + seq, vec = _make_masked_neg_pair() |
| 1350 | + loss_seq = seq.compute(preds, targets) |
| 1351 | + loss_vec = vec.compute(preds, targets) |
| 1352 | + assert torch.isclose(loss_seq, loss_vec, rtol=RTOL, atol=ATOL), ( |
| 1353 | + f"large batch: seq={loss_seq.item()}, vec={loss_vec.item()}" |
| 1354 | + ) |
0 commit comments