-
Notifications
You must be signed in to change notification settings - Fork 193
Introduce optuna.artifacts to the PyTorch checkpoint example
#280
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
optuna.artifacts to the PyTorch checkpoint example
|
@nabenabe0928 Could you review this PR? |
|
This pull request has not seen any recent activity. |
not522
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for your PR! Could you check my comments?
pytorch/pytorch_checkpoint.py
Outdated
| checkpoint = torch.load(checkpoint_path) | ||
| if trial_number is not None: | ||
| study = optuna.load_study(study_name="pytorch_checkpoint", storage="sqlite:///example.db") | ||
| artifact_id = study.trials[trial_number].user_attrs["artifact_id"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the process is terminated before the first checkpoint, the artifact will not be saved, so check if it exists.
pytorch/pytorch_checkpoint.py
Outdated
| "accuracy": accuracy, | ||
| }, | ||
| tmp_checkpoint_path, | ||
| "./tmp_model.pt", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you change the path of checkpoint for each trial? If we run this script with multi-process, the saved models can be broken by other processes.
|
This pull request has not seen any recent activity. |
Co-authored-by: Naoto Mizuno <[email protected]>
|
The fix could be like this. |
|
Thank you for your review! I have fixed it according to your suggestion. |
not522
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for your update. It's almost LGTM. Could you check my comment?
| file_path=f"./tmp_model_{trial.number}.pt", | ||
| artifact_id=artifact_id, | ||
| ) | ||
| checkpoint = torch.load(f"./tmp_model_{trial.number}.pt") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you remove the temporary file here?
os.remove(f"./tmp_model_{trial.number}.pt")
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for your comment. I have fix this.
not522
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
Motivation
Currently, the PyTorch checkpoint example is using local file system to save and manage checkpoints, not yet reflecting the recent
optuna.artifactsfunctionalities.Description of the changes
optuna.artifacts.