Skip to content

Commit 2d8aff9

Browse files
authored
fix loss mask bug in dataflow when using no template (#2947)
1 parent 8156d71 commit 2d8aff9

File tree

1 file changed

+9
-3
lines changed

1 file changed

+9
-3
lines changed

paddleformers/datasets/finetuning.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -722,7 +722,7 @@ def _postprocess_sequence(self, example, actual_example_num):
722722
oral_tokens = tokens
723723
tokens = oral_tokens[:-1]
724724
labels = oral_tokens[1:]
725-
loss_mask = loss_mask[1:]
725+
loss_mask = loss_mask[:-1]
726726
if len(tokens) > self.max_seq_len:
727727
raise RuntimeError(f"token_ids is too long: {len(tokens)}")
728728

@@ -750,9 +750,15 @@ def print_debug_info(tokenizer, data, label):
750750
logger.info("[dataset debug] Debug mode enabled")
751751

752752
if hasattr(self, "tokenizer"):
753+
print("========================================")
753754
print_debug_info(self.tokenizer, tokens, "input")
754-
labels = [x for x in labels if x != -100] # remove -100
755-
print_debug_info(self.tokenizer, labels, "labels")
755+
print("========================================\n")
756+
757+
filtered_labels = [label if mask == 1 else -100 for label, mask in zip(labels, loss_mask)]
758+
filtered_labels = [x for x in filtered_labels if x != -100] # remove -100
759+
print("========================================")
760+
print_debug_info(self.tokenizer, filtered_labels, "labels")
761+
print("========================================\n")
756762
logger.info(f"[dataset debug] loss mask: {loss_mask}")
757763
else:
758764
logger.info("[dataset debug] Tokenizer not available")

0 commit comments

Comments
 (0)