Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ModelCheckpoint must be defined in the config dict, not during the parsing. #454

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 21 additions & 16 deletions terratorch/cli_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,23 @@ def write_tiff(img_wrt, filename, metadata):
dest.write(img_wrt[i, :, :], i + 1)
return filename

def add_default_checkpointing_config(config):

subcommand = config["subcommand"]
enable_checkpointing = config[subcommand + ".trainer.enable_checkpointing"]

if enable_checkpointing:
print("Enabling ModelCheckpoint since the user defined enable_checkpointing=True.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Joao-L-S-Almeida Can we avoild prints and use loggers? It is better practice


config["ModelCheckpoint"] = StateDictAwareModelCheckpoint
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know this is the way Carlos implemented the checkpoints.

But in lightning, checkpoints are part of the callbacks: https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html

I am not sure why follow a different pattern, it might be easier for users to overwrite the defaults if we follow the lightning approach (beeing in line with their docs). Otherwise, we need to describe it well in the docs.
@CarlosGomes98 any arguments why the model checkpoint is not implemented as a callback?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, the setting by the user is just overwritten with this code. So there is no possibility that a user can change this setting.

config["ModelCheckpoint.filename"] = "{epoch}"
config["ModelCheckpoint.monitor"] = "val/loss"
config["StateDictModelCheckpoint"] = StateDictAwareModelCheckpoint
config["StateDictModelCheckpoint.filename"] = "{epoch}_state_dict"
config["StateDictModelCheckpoint.save_weights_only"] = True
config["StateDictModelCheckpoint.monitor"] = "val/loss"

return config

def save_prediction(prediction, input_file_name, out_dir, dtype:str="int16"):
mask, metadata = open_tiff(input_file_name)
Expand Down Expand Up @@ -375,24 +392,12 @@ def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None:
parser.add_argument("--deploy_config_file", type=bool, default=True)
parser.add_argument("--custom_modules_path", type=str, default=None)

# parser.set_defaults({"trainer.enable_checkpointing": False})

parser.add_lightning_class_args(StateDictAwareModelCheckpoint, "ModelCheckpoint")
parser.set_defaults({"ModelCheckpoint.filename": "{epoch}", "ModelCheckpoint.monitor": "val/loss"})

parser.add_lightning_class_args(StateDictAwareModelCheckpoint, "StateDictModelCheckpoint")
parser.set_defaults(
{
"StateDictModelCheckpoint.filename": "{epoch}_state_dict",
"StateDictModelCheckpoint.save_weights_only": True,
"StateDictModelCheckpoint.monitor": "val/loss",
}
)

parser.link_arguments("ModelCheckpoint.dirpath", "StateDictModelCheckpoint.dirpath")

def instantiate_classes(self) -> None:

# Adding default configuration for checkpoint saving when
# enable_checkpointing is True.
self.config = add_default_checkpointing_config(self.config)

super().instantiate_classes()
# get the predict_output_dir. Depending on the value of run, it may be in the subcommand
try:
Expand Down