Skip to content

Commit 8158f46

Browse files
authored
Merge pull request #280 from kAIto47802/pytorch-checkpoint-artifact
Introduce `optuna.artifacts` to the PyTorch checkpoint example
2 parents c039f99 + 3fc41e8 commit 8158f46

File tree

1 file changed

+32
-20
lines changed

1 file changed

+32
-20
lines changed

pytorch/pytorch_checkpoint.py

Lines changed: 32 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,11 @@
1414
"""
1515

1616
import os
17-
import shutil
1817

1918
import optuna
19+
from optuna.artifacts import download_artifact
20+
from optuna.artifacts import FileSystemArtifactStore
21+
from optuna.artifacts import upload_artifact
2022
from optuna.storages import RetryFailedTrialCallback
2123
import torch
2224
import torch.nn as nn
@@ -37,6 +39,10 @@
3739
N_VALID_EXAMPLES = BATCHSIZE * 10
3840
CHECKPOINT_DIR = "pytorch_checkpoint"
3941

42+
base_path = "./artifacts"
43+
os.makedirs(base_path, exist_ok=True)
44+
artifact_store = FileSystemArtifactStore(base_path=base_path)
45+
4046

4147
def define_model(trial):
4248
# We optimize the number of layers, hidden units and dropout ratio in each layer.
@@ -83,36 +89,36 @@ def objective(trial):
8389
lr = trial.suggest_float("lr", 1e-5, 1e-1, log=True)
8490
optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=lr)
8591

86-
trial_number = RetryFailedTrialCallback.retried_trial_number(trial)
87-
trial_checkpoint_dir = os.path.join(CHECKPOINT_DIR, str(trial_number))
88-
checkpoint_path = os.path.join(trial_checkpoint_dir, "model.pt")
89-
checkpoint_exists = os.path.isfile(checkpoint_path)
90-
91-
if trial_number is not None and checkpoint_exists:
92-
checkpoint = torch.load(checkpoint_path)
92+
artifact_id = None
93+
retry_history = RetryFailedTrialCallback.retry_history(trial)
94+
for trial_number in reversed(retry_history):
95+
artifact_id = trial.study.trials[trial_number].user_attrs.get("artifact_id")
96+
if artifact_id is not None:
97+
retry_trial_number = trial_number
98+
break
99+
100+
if artifact_id is not None:
101+
download_artifact(
102+
artifact_store=artifact_store,
103+
file_path=f"./tmp_model_{trial.number}.pt",
104+
artifact_id=artifact_id,
105+
)
106+
checkpoint = torch.load(f"./tmp_model_{trial.number}.pt")
107+
os.remove(f"./tmp_model_{trial.number}.pt")
93108
epoch = checkpoint["epoch"]
94109
epoch_begin = epoch + 1
95110

96-
print(f"Loading a checkpoint from trial {trial_number} in epoch {epoch}.")
111+
print(f"Loading a checkpoint from trial {retry_trial_number} in epoch {epoch}.")
97112

98113
model.load_state_dict(checkpoint["model_state_dict"])
99114
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
100115
accuracy = checkpoint["accuracy"]
101116
else:
102-
trial_checkpoint_dir = os.path.join(CHECKPOINT_DIR, str(trial.number))
103-
checkpoint_path = os.path.join(trial_checkpoint_dir, "model.pt")
104117
epoch_begin = 0
105118

106119
# Get the FashionMNIST dataset.
107120
train_loader, valid_loader = get_mnist()
108121

109-
os.makedirs(trial_checkpoint_dir, exist_ok=True)
110-
# A checkpoint may be corrupted when the process is killed during `torch.save`.
111-
# Reduce the risk by first calling `torch.save` to a temporary file, then copy.
112-
tmp_checkpoint_path = os.path.join(trial_checkpoint_dir, "tmp_model.pt")
113-
114-
print(f"Checkpoint path for trial is '{checkpoint_path}'.")
115-
116122
# Training of the model.
117123
for epoch in range(epoch_begin, EPOCHS):
118124
model.train()
@@ -159,9 +165,15 @@ def objective(trial):
159165
"optimizer_state_dict": optimizer.state_dict(),
160166
"accuracy": accuracy,
161167
},
162-
tmp_checkpoint_path,
168+
f"./tmp_model_{trial.number}.pt",
169+
)
170+
artifact_id = upload_artifact(
171+
artifact_store=artifact_store,
172+
file_path=f"./tmp_model_{trial.number}.pt",
173+
study_or_trial=trial,
163174
)
164-
shutil.move(tmp_checkpoint_path, checkpoint_path)
175+
trial.set_user_attr("artifact_id", artifact_id)
176+
os.remove(f"./tmp_model_{trial.number}.pt")
165177

166178
# Handle pruning based on the intermediate value.
167179
if trial.should_prune():

0 commit comments

Comments
 (0)