Skip to content

Commit 5cb390c

Browse files
authored
⛔ Add EOS token to processed input in SFT (#3091)
* Add EOS token to processed input * Update sft_trainer.py * fix test
1 parent fc4dae2 commit 5cb390c

File tree

2 files changed

+11
-4
lines changed

2 files changed

+11
-4
lines changed

tests/test_sft_trainer.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1100,7 +1100,7 @@ def test_sft_trainer_only_train_packing(self):
11001100
eval_dataset=self.conversational_lm_dataset["test"],
11011101
)
11021102

1103-
self.assertEqual(len(trainer.train_dataset["input_ids"]), 46) # w/ this dataset, we end up with 46 seqs
1103+
self.assertEqual(len(trainer.train_dataset["input_ids"]), 47) # w/ this dataset, we end up with 46 seqs
11041104
self.assertEqual(len(trainer.eval_dataset["input_ids"]), len(self.conversational_lm_dataset["test"]))
11051105

11061106
def test_sft_trainer_eval_packing(self):
@@ -1125,8 +1125,8 @@ def test_sft_trainer_eval_packing(self):
11251125
eval_dataset=self.conversational_lm_dataset["test"],
11261126
)
11271127

1128-
self.assertEqual(len(trainer.train_dataset["input_ids"]), 46) # w/ this dataset, we end up with 46 seqs
1129-
self.assertEqual(len(trainer.eval_dataset["input_ids"]), 6) # w/ this dataset, we end up with 6 seqs
1128+
self.assertEqual(len(trainer.train_dataset["input_ids"]), 47) # w/ this dataset, we end up with 47 seqs
1129+
self.assertEqual(len(trainer.eval_dataset["input_ids"]), 7) # w/ this dataset, we end up with 7 seqs
11301130

11311131
def test_sft_trainer_no_packing(self):
11321132
with tempfile.TemporaryDirectory() as tmp_dir:

trl/trainer/sft_trainer.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -422,7 +422,14 @@ def concat_prompt_completion(example):
422422
map_kwargs["desc"] = f"Tokenizing {dataset_name} dataset"
423423

424424
def tokenize(example, processing_class, dataset_text_field):
425-
return processing_class(text=example[dataset_text_field])
425+
processed = processing_class(text=example[dataset_text_field])
426+
if (
427+
processing_class.eos_token_id is not None
428+
and processed["input_ids"][-1] != processing_class.eos_token_id
429+
):
430+
processed["input_ids"] = processed["input_ids"] + [processing_class.eos_token_id]
431+
processed["attention_mask"] = processed["attention_mask"] + [1]
432+
return processed
426433

427434
dataset = dataset.map(
428435
tokenize,

0 commit comments

Comments
 (0)