We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent e0e70c9 commit de8dad3Copy full SHA for de8dad3
1 file changed
src/grouping_trainer/loss.py
@@ -48,7 +48,8 @@ def _mrl_loss(
48
# 1. Detach embeddings from graph
49
# 2. Loop over dims, compute loss, backward in loop
50
# 3. Backprop detached embeddings' gradients to model
51
- return loss_total / len(dim_indices)
+ weight_total = sum(mrl_dim_to_weight[mrl_dims[idx]] for idx in dim_indices)
52
+ return loss_total / weight_total
53
54
55
class PairwiseLoss(torch.nn.Module, ABC):
0 commit comments