Skip to content

Commit 8bc85e8

Browse files
committed
fix mask shift
1 parent 46a45a4 commit 8bc85e8

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

open_lm/utils/transformers/hf_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def forward(
140140
shift_logits = shift_logits.view(-1, shift_logits.size(-1))
141141
shift_labels = shift_labels.view(-1).to(shift_logits.device)
142142
if loss_mask is not None:
143-
shift_mask = loss_mask[..., :-1].contiguous()
143+
shift_mask = loss_mask[..., 1:].contiguous()
144144
loss_fct = nn.CrossEntropyLoss(reduction="none")
145145
loss = loss_fct(shift_logits, shift_labels)
146146
shift_mask = torch.logical_and(shift_mask.view(-1), shift_labels != -100)

0 commit comments

Comments
 (0)