The lightning trainer has a datamodule property which will set on train and predict methods by the passed datamodule:
trainer.fit(model, datamodule=data_module)
or
trainer.predict(model, datamodule=inf_data_module)
When saving the checkpoint it will save data config with respect to the current datamodule in the trainer. In careamics, the data config will be saved in hparams only for train/validation: here.
A basic example to reproduce the issue (works with dev/v0.2):
from pathlib import Path
import numpy as np
import torch
from pytorch_lightning import Trainer
from careamics.config.factories import create_advanced_n2v_config
from careamics.lightning import (
CareamicsDataModule,
N2VModule,
)
my_path = Path(".")
config = create_advanced_n2v_config(
experiment_name="na",
data_type="array",
axes="YX",
patch_size=(32, 32),
batch_size=2,
num_epochs=1,
num_workers=0,
)
# create lightning modules
model = N2VModule(config.algorithm_config)
# training data
rng = np.random.default_rng(42)
train_array = rng.integers(0, 255, (64, 64)).astype(np.float32)
val_array = rng.integers(0, 255, (64, 64)).astype(np.float32)
data_module = CareamicsDataModule(
data_config=config.data_config,
train_data=train_array,
val_data=val_array,
)
trainer = Trainer(
enable_progress_bar=True,
**config.training_config.trainer_params,
)
# train
trainer.fit(model, datamodule=data_module)
# saving the checkpoint
trainer.save_checkpoint(my_path / "after_training.ckpt")
# create an inference data config
pred_config = config.data_config.convert_mode(
new_mode="predicting",
new_patch_size=(64, 64),
overlap_size=(32, 32),
new_batch_size=1,
)
inf_data_module = CareamicsDataModule(
data_config=pred_config,
pred_data=val_array,
)
# run inference
tiled_predictions = trainer.predict(model, datamodule=inf_data_module)
# saving the checkpoint
trainer.save_checkpoint(my_path / "after_prediction.ckpt")
# load checkpoints
ckpt1 = torch.load(my_path / "after_training.ckpt")
ckpt2 = torch.load(my_path / "after_prediction.ckpt")
data_hparam_key = "datamodule_hyper_parameters"
print(f"\nafter training: {data_hparam_key in ckpt1}, {ckpt1.keys()}\n")
print(f"after prediction: {data_hparam_key in ckpt2}, {ckpt2.keys()}")
If we call the self._save_hparams() for the prediction datamodule as well, then the saved data config will be the prediction config and not the training one.
The lightning
trainerhas adatamoduleproperty which will set ontrainandpredictmethods by the passed datamodule:trainer.fit(model, datamodule=data_module)or
trainer.predict(model, datamodule=inf_data_module)When saving the checkpoint it will save data config with respect to the current
datamodulein the trainer. Incareamics, the data config will be saved inhparamsonly for train/validation: here.A basic example to reproduce the issue (works with
dev/v0.2):If we call the
self._save_hparams()for the prediction datamodule as well, then the saved data config will be the prediction config and not the training one.