We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 6277b33 commit ec0732aCopy full SHA for ec0732a
open_lm/utils/transformers/hf_model.py
@@ -140,7 +140,7 @@ def forward(
140
shift_logits = shift_logits.view(-1, shift_logits.size(-1))
141
shift_labels = shift_labels.view(-1).to(shift_logits.device)
142
if loss_mask is not None:
143
- shift_mask = loss_mask[..., :-1].contiguous()
+ shift_mask = loss_mask[..., 1:].contiguous()
144
loss_fct = nn.CrossEntropyLoss(reduction="none")
145
loss = loss_fct(shift_logits, shift_labels)
146
shift_mask = torch.logical_and(shift_mask.view(-1), shift_labels != -100)
0 commit comments