Skip to content

Commit 7b3bb8c

Browse files
committed
fix typo in input for masked lm loss function
1 parent 257a351 commit 7b3bb8c

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

pytorch_pretrained_bert/modeling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -678,7 +678,7 @@ def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm
678678

679679
if masked_lm_labels is not None and next_sentence_label is not None:
680680
loss_fct = CrossEntropyLoss(ignore_index=-1)
681-
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels(-1))
681+
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
682682
next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
683683
total_loss = masked_lm_loss + next_sentence_loss
684684
return total_loss

0 commit comments

Comments
 (0)