-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsupconloss.py
44 lines (37 loc) · 1.62 KB
/
supconloss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
from ..distances import CosineSimilarity
from ..reducers import AvgNonZeroReducer
from ..utils import common_functions as c_f
from ..utils import loss_and_miner_utils as lmu
from .generic_pair_loss import GenericPairLoss
class SupConLoss(GenericPairLoss):
def __init__(self, temperature=0.1, **kwargs):
super().__init__(mat_based_loss=True, **kwargs)
self.temperature = temperature
self.add_to_recordable_attributes(list_of_names=["temperature"], is_stat=False)
def _compute_loss(self, mat, pos_mask, neg_mask):
if pos_mask.bool().any() and neg_mask.bool().any():
# if dealing with actual distances, use negative distances
if not self.distance.is_inverted:
mat = -mat
mat = mat / self.temperature
mat_max, _ = mat.max(dim=1, keepdim=True)
mat = mat - mat_max.detach() # for numerical stability
denominator = lmu.logsumexp(
mat, keep_mask=(pos_mask + neg_mask).bool(), add_one=False, dim=1
)
log_prob = mat - denominator
mean_log_prob_pos = (pos_mask * log_prob).sum(dim=1) / (
pos_mask.sum(dim=1) + c_f.small_val(mat.dtype)
)
return {
"loss": {
"losses": -mean_log_prob_pos,
"indices": c_f.torch_arange_from_size(mat),
"reduction_type": "element",
}
}
return self.zero_losses()
def get_default_reducer(self):
return AvgNonZeroReducer()
def get_default_distance(self):
return CosineSimilarity()