Skip to content

Commit 8c7267f

Browse files
authored
Merge pull request #70 from deepset-ai/fix_lm_loss
fix typo in input for masked lm loss function
2 parents 257a351 + 7b3bb8c commit 8c7267f

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

pytorch_pretrained_bert/modeling.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -678,7 +678,7 @@ def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm
678678

679679
if masked_lm_labels is not None and next_sentence_label is not None:
680680
loss_fct = CrossEntropyLoss(ignore_index=-1)
681-
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels(-1))
681+
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
682682
next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
683683
total_loss = masked_lm_loss + next_sentence_loss
684684
return total_loss

0 commit comments

Comments
 (0)