-
Notifications
You must be signed in to change notification settings - Fork 21
Feat: Proposal config builder system #904
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
melisande-c
wants to merge
24
commits into
main
Choose a base branch
from
mc/feat/config-builder
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from all commits
Commits
Show all changes
24 commits
Select commit
Hold shift + click to select a range
c4639bc
feat: make default norm in config mean-std with auto stats calc
melisande-c ca53749
fix(typing): use explicit default arg in pydantic field for type chec…
melisande-c ba4f2fa
feat: WIP config builder base class
melisande-c d589284
feat: WIP add DataConfigMixin for ConfigBuilder
melisande-c 90264b2
feat: WIP add unet params mixin
melisande-c ebeb9a8
style: WIP function order minor fixes
melisande-c 898c53d
feat: WIP add TrainingParamMixin
melisande-c 8858e4b
feat: N2VConfigBuilder + fixes
melisande-c f9e8439
fix: mixin self typing some config_dict errors
melisande-c 7afe072
fix: seed propagation; incorrect names
melisande-c 7e01aac
feat: add care & n2n builders; n2v propagate monitor metric
melisande-c 0247627
Merge branch 'dev/v0.2' into mc/feat/config-builder
melisande-c 44d2c52
refac: make hook before_build private
melisande-c f3c8721
refac: split mixins into seperate modules
melisande-c bcf289f
feat: instantiate default unet config for care, n2n, n2v algorithm co…
melisande-c 85e9362
feat: make BaseConfigBuilder work alone\n- mv data defaults to __init…
melisande-c 1fd3262
Revert "feat: make BaseConfigBuilder work alone\n- mv data defaults t…
melisande-c 6808a1c
fix: set care and n2n algorithm in builder
melisande-c a078c50
Revert "feat: instantiate default unet config for care, n2n, n2v algo…
melisande-c e362d57
fix(n2vConfigBuilder): structn2v parameter names
melisande-c bec96b9
fix: typing for augmentations
melisande-c da3e7ef
refac: make all mixins pure (no __init__); it is the responsibility o…
melisande-c bb8f85b
fix: use stratified patching as default
melisande-c 358f9e1
feat: ability to turn off early stopping
melisande-c File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,290 @@ | ||
| from collections.abc import Sequence | ||
| from dataclasses import asdict | ||
| from typing import Any, Literal, Self | ||
|
|
||
| from careamics.config.data.data_config import _is_3D | ||
| from careamics.config.factories.training_factory import update_trainer_params | ||
| from careamics.config.lightning.training_configuration import ( | ||
| SelfSupervisedCheckpointing, | ||
| SupervisedCheckpointing, | ||
| ) | ||
| from careamics.config.support import SupportedData | ||
|
|
||
| from .config_builder import BaseConfigBuilder, ConfigDict | ||
| from .mixins import ( | ||
| DataParamsMixin, | ||
| OptimizerParamsMixin, | ||
| TrainingParamsMixin, | ||
| UnetParamsMixin, | ||
| ) | ||
|
|
||
|
|
||
| def minimum_unet_config_dict( | ||
| algorithm: Literal["n2v", "care", "n2n"], | ||
| experiment_name: str, | ||
| data_type: Literal["array", "tiff", "zarr", "czi", "custom"], | ||
| axes: str, | ||
| patch_size: Sequence[int], | ||
| batch_size: int, | ||
| # optional | ||
| num_epochs: int = 30, | ||
| num_steps: int | None = None, | ||
| n_channels_in: int = 1, | ||
| n_channels_out: int = 1, | ||
| seed: int | None = None, | ||
| ) -> ConfigDict: | ||
| config_dict: ConfigDict = { | ||
| "experiment_name": experiment_name, | ||
| "data_config": { | ||
| "mode": "training", | ||
| "axes": axes, | ||
| "data_type": SupportedData(data_type), | ||
| "patching": {"name": "stratified", "patch_size": patch_size}, | ||
| "batch_size": batch_size, | ||
| }, | ||
| "algorithm_config": { | ||
| "algorithm": algorithm, | ||
| "model": default_unet_config( | ||
| _is_3D(axes, SupportedData(data_type)), n_channels_in, n_channels_out | ||
| ), | ||
| }, | ||
| "training_config": { | ||
| "trainer_params": update_trainer_params({}, num_epochs, num_steps) | ||
| }, | ||
| } | ||
| if seed is not None: | ||
| config_dict["data_config"]["seed"] = seed | ||
| return config_dict | ||
|
|
||
|
|
||
| def default_unet_config( | ||
| is_3D: bool, | ||
| n_channels_in: int, | ||
| n_channels_out: int, | ||
| ) -> dict[str, Any]: | ||
| return { | ||
| "architecture": "UNet", | ||
| "conv_dims": 3 if is_3D else 2, | ||
| "in_channels": n_channels_in, | ||
| "num_classes": n_channels_out, | ||
| } | ||
|
|
||
|
|
||
| class CAREConfigBuilder( | ||
| TrainingParamsMixin, | ||
| DataParamsMixin, | ||
| UnetParamsMixin, | ||
| OptimizerParamsMixin, | ||
| BaseConfigBuilder, | ||
| ): | ||
| def __init__( | ||
| self, | ||
| experiment_name: str, | ||
| data_type: Literal["array", "tiff", "zarr", "czi", "custom"], | ||
| axes: str, | ||
| patch_size: Sequence[int], | ||
| batch_size: int, | ||
| # optional | ||
| num_epochs: int = 30, | ||
| num_steps: int | None = None, | ||
| n_channels_in: int = 1, | ||
| n_channels_out: int = 1, | ||
| seed: int | None = None, | ||
| ): | ||
| self.seed = seed | ||
| self.config_dict = minimum_unet_config_dict( | ||
| algorithm="care", | ||
| experiment_name=experiment_name, | ||
| data_type=data_type, | ||
| axes=axes, | ||
| patch_size=patch_size, | ||
| batch_size=batch_size, | ||
| num_epochs=num_epochs, | ||
| num_steps=num_steps, | ||
| n_channels_in=n_channels_in, | ||
| n_channels_out=n_channels_out, | ||
| ) | ||
|
|
||
| # set default checkpointing params | ||
| # (can be overwritten with set_checkpoint_params from TrainingParamMixin) | ||
| self.config_dict["training_config"]["checkpoint_params"] = asdict( | ||
| SupervisedCheckpointing() | ||
| ) | ||
|
|
||
| self.config_dict["training_config"]["early_stopping_params"] = { | ||
| "monitor": "val_loss", | ||
| "mode": "min", | ||
| } | ||
|
|
||
| def set_loss(self, loss: Literal["mae", "mse"]) -> Self: | ||
| self.config_dict["algorithm_config"]["loss"] = loss | ||
| return self | ||
|
|
||
|
|
||
| class N2NConfigBuilder( | ||
| TrainingParamsMixin, | ||
| DataParamsMixin, | ||
| UnetParamsMixin, | ||
| OptimizerParamsMixin, | ||
| BaseConfigBuilder, | ||
| ): | ||
| def __init__( | ||
| self, | ||
| experiment_name: str, | ||
| data_type: Literal["array", "tiff", "zarr", "czi", "custom"], | ||
| axes: str, | ||
| patch_size: Sequence[int], | ||
| batch_size: int, | ||
| # optional | ||
| num_epochs: int = 30, | ||
| num_steps: int | None = None, | ||
| n_channels_in: int = 1, | ||
| n_channels_out: int = 1, | ||
| seed: int | None = None, | ||
| ): | ||
| self.seed = seed | ||
| self.config_dict = minimum_unet_config_dict( | ||
| algorithm="n2n", | ||
| experiment_name=experiment_name, | ||
| data_type=data_type, | ||
| axes=axes, | ||
| patch_size=patch_size, | ||
| batch_size=batch_size, | ||
| num_epochs=num_epochs, | ||
| num_steps=num_steps, | ||
| n_channels_in=n_channels_in, | ||
| n_channels_out=n_channels_out, | ||
| ) | ||
|
|
||
| # set default checkpointing params (n2n self supervised) | ||
| # (can be overwritten with set_checkpoint_params from TrainingParamMixin) | ||
| self.config_dict["training_config"]["checkpoint_params"] = asdict( | ||
| SelfSupervisedCheckpointing() | ||
| ) | ||
|
|
||
| # no early stopping by default | ||
| self.config_dict["training_config"]["early_stopping_params"] = None | ||
|
|
||
| def set_loss(self, loss: Literal["mae", "mse"]) -> Self: | ||
| self.config_dict["algorithm_config"]["loss"] = loss | ||
| return self | ||
|
|
||
|
|
||
| class N2VConfigBuilder( | ||
| TrainingParamsMixin, | ||
| DataParamsMixin, | ||
| UnetParamsMixin, | ||
| OptimizerParamsMixin, | ||
| BaseConfigBuilder, | ||
| ): | ||
| def __init__( | ||
| self, | ||
| experiment_name: str, | ||
| data_type: Literal["array", "tiff", "zarr", "czi", "custom"], | ||
| axes: str, | ||
| patch_size: Sequence[int], | ||
| batch_size: int, | ||
| # optional | ||
| num_epochs: int = 30, | ||
| num_steps: int | None = None, | ||
| n_channels: int = 1, | ||
| seed: int | None = None, | ||
| ): | ||
| self.seed = seed | ||
| self.config_dict = minimum_unet_config_dict( | ||
| algorithm="n2v", | ||
| experiment_name=experiment_name, | ||
| data_type=data_type, | ||
| axes=axes, | ||
| patch_size=patch_size, | ||
| batch_size=batch_size, | ||
| num_epochs=num_epochs, | ||
| num_steps=num_steps, | ||
| n_channels_in=n_channels, | ||
| n_channels_out=n_channels, | ||
| ) | ||
|
|
||
| # this will be used to propagate the monitor metric before building the config | ||
| # we have to wait for after set_checkpoint_params and set_early_stopping_params | ||
| # it can be changed using the set_monitor_metric method | ||
| self.monitor_metric: Literal["train_loss", "train_loss_epoch", "val_loss"] = ( | ||
| "val_loss" | ||
| ) | ||
|
|
||
| # set default checkpointing params | ||
| # (can be overwritten with set_checkpoint_params from TrainingParamMixin) | ||
| self.config_dict["training_config"]["checkpoint_params"] = asdict( | ||
| SelfSupervisedCheckpointing() | ||
| ) | ||
|
|
||
| # no early stopping by default | ||
| self.config_dict["training_config"]["early_stopping_params"] = None | ||
|
|
||
| # propagate seed | ||
| self.config_dict["algorithm_config"]["n2v_config"] = {} | ||
| if self.seed is not None: | ||
| self.config_dict["algorithm_config"]["n2v_config"]["seed"] = self.seed | ||
|
|
||
| def set_n2v_params( | ||
| self, | ||
| use_n2v2: bool | None = None, | ||
| roi_size: int | None = None, | ||
| masked_pixel_percentage: float | None = None, | ||
| # - structN2V specific | ||
| struct_n2v_axis: Literal["horizontal", "vertical", "none"] | None = None, | ||
| struct_n2v_span: int | None = None, | ||
| ) -> Self: | ||
| n2v_manipulate_config: dict[str, Any] = {} | ||
| if roi_size is not None: | ||
| n2v_manipulate_config["roi_size"] = roi_size | ||
|
|
||
| if masked_pixel_percentage is not None: | ||
| n2v_manipulate_config["masked_pixel_percentage"] = masked_pixel_percentage | ||
|
|
||
| if struct_n2v_axis is not None: | ||
| n2v_manipulate_config["struct_mask_axis"] = struct_n2v_axis | ||
|
|
||
| if struct_n2v_span is not None: | ||
| n2v_manipulate_config["struct_mask_span"] = struct_n2v_span | ||
|
|
||
| if use_n2v2 is not None: | ||
| # already added by UnetParamMixin | ||
| assert isinstance(self.config_dict["algorithm_config"]["model"], dict) | ||
| self.config_dict["algorithm_config"]["model"]["n2v2"] = use_n2v2 | ||
|
|
||
| n2v_manipulate_config["strategy"] = "median" if use_n2v2 else "uniform" | ||
|
|
||
| assert isinstance(self.config_dict["algorithm_config"]["n2v_config"], dict) | ||
| self.config_dict["algorithm_config"]["n2v_config"].update(n2v_manipulate_config) | ||
| return self | ||
|
|
||
| def set_monitor_metric( | ||
| self, monitor_metric: Literal["train_loss", "train_loss_epoch", "val_loss"] | ||
| ) -> Self: | ||
| self.monitor_metric = monitor_metric | ||
| self.config_dict["algorithm_config"]["monitor_metric"] = monitor_metric | ||
| return self | ||
|
|
||
| def _propagate_monitor_to_callbacks(self): | ||
| # only overwrite monitor if it not explicitly set | ||
| if "checkpoint_params" not in self.config_dict["training_config"]: | ||
| self.config_dict["training_config"]["checkpoint_params"] = {} | ||
| checkpoint_params = self.config_dict["training_config"]["checkpoint_params"] | ||
| if "monitor" not in checkpoint_params: | ||
| checkpoint_params["monitor"] = self.monitor_metric | ||
|
|
||
| # TODO: default value in the config is dict so we need to propagate in that case | ||
| # probably we want a mechanism to propagate monitor metric to default | ||
| if "early_stopping_params" not in self.config_dict["training_config"]: | ||
| self.config_dict["training_config"]["early_stopping_params"] = {} | ||
| early_stopping_params = self.config_dict["training_config"][ | ||
| "early_stopping_params" | ||
| ] | ||
| has_early_stopping = early_stopping_params is not None | ||
| if has_early_stopping and "monitor" not in early_stopping_params: | ||
| assert isinstance(early_stopping_params, dict) | ||
| early_stopping_params["monitor"] = self.monitor_metric | ||
|
|
||
| def _before_build(self) -> None: | ||
| super()._before_build() | ||
| self._propagate_monitor_to_callbacks() | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wouldn't it be better to not add a
training_config, only adding it if the specific methods are called, and leave it to the default constructor to set the defaults?