Skip to content

Commit 8f69913

Browse files
authored
[fix] early stop error when specify metric in training (#333)
* [fix] early stop error when specify metric in training * [fix] enable find_unused_parameters=True for ddp to avoid exception in unimolv2 * [typo] fix typo
1 parent 21d1bba commit 8f69913

File tree

2 files changed

+13
-6
lines changed

2 files changed

+13
-6
lines changed

unimol_tools/setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55

66
setup(
77
name="unimol_tools",
8-
version="0.1.3",
9-
description=("unimol_tools is a Python package for property prediciton with Uni-Mol in molecule, materials and protein."),
8+
version="0.1.3.post1",
9+
description=("unimol_tools is a Python package for property prediction with Uni-Mol in molecule, materials and protein."),
1010
long_description=open('README.md').read(),
1111
long_description_content_type='text/markdown',
1212
author="DP Technology",

unimol_tools/unimol_tools/tasks/trainer.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ def fit_predict(
205205
# print(f"Main function returned: {y_preds}")
206206
except:
207207
print("No return value received from main function.")
208-
return y_preds
208+
return y_preds
209209
else:
210210
return self.fit_predict_wo_ddp(
211211
model,
@@ -341,7 +341,7 @@ def fit_predict_with_ddp(
341341
"""
342342
self.init_ddp(local_rank)
343343
model = model.to(local_rank)
344-
model = DistributedDataParallel(model, device_ids=[local_rank])
344+
model = DistributedDataParallel(model, device_ids=[local_rank], find_unused_parameters=True)
345345
train_dataloader = NNDataLoader(
346346
feature_name=feature_name,
347347
dataset=train_dataset,
@@ -719,7 +719,7 @@ def inference_with_ddp(
719719
"""
720720
self.init_ddp(local_rank)
721721
model = model.to(local_rank)
722-
model = DistributedDataParallel(model, device_ids=[local_rank])
722+
model = DistributedDataParallel(model, device_ids=[local_rank], find_unused_parameters=True)
723723
dataloader = NNDataLoader(
724724
feature_name=feature_name,
725725
dataset=dataset,
@@ -870,6 +870,7 @@ def __init__(self, patience, dump_dir, fold, metrics, metrics_str):
870870
self.metrics_str = metrics_str
871871
self.wait = 0
872872
self.min_loss = float("inf")
873+
self.max_loss = float("-inf")
873874
self.is_early_stop = False
874875

875876
def early_stop_choice(self, model, epoch, loss, metric_score=None):
@@ -890,16 +891,22 @@ def early_stop_choice(self, model, epoch, loss, metric_score=None):
890891
]:
891892
return self._judge_early_stop_loss(loss, model, epoch)
892893
else:
893-
return self.metrics._early_stop_choice(
894+
is_early_stop, min_score, wait, max_score = self.metrics._early_stop_choice(
894895
self.wait,
895896
self.min_loss,
896897
metric_score,
898+
self.max_loss,
897899
model,
898900
self.dump_dir,
899901
self.fold,
900902
self.patience,
901903
epoch,
902904
)
905+
self.min_loss = min_score
906+
self.max_loss = max_score
907+
self.wait = wait
908+
self.is_early_stop = is_early_stop
909+
return self.is_early_stop
903910

904911
def _judge_early_stop_loss(self, loss, model, epoch):
905912
"""

0 commit comments

Comments
 (0)