Skip to content

Commit cf912d5

Browse files
authored
Fix: EarlyStopping callback initialisation in CAREamist (#774)
## Description > [!NOTE] > **tldr**: Fixes a bug in CAREamist initialisation of the `EarlyStopping` callback. ### Background - why do we need this PR? Instead of dumping the parameters from the `EarlyStoppingConfig` pydantic model, it was passed as the first argument and this was interpreted as the value the callback was meant to monitor, later leading to an error. ### Overview - what changed? Fixed the callback initialisation and added a test. ## Changes Made <!-- This section highlights the important features and files that reviewers should pay attention to when reviewing. Only list important features or files, this is useful for reviewers to correctly assess how deeply the modifications impact the code base. ## How has this been tested? Added a test which runs training with CAREamist and has both the `EarlyStopping` and `Checkpoint` callback configuration. ## Related Issues <!-- Link to any related issues or discussions. Use keywords like "Fixes", "Resolves", or "Closes" to link to issues automatically. --> - Resolves #770 --- **Please ensure your PR meets the following requirements:** - [x] Code builds and passes tests locally, including doctests - [x] New tests have been added (for bug fixes/features) - [x] Pre-commit passes - [ ] PR to the documentation exists (for bug fixes / features)
1 parent 49f0b0f commit cf912d5

2 files changed

Lines changed: 35 additions & 8 deletions

File tree

src/careamics/careamist.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,9 @@ def _define_callbacks(
270270
# early stopping callback
271271
if self.cfg.training_config.early_stopping_callback is not None:
272272
self.callbacks.append(
273-
EarlyStopping(self.cfg.training_config.early_stopping_callback)
273+
EarlyStopping(
274+
**self.cfg.training_config.early_stopping_callback.model_dump()
275+
)
274276
)
275277

276278
def stop_training(self) -> None:

tests/test_careamist.py

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@
55
import pytest
66
import tifffile
77
from numpy.typing import NDArray
8-
from pytorch_lightning import Trainer
98
from pytorch_lightning.callbacks import Callback, EarlyStopping, ModelCheckpoint
109

1110
from careamics import CAREamist
1211
from careamics.config import Configuration, save_configuration
12+
from careamics.config.lightning import CheckpointConfig, EarlyStoppingConfig
1313
from careamics.config.support import SupportedAlgorithm, SupportedData
1414
from careamics.dataset.dataset_utils import reshape_array
1515
from careamics.lightning.callbacks import HyperParametersCallback, ProgressBarCallback
@@ -386,6 +386,36 @@ def test_train_tiff_files(tmp_path: Path, minimum_n2v_configuration: dict):
386386
assert (tmp_path / "model.zip").exists()
387387

388388

389+
@pytest.mark.mps_gh_fail
390+
def test_train_w_callbacks(tmp_path: Path, minimum_n2v_configuration: dict):
391+
"""
392+
Test that basic training with arrays runs without error with supported callbacks.
393+
"""
394+
# training data
395+
train_array = random_array((32, 32))
396+
val_array = random_array((32, 32))
397+
398+
# create configuration
399+
config = Configuration(**minimum_n2v_configuration)
400+
config.data_config.axes = "YX"
401+
config.data_config.batch_size = 2
402+
config.data_config.data_type = SupportedData.ARRAY.value
403+
config.data_config.patch_size = (8, 8)
404+
405+
# add supported callback configuration
406+
config.training_config.checkpoint_callback = CheckpointConfig()
407+
config.training_config.early_stopping_callback = EarlyStoppingConfig()
408+
409+
# set epochs 2 to make sure callbacks are accessed.
410+
config.training_config.lightning_trainer_config["max_epochs"] = 2
411+
412+
# instantiate CAREamist
413+
careamist = CAREamist(source=config, work_dir=tmp_path)
414+
415+
# train CAREamist
416+
careamist.train(train_source=train_array, val_source=val_array)
417+
418+
389419
@pytest.mark.mps_gh_fail
390420
def test_train_array_supervised(tmp_path: Path, minimum_supervised_configuration: dict):
391421
"""Test that CAREamics can be trained with arrays."""
@@ -1136,12 +1166,7 @@ def test_error_passing_careamics_callback(tmp_path, minimum_n2v_configuration):
11361166
with pytest.raises(ValueError):
11371167
CAREamist(source=config, work_dir=tmp_path, callbacks=[model_ckp])
11381168

1139-
early_stp = EarlyStopping(
1140-
Trainer(
1141-
max_epochs=1,
1142-
default_root_dir=tmp_path,
1143-
)
1144-
)
1169+
early_stp = EarlyStopping(monitor="val_loss")
11451170

11461171
with pytest.raises(ValueError):
11471172
CAREamist(source=config, work_dir=tmp_path, callbacks=[early_stp])

0 commit comments

Comments
 (0)