|
14 | 14 | """ |
15 | 15 |
|
16 | 16 | import os |
17 | | -import shutil |
18 | 17 |
|
19 | 18 | import optuna |
| 19 | +from optuna.artifacts import download_artifact |
| 20 | +from optuna.artifacts import FileSystemArtifactStore |
| 21 | +from optuna.artifacts import upload_artifact |
20 | 22 | from optuna.storages import RetryFailedTrialCallback |
21 | 23 | import torch |
22 | 24 | import torch.nn as nn |
|
37 | 39 | N_VALID_EXAMPLES = BATCHSIZE * 10 |
38 | 40 | CHECKPOINT_DIR = "pytorch_checkpoint" |
39 | 41 |
|
| 42 | +base_path = "./artifacts" |
| 43 | +os.makedirs(base_path, exist_ok=True) |
| 44 | +artifact_store = FileSystemArtifactStore(base_path=base_path) |
| 45 | + |
40 | 46 |
|
41 | 47 | def define_model(trial): |
42 | 48 | # We optimize the number of layers, hidden units and dropout ratio in each layer. |
@@ -83,36 +89,36 @@ def objective(trial): |
83 | 89 | lr = trial.suggest_float("lr", 1e-5, 1e-1, log=True) |
84 | 90 | optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=lr) |
85 | 91 |
|
86 | | - trial_number = RetryFailedTrialCallback.retried_trial_number(trial) |
87 | | - trial_checkpoint_dir = os.path.join(CHECKPOINT_DIR, str(trial_number)) |
88 | | - checkpoint_path = os.path.join(trial_checkpoint_dir, "model.pt") |
89 | | - checkpoint_exists = os.path.isfile(checkpoint_path) |
90 | | - |
91 | | - if trial_number is not None and checkpoint_exists: |
92 | | - checkpoint = torch.load(checkpoint_path) |
| 92 | + artifact_id = None |
| 93 | + retry_history = RetryFailedTrialCallback.retry_history(trial) |
| 94 | + for trial_number in reversed(retry_history): |
| 95 | + artifact_id = trial.study.trials[trial_number].user_attrs.get("artifact_id") |
| 96 | + if artifact_id is not None: |
| 97 | + retry_trial_number = trial_number |
| 98 | + break |
| 99 | + |
| 100 | + if artifact_id is not None: |
| 101 | + download_artifact( |
| 102 | + artifact_store=artifact_store, |
| 103 | + file_path=f"./tmp_model_{trial.number}.pt", |
| 104 | + artifact_id=artifact_id, |
| 105 | + ) |
| 106 | + checkpoint = torch.load(f"./tmp_model_{trial.number}.pt") |
| 107 | + os.remove(f"./tmp_model_{trial.number}.pt") |
93 | 108 | epoch = checkpoint["epoch"] |
94 | 109 | epoch_begin = epoch + 1 |
95 | 110 |
|
96 | | - print(f"Loading a checkpoint from trial {trial_number} in epoch {epoch}.") |
| 111 | + print(f"Loading a checkpoint from trial {retry_trial_number} in epoch {epoch}.") |
97 | 112 |
|
98 | 113 | model.load_state_dict(checkpoint["model_state_dict"]) |
99 | 114 | optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) |
100 | 115 | accuracy = checkpoint["accuracy"] |
101 | 116 | else: |
102 | | - trial_checkpoint_dir = os.path.join(CHECKPOINT_DIR, str(trial.number)) |
103 | | - checkpoint_path = os.path.join(trial_checkpoint_dir, "model.pt") |
104 | 117 | epoch_begin = 0 |
105 | 118 |
|
106 | 119 | # Get the FashionMNIST dataset. |
107 | 120 | train_loader, valid_loader = get_mnist() |
108 | 121 |
|
109 | | - os.makedirs(trial_checkpoint_dir, exist_ok=True) |
110 | | - # A checkpoint may be corrupted when the process is killed during `torch.save`. |
111 | | - # Reduce the risk by first calling `torch.save` to a temporary file, then copy. |
112 | | - tmp_checkpoint_path = os.path.join(trial_checkpoint_dir, "tmp_model.pt") |
113 | | - |
114 | | - print(f"Checkpoint path for trial is '{checkpoint_path}'.") |
115 | | - |
116 | 122 | # Training of the model. |
117 | 123 | for epoch in range(epoch_begin, EPOCHS): |
118 | 124 | model.train() |
@@ -159,9 +165,15 @@ def objective(trial): |
159 | 165 | "optimizer_state_dict": optimizer.state_dict(), |
160 | 166 | "accuracy": accuracy, |
161 | 167 | }, |
162 | | - tmp_checkpoint_path, |
| 168 | + f"./tmp_model_{trial.number}.pt", |
| 169 | + ) |
| 170 | + artifact_id = upload_artifact( |
| 171 | + artifact_store=artifact_store, |
| 172 | + file_path=f"./tmp_model_{trial.number}.pt", |
| 173 | + study_or_trial=trial, |
163 | 174 | ) |
164 | | - shutil.move(tmp_checkpoint_path, checkpoint_path) |
| 175 | + trial.set_user_attr("artifact_id", artifact_id) |
| 176 | + os.remove(f"./tmp_model_{trial.number}.pt") |
165 | 177 |
|
166 | 178 | # Handle pruning based on the intermediate value. |
167 | 179 | if trial.should_prune(): |
|
0 commit comments