5454 TTrainConfig ,
5555)
5656from mblm .train .core .iter import epoch_cycler
57- from mblm .utils .cuda import cuda_memory_snapshot , cuda_properties
57+ from mblm .utils .cuda import IS_BF16_AVAILABLE , cuda_memory_snapshot , cuda_properties
5858from mblm .utils .distributed import ElasticRunVars
59- from mblm .utils .io import CSVWriter , StateDict , dump_yml , load_model_state , save_model_state
59+ from mblm .utils .io import (
60+ CSVWriter ,
61+ StateDict ,
62+ dump_yml ,
63+ load_model_state ,
64+ save_model_state ,
65+ )
6066from mblm .utils .logging import create_logger
6167from mblm .utils .misc import retry
6268from mblm .utils .top_n import TopN
@@ -76,7 +82,7 @@ class CoreTrainerOptions:
7682 train_prog_min_interval_seconds : int = 1
7783 valid_prog_min_interval_seconds : int = 1
7884 track_first_fw_bw_exec_times : int | None = 30 # for 30 first passes, track fw/bw time
79- amp_dtype : torch .dtype = torch .half # may use bfloat16
85+ amp_dtype : torch .dtype = torch .bfloat16 if IS_BF16_AVAILABLE else torch . half
8086
8187
8288class CoreTrainer (ABC , Generic [TModel , TBatch , TModelParams , TTrainConfig , TIoConfig ]):
@@ -159,7 +165,9 @@ def __init__(
159165 )
160166
161167 assert config .io .validate_amount > 0 , "Validate amount must be strictly positive"
162- assert config .io .num_models_to_save > 0 , "Must save at least 1 model"
168+ assert config .io .num_models_to_save >= 0 , "num_models_to_save cant be negative"
169+ if config .io .num_models_to_save == 0 :
170+ self ._log .warning ("No model of this training will be saved!" )
163171
164172 if config .io .validate_amount < config .io .num_models_to_save :
165173 self ._log .warning (
@@ -963,7 +971,7 @@ def before_new_epoch(epoch: int) -> None:
963971
964972 best_model = self ._unpack_distributed_model (self ._model_dist )
965973
966- if self ._is_main_worker :
974+ if self ._is_main_worker and self . config . io . num_models_to_save > 0 :
967975 # if, on the main worker, populate the model with the best state
968976 # non-main workers will simply return the latest model, which won't
969977 # be used anyway because testing happens only on the main worker
@@ -1003,4 +1011,3 @@ def test(
10031011 avg_grad_clipped = - 1 ,
10041012 )
10051013 self ._log .info ("Finished testing" )
1006- return None
0 commit comments