File tree 2 files changed +11
-4
lines changed
2 files changed +11
-4
lines changed Original file line number Diff line number Diff line change @@ -1100,7 +1100,7 @@ def test_sft_trainer_only_train_packing(self):
1100
1100
eval_dataset = self .conversational_lm_dataset ["test" ],
1101
1101
)
1102
1102
1103
- self .assertEqual (len (trainer .train_dataset ["input_ids" ]), 46 ) # w/ this dataset, we end up with 46 seqs
1103
+ self .assertEqual (len (trainer .train_dataset ["input_ids" ]), 47 ) # w/ this dataset, we end up with 46 seqs
1104
1104
self .assertEqual (len (trainer .eval_dataset ["input_ids" ]), len (self .conversational_lm_dataset ["test" ]))
1105
1105
1106
1106
def test_sft_trainer_eval_packing (self ):
@@ -1125,8 +1125,8 @@ def test_sft_trainer_eval_packing(self):
1125
1125
eval_dataset = self .conversational_lm_dataset ["test" ],
1126
1126
)
1127
1127
1128
- self .assertEqual (len (trainer .train_dataset ["input_ids" ]), 46 ) # w/ this dataset, we end up with 46 seqs
1129
- self .assertEqual (len (trainer .eval_dataset ["input_ids" ]), 6 ) # w/ this dataset, we end up with 6 seqs
1128
+ self .assertEqual (len (trainer .train_dataset ["input_ids" ]), 47 ) # w/ this dataset, we end up with 47 seqs
1129
+ self .assertEqual (len (trainer .eval_dataset ["input_ids" ]), 7 ) # w/ this dataset, we end up with 7 seqs
1130
1130
1131
1131
def test_sft_trainer_no_packing (self ):
1132
1132
with tempfile .TemporaryDirectory () as tmp_dir :
Original file line number Diff line number Diff line change @@ -422,7 +422,14 @@ def concat_prompt_completion(example):
422
422
map_kwargs ["desc" ] = f"Tokenizing { dataset_name } dataset"
423
423
424
424
def tokenize (example , processing_class , dataset_text_field ):
425
- return processing_class (text = example [dataset_text_field ])
425
+ processed = processing_class (text = example [dataset_text_field ])
426
+ if (
427
+ processing_class .eos_token_id is not None
428
+ and processed ["input_ids" ][- 1 ] != processing_class .eos_token_id
429
+ ):
430
+ processed ["input_ids" ] = processed ["input_ids" ] + [processing_class .eos_token_id ]
431
+ processed ["attention_mask" ] = processed ["attention_mask" ] + [1 ]
432
+ return processed
426
433
427
434
dataset = dataset .map (
428
435
tokenize ,
You can’t perform that action at this time.
0 commit comments