Skip to content

Commit 4852e9d

Browse files
committed
Update to check the entire retry history
1 parent 262c48d commit 4852e9d

File tree

1 file changed

+15
-10
lines changed

1 file changed

+15
-10
lines changed

pytorch/pytorch_checkpoint.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -89,20 +89,25 @@ def objective(trial):
8989
lr = trial.suggest_float("lr", 1e-5, 1e-1, log=True)
9090
optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=lr)
9191

92-
trial_number = RetryFailedTrialCallback.retried_trial_number(trial)
93-
94-
artifact_id = trial_number and trial.study.trials[trial_number].user_attrs.get("artifact_id")
95-
if trial_number is not None and artifact_id is not None:
92+
artifact_id = None
93+
retry_history = RetryFailedTrialCallback.retry_history(trial)
94+
for trial_number in reversed(retry_history):
95+
artifact_id = trial.study.trials[trial_number].user_attrs.get("artifact_id")
96+
if artifact_id is not None:
97+
retry_trial_number = trial_number
98+
break
99+
100+
if artifact_id is not None:
96101
download_artifact(
97102
artifact_store=artifact_store,
98-
file_path=f"./tmp_model_{trial_number}.pt",
103+
file_path=f"./tmp_model_{trial.number}.pt",
99104
artifact_id=artifact_id,
100105
)
101-
checkpoint = torch.load(f"./tmp_model_{trial_number}.pt")
106+
checkpoint = torch.load(f"./tmp_model_{trial.number}.pt")
102107
epoch = checkpoint["epoch"]
103108
epoch_begin = epoch + 1
104109

105-
print(f"Loading a checkpoint from trial {trial_number} in epoch {epoch}.")
110+
print(f"Loading a checkpoint from trial {retry_trial_number} in epoch {epoch}.")
106111

107112
model.load_state_dict(checkpoint["model_state_dict"])
108113
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
@@ -159,15 +164,15 @@ def objective(trial):
159164
"optimizer_state_dict": optimizer.state_dict(),
160165
"accuracy": accuracy,
161166
},
162-
f"./tmp_model_{trial_number}.pt",
167+
f"./tmp_model_{trial.number}.pt",
163168
)
164169
artifact_id = upload_artifact(
165170
artifact_store=artifact_store,
166-
file_path=f"./tmp_model_{trial_number}.pt",
171+
file_path=f"./tmp_model_{trial.number}.pt",
167172
study_or_trial=trial,
168173
)
169174
trial.set_user_attr("artifact_id", artifact_id)
170-
os.remove(f"./tmp_model_{trial_number}.pt")
175+
os.remove(f"./tmp_model_{trial.number}.pt")
171176

172177
# Handle pruning based on the intermediate value.
173178
if trial.should_prune():

0 commit comments

Comments
 (0)