Skip to content

Commit 0443f6b

Browse files
committed
Update based on review
1 parent 5f329a6 commit 0443f6b

File tree

1 file changed

+7
-8
lines changed

1 file changed

+7
-8
lines changed

pytorch/pytorch_checkpoint.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -91,13 +91,12 @@ def objective(trial):
9191

9292
trial_number = RetryFailedTrialCallback.retried_trial_number(trial)
9393

94-
if trial_number is not None:
95-
study = trial.study
96-
artifact_id = study.trials[trial_number].user_attrs["artifact_id"]
94+
artifact_id = trial.study.trials[trial_number].user_attrs.get("artifact_id")
95+
if trial_number is not None and artifact_id is not None:
9796
download_artifact(
98-
artifact_store=artifact_store, file_path="./tmp_model.pt", artifact_id=artifact_id
97+
artifact_store=artifact_store, file_path=f"./tmp_model_{trial_number}.pt", artifact_id=artifact_id
9998
)
100-
checkpoint = torch.load("./tmp_model.pt")
99+
checkpoint = torch.load(f"./tmp_model_{trial_number}.pt")
101100
epoch = checkpoint["epoch"]
102101
epoch_begin = epoch + 1
103102

@@ -158,15 +157,15 @@ def objective(trial):
158157
"optimizer_state_dict": optimizer.state_dict(),
159158
"accuracy": accuracy,
160159
},
161-
"./tmp_model.pt",
160+
f"./tmp_model_{trial_number}.pt",
162161
)
163162
artifact_id = upload_artifact(
164163
artifact_store=artifact_store,
165-
file_path="./tmp_model.pt",
164+
file_path=f"./tmp_model_{trial_number}.pt",
166165
study_or_trial=trial,
167166
)
168167
trial.set_user_attr("artifact_id", artifact_id)
169-
os.remove("./tmp_model.pt")
168+
os.remove(f"./tmp_model_{trial_number}.pt")
170169

171170
# Handle pruning based on the intermediate value.
172171
if trial.should_prune():

0 commit comments

Comments
 (0)