File tree Expand file tree Collapse file tree 1 file changed +3
-2
lines changed
open_lm/utils/transformers Expand file tree Collapse file tree 1 file changed +3
-2
lines changed Original file line number Diff line number Diff line change @@ -140,10 +140,11 @@ def forward(
140
140
shift_logits = shift_logits .view (- 1 , shift_logits .size (- 1 ))
141
141
shift_labels = shift_labels .view (- 1 ).to (shift_logits .device )
142
142
if loss_mask is not None :
143
- shift_mask = loss_mask [..., 1 : ].contiguous ()
143
+ shift_mask = loss_mask [..., : - 1 ].contiguous ()
144
144
loss_fct = nn .CrossEntropyLoss (reduction = "none" )
145
145
loss = loss_fct (shift_logits , shift_labels )
146
- loss = loss [shift_mask .view (- 1 )].sum () / shift_mask .sum ()
146
+ shift_mask = torch .logical_and (shift_mask .view (- 1 ), shift_labels != - 100 )
147
+ loss = loss [shift_mask .view (- 1 )].sum ()/ shift_mask .sum ()
147
148
else :
148
149
loss_fct = nn .CrossEntropyLoss ()
149
150
loss = loss_fct (shift_logits , shift_labels )
You can’t perform that action at this time.
0 commit comments