-
Notifications
You must be signed in to change notification settings - Fork 40
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ModelCheckpoint must be defined in the config dict, not during the parsing. #454
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -96,6 +96,23 @@ def write_tiff(img_wrt, filename, metadata): | |
dest.write(img_wrt[i, :, :], i + 1) | ||
return filename | ||
|
||
def add_default_checkpointing_config(config): | ||
|
||
subcommand = config["subcommand"] | ||
enable_checkpointing = config[subcommand + ".trainer.enable_checkpointing"] | ||
|
||
if enable_checkpointing: | ||
print("Enabling ModelCheckpoint since the user defined enable_checkpointing=True.") | ||
|
||
config["ModelCheckpoint"] = StateDictAwareModelCheckpoint | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I know this is the way Carlos implemented the checkpoints. But in lightning, checkpoints are part of the callbacks: https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html I am not sure why follow a different pattern, it might be easier for users to overwrite the defaults if we follow the lightning approach (beeing in line with their docs). Otherwise, we need to describe it well in the docs. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also, the setting by the user is just overwritten with this code. So there is no possibility that a user can change this setting. |
||
config["ModelCheckpoint.filename"] = "{epoch}" | ||
config["ModelCheckpoint.monitor"] = "val/loss" | ||
config["StateDictModelCheckpoint"] = StateDictAwareModelCheckpoint | ||
config["StateDictModelCheckpoint.filename"] = "{epoch}_state_dict" | ||
config["StateDictModelCheckpoint.save_weights_only"] = True | ||
config["StateDictModelCheckpoint.monitor"] = "val/loss" | ||
|
||
return config | ||
|
||
def save_prediction(prediction, input_file_name, out_dir, dtype:str="int16"): | ||
mask, metadata = open_tiff(input_file_name) | ||
|
@@ -375,24 +392,12 @@ def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: | |
parser.add_argument("--deploy_config_file", type=bool, default=True) | ||
parser.add_argument("--custom_modules_path", type=str, default=None) | ||
|
||
# parser.set_defaults({"trainer.enable_checkpointing": False}) | ||
|
||
parser.add_lightning_class_args(StateDictAwareModelCheckpoint, "ModelCheckpoint") | ||
parser.set_defaults({"ModelCheckpoint.filename": "{epoch}", "ModelCheckpoint.monitor": "val/loss"}) | ||
|
||
parser.add_lightning_class_args(StateDictAwareModelCheckpoint, "StateDictModelCheckpoint") | ||
parser.set_defaults( | ||
{ | ||
"StateDictModelCheckpoint.filename": "{epoch}_state_dict", | ||
"StateDictModelCheckpoint.save_weights_only": True, | ||
"StateDictModelCheckpoint.monitor": "val/loss", | ||
} | ||
) | ||
|
||
parser.link_arguments("ModelCheckpoint.dirpath", "StateDictModelCheckpoint.dirpath") | ||
|
||
def instantiate_classes(self) -> None: | ||
|
||
# Adding default configuration for checkpoint saving when | ||
# enable_checkpointing is True. | ||
self.config = add_default_checkpointing_config(self.config) | ||
|
||
super().instantiate_classes() | ||
# get the predict_output_dir. Depending on the value of run, it may be in the subcommand | ||
try: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Joao-L-S-Almeida Can we avoild prints and use loggers? It is better practice