Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
c4639bc
feat: make default norm in config mean-std with auto stats calc
melisande-c Apr 16, 2026
ca53749
fix(typing): use explicit default arg in pydantic field for type chec…
melisande-c Apr 16, 2026
ba4f2fa
feat: WIP config builder base class
melisande-c Apr 16, 2026
d589284
feat: WIP add DataConfigMixin for ConfigBuilder
melisande-c Apr 16, 2026
90264b2
feat: WIP add unet params mixin
melisande-c Apr 16, 2026
ebeb9a8
style: WIP function order minor fixes
melisande-c Apr 16, 2026
898c53d
feat: WIP add TrainingParamMixin
melisande-c Apr 16, 2026
8858e4b
feat: N2VConfigBuilder + fixes
melisande-c Apr 16, 2026
f9e8439
fix: mixin self typing some config_dict errors
melisande-c Apr 16, 2026
7afe072
fix: seed propagation; incorrect names
melisande-c Apr 17, 2026
7e01aac
feat: add care & n2n builders; n2v propagate monitor metric
melisande-c Apr 17, 2026
0247627
Merge branch 'dev/v0.2' into mc/feat/config-builder
melisande-c Apr 17, 2026
44d2c52
refac: make hook before_build private
melisande-c Apr 17, 2026
f3c8721
refac: split mixins into seperate modules
melisande-c Apr 17, 2026
bcf289f
feat: instantiate default unet config for care, n2n, n2v algorithm co…
melisande-c Apr 17, 2026
85e9362
feat: make BaseConfigBuilder work alone\n- mv data defaults to __init…
melisande-c Apr 17, 2026
1fd3262
Revert "feat: make BaseConfigBuilder work alone\n- mv data defaults t…
melisande-c Apr 17, 2026
6808a1c
fix: set care and n2n algorithm in builder
melisande-c Apr 17, 2026
a078c50
Revert "feat: instantiate default unet config for care, n2n, n2v algo…
melisande-c Apr 17, 2026
e362d57
fix(n2vConfigBuilder): structn2v parameter names
melisande-c Apr 17, 2026
bec96b9
fix: typing for augmentations
melisande-c Apr 17, 2026
da3e7ef
refac: make all mixins pure (no __init__); it is the responsibility o…
melisande-c Apr 20, 2026
bb8f85b
fix: use stratified patching as default
melisande-c Apr 20, 2026
358f9e1
feat: ability to turn off early stopping
melisande-c Apr 21, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,7 @@ follow_imports = skip
follow_imports = skip

[mypy-careamics.config.likelihood_model]
follow_imports = skip
follow_imports = skip

[mypy-careamics.config.builder.*]
disable_error_code = typeddict-unknown-key, typeddict-item
6 changes: 3 additions & 3 deletions src/careamics/config/augmentations/xy_flip_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,15 @@ class XYFlipConfig(BaseModel):

name: Literal["XYFlip"] = "XYFlip"
flip_x: bool = Field(
True,
default=True,
description="Whether to flip along the X axis.",
)
flip_y: bool = Field(
True,
default=True,
description="Whether to flip along the Y axis.",
)
p: float = Field(
0.5,
default=0.5,
description="Probability of applying the transform.",
ge=0,
le=1,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class XYRandomRotate90Config(BaseModel):

name: Literal["XYRandomRotate90"] = "XYRandomRotate90"
p: float = Field(
0.5,
default=0.5,
description="Probability of applying the transform.",
ge=0,
le=1,
Expand Down
290 changes: 290 additions & 0 deletions src/careamics/config/builder/builders.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,290 @@
from collections.abc import Sequence
from dataclasses import asdict
from typing import Any, Literal, Self

from careamics.config.data.data_config import _is_3D
from careamics.config.factories.training_factory import update_trainer_params
from careamics.config.lightning.training_configuration import (
SelfSupervisedCheckpointing,
SupervisedCheckpointing,
)
from careamics.config.support import SupportedData

from .config_builder import BaseConfigBuilder, ConfigDict
from .mixins import (
DataParamsMixin,
OptimizerParamsMixin,
TrainingParamsMixin,
UnetParamsMixin,
)


def minimum_unet_config_dict(
algorithm: Literal["n2v", "care", "n2n"],
experiment_name: str,
data_type: Literal["array", "tiff", "zarr", "czi", "custom"],
axes: str,
patch_size: Sequence[int],
batch_size: int,
# optional
num_epochs: int = 30,
num_steps: int | None = None,
n_channels_in: int = 1,
n_channels_out: int = 1,
seed: int | None = None,
) -> ConfigDict:
config_dict: ConfigDict = {
"experiment_name": experiment_name,
"data_config": {
"mode": "training",
"axes": axes,
"data_type": SupportedData(data_type),
"patching": {"name": "stratified", "patch_size": patch_size},
"batch_size": batch_size,
},
"algorithm_config": {
"algorithm": algorithm,
"model": default_unet_config(
_is_3D(axes, SupportedData(data_type)), n_channels_in, n_channels_out
),
},
"training_config": {
"trainer_params": update_trainer_params({}, num_epochs, num_steps)
},
}
if seed is not None:
config_dict["data_config"]["seed"] = seed
return config_dict


def default_unet_config(
is_3D: bool,
n_channels_in: int,
n_channels_out: int,
) -> dict[str, Any]:
return {
"architecture": "UNet",
"conv_dims": 3 if is_3D else 2,
"in_channels": n_channels_in,
"num_classes": n_channels_out,
}


class CAREConfigBuilder(
TrainingParamsMixin,
DataParamsMixin,
UnetParamsMixin,
OptimizerParamsMixin,
BaseConfigBuilder,
):
def __init__(
self,
experiment_name: str,
data_type: Literal["array", "tiff", "zarr", "czi", "custom"],
axes: str,
patch_size: Sequence[int],
batch_size: int,
# optional
num_epochs: int = 30,
num_steps: int | None = None,
n_channels_in: int = 1,
n_channels_out: int = 1,
seed: int | None = None,
):
self.seed = seed
self.config_dict = minimum_unet_config_dict(
algorithm="care",
experiment_name=experiment_name,
data_type=data_type,
axes=axes,
patch_size=patch_size,
batch_size=batch_size,
num_epochs=num_epochs,
num_steps=num_steps,
n_channels_in=n_channels_in,
n_channels_out=n_channels_out,
)

# set default checkpointing params
# (can be overwritten with set_checkpoint_params from TrainingParamMixin)
self.config_dict["training_config"]["checkpoint_params"] = asdict(
SupervisedCheckpointing()
)

self.config_dict["training_config"]["early_stopping_params"] = {
"monitor": "val_loss",
"mode": "min",
}

def set_loss(self, loss: Literal["mae", "mse"]) -> Self:
self.config_dict["algorithm_config"]["loss"] = loss
return self


class N2NConfigBuilder(
TrainingParamsMixin,
DataParamsMixin,
UnetParamsMixin,
OptimizerParamsMixin,
BaseConfigBuilder,
):
def __init__(
self,
experiment_name: str,
data_type: Literal["array", "tiff", "zarr", "czi", "custom"],
axes: str,
patch_size: Sequence[int],
batch_size: int,
# optional
num_epochs: int = 30,
num_steps: int | None = None,
n_channels_in: int = 1,
n_channels_out: int = 1,
seed: int | None = None,
):
self.seed = seed
self.config_dict = minimum_unet_config_dict(
algorithm="n2n",
experiment_name=experiment_name,
data_type=data_type,
axes=axes,
patch_size=patch_size,
batch_size=batch_size,
num_epochs=num_epochs,
num_steps=num_steps,
n_channels_in=n_channels_in,
n_channels_out=n_channels_out,
)

# set default checkpointing params (n2n self supervised)
# (can be overwritten with set_checkpoint_params from TrainingParamMixin)
self.config_dict["training_config"]["checkpoint_params"] = asdict(
SelfSupervisedCheckpointing()
)

# no early stopping by default
self.config_dict["training_config"]["early_stopping_params"] = None
Comment on lines +159 to +166
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't it be better to not add a training_config, only adding it if the specific methods are called, and leave it to the default constructor to set the defaults?


def set_loss(self, loss: Literal["mae", "mse"]) -> Self:
self.config_dict["algorithm_config"]["loss"] = loss
return self


class N2VConfigBuilder(
TrainingParamsMixin,
DataParamsMixin,
UnetParamsMixin,
OptimizerParamsMixin,
BaseConfigBuilder,
):
def __init__(
self,
experiment_name: str,
data_type: Literal["array", "tiff", "zarr", "czi", "custom"],
axes: str,
patch_size: Sequence[int],
batch_size: int,
# optional
num_epochs: int = 30,
num_steps: int | None = None,
n_channels: int = 1,
seed: int | None = None,
):
self.seed = seed
self.config_dict = minimum_unet_config_dict(
algorithm="n2v",
experiment_name=experiment_name,
data_type=data_type,
axes=axes,
patch_size=patch_size,
batch_size=batch_size,
num_epochs=num_epochs,
num_steps=num_steps,
n_channels_in=n_channels,
n_channels_out=n_channels,
)

# this will be used to propagate the monitor metric before building the config
# we have to wait for after set_checkpoint_params and set_early_stopping_params
# it can be changed using the set_monitor_metric method
self.monitor_metric: Literal["train_loss", "train_loss_epoch", "val_loss"] = (
"val_loss"
)

# set default checkpointing params
# (can be overwritten with set_checkpoint_params from TrainingParamMixin)
self.config_dict["training_config"]["checkpoint_params"] = asdict(
SelfSupervisedCheckpointing()
)

# no early stopping by default
self.config_dict["training_config"]["early_stopping_params"] = None

# propagate seed
self.config_dict["algorithm_config"]["n2v_config"] = {}
if self.seed is not None:
self.config_dict["algorithm_config"]["n2v_config"]["seed"] = self.seed

def set_n2v_params(
self,
use_n2v2: bool | None = None,
roi_size: int | None = None,
masked_pixel_percentage: float | None = None,
# - structN2V specific
struct_n2v_axis: Literal["horizontal", "vertical", "none"] | None = None,
struct_n2v_span: int | None = None,
) -> Self:
n2v_manipulate_config: dict[str, Any] = {}
if roi_size is not None:
n2v_manipulate_config["roi_size"] = roi_size

if masked_pixel_percentage is not None:
n2v_manipulate_config["masked_pixel_percentage"] = masked_pixel_percentage

if struct_n2v_axis is not None:
n2v_manipulate_config["struct_mask_axis"] = struct_n2v_axis

if struct_n2v_span is not None:
n2v_manipulate_config["struct_mask_span"] = struct_n2v_span

if use_n2v2 is not None:
# already added by UnetParamMixin
assert isinstance(self.config_dict["algorithm_config"]["model"], dict)
self.config_dict["algorithm_config"]["model"]["n2v2"] = use_n2v2

n2v_manipulate_config["strategy"] = "median" if use_n2v2 else "uniform"

assert isinstance(self.config_dict["algorithm_config"]["n2v_config"], dict)
self.config_dict["algorithm_config"]["n2v_config"].update(n2v_manipulate_config)
return self

def set_monitor_metric(
self, monitor_metric: Literal["train_loss", "train_loss_epoch", "val_loss"]
) -> Self:
self.monitor_metric = monitor_metric
self.config_dict["algorithm_config"]["monitor_metric"] = monitor_metric
return self

def _propagate_monitor_to_callbacks(self):
# only overwrite monitor if it not explicitly set
if "checkpoint_params" not in self.config_dict["training_config"]:
self.config_dict["training_config"]["checkpoint_params"] = {}
checkpoint_params = self.config_dict["training_config"]["checkpoint_params"]
if "monitor" not in checkpoint_params:
checkpoint_params["monitor"] = self.monitor_metric

# TODO: default value in the config is dict so we need to propagate in that case
# probably we want a mechanism to propagate monitor metric to default
if "early_stopping_params" not in self.config_dict["training_config"]:
self.config_dict["training_config"]["early_stopping_params"] = {}
early_stopping_params = self.config_dict["training_config"][
"early_stopping_params"
]
has_early_stopping = early_stopping_params is not None
if has_early_stopping and "monitor" not in early_stopping_params:
assert isinstance(early_stopping_params, dict)
early_stopping_params["monitor"] = self.monitor_metric

def _before_build(self) -> None:
super()._before_build()
self._propagate_monitor_to_callbacks()
Loading
Loading