-
Notifications
You must be signed in to change notification settings - Fork 31
Expand file tree
/
Copy pathloss.py
More file actions
1143 lines (960 loc) · 41.6 KB
/
loss.py
File metadata and controls
1143 lines (960 loc) · 41.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""Loss functions for training."""
import logging
import math
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any
import torch
import torch.nn.functional as F
from class_registry import ClassRegistry
from einops import rearrange, repeat
from olmo_core.config import Config
from torch import Tensor
from helios.data.constants import Modality
from helios.nn.flexihelios import PoolingType, TokensAndMasks
from helios.train.masking import MaskedHeliosSample, MaskValue
logger = logging.getLogger(__name__)
class Loss(ABC):
"""Abstract base class for loss functions."""
name: str
@abstractmethod
def compute(self, predictions: Any, targets: Any, **kwargs: Any) -> Tensor:
"""Compute the loss between predictions and targets."""
pass
@staticmethod
def _expand_and_reciprocate(t: Tensor) -> Tensor:
"""As described in the name.
>>> _expand_and_reciprocate(torch.tensor([1, 2, 3]))
tensor([1.0000, 0.5000, 0.5000, 0.3333, 0.3333, 0.3333])
"""
reciprocals = torch.reciprocal(t.float())
return torch.repeat_interleave(reciprocals, t)
LOSS_REGISTRY = ClassRegistry[Loss]()
@LOSS_REGISTRY.register("clip_patch_discrimination")
class ClipPatchDiscriminationLoss(Loss):
"""Loss function for configurable CLIP-like patch discrimination task.
Closer to the original loss from CLIP paper.
"""
name = "ClipPatchDisc"
def __init__(
self,
label_smoothing: float = 0,
prediction_norm: float | None = None,
target_norm: float | None = None,
modality_loss: bool = True,
symmetric: bool = True,
batch_loss: bool = False,
bandset_loss: bool = False,
spatial_loss: bool = False,
time_loss: bool = False,
mean_of_modalities: bool = True,
sum_of_modalities: bool = False,
decode_only: bool = True,
weight: float = 1.0,
):
"""Initialize patch discrimination loss.
Args:
alpha: scalar multiple for norm
label_smoothing: label smoothing [0,1], 0=none, 1=too much
prediction_norm: norm for predictions,
target_norm: norm for targets,
modality_loss: calculate loss across each modality
symmetric: calculate symmetric version of contrastive loss
batch_loss: caluclate loss across batches
bandset_loss: caluclate loss across bandset
weight: the weight to apply to this loss
mean_of_modalities: mean of means instead of mean of all losses
sum_of_modalities: sum of means instead of mean of all losses
decode_only: only compare to targets masked as decode (prevents cheating maybe?)
spatial_loss: bool = False
time_loss: bool = False
"""
self.label_smoothing = label_smoothing
self.weight = weight
self.prediction_norm = prediction_norm
self.target_norm = target_norm
self.modality_loss = modality_loss
self.symmetric = symmetric
self.batch_loss = batch_loss
self.bandset_loss = bandset_loss
self.mean_of_modalities = mean_of_modalities
self.sum_of_modalities = sum_of_modalities
self.decode_only = decode_only
self.spatial_loss = spatial_loss
self.time_loss = time_loss
def _calculate_modality_loss(
self, preds: Tensor, targs: Tensor, masks: Tensor
) -> Tensor:
preds_flat = rearrange(preds, "b ... d -> b (...) d")
targs_flat = rearrange(targs, "b ... d -> b (...) d")
score = torch.einsum("bxd,byd->bxy", preds_flat, targs_flat) * self.logit_scale
if self.decode_only:
score_mask = (
(masks != MaskValue.DECODER.value)
.flatten(start_dim=1)
.unsqueeze(1)
.expand_as(score)
)
score[score_mask] = torch.finfo(score.dtype).min
label = torch.arange(score.shape[1], dtype=torch.long, device=score.device)
loss = F.cross_entropy(
score.flatten(0, 1),
label.repeat(score.shape[0]),
reduction="none",
label_smoothing=self.label_smoothing,
)[masks.flatten() == MaskValue.DECODER.value]
return loss
def _calculate_batch_loss(
self, preds: Tensor, targs: Tensor, masks: Tensor
) -> Tensor:
preds_flat = rearrange(preds, "b ... d -> (...) b d")
targs_flat = rearrange(targs, "b ... d -> (...) b d")
masks_flat = rearrange(masks, "b ... -> (...) b")
score = torch.einsum("bxd,byd->bxy", preds_flat, targs_flat) * self.logit_scale
if self.decode_only:
score_mask = (
(masks_flat != MaskValue.DECODER.value).unsqueeze(1).expand_as(score)
)
score[score_mask] = torch.finfo(score.dtype).min
label = torch.arange(score.shape[2], dtype=torch.long, device=score.device)
loss = F.cross_entropy(
score.flatten(0, 1),
label.repeat(score.shape[0]),
reduction="none",
label_smoothing=self.label_smoothing,
)[masks_flat.flatten() == MaskValue.DECODER.value]
return loss
def _calculate_bandset_loss(
self, preds: Tensor, targs: Tensor, masks: Tensor
) -> Tensor:
# (B, P_H, P_W, T, Band_Sets, D)
preds_flat = rearrange(preds, "b ... bs d -> (b bs) (...) d")
targs_flat = rearrange(targs, "b ... bs d -> (b bs) (...) d")
masks_flat = rearrange(masks, "b ... bs -> (b bs) (...)")
score = torch.einsum("bxd,byd->bxy", preds_flat, targs_flat) * self.logit_scale
if self.decode_only:
score_mask = (
(masks_flat != MaskValue.DECODER.value)
.flatten(start_dim=1)
.unsqueeze(1)
.expand_as(score)
)
score[score_mask] = torch.finfo(score.dtype).min
label = torch.arange(score.shape[1], dtype=torch.long, device=score.device)
loss = F.cross_entropy(
score.flatten(0, 1),
label.repeat(score.shape[0]),
reduction="none",
label_smoothing=self.label_smoothing,
)[masks_flat.flatten() == MaskValue.DECODER.value]
return loss
def _calculate_spatial_loss(
self, preds: Tensor, targs: Tensor, masks: Tensor
) -> Tensor:
# (B, P_H, P_W, T, Band_Sets, D)
preds_flat = rearrange(preds, "b ph pw t bs d -> (b t) (ph pw bs) d")
targs_flat = rearrange(targs, "b ph pw t bs d -> (b t) (ph pw bs) d")
masks_flat = rearrange(masks, "b ph pw t bs -> (b t) (ph pw bs)")
score = torch.einsum("bxd,byd->bxy", preds_flat, targs_flat) * self.logit_scale
if self.decode_only:
score_mask = (
(masks_flat != MaskValue.DECODER.value)
.flatten(start_dim=1)
.unsqueeze(1)
.expand_as(score)
)
score[score_mask] = torch.finfo(score.dtype).min
label = torch.arange(score.shape[1], dtype=torch.long, device=score.device)
loss = F.cross_entropy(
score.flatten(0, 1),
label.repeat(score.shape[0]),
reduction="none",
label_smoothing=self.label_smoothing,
)[masks_flat.flatten() == MaskValue.DECODER.value]
return loss
def _calculate_time_loss(
self, preds: Tensor, targs: Tensor, masks: Tensor
) -> Tensor:
# (B, P_H, P_W, T, Band_Sets, D)
preds_flat = rearrange(preds, "b ph pw t bs d -> (b ph pw) (t bs) d")
targs_flat = rearrange(targs, "b ph pw t bs d -> (b ph pw) (t bs) d")
masks_flat = rearrange(masks, "b ph pw t bs -> (b ph pw) (t bs)")
score = torch.einsum("bxd,byd->bxy", preds_flat, targs_flat) * self.logit_scale
if self.decode_only:
score_mask = (
(masks_flat != MaskValue.DECODER.value)
.flatten(start_dim=1)
.unsqueeze(1)
.expand_as(score)
)
score[score_mask] = torch.finfo(score.dtype).min
label = torch.arange(score.shape[1], dtype=torch.long, device=score.device)
loss = F.cross_entropy(
score.flatten(0, 1),
label.repeat(score.shape[0]),
reduction="none",
label_smoothing=self.label_smoothing,
)[masks_flat.flatten() == MaskValue.DECODER.value]
return loss
def compute(
self,
predictions: TokensAndMasks,
targets: TokensAndMasks,
logit_scale: Tensor = None,
**kwargs: Any,
) -> Tensor:
"""Compute patch discrimination loss between predictions and targets.
Args:
predictions: Model predictions.
targets: Ground truth targets.
logit_scale: scalar for logit.
**kwargs: Additional keyword arguments.
Returns:
The computed loss value.
"""
# sentinel2: sentinel 2 data of shape (B, P_H, P_W, T, Band_Sets, D)
self.logit_scale = logit_scale
losses = []
for modality_name in predictions.modalities:
preds = getattr(predictions, modality_name)
targs = getattr(targets, modality_name)
masks = getattr(
predictions, predictions.get_masked_modality_name(modality_name)
)
if self.target_norm is not None:
targs = self.target_norm * F.normalize(targs, p=2, dim=-1)
if self.prediction_norm is not None:
preds = self.prediction_norm * F.normalize(preds, p=2, dim=-1)
if self.modality_loss:
losses.append(self._calculate_modality_loss(preds, targs, masks))
if self.symmetric:
losses.append(self._calculate_modality_loss(targs, preds, masks))
if self.batch_loss:
losses.append(self._calculate_batch_loss(preds, targs, masks))
if self.symmetric:
losses.append(self._calculate_batch_loss(targs, preds, masks))
if self.bandset_loss:
losses.append(self._calculate_bandset_loss(preds, targs, masks))
if self.symmetric:
losses.append(self._calculate_bandset_loss(targs, preds, masks))
if self.time_loss:
if Modality.get(modality_name).is_multitemporal:
losses.append(self._calculate_time_loss(preds, targs, masks))
if self.symmetric:
losses.append(self._calculate_time_loss(targs, preds, masks))
else:
losses.append(self._calculate_modality_loss(preds, targs, masks))
if self.symmetric:
losses.append(
self._calculate_modality_loss(targs, preds, masks)
)
if self.spatial_loss:
if Modality.get(modality_name).is_multitemporal:
losses.append(self._calculate_spatial_loss(preds, targs, masks))
if self.symmetric:
losses.append(self._calculate_spatial_loss(targs, preds, masks))
else:
losses.append(self._calculate_modality_loss(preds, targs, masks))
if self.symmetric:
losses.append(
self._calculate_modality_loss(targs, preds, masks)
)
if self.mean_of_modalities:
total_loss = torch.stack(
[
loss.mean()
if loss.numel() > 0
else torch.tensor(0, device=loss.device)
for loss in losses
]
)
total_loss = total_loss.mean() if total_loss.numel() > 0 else 0
elif self.sum_of_modalities:
total_loss = torch.stack(
[loss.mean() for loss in losses if loss.numel() > 0]
).sum()
else:
total_loss = torch.cat(
[loss.flatten() for loss in losses if loss.numel() > 0]
)
total_loss = total_loss.mean() if total_loss.numel() > 0 else 0
return self.weight * total_loss
@LOSS_REGISTRY.register("all_discrimination")
class AllDiscriminationLoss(Loss):
"""Loss function for all discrimination task.
Discriminates across patches using all samples in a batch.
"""
name = "AllDisc"
def __init__(self, tau: float = 0.1, pred2unit: bool = False):
"""Initialize all patch discrimination loss.
Args:
tau: the softmax temperature
pred2unit: whether to standardize the predictions using batch statistics
"""
self.tau = tau
self.pred2unit = pred2unit
def compute(
self, predictions: TokensAndMasks, targets: TokensAndMasks, **kwargs: Any
) -> Tensor:
"""Compute all patch discrimination loss between predictions and targets.
Args:
predictions: Model predictions.
targets: Ground truth targets.
**kwargs: Additional keyword arguments.
Returns:
The computed loss value.
"""
all_preds, all_masks = predictions.flatten_tokens_and_masks()
all_targets = targets.flatten_tokens_and_masks()[0]
pred = all_preds[all_masks == MaskValue.DECODER.value].unsqueeze(dim=0)
target = all_targets[all_masks == MaskValue.DECODER.value].unsqueeze(dim=0)
bs, nt, _ = pred.shape
if self.pred2unit:
pred_mu = pred.mean(1, keepdims=True)
pred_std = pred.std(1, keepdims=True)
pred = (pred - pred_mu) / (pred_std + 1e-4)
pred = F.normalize(pred, p=2, dim=-1)
target = F.normalize(target, p=2, dim=-1)
scores = torch.einsum("npd,nqd->npq", pred, target) / self.tau
count = (all_masks == MaskValue.DECODER.value).sum(dim=-1)
labels = torch.arange(nt, dtype=torch.long, device=pred.device)[None].repeat(
bs, 1
)
loss = F.cross_entropy(
scores.flatten(0, 1), labels.flatten(0, 1), reduction="none"
) * (self.tau * 2)
# emulate averaging across the batch dimension
loss_multiplier = self._expand_and_reciprocate(count)
# can't use bs here since this is after the unsqueezing, so bs == 1
loss = (loss * loss_multiplier).sum() / all_preds.shape[0]
return loss
@LOSS_REGISTRY.register("modality_all_discrimination")
class ModalityAllDiscriminationLoss(Loss):
"""Loss function for all discrimination task.
Discriminates across patches using all samples in a batch.
"""
name = "ModalityAllDisc"
def __init__(self, tau: float = 0.1, pred2unit: bool = False):
"""Initialize all patch discrimination loss.
Args:
tau: the softmax temperature
pred2unit: whether to standardize the predictions using batch statistics
"""
self.tau = tau
self.pred2unit = pred2unit
def compute(
self, predictions: TokensAndMasks, targets: TokensAndMasks, **kwargs: Any
) -> Tensor:
"""Compute all patch discrimination loss between predictions and targets.
Args:
predictions: Model predictions.
targets: Ground truth targets.
**kwargs: Additional keyword arguments.
Returns:
The computed loss value.
"""
modality_preds, modality_masks = predictions.flatten_tokens_and_masks(
return_lists=True
)
modality_targets = targets.flatten_tokens_and_masks(return_lists=True)[0]
total_loss = 0
for all_preds, all_masks, all_targets in zip(
modality_preds, modality_masks, modality_targets
):
pred = all_preds[all_masks == MaskValue.DECODER.value].unsqueeze(dim=0)
target = all_targets[all_masks == MaskValue.DECODER.value].unsqueeze(dim=0)
bs, nt, _ = pred.shape
if nt == 0:
# If no decoded values, skip this modality
logger.warning("No decoded values for this modality")
continue
if self.pred2unit:
pred_mu = pred.mean(1, keepdims=True)
pred_std = pred.std(1, keepdims=True)
pred = (pred - pred_mu) / (pred_std + 1e-4)
pred = F.normalize(pred, p=2, dim=-1)
target = F.normalize(target, p=2, dim=-1)
scores = torch.einsum("npd,nqd->npq", pred, target) / self.tau
count = (all_masks == MaskValue.DECODER.value).sum(dim=-1)
labels = torch.arange(nt, dtype=torch.long, device=pred.device)[
None
].repeat(bs, 1)
loss = F.cross_entropy(
scores.flatten(0, 1), labels.flatten(0, 1), reduction="none"
) * (self.tau * 2)
# emulate averaging across the batch dimension
loss_multiplier = self._expand_and_reciprocate(count)
# can't use bs here since this is after the unsqueezing, so bs == 1
loss = (loss * loss_multiplier).sum() / all_preds.shape[0]
total_loss += loss
return total_loss
@LOSS_REGISTRY.register("patch_discrimination_new")
class PatchDiscriminationLossNew(Loss):
"""Loss function for patch discrimination task.
This has lower memory consumption than the old patch discrimination loss.
It does not support all discrimination loss.
"""
name = "PatchDisc"
def __init__(self, tau: float = 0.1, pred2unit: bool = False, weight: float = 1.0):
"""Initialize patch discrimination loss.
Args:
tau: the softmax temperature
pred2unit: whether to standardize the predictions using batch statistics
mask_other_samples: whether to apply the contrastive loss drawing samples
from within a sample (True) or using all other instances in a batch (False).
If this is False, then this is the AllDisc loss from the Galileo paper
weight: the weight to apply to this loss
"""
self.tau = tau
self.pred2unit = pred2unit
self.weight = weight
def compute(
self, predictions: TokensAndMasks, targets: TokensAndMasks, **kwargs: Any
) -> Tensor:
"""Compute patch discrimination loss between predictions and targets.
Args:
predictions: Model predictions.
targets: Ground truth targets.
**kwargs: Additional keyword arguments.
Returns:
The computed loss value.
"""
all_preds, all_masks = predictions.flatten_tokens_and_masks()
all_targets = targets.flatten_tokens_and_masks()[0]
# Samples may have different number of tokens
# TODO: Skip unqueeze and the for loop when mask_other_samples is True
pred = all_preds[all_masks == MaskValue.DECODER.value].unsqueeze(dim=0)
target = all_targets[all_masks == MaskValue.DECODER.value].unsqueeze(dim=0)
bs, nt, _ = pred.shape
if self.pred2unit:
pred_mu = pred.mean(1, keepdims=True)
pred_std = pred.std(1, keepdims=True)
pred = (pred - pred_mu) / (pred_std + 1e-4)
pred = F.normalize(pred, p=2, dim=-1)
target = F.normalize(target, p=2, dim=-1)
count = (all_masks == MaskValue.DECODER.value).sum(dim=-1)
losses = []
start = 0
for c in count:
end = start + c
if c == 0:
# we will occasionally get a sample with no decoded values due to missing data this will let us skip it
logger.warning("No decoded values for this sample")
continue
pred_sample = pred[:, start:end, :]
target_sample = target[:, start:end, :]
score_sample = (
torch.einsum("npd,nqd->npq", pred_sample, target_sample) / self.tau
)
labels = torch.arange(c, dtype=torch.long, device=pred.device)[None]
loss = F.cross_entropy(
score_sample.flatten(0, 1),
labels.flatten(0, 1),
reduction="none",
) * (self.tau * 2)
loss = loss.mean()
losses.append(loss)
start = end
loss = torch.stack(losses).mean()
return self.weight * loss
@LOSS_REGISTRY.register("modality_patch_discrimination_new")
class ModalityPatchDiscriminationLossNew(Loss):
"""Loss function for per-modality patch discrimination task.
This has lower memory consumption than the old patch discrimination loss.
It does not support all discrimination loss.
"""
name = "ModalityPatchDisc"
def __init__(
self,
tau: float = 0.1,
pred2unit: bool = False,
weight: float = 1.0,
modality_weights: dict[str, float] | None = None,
):
"""Initialize patch discrimination loss.
Args:
tau: the softmax temperature
pred2unit: whether to standardize the predictions using batch statistics
mask_other_samples: whether to apply the contrastive loss drawing samples
from within a sample (True) or using all other instances in a batch (False).
If this is False, then this is the AllDisc loss from the Galileo paper
weight: the weight to apply to this loss
modality_weights: the weights to apply to each modality
"""
self.tau = tau
self.pred2unit = pred2unit
self.weight = weight
self.modality_weights = modality_weights
def compute(
self, predictions: TokensAndMasks, targets: TokensAndMasks, **kwargs: Any
) -> Tensor:
"""Compute patch discrimination loss between predictions and targets.
Args:
predictions: Model predictions.
targets: Ground truth targets.
**kwargs: Additional keyword arguments.
Returns:
The computed loss value.
"""
modality_preds, modality_masks = predictions.flatten_tokens_and_masks(
return_lists=True
)
modality_targets = targets.flatten_tokens_and_masks(return_lists=True)[0]
# Accumulate to the total loss
total_loss = 0
for all_preds, all_masks, all_targets, modality in zip(
modality_preds, modality_masks, modality_targets, targets.modalities
):
# Samples may have different number of tokens
# TODO: Skip unqueeze and the for loop when mask_other_samples is True
pred = all_preds[all_masks == MaskValue.DECODER.value].unsqueeze(dim=0)
target = all_targets[all_masks == MaskValue.DECODER.value].unsqueeze(dim=0)
bs, nt, _ = pred.shape
if self.pred2unit:
pred_mu = pred.mean(1, keepdims=True)
pred_std = pred.std(1, keepdims=True)
pred = (pred - pred_mu) / (pred_std + 1e-4)
pred = F.normalize(pred, p=2, dim=-1)
target = F.normalize(target, p=2, dim=-1)
count = (all_masks == MaskValue.DECODER.value).sum(dim=-1)
losses = []
start = 0
for c in count:
end = start + c
if c == 0:
# we will occasionally get a sample with no decoded values due to missing data this will let us skip it
# logger.warning("No decoded values for this sample")
continue
pred_sample = pred[:, start:end, :]
target_sample = target[:, start:end, :]
score_sample = (
torch.einsum("npd,nqd->npq", pred_sample, target_sample) / self.tau
)
labels = torch.arange(c, dtype=torch.long, device=pred.device)[None]
loss = F.cross_entropy(
score_sample.flatten(0, 1),
labels.flatten(0, 1),
reduction="none",
) * (self.tau * 2)
loss = loss.mean()
losses.append(loss)
start = end
if len(losses) == 0:
# If no losses were computed, skip this modality
# logger.warning("No decoded values for this modality")
continue
loss = torch.stack(losses).mean()
if self.modality_weights is not None:
loss = loss * self.modality_weights[modality]
total_loss += loss
return self.weight * total_loss
@LOSS_REGISTRY.register("patch_discrimination")
class PatchDiscriminationLoss(Loss):
"""Loss function for patch discrimination task."""
name = "PatchDisc"
def __init__(
self,
tau: float = 0.1,
pred2unit: bool = False,
mask_other_samples: bool = True,
):
"""Initialize patch discrimination loss.
Args:
tau: the softmax temperature
pred2unit: whether to standardize the predictions using batch statistics
mask_other_samples: whether to apply the contrastive loss drawing samples
from within a sample (True) or using all other instances in a batch (False).
If this is False, then this is the AllDisc loss from the Galileo paper
"""
self.tau = tau
self.pred2unit = pred2unit
self.mask_other_samples = mask_other_samples
def compute(
self, predictions: TokensAndMasks, targets: TokensAndMasks, **kwargs: Any
) -> Tensor:
"""Compute patch discrimination loss between predictions and targets.
Args:
predictions: Model predictions.
targets: Ground truth targets.
**kwargs: Additional keyword arguments.
Returns:
The computed loss value.
"""
all_preds, all_masks = predictions.flatten_tokens_and_masks()
all_targets = targets.flatten_tokens_and_masks()[0]
decoder_mask = all_masks == MaskValue.DECODER.value
pred = all_preds[decoder_mask].unsqueeze(dim=0)
target = all_targets[decoder_mask].unsqueeze(dim=0)
bs, nt, _ = pred.shape
if self.pred2unit:
pred_mu = pred.mean(1, keepdims=True)
pred_std = pred.std(1, keepdims=True)
pred = (pred - pred_mu) / (pred_std + 1e-4)
pred = F.normalize(pred, p=2, dim=-1)
target = F.normalize(target, p=2, dim=-1)
scores = torch.einsum("npd,nqd->npq", pred, target) / self.tau
count = (all_masks == MaskValue.DECODER.value).sum(dim=-1)
if self.mask_other_samples:
logit_mask = torch.full_like(scores, -torch.finfo(scores.dtype).max)
start = 0
for c in count:
end = start + c
logit_mask[:, start:end, start:end] = 0
start += c
scores = scores + logit_mask
labels = torch.arange(nt, dtype=torch.long, device=pred.device)[None].repeat(
bs, 1
)
loss = F.cross_entropy(
scores.flatten(0, 1), labels.flatten(0, 1), reduction="none"
) * (self.tau * 2)
# emulate averaging across the batch dimension
loss_multiplier = self._expand_and_reciprocate(count)
# can't use bs here since this is after the unsqueezing, so bs == 1
loss = (loss * loss_multiplier).sum() / all_preds.shape[0]
return loss
@LOSS_REGISTRY.register("adjusted_patch_discrimination")
class AdjustedPatchDiscriminationLoss(Loss):
"""Loss function for adjusted patch discrimination task.
Reference: https://proceedings.neurips.cc/paper_files/paper/2023/file/48aaa5ea741ae8430bd58e25917d267d-Paper-Conference.pdf
"""
name = "AdjustedPatchDisc"
def __init__(
self,
tau: float = 0.1,
mu: float = 0.7,
sigma: float = 1.0,
pred2unit: bool = False,
):
"""Initialize adjusted patch discrimination loss.
Args:
tau: the softmax temperature
mu: the mean of the Gaussian distribution
sigma: the standard deviation of the Gaussian distribution
pred2unit: whether to standardize the predictions using batch statistics
"""
self.tau = tau
self.mu = mu
self.sigma = sigma
self.pred2unit = pred2unit
def compute(
self, predictions: TokensAndMasks, targets: TokensAndMasks, **kwargs: Any
) -> Tensor:
"""Compute patch discrimination loss between predictions and targets.
Args:
predictions: Model predictions.
targets: Ground truth targets.
**kwargs: Additional keyword arguments.
Returns:
The computed loss value.
"""
all_preds, all_masks = predictions.flatten_tokens_and_masks()
all_targets = targets.flatten_tokens_and_masks()[0]
pred = all_preds[all_masks == MaskValue.DECODER.value].unsqueeze(dim=0)
target = all_targets[all_masks == MaskValue.DECODER.value].unsqueeze(dim=0)
bs, nt, _ = pred.shape
if self.pred2unit:
pred_mu = pred.mean(1, keepdims=True)
pred_std = pred.std(1, keepdims=True)
pred = (pred - pred_mu) / (pred_std + 1e-4)
pred = F.normalize(pred, p=2, dim=-1)
target = F.normalize(target, p=2, dim=-1)
count = (all_masks == MaskValue.DECODER.value).sum(dim=-1)
losses = []
start = 0
for c in count:
end = start + c
pred_sample = pred[:, start:end, :] # (1, c, d)
target_sample = target[:, start:end, :] # (1, c, d)
sim_matrix = torch.einsum(
"npd,nqd->npq", pred_sample, target_sample
) # (1, c, c)
pos_scores = torch.diagonal(sim_matrix, dim1=-2, dim2=-1) # (1, c)
pos_scores = pos_scores / self.tau
# Mask out diagonal (positives) to get negatives
mask = ~torch.eye(c, dtype=torch.bool, device=pred.device)
neg_scores = sim_matrix.masked_select(mask).view(1, c, c - 1) # (1, c, c-1)
neg_scores = neg_scores / self.tau
# Apply Gaussian-based weights to negatives
# Weight is computed based on the neg_scores from a sample
weight = (
1.0
/ (self.sigma * math.sqrt(2 * math.pi))
* torch.exp(
-((neg_scores * self.tau - self.mu) ** 2)
/ (2 * math.pow(self.sigma, 2))
)
) # (1, c, c-1)
# Normalize the weights per query
weight = weight / weight.mean(dim=-1, keepdim=True)
neg_scores = neg_scores * weight.detach()
# Reconstruct the sim_matrix
sim_matrix = torch.zeros(
1, c, c, device=pred.device, dtype=neg_scores.dtype
)
sim_matrix.diagonal(dim1=-2, dim2=-1).copy_(pos_scores)
sim_matrix.masked_scatter_(mask, neg_scores)
labels = torch.arange(c, dtype=torch.long, device=pred.device)[None]
loss = F.cross_entropy(
sim_matrix.flatten(0, 1),
labels.flatten(0, 1),
reduction="none",
) * (self.tau * 2)
loss = loss.mean()
losses.append(loss)
start = end
loss = torch.stack(losses).mean()
return loss
@LOSS_REGISTRY.register("l1")
class L1Loss(Loss):
"""Loss function for L1 (mean average error)."""
name = "L1"
def compute(
self, predictions: TokensAndMasks, targets: TokensAndMasks, **kwargs: Any
) -> Tensor:
"""Compute L1 loss between predictions and targets.
Args:
predictions: Model predictions.
targets: Ground truth targets.
**kwargs: Additional keyword arguments.
Returns:
The computed loss value.
"""
all_preds, all_masks = predictions.flatten_tokens_and_masks()
all_targets = targets.flatten_tokens_and_masks()[0]
pred = all_preds[all_masks == MaskValue.DECODER.value]
target = all_targets[all_masks == MaskValue.DECODER.value]
return F.l1_loss(pred, target)
@LOSS_REGISTRY.register("l2")
class L2Loss(Loss):
"""Loss function for L2 (mean squared error)."""
name = "L2"
def compute(
self, predictions: TokensAndMasks, targets: TokensAndMasks, **kwargs: Any
) -> Tensor:
"""Compute L2 loss between predictions and targets.
Args:
predictions: Model predictions.
targets: Ground truth targets.
**kwargs: Additional keyword arguments.
Returns:
The computed loss value.
"""
all_preds, all_masks = predictions.flatten_tokens_and_masks()
all_targets = targets.flatten_tokens_and_masks()[0]
pred = all_preds[all_masks == MaskValue.DECODER.value]
target = all_targets[all_masks == MaskValue.DECODER.value]
return F.mse_loss(pred, target)
@LOSS_REGISTRY.register("mae")
class MAELoss(Loss):
"""Loss function masked auto-encoding (reconstruction)."""
name = "MAE"
def __init__(
self,
loss_function: str = "MSELoss",
only_decode: bool = True,
weight: float = 1.0,
**kwargs: Any,
):
"""Initialize MAE loss.
Args:
loss_function: pytorch loss to use
only_decode: only calculate loss on DECODER masked tokens, otherwise all
weight: the weight to apply to this loss
**kwargs: arguments for pytorch loss constructor
"""
self.only_decode = only_decode
self.loss = getattr(torch.nn, loss_function)(reduction="sum", **kwargs)
self.weight = weight
# data: [B, H, W, T, C]
def _flatten_helios_data(self, data: TokensAndMasks) -> tuple[Tensor, Tensor]:
masks = []
datas = []
for modality in data.modalities:
modality_spec = Modality.get(modality)
pred = getattr(data, modality)
if pred is not None:
mask = getattr(data, data.get_masked_modality_name(modality))
for idx, channel_set_idxs in enumerate(
modality_spec.bandsets_as_indices()
):
bs_mask = mask[..., idx]
bs_mask = repeat(
bs_mask, "b h w t -> b h w t c", c=len(channel_set_idxs)
)
bs_mask = rearrange(bs_mask, "b h w t c -> b (h w t c)")
masks.append(bs_mask)
bs_data = pred[..., channel_set_idxs]
bs_data = rearrange(bs_data, "b h w t c -> b (h w t c)")
datas.append(bs_data)
return torch.cat(datas, dim=1), torch.cat(masks, dim=1)
def compute(
self, predictions: TokensAndMasks, targets: MaskedHeliosSample, **kwargs: Any
) -> Tensor:
"""Compute MAE loss between predictions and targets.
Args:
predictions: Model predictions.
targets: Ground truth targets.
**kwargs: Additional keyword arguments.
Returns:
The computed loss value.
"""
data, masks = self._flatten_helios_data(predictions)
valid_dict = {}
for modality in predictions.modalities:
if getattr(predictions, modality) is not None:
masked_name = predictions.get_masked_modality_name(modality)
valid_dict[modality] = getattr(targets, modality)
valid_dict[masked_name] = getattr(targets, masked_name)
valid_targets = TokensAndMasks(**valid_dict)
labels, label_masks = self._flatten_helios_data(valid_targets)
if self.only_decode:
decode = label_masks == MaskValue.DECODER.value
else:
decode = label_masks != MaskValue.MISSING.value
data = data * decode
labels = labels * decode
return self.weight * self.loss(data, labels) / torch.count_nonzero(decode)
@LOSS_REGISTRY.register("cross_entropy")
class CrossEntropyLoss(Loss):
"""Loss function for cross entropy."""
name = "CrossEntropy"
def compute(
self, predictions: TokensAndMasks, targets: TokensAndMasks, **kwargs: Any
) -> Tensor:
"""Compute cross entropy between predictions and targets.
Args:
predictions: Model predictions.
targets: Ground truth targets.
**kwargs: Additional keyword arguments.
Returns:
The computed loss value.
"""
all_preds, all_masks = predictions.flatten_tokens_and_masks()
all_targets = targets.flatten_tokens_and_masks()[0]
pred = all_preds[all_masks == MaskValue.DECODER.value]
target = all_targets[all_masks == MaskValue.DECODER.value]
return F.cross_entropy(pred, target.squeeze())
@LOSS_REGISTRY.register("InfoNCE")
class InfoNCELoss(Loss):
"""Loss function for InfoNCE."""
name = "InfoNCE"
def __init__(self, tau: float = 0.1, weight: float = 1):
"""Initialize InfoNCE loss.