Skip to content

Commit de8dad3

Browse files
committed
s
1 parent e0e70c9 commit de8dad3

1 file changed

Lines changed: 2 additions & 1 deletion

File tree

src/grouping_trainer/loss.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ def _mrl_loss(
4848
# 1. Detach embeddings from graph
4949
# 2. Loop over dims, compute loss, backward in loop
5050
# 3. Backprop detached embeddings' gradients to model
51-
return loss_total / len(dim_indices)
51+
weight_total = sum(mrl_dim_to_weight[mrl_dims[idx]] for idx in dim_indices)
52+
return loss_total / weight_total
5253

5354

5455
class PairwiseLoss(torch.nn.Module, ABC):

0 commit comments

Comments
 (0)