@@ -49,7 +49,7 @@ def main():
4949 parser .add_argument ("--lora_alpha" , default = 1.0 , type = float )
5050 parser .add_argument ("--activation_checkpointing" , action = "store_true" )
5151 parser .add_argument ("--eval_interval" , default = 1 , type = int )
52- parser .add_argument ("--eval_ratio " , default = 0.1 , type = float )
52+ parser .add_argument ("--num_eval_samples " , default = 2048 , type = int )
5353 parser .add_argument ("--checkpoint_interval" , default = 1 , type = int )
5454 parser .add_argument (
5555 "--checkpoint_path" , default = "./checkpoints/checkpoint.pt" , type = str
@@ -89,11 +89,6 @@ def main():
8989 f"Eval interval must be greater than 0, { args .eval_interval } given."
9090 )
9191
92- if args .eval_ratio < 0 or args .eval_ratio > 1 :
93- raise ValueError (
94- f"Eval ratio must be between 0 and 1, { args .eval_ratio } given."
95- )
96-
9792 if args .checkpoint_interval < 1 :
9893 raise ValueError (
9994 f"Checkpoint interval must be greater than 0, { args .checkpoint_interval } given."
@@ -151,9 +146,9 @@ def main():
151146
152147 dataset = ConcatDataset (datasets )
153148
154- training_ratio = 1.0 - args .eval_ratio
149+ n_train_samples = len ( dataset ) - args .num_eval_samples
155150
156- training , testing = random_split (dataset , ( training_ratio , args .eval_ratio ) )
151+ training , testing = random_split (dataset , [ n_train_samples , args .num_eval_samples ] )
157152
158153 right_pad_collate = partial (
159154 pad_collate ,
@@ -198,7 +193,7 @@ def main():
198193
199194 model .add_lora_parameters (** lora_args )
200195
201- print ("LoRA parameters added " )
196+ print ("Added LoRA adapters " )
202197
203198 print (f"Model has { model .num_trainable_params :,} trainable parameters" )
204199
0 commit comments