Skip to content

Commit 0949699

Browse files
Update TRL pin to 0.9.3+ (#213)
Add sft config hack for trl upgrade code formatter Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
1 parent 8ddd68c commit 0949699

File tree

2 files changed

+21
-3
lines changed

2 files changed

+21
-3
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ dependencies = [
3333
"sentencepiece>=0.1.99,<0.3",
3434
"tokenizers>=0.13.3,<1.0",
3535
"tqdm>=4.66.2,<5.0",
36-
"trl==0.8.6",
36+
"trl>=0.9.3,<1.0",
3737
"peft>=0.8.0,<0.13",
3838
"datasets>=2.15.0,<3.0",
3939
"fire>=0.5.0,<1.0",

tuning/sft_trainer.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
# Standard
1616
from typing import Dict, List, Optional, Union
17+
import dataclasses
1718
import json
1819
import sys
1920
import time
@@ -33,7 +34,7 @@
3334
TrainerCallback,
3435
)
3536
from transformers.utils import is_accelerate_available, logging
36-
from trl import DataCollatorForCompletionOnlyLM, SFTTrainer
37+
from trl import DataCollatorForCompletionOnlyLM, SFTConfig, SFTTrainer
3738
import datasets
3839
import fire
3940
import transformers
@@ -315,6 +316,23 @@ def train(
315316
model, train_args, modifiable_args=(peft_config,)
316317
)
317318

319+
# HACK - The SFT Trainer has internal validation which inspects the name of the class
320+
# being used for the HF training args; if it's a TrainingArguments class, which is
321+
# presumably from transformers, it tries to build it into an SFT Config.
322+
#
323+
# This is unfortunately a naming collision with one of our own classes, which has extra
324+
# fields, and therefore can't be used to initialize the SFT Config. For now, to sidestep
325+
# this validation, we just drop the things that aren't part of the SFT Config and build one
326+
# from our object directly. In the future, we should consider renaming this class and / or
327+
# not adding things that are not directly used by the trainer instance to it.
328+
transformer_train_arg_fields = [x.name for x in dataclasses.fields(SFTConfig)]
329+
transformer_kwargs = {
330+
k: v
331+
for k, v in train_args.to_dict().items()
332+
if k in transformer_train_arg_fields
333+
}
334+
training_args = SFTConfig(**transformer_kwargs)
335+
318336
trainer = SFTTrainer(
319337
model=model,
320338
tokenizer=tokenizer,
@@ -323,7 +341,7 @@ def train(
323341
packing=packing,
324342
data_collator=data_collator,
325343
dataset_text_field=data_args.dataset_text_field,
326-
args=train_args,
344+
args=training_args,
327345
max_seq_length=max_seq_length,
328346
callbacks=trainer_callbacks,
329347
peft_config=peft_config,

0 commit comments

Comments
 (0)