4646 setup_torch_training_env ,
4747)
4848from trainer .utils .cuda_memory import cuda_meminfo , should_reduce_batch_size
49- from trainer .utils .distributed import init_distributed
49+ from trainer .utils .distributed import init_distributed , rank_zero_only
5050
5151logger = logging .getLogger ("trainer" )
5252
@@ -111,6 +111,9 @@ class TrainerConfig(Coqpit):
111111 default = "tensorboard" , metadata = {"help" : "Logger to use for the tracking dashboard. Defaults to 'tensorboard'" }
112112 )
113113 # Fields for checkpointing
114+ save_on_interrupt : bool = field (
115+ default = True , metadata = {"help" : "Save checkpoint on interrupt (Ctrl+C). Defaults to True" }
116+ )
114117 log_model_step : int = field (
115118 default = None ,
116119 metadata = {
@@ -455,7 +458,7 @@ def __init__( # pylint: disable=dangerous-default-value
455458 self .eval_samples = None
456459 self .test_samples = None
457460
458- #define custom train and eval loader
461+ # define custom train and eval loader
459462 self .train_loader = train_loader
460463 self .eval_loader = eval_loader
461464
@@ -1295,47 +1298,11 @@ def train_step(self, batch: Dict, batch_n_steps: int, step: int, loader_start_ti
12951298 if self .total_steps_done % self .config .save_step == 0 and self .total_steps_done != 0 :
12961299 if self .config .save_checkpoints :
12971300 # checkpoint the model
1298- target_avg_loss = self ._pick_target_avg_loss (self .keep_avg_train )
1299- save_checkpoint (
1300- self .config ,
1301- self .model ,
1302- self .optimizer ,
1303- self .scaler if self .use_amp_scaler else None ,
1304- self .total_steps_done ,
1305- self .epochs_done ,
1306- self .output_path ,
1307- model_loss = target_avg_loss ,
1308- save_n_checkpoints = self .config .save_n_checkpoints ,
1309- save_func = self .dashboard_logger .save_model ,
1310- )
1301+ self .save_checkpoint ()
13111302
1312- if self .total_steps_done % self .config .log_model_step == 0 :
1313- # log checkpoint as artifact
1314- aliases = [
1315- f"epoch-{ self .epochs_done } " ,
1316- f"step-{ self .total_steps_done } " ,
1317- ]
1318- self .dashboard_logger .add_artifact (
1319- file_or_dir = self .output_path , name = "checkpoint" , artifact_type = "model" , aliases = aliases
1320- )
1321-
1322- # training visualizations
1323- if hasattr (self .model , "module" ) and isimplemented (self .model .module , "train_log" ):
1324- self .model .module .train_log (
1325- batch ,
1326- outputs ,
1327- self .dashboard_logger ,
1328- self .training_assets ,
1329- self .total_steps_done ,
1330- )
1331- elif isimplemented (self .model , "train_log" ):
1332- self .model .train_log (
1333- batch ,
1334- outputs ,
1335- self .dashboard_logger ,
1336- self .training_assets ,
1337- self .total_steps_done ,
1338- )
1303+ if self .total_steps_done % self .config .log_model_step == 0 :
1304+ # log checkpoint as artifact
1305+ self .update_training_dashboard_logger (batch = batch , outputs = outputs )
13391306
13401307 self .dashboard_logger .flush ()
13411308
@@ -1683,6 +1650,14 @@ def fit(self) -> None:
16831650 if self .args .rank == 0 :
16841651 self .dashboard_logger .finish ()
16851652 except KeyboardInterrupt :
1653+ logger .info (" > Keyboard interrupt detected." )
1654+ if self .config .save_on_interrupt :
1655+ logger .info (" > Saving model before exiting..." )
1656+ # save the model on keyboard interrupt
1657+ self .save_checkpoint ()
1658+ # update the training dashboard logger
1659+ self .update_training_dashboard_logger ()
1660+ # call the keyboard interrupt callback
16861661 self .callbacks .on_keyboard_interrupt (self )
16871662 # if the output folder is empty remove the run.
16881663 remove_experiment_folder (self .output_path )
@@ -1694,9 +1669,9 @@ def fit(self) -> None:
16941669 self .dashboard_logger .finish ()
16951670 # stop without error signal
16961671 try :
1697- sys .exit (0 )
1672+ sys .exit (1 )
16981673 except SystemExit :
1699- os ._exit (0 ) # pylint: disable=protected-access
1674+ os ._exit (1 ) # pylint: disable=protected-access
17001675 except BaseException : # pylint: disable=broad-except
17011676 remove_experiment_folder (self .output_path )
17021677 traceback .print_exc ()
@@ -1746,6 +1721,7 @@ def profile_fit(self, torch_profiler, epochs=None, small_run=None):
17461721 self .torch_profiler .stop ()
17471722 return self .torch_profiler
17481723
1724+ @rank_zero_only
17491725 def save_best_model (self ) -> None :
17501726 """Save the best model. It only saves if the current target loss is smaller then the previous."""
17511727
@@ -1768,6 +1744,52 @@ def save_best_model(self) -> None:
17681744 save_func = self .dashboard_logger .save_model ,
17691745 )
17701746
1747+ @rank_zero_only
1748+ def save_checkpoint (self ) -> None :
1749+ """Save the current model checkpoint."""
1750+ target_avg_loss = self ._pick_target_avg_loss (self .keep_avg_train )
1751+ save_checkpoint (
1752+ self .config ,
1753+ self .model ,
1754+ self .optimizer ,
1755+ self .scaler if self .use_amp_scaler else None ,
1756+ self .total_steps_done ,
1757+ self .epochs_done ,
1758+ self .output_path ,
1759+ model_loss = target_avg_loss ,
1760+ save_n_checkpoints = self .config .save_n_checkpoints ,
1761+ save_func = self .dashboard_logger .save_model ,
1762+ )
1763+
1764+ @rank_zero_only
1765+ def update_training_dashboard_logger (self , batch = None , outputs = None ):
1766+ aliases = [
1767+ f"epoch-{ self .epochs_done } " ,
1768+ f"step-{ self .total_steps_done } " ,
1769+ ]
1770+ self .dashboard_logger .add_artifact (
1771+ file_or_dir = self .output_path , name = "checkpoint" , artifact_type = "model" , aliases = aliases
1772+ )
1773+
1774+ # training visualizations
1775+ if batch is not None and outputs is not None :
1776+ if hasattr (self .model , "module" ) and isimplemented (self .model .module , "train_log" ):
1777+ self .model .module .train_log (
1778+ batch ,
1779+ outputs ,
1780+ self .dashboard_logger ,
1781+ self .training_assets ,
1782+ self .total_steps_done ,
1783+ )
1784+ elif isimplemented (self .model , "train_log" ):
1785+ self .model .train_log (
1786+ batch ,
1787+ outputs ,
1788+ self .dashboard_logger ,
1789+ self .training_assets ,
1790+ self .total_steps_done ,
1791+ )
1792+
17711793 #####################
17721794 # GET FUNCTIONS
17731795 #####################
@@ -1921,7 +1943,12 @@ def _pick_target_avg_loss(self, keep_avg_target: KeepAverage) -> Dict:
19211943 if "target_loss" in self .config and self .config .target_loss :
19221944 if f"avg_{ self .config .target_loss } " in keep_avg_target .avg_values .keys ():
19231945 return keep_avg_target [f"avg_{ self .config .target_loss } " ]
1924- return keep_avg_target ["avg_loss_1" ]
1946+ target_loss = keep_avg_target ["avg_loss_1" ]
1947+ if target_loss is None :
1948+ raise ValueError (
1949+ " [!] Target loss not found in the keep_avg_target. You might be exiting the training loop before it is computed or set the target_loss in the model config incorrectly."
1950+ )
1951+ return target_loss
19251952
19261953 # take the average of loss_{optimizer_idx} as the target loss when there are multiple optimizers
19271954 if isinstance (self .optimizer , list ):
0 commit comments