diff --git a/pytorch/pytorch_checkpoint.py b/pytorch/pytorch_checkpoint.py index bc2e9808..70b458ce 100644 --- a/pytorch/pytorch_checkpoint.py +++ b/pytorch/pytorch_checkpoint.py @@ -14,9 +14,11 @@ """ import os -import shutil import optuna +from optuna.artifacts import download_artifact +from optuna.artifacts import FileSystemArtifactStore +from optuna.artifacts import upload_artifact from optuna.storages import RetryFailedTrialCallback import torch import torch.nn as nn @@ -37,6 +39,10 @@ N_VALID_EXAMPLES = BATCHSIZE * 10 CHECKPOINT_DIR = "pytorch_checkpoint" +base_path = "./artifacts" +os.makedirs(base_path, exist_ok=True) +artifact_store = FileSystemArtifactStore(base_path=base_path) + def define_model(trial): # We optimize the number of layers, hidden units and dropout ratio in each layer. @@ -83,36 +89,36 @@ def objective(trial): lr = trial.suggest_float("lr", 1e-5, 1e-1, log=True) optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=lr) - trial_number = RetryFailedTrialCallback.retried_trial_number(trial) - trial_checkpoint_dir = os.path.join(CHECKPOINT_DIR, str(trial_number)) - checkpoint_path = os.path.join(trial_checkpoint_dir, "model.pt") - checkpoint_exists = os.path.isfile(checkpoint_path) - - if trial_number is not None and checkpoint_exists: - checkpoint = torch.load(checkpoint_path) + artifact_id = None + retry_history = RetryFailedTrialCallback.retry_history(trial) + for trial_number in reversed(retry_history): + artifact_id = trial.study.trials[trial_number].user_attrs.get("artifact_id") + if artifact_id is not None: + retry_trial_number = trial_number + break + + if artifact_id is not None: + download_artifact( + artifact_store=artifact_store, + file_path=f"./tmp_model_{trial.number}.pt", + artifact_id=artifact_id, + ) + checkpoint = torch.load(f"./tmp_model_{trial.number}.pt") + os.remove(f"./tmp_model_{trial.number}.pt") epoch = checkpoint["epoch"] epoch_begin = epoch + 1 - print(f"Loading a checkpoint from trial {trial_number} in epoch {epoch}.") + print(f"Loading a checkpoint from trial {retry_trial_number} in epoch {epoch}.") model.load_state_dict(checkpoint["model_state_dict"]) optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) accuracy = checkpoint["accuracy"] else: - trial_checkpoint_dir = os.path.join(CHECKPOINT_DIR, str(trial.number)) - checkpoint_path = os.path.join(trial_checkpoint_dir, "model.pt") epoch_begin = 0 # Get the FashionMNIST dataset. train_loader, valid_loader = get_mnist() - os.makedirs(trial_checkpoint_dir, exist_ok=True) - # A checkpoint may be corrupted when the process is killed during `torch.save`. - # Reduce the risk by first calling `torch.save` to a temporary file, then copy. - tmp_checkpoint_path = os.path.join(trial_checkpoint_dir, "tmp_model.pt") - - print(f"Checkpoint path for trial is '{checkpoint_path}'.") - # Training of the model. for epoch in range(epoch_begin, EPOCHS): model.train() @@ -159,9 +165,15 @@ def objective(trial): "optimizer_state_dict": optimizer.state_dict(), "accuracy": accuracy, }, - tmp_checkpoint_path, + f"./tmp_model_{trial.number}.pt", + ) + artifact_id = upload_artifact( + artifact_store=artifact_store, + file_path=f"./tmp_model_{trial.number}.pt", + study_or_trial=trial, ) - shutil.move(tmp_checkpoint_path, checkpoint_path) + trial.set_user_attr("artifact_id", artifact_id) + os.remove(f"./tmp_model_{trial.number}.pt") # Handle pruning based on the intermediate value. if trial.should_prune():