44import sys
55import json
66import math
7+ import tempfile
78
89from dataclasses import dataclass , field
910from collections import Counter , defaultdict
2122 DataCollatorWithPadding ,
2223 EvalPrediction ,
2324 HfArgumentParser ,
24- PretrainedConfig ,
2525 Trainer ,
2626 TrainingArguments ,
2727 default_data_collator ,
@@ -328,7 +328,7 @@ def has_custom_dataset(self):
328328 return (
329329 self ._dataset is not None
330330 or self ._train_dataset is not None
331- or self ._test_datast is not None
331+ or self ._test_dataset is not None
332332 or self ._validation_dataset is not None
333333 )
334334
@@ -351,9 +351,20 @@ def set_args(self, args_dict=None):
351351 training_args ,
352352 ) = parser .parse_args_into_dataclasses ()
353353 else :
354- args_dict ["do_train" ] = True
355- args_dict ["do_eval" ] = True
356- args_dict ["do_predict" ] = True
354+ if self ._has_train_dataset :
355+ args_dict ["do_train" ] = True
356+ else :
357+ raise ValueError (
358+ "You must intilize the FastFitTrainer with train_dataset or train_file."
359+ )
360+ if self ._has_validation_dataset :
361+ args_dict ["do_eval" ] = True
362+ if self ._has_test_dataset :
363+ args_dict ["do_predict" ] = True
364+ if "output_dir" not in args_dict :
365+ args_dict ["save_strategy" ] = "no"
366+ args_dict ["output_dir" ] = tempfile .gettempdir ()
367+ args_dict ["overwrite_output_dir" ] = True
357368 if self .has_custom_dataset ():
358369 args_dict ["task_name" ] = "custom"
359370 config_args , model_args , data_args , training_args = parser .parse_dict (
@@ -457,7 +468,7 @@ def set_data(
457468
458469 # Get the test dataset: you can provide your own CSV/JSON test file (see below)
459470 # when you use `do_predict` without specifying a GLUE benchmark task.
460- if self .training_args .do_predict :
471+ if self .training_args .do_predict and self . is_command_line_mode :
461472 if self .data_args .test_file is not None :
462473 train_extension = self .data_args .train_file .split ("." )[- 1 ]
463474 test_extension = self .data_args .test_file .split ("." )[- 1 ]
@@ -663,28 +674,7 @@ def preprocess_data(self):
663674
664675 # Some models have set the order of the labels to use, so let's make sure we do use it.
665676 self .label_to_id = None
666- if (
667- self .model .config .label2id
668- != PretrainedConfig (num_labels = self .num_labels ).label2id
669- and self .data_args .task_name is not None
670- and not self .is_regression
671- ):
672- # Some have all caps in their config, some don't.
673- self .label_name_to_id = {
674- k .lower (): v for k , v in self .model .config .label2id .items ()
675- }
676- if list (sorted (self .label_name_to_id .keys ())) == list (sorted (self .labels )):
677- self .label_to_id = {
678- i : int (self .label_name_to_id [self .labels [i ]])
679- for i in range (self .num_labels )
680- }
681- else :
682- logger .warning (
683- "Your model seems to have been trained with labels, but they don't match the dataset: " ,
684- f"model labels: { list (sorted (self .label_name_to_id .keys ()))} , dataset labels: { list (sorted (self .labels ))} ."
685- "\n Ignoring the model labels as a result." ,
686- )
687- elif self .data_args .task_name is None and not self .is_regression :
677+ if not self .is_regression :
688678 self .label_to_id = {v : i for i , v in enumerate (self .labels )}
689679
690680 if self .label_to_id is not None :
@@ -936,12 +926,28 @@ def __init__(
936926 train_dataset = None ,
937927 validation_dataset = None ,
938928 test_dataset = None ,
929+ is_command_line_mode = False ,
939930 ** kwargs ,
940931 ):
932+ self ._has_train_dataset = train_dataset is not None or "train_file" in kwargs
933+ if dataset is not None :
934+ self ._has_train_dataset = "train" in dataset
935+
936+ self ._has_validation_dataset = (
937+ validation_dataset is not None or "validation_file" in kwargs
938+ )
939+ if dataset is not None :
940+ self ._has_validation_dataset = "validation" in dataset
941+
942+ self ._has_test_dataset = test_dataset is not None or "test_file" in kwargs
943+ if dataset is not None :
944+ self ._has_test_dataset = "test" in dataset
945+
941946 self ._dataset = dataset
942947 self ._train_dataset = train_dataset
943948 self ._validation_dataset = validation_dataset
944949 self ._test_dataset = test_dataset
950+ self .is_command_line_mode = is_command_line_mode
945951 self .set_args (kwargs )
946952 self .set_logger ()
947953 self .set_last_checkpoint ()
@@ -1049,20 +1055,15 @@ def test(self):
10491055 def push_to_hub (self ):
10501056 kwargs = {
10511057 "finetuned_from" : self .model_args .model_name_or_path ,
1052- "tasks" : "text-classification" ,
1058+ "tasks" : self . data_args . task_name ,
10531059 }
1054- if self .data_args .task_name is not None :
1055- kwargs ["language" ] = "en"
1056- kwargs ["dataset_tags" ] = "glue"
1057- kwargs ["dataset_args" ] = self .data_args .task_name
1058- kwargs ["dataset" ] = f"GLUE { self .data_args .task_name .upper ()} "
10591060
10601061 if self .training_args .push_to_hub :
10611062 self .trainer .push_to_hub (** kwargs )
10621063
10631064
10641065def main ():
1065- trainer = FastFitTrainer ()
1066+ trainer = FastFitTrainer (is_command_line_mode = True )
10661067 trainer .train ()
10671068 trainer .evaluate ()
10681069 trainer .test ()
0 commit comments