@@ -205,7 +205,7 @@ def fit_predict(
205205 # print(f"Main function returned: {y_preds}")
206206 except :
207207 print ("No return value received from main function." )
208- return y_preds
208+ return y_preds
209209 else :
210210 return self .fit_predict_wo_ddp (
211211 model ,
@@ -341,7 +341,7 @@ def fit_predict_with_ddp(
341341 """
342342 self .init_ddp (local_rank )
343343 model = model .to (local_rank )
344- model = DistributedDataParallel (model , device_ids = [local_rank ])
344+ model = DistributedDataParallel (model , device_ids = [local_rank ], find_unused_parameters = True )
345345 train_dataloader = NNDataLoader (
346346 feature_name = feature_name ,
347347 dataset = train_dataset ,
@@ -719,7 +719,7 @@ def inference_with_ddp(
719719 """
720720 self .init_ddp (local_rank )
721721 model = model .to (local_rank )
722- model = DistributedDataParallel (model , device_ids = [local_rank ])
722+ model = DistributedDataParallel (model , device_ids = [local_rank ], find_unused_parameters = True )
723723 dataloader = NNDataLoader (
724724 feature_name = feature_name ,
725725 dataset = dataset ,
@@ -870,6 +870,7 @@ def __init__(self, patience, dump_dir, fold, metrics, metrics_str):
870870 self .metrics_str = metrics_str
871871 self .wait = 0
872872 self .min_loss = float ("inf" )
873+ self .max_loss = float ("-inf" )
873874 self .is_early_stop = False
874875
875876 def early_stop_choice (self , model , epoch , loss , metric_score = None ):
@@ -890,16 +891,22 @@ def early_stop_choice(self, model, epoch, loss, metric_score=None):
890891 ]:
891892 return self ._judge_early_stop_loss (loss , model , epoch )
892893 else :
893- return self .metrics ._early_stop_choice (
894+ is_early_stop , min_score , wait , max_score = self .metrics ._early_stop_choice (
894895 self .wait ,
895896 self .min_loss ,
896897 metric_score ,
898+ self .max_loss ,
897899 model ,
898900 self .dump_dir ,
899901 self .fold ,
900902 self .patience ,
901903 epoch ,
902904 )
905+ self .min_loss = min_score
906+ self .max_loss = max_score
907+ self .wait = wait
908+ self .is_early_stop = is_early_stop
909+ return self .is_early_stop
903910
904911 def _judge_early_stop_loss (self , loss , model , epoch ):
905912 """
0 commit comments