1414
1515# Standard
1616from typing import Dict , List , Optional , Union
17+ import dataclasses
1718import json
1819import sys
1920import time
3334 TrainerCallback ,
3435)
3536from transformers .utils import is_accelerate_available , logging
36- from trl import DataCollatorForCompletionOnlyLM , SFTTrainer
37+ from trl import DataCollatorForCompletionOnlyLM , SFTConfig , SFTTrainer
3738import datasets
3839import fire
3940import transformers
@@ -315,6 +316,23 @@ def train(
315316 model , train_args , modifiable_args = (peft_config ,)
316317 )
317318
319+ # HACK - The SFT Trainer has internal validation which inspects the name of the class
320+ # being used for the HF training args; if it's a TrainingArguments class, which is
321+ # presumably from transformers, it tries to build it into an SFT Config.
322+ #
323+ # This is unfortunately a naming collision with one of our own classes, which has extra
324+ # fields, and therefore can't be used to initialize the SFT Config. For now, to sidestep
325+ # this validation, we just drop the things that aren't part of the SFT Config and build one
326+ # from our object directly. In the future, we should consider renaming this class and / or
327+ # not adding things that are not directly used by the trainer instance to it.
328+ transformer_train_arg_fields = [x .name for x in dataclasses .fields (SFTConfig )]
329+ transformer_kwargs = {
330+ k : v
331+ for k , v in train_args .to_dict ().items ()
332+ if k in transformer_train_arg_fields
333+ }
334+ training_args = SFTConfig (** transformer_kwargs )
335+
318336 trainer = SFTTrainer (
319337 model = model ,
320338 tokenizer = tokenizer ,
@@ -323,7 +341,7 @@ def train(
323341 packing = packing ,
324342 data_collator = data_collator ,
325343 dataset_text_field = data_args .dataset_text_field ,
326- args = train_args ,
344+ args = training_args ,
327345 max_seq_length = max_seq_length ,
328346 callbacks = trainer_callbacks ,
329347 peft_config = peft_config ,
0 commit comments