@@ -722,7 +722,7 @@ def _postprocess_sequence(self, example, actual_example_num):
722722 oral_tokens = tokens
723723 tokens = oral_tokens [:- 1 ]
724724 labels = oral_tokens [1 :]
725- loss_mask = loss_mask [1 : ]
725+ loss_mask = loss_mask [: - 1 ]
726726 if len (tokens ) > self .max_seq_len :
727727 raise RuntimeError (f"token_ids is too long: { len (tokens )} " )
728728
@@ -750,9 +750,15 @@ def print_debug_info(tokenizer, data, label):
750750 logger .info ("[dataset debug] Debug mode enabled" )
751751
752752 if hasattr (self , "tokenizer" ):
753+ print ("========================================" )
753754 print_debug_info (self .tokenizer , tokens , "input" )
754- labels = [x for x in labels if x != - 100 ] # remove -100
755- print_debug_info (self .tokenizer , labels , "labels" )
755+ print ("========================================\n " )
756+
757+ filtered_labels = [label if mask == 1 else - 100 for label , mask in zip (labels , loss_mask )]
758+ filtered_labels = [x for x in filtered_labels if x != - 100 ] # remove -100
759+ print ("========================================" )
760+ print_debug_info (self .tokenizer , filtered_labels , "labels" )
761+ print ("========================================\n " )
756762 logger .info (f"[dataset debug] loss mask: { loss_mask } " )
757763 else :
758764 logger .info ("[dataset debug] Tokenizer not available" )
0 commit comments