@@ -451,7 +451,7 @@ def __init__( # pylint: disable=dangerous-default-value
451451 self .epochs_done = 0
452452 self .restore_step = 0
453453 self .restore_epoch = 0
454- self .best_loss = { "train_loss" : float ("inf" ), "eval_loss" : float ( "inf" ) if self . config . run_eval else None }
454+ self .best_loss = float ("inf" )
455455 self .train_loader = None
456456 self .test_loader = None
457457 self .eval_loader = None
@@ -1724,15 +1724,8 @@ def _restore_best_loss(self):
17241724 logger .info (" > Restoring best loss from %s ..." , os .path .basename (self .args .best_path ))
17251725 ch = load_fsspec (self .args .restore_path , map_location = "cpu" )
17261726 if "model_loss" in ch :
1727- if isinstance (ch ["model_loss" ], dict ):
1728- self .best_loss = ch ["model_loss" ]
1729- # For backwards-compatibility:
1730- elif isinstance (ch ["model_loss" ], float ):
1731- if self .config .run_eval :
1732- self .best_loss = {"train_loss" : None , "eval_loss" : ch ["model_loss" ]}
1733- else :
1734- self .best_loss = {"train_loss" : ch ["model_loss" ], "eval_loss" : None }
1735- logger .info (" > Starting with loaded last best loss %s" , self .best_loss )
1727+ self .best_loss = ch ["model_loss" ]
1728+ logger .info (" > Starting with loaded last best loss %f" , self .best_loss )
17361729
17371730 def test (self , model = None , test_samples = None ) -> None :
17381731 """Run evaluation steps on the test data split. You can either provide the model and the test samples
@@ -1914,7 +1907,7 @@ def save_best_model(self) -> None:
19141907
19151908 # save the model and update the best_loss
19161909 self .best_loss = save_best_model (
1917- { "train_loss" : train_loss , " eval_loss" : eval_loss } ,
1910+ eval_loss if eval_loss else train_loss ,
19181911 self .best_loss ,
19191912 self .config ,
19201913 self .model ,
0 commit comments