Skip to content

Saving checkpoint after prediction wouldn't save the data config in datamodule_hyper_parameters #939

@mese79

Description

@mese79

The lightning trainer has a datamodule property which will set on train and predict methods by the passed datamodule:
trainer.fit(model, datamodule=data_module)
or
trainer.predict(model, datamodule=inf_data_module)
When saving the checkpoint it will save data config with respect to the current datamodule in the trainer. In careamics, the data config will be saved in hparams only for train/validation: here.

A basic example to reproduce the issue (works with dev/v0.2):

from pathlib import Path

import numpy as np
import torch
from pytorch_lightning import Trainer

from careamics.config.factories import create_advanced_n2v_config
from careamics.lightning import (
    CareamicsDataModule,
    N2VModule,
)


my_path = Path(".")

config = create_advanced_n2v_config(
    experiment_name="na",
    data_type="array",
    axes="YX",
    patch_size=(32, 32),
    batch_size=2,
    num_epochs=1,
    num_workers=0,
)

# create lightning modules
model = N2VModule(config.algorithm_config)

# training data
rng = np.random.default_rng(42)
train_array = rng.integers(0, 255, (64, 64)).astype(np.float32)
val_array = rng.integers(0, 255, (64, 64)).astype(np.float32)

data_module = CareamicsDataModule(
    data_config=config.data_config,
    train_data=train_array,
    val_data=val_array,
)

trainer = Trainer(
    enable_progress_bar=True,
    **config.training_config.trainer_params,
)

# train
trainer.fit(model, datamodule=data_module)

# saving the checkpoint
trainer.save_checkpoint(my_path / "after_training.ckpt")

# create an inference data config
pred_config = config.data_config.convert_mode(
    new_mode="predicting",
    new_patch_size=(64, 64),
    overlap_size=(32, 32),
    new_batch_size=1,
)

inf_data_module = CareamicsDataModule(
    data_config=pred_config,
    pred_data=val_array,
)

# run inference
tiled_predictions = trainer.predict(model, datamodule=inf_data_module)

# saving the checkpoint
trainer.save_checkpoint(my_path / "after_prediction.ckpt")

# load checkpoints
ckpt1 = torch.load(my_path / "after_training.ckpt")
ckpt2 = torch.load(my_path / "after_prediction.ckpt")

data_hparam_key = "datamodule_hyper_parameters"

print(f"\nafter training: {data_hparam_key in ckpt1}, {ckpt1.keys()}\n")
print(f"after prediction: {data_hparam_key in ckpt2}, {ckpt2.keys()}")

If we call the self._save_hparams() for the prediction datamodule as well, then the saved data config will be the prediction config and not the training one.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions