1919from torch .nn .parallel import DistributedDataParallel as DDP_th
2020from torch .utils .data import DataLoader
2121
22+ from trainer .analytics import ping_training_run
2223from trainer .callbacks import TrainerCallback
2324from trainer .generic_utils import (
2425 KeepAverage ,
@@ -241,6 +242,10 @@ class TrainerArgs(Coqpit):
241242 default = False ,
242243 metadata = {"help" : "Skip training and only run evaluation and test." },
243244 )
245+ start_with_eval : bool = field (
246+ default = False ,
247+ metadata = {"help" : "Start with evaluation and test." },
248+ )
244249 small_run : int = field (
245250 default = None ,
246251 metadata = {
@@ -388,6 +393,7 @@ def __init__( # pylint: disable=dangerous-default-value
388393 self .grad_accum_steps = args .grad_accum_steps
389394 self .overfit_batch = args .overfit_batch
390395 self .skip_train_epoch = args .skip_train_epoch
396+ self .start_with_eval = args .start_with_eval
391397
392398 assert self .grad_accum_steps > 0 , " [!] grad_accum_steps must be greater than 0."
393399
@@ -519,6 +525,7 @@ def __init__( # pylint: disable=dangerous-default-value
519525 self .callbacks .on_init_end (self )
520526 self .dashboard_logger .add_config (config )
521527 self .save_training_script ()
528+ ping_training_run ()
522529
523530 def save_training_script (self ):
524531 """Save the training script to tracking dashboard and output path."""
@@ -1519,7 +1526,7 @@ def _fit(self) -> None:
15191526 self .keep_avg_eval = KeepAverage () if self .config .run_eval else None
15201527 self .epochs_done = epoch
15211528 self .c_logger .print_epoch_start (epoch , self .config .epochs , self .output_path )
1522- if not self .skip_train_epoch :
1529+ if not self .skip_train_epoch and not self . start_with_eval :
15231530 self .train_epoch ()
15241531 if self .config .run_eval :
15251532 self .eval_epoch ()
@@ -1532,6 +1539,7 @@ def _fit(self) -> None:
15321539 if self .args .rank in [None , 0 ]:
15331540 self .save_best_model ()
15341541 self .callbacks .on_epoch_end (self )
1542+ self .start_with_eval = False
15351543
15361544 def fit_with_largest_batch_size (self , starting_batch_size = 2048 ) -> None :
15371545 cuda_meminfo ()
@@ -1552,7 +1560,7 @@ def fit_with_largest_batch_size(self, starting_batch_size=2048) -> None:
15521560 torch .cuda .empty_cache ()
15531561 else :
15541562 raise
1555- except Exception as exception : # pylint: disable=broad-except
1563+ except Exception as exception : # pylint: disable=broad-except
15561564 # catches the torch.cuda.OutOfMemoryError
15571565 if bs > 1 and should_reduce_batch_size (exception ):
15581566 bs //= 2
0 commit comments