Skip to content

Commit 542bd23

Browse files
authored
Save on exit (#102)
* Save on keyboard interrupt * Bump up to v0.0.23
1 parent ace9f13 commit 542bd23

File tree

3 files changed

+73
-47
lines changed

3 files changed

+73
-47
lines changed

trainer/VERSION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
v0.0.22
1+
v0.0.23

trainer/io.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import datetime
22
import json
3-
import sys
43
import os
54
import re
65
import sys

trainer/trainer.py

Lines changed: 72 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
setup_torch_training_env,
4747
)
4848
from trainer.utils.cuda_memory import cuda_meminfo, should_reduce_batch_size
49-
from trainer.utils.distributed import init_distributed
49+
from trainer.utils.distributed import init_distributed, rank_zero_only
5050

5151
logger = logging.getLogger("trainer")
5252

@@ -111,6 +111,9 @@ class TrainerConfig(Coqpit):
111111
default="tensorboard", metadata={"help": "Logger to use for the tracking dashboard. Defaults to 'tensorboard'"}
112112
)
113113
# Fields for checkpointing
114+
save_on_interrupt: bool = field(
115+
default=True, metadata={"help": "Save checkpoint on interrupt (Ctrl+C). Defaults to True"}
116+
)
114117
log_model_step: int = field(
115118
default=None,
116119
metadata={
@@ -455,7 +458,7 @@ def __init__( # pylint: disable=dangerous-default-value
455458
self.eval_samples = None
456459
self.test_samples = None
457460

458-
#define custom train and eval loader
461+
# define custom train and eval loader
459462
self.train_loader = train_loader
460463
self.eval_loader = eval_loader
461464

@@ -1295,47 +1298,11 @@ def train_step(self, batch: Dict, batch_n_steps: int, step: int, loader_start_ti
12951298
if self.total_steps_done % self.config.save_step == 0 and self.total_steps_done != 0:
12961299
if self.config.save_checkpoints:
12971300
# checkpoint the model
1298-
target_avg_loss = self._pick_target_avg_loss(self.keep_avg_train)
1299-
save_checkpoint(
1300-
self.config,
1301-
self.model,
1302-
self.optimizer,
1303-
self.scaler if self.use_amp_scaler else None,
1304-
self.total_steps_done,
1305-
self.epochs_done,
1306-
self.output_path,
1307-
model_loss=target_avg_loss,
1308-
save_n_checkpoints=self.config.save_n_checkpoints,
1309-
save_func=self.dashboard_logger.save_model,
1310-
)
1301+
self.save_checkpoint()
13111302

1312-
if self.total_steps_done % self.config.log_model_step == 0:
1313-
# log checkpoint as artifact
1314-
aliases = [
1315-
f"epoch-{self.epochs_done}",
1316-
f"step-{self.total_steps_done}",
1317-
]
1318-
self.dashboard_logger.add_artifact(
1319-
file_or_dir=self.output_path, name="checkpoint", artifact_type="model", aliases=aliases
1320-
)
1321-
1322-
# training visualizations
1323-
if hasattr(self.model, "module") and isimplemented(self.model.module, "train_log"):
1324-
self.model.module.train_log(
1325-
batch,
1326-
outputs,
1327-
self.dashboard_logger,
1328-
self.training_assets,
1329-
self.total_steps_done,
1330-
)
1331-
elif isimplemented(self.model, "train_log"):
1332-
self.model.train_log(
1333-
batch,
1334-
outputs,
1335-
self.dashboard_logger,
1336-
self.training_assets,
1337-
self.total_steps_done,
1338-
)
1303+
if self.total_steps_done % self.config.log_model_step == 0:
1304+
# log checkpoint as artifact
1305+
self.update_training_dashboard_logger(batch=batch, outputs=outputs)
13391306

13401307
self.dashboard_logger.flush()
13411308

@@ -1683,6 +1650,14 @@ def fit(self) -> None:
16831650
if self.args.rank == 0:
16841651
self.dashboard_logger.finish()
16851652
except KeyboardInterrupt:
1653+
logger.info(" > Keyboard interrupt detected.")
1654+
if self.config.save_on_interrupt:
1655+
logger.info(" > Saving model before exiting...")
1656+
# save the model on keyboard interrupt
1657+
self.save_checkpoint()
1658+
# update the training dashboard logger
1659+
self.update_training_dashboard_logger()
1660+
# call the keyboard interrupt callback
16861661
self.callbacks.on_keyboard_interrupt(self)
16871662
# if the output folder is empty remove the run.
16881663
remove_experiment_folder(self.output_path)
@@ -1694,9 +1669,9 @@ def fit(self) -> None:
16941669
self.dashboard_logger.finish()
16951670
# stop without error signal
16961671
try:
1697-
sys.exit(0)
1672+
sys.exit(1)
16981673
except SystemExit:
1699-
os._exit(0) # pylint: disable=protected-access
1674+
os._exit(1) # pylint: disable=protected-access
17001675
except BaseException: # pylint: disable=broad-except
17011676
remove_experiment_folder(self.output_path)
17021677
traceback.print_exc()
@@ -1746,6 +1721,7 @@ def profile_fit(self, torch_profiler, epochs=None, small_run=None):
17461721
self.torch_profiler.stop()
17471722
return self.torch_profiler
17481723

1724+
@rank_zero_only
17491725
def save_best_model(self) -> None:
17501726
"""Save the best model. It only saves if the current target loss is smaller then the previous."""
17511727

@@ -1768,6 +1744,52 @@ def save_best_model(self) -> None:
17681744
save_func=self.dashboard_logger.save_model,
17691745
)
17701746

1747+
@rank_zero_only
1748+
def save_checkpoint(self) -> None:
1749+
"""Save the current model checkpoint."""
1750+
target_avg_loss = self._pick_target_avg_loss(self.keep_avg_train)
1751+
save_checkpoint(
1752+
self.config,
1753+
self.model,
1754+
self.optimizer,
1755+
self.scaler if self.use_amp_scaler else None,
1756+
self.total_steps_done,
1757+
self.epochs_done,
1758+
self.output_path,
1759+
model_loss=target_avg_loss,
1760+
save_n_checkpoints=self.config.save_n_checkpoints,
1761+
save_func=self.dashboard_logger.save_model,
1762+
)
1763+
1764+
@rank_zero_only
1765+
def update_training_dashboard_logger(self, batch=None, outputs=None):
1766+
aliases = [
1767+
f"epoch-{self.epochs_done}",
1768+
f"step-{self.total_steps_done}",
1769+
]
1770+
self.dashboard_logger.add_artifact(
1771+
file_or_dir=self.output_path, name="checkpoint", artifact_type="model", aliases=aliases
1772+
)
1773+
1774+
# training visualizations
1775+
if batch is not None and outputs is not None:
1776+
if hasattr(self.model, "module") and isimplemented(self.model.module, "train_log"):
1777+
self.model.module.train_log(
1778+
batch,
1779+
outputs,
1780+
self.dashboard_logger,
1781+
self.training_assets,
1782+
self.total_steps_done,
1783+
)
1784+
elif isimplemented(self.model, "train_log"):
1785+
self.model.train_log(
1786+
batch,
1787+
outputs,
1788+
self.dashboard_logger,
1789+
self.training_assets,
1790+
self.total_steps_done,
1791+
)
1792+
17711793
#####################
17721794
# GET FUNCTIONS
17731795
#####################
@@ -1921,7 +1943,12 @@ def _pick_target_avg_loss(self, keep_avg_target: KeepAverage) -> Dict:
19211943
if "target_loss" in self.config and self.config.target_loss:
19221944
if f"avg_{self.config.target_loss}" in keep_avg_target.avg_values.keys():
19231945
return keep_avg_target[f"avg_{self.config.target_loss}"]
1924-
return keep_avg_target["avg_loss_1"]
1946+
target_loss = keep_avg_target["avg_loss_1"]
1947+
if target_loss is None:
1948+
raise ValueError(
1949+
" [!] Target loss not found in the keep_avg_target. You might be exiting the training loop before it is computed or set the target_loss in the model config incorrectly."
1950+
)
1951+
return target_loss
19251952

19261953
# take the average of loss_{optimizer_idx} as the target loss when there are multiple optimizers
19271954
if isinstance(self.optimizer, list):

0 commit comments

Comments
 (0)