Skip to content
Draft
Show file tree
Hide file tree
Changes from 77 commits
Commits
Show all changes
92 commits
Select commit Hold shift + click to select a range
fa5c320
multifile pred, avg_psnr
CatEek Dec 6, 2024
9a35653
Merge branch 'main' into splits_prediction_refac
CatEek Dec 10, 2024
7675ea7
inference mode lvae
CatEek Dec 15, 2024
36def88
inference fix
CatEek Dec 16, 2024
ec7f0b2
lvae pred func upd
CatEek Dec 18, 2024
78fa078
Merge remote-tracking branch 'origin/main' into splits_prediction_refac
CatEek Dec 18, 2024
0eef8c7
reduce data fix
CatEek Dec 19, 2024
cbb29a8
hdn init configs
CatEek Dec 23, 2024
f514d7f
basic config fixture + test
CatEek Dec 24, 2024
1f2b0e9
out channels test
CatEek Dec 24, 2024
de7a939
hdn lightning init tests
CatEek Dec 24, 2024
4f7de87
hdn trainstep wip
CatEek Dec 25, 2024
5b71750
hdn trainloop test
CatEek Dec 25, 2024
12b42e4
hdn logvar test
CatEek Dec 26, 2024
b9598b0
train/val steps test
CatEek Dec 26, 2024
10e1919
tests pass
CatEek Dec 26, 2024
147c280
wip
CatEek Jan 3, 2025
2bc3525
3d check config
CatEek Jan 4, 2025
9be50db
wip
CatEek Jan 6, 2025
201eea5
Merge remote-tracking branch 'origin/main' into hdn_config
CatEek Jan 22, 2025
9f91543
hdn_configs
CatEek Jan 22, 2025
6fe54a0
conf fixes wip
CatEek Jan 23, 2025
ff50695
hdn conf wip
CatEek Jan 23, 2025
39bd601
hdn conf factory
CatEek Jan 24, 2025
5d9d534
hdn conf factory
CatEek Feb 4, 2025
0c5572f
wip
CatEek Feb 5, 2025
131f9c4
batch unpack fix
CatEek Feb 5, 2025
c2559e0
input shape fix
CatEek Feb 10, 2025
0fcfbdc
train vae test pass
CatEek Feb 11, 2025
4f7aa06
ds patch_transform unpack fix
CatEek Feb 11, 2025
77b5b59
wip
CatEek Feb 11, 2025
197cbfa
Merge remote-tracking branch 'origin/main' into hdn_config
CatEek Feb 11, 2025
51f0e5f
style(pre-commit.ci): auto fixes [...]
pre-commit-ci[bot] Feb 11, 2025
bdc43b3
config wip
CatEek Feb 13, 2025
e69a656
Merge remote-tracking branch 'origin/hdn_config' into hdn_config
CatEek Feb 13, 2025
4c8570a
Merge remote-tracking branch 'origin/main' into hdn_config
CatEek Feb 13, 2025
814ba84
post-merge fixes
CatEek Feb 13, 2025
9e1d95f
style(pre-commit.ci): auto fixes [...]
pre-commit-ci[bot] Feb 13, 2025
fb570a6
rnd device fix for cpu tests
CatEek Feb 13, 2025
e0ede24
ds tests fix
CatEek Feb 13, 2025
1b852dd
pred ds tests fix
CatEek Feb 13, 2025
4e10d1c
careamist train array vae test
CatEek Feb 13, 2025
c93a131
hdn conf test
CatEek Feb 13, 2025
491c7d3
Merge remote-tracking branch 'origin/hdn_config' into hdn_config
CatEek Feb 13, 2025
88f2188
style(pre-commit.ci): auto fixes [...]
pre-commit-ci[bot] Feb 13, 2025
00d9eb2
hdn config test
CatEek Feb 17, 2025
3ce707a
Merge remote-tracking branch 'origin/hdn_config' into hdn_config
CatEek Feb 17, 2025
2824cb7
Merge remote-tracking branch 'origin/main' into hdn_config
CatEek Feb 17, 2025
ff2827b
hdn conf ll and input_sz
CatEek Feb 18, 2025
4f3f4ac
hdn ll test remove
CatEek Feb 18, 2025
3346b2e
mmse prediction draft
CatEek Feb 19, 2025
5cb3260
tests pass
CatEek Feb 19, 2025
8265f61
style(pre-commit.ci): auto fixes [...]
pre-commit-ci[bot] Feb 19, 2025
990abfd
Update src/careamics/config/configuration_model.py
CatEek Feb 19, 2025
ebc415b
style(pre-commit.ci): auto fixes [...]
pre-commit-ci[bot] Feb 19, 2025
4f5180b
imports fix
CatEek Feb 19, 2025
92cc4a1
style(pre-commit.ci): auto fixes [...]
pre-commit-ci[bot] Feb 19, 2025
a6e7af5
Merge remote-tracking branch 'origin/main' into hdn_config
CatEek Feb 19, 2025
97eca98
Merge remote-tracking branch 'origin/hdn_config' into hdn_config
CatEek Feb 19, 2025
e9899fb
rm old file
CatEek Feb 19, 2025
33f9f67
Merge branch 'main' into hdn_config
jdeschamps Feb 25, 2025
92e9057
cvs logger test fix
CatEek Feb 26, 2025
3164ad3
ref folder remove
CatEek Feb 26, 2025
9294cb3
not implemented err for vae in careamist
CatEek Feb 26, 2025
aac24ae
vae test skip
CatEek Feb 26, 2025
2cb5f86
Merge remote-tracking branch 'origin/main' into hdn_config
CatEek Feb 26, 2025
273ad43
tests fix
CatEek Feb 26, 2025
55504fe
Merge remote-tracking branch 'origin/hdn_config' into hdn_config
CatEek Feb 26, 2025
b716eb7
bmz not implemented vae
CatEek Feb 27, 2025
8e8c59e
Merge branch 'main' into hdn_config
jdeschamps Feb 27, 2025
e2f79ce
ds unpacking fix + hdn conf
CatEek Feb 27, 2025
bda115c
Merge remote-tracking branch 'origin/hdn_config' into hdn_config
CatEek Feb 27, 2025
9c7c304
pre-coms fix
CatEek Feb 27, 2025
3f566b3
unused mmse count rm
CatEek Feb 27, 2025
55e2e77
lvae pydantic shape fix
CatEek Feb 27, 2025
db57ec2
mypy fixes
CatEek Feb 27, 2025
cab3884
ds test fix + loss comment
CatEek Feb 27, 2025
323e3e5
add mmse_count and is_supervised parameters; add lr logging
veegalinova Jun 9, 2025
ca0fb77
Merge remote-tracking branch 'origin/main' into hdn_config
CatEek Jun 10, 2025
4a89256
merge
CatEek Jun 10, 2025
831467d
tests fix
CatEek Jun 10, 2025
958dd77
style(pre-commit.ci): auto fixes [...]
pre-commit-ci[bot] Jun 10, 2025
830323d
wip
CatEek Jun 12, 2025
704b3e5
Merge remote-tracking branch 'origin/hdn_config' into hdn_config
CatEek Jun 16, 2025
8619b36
tile unpack fix
CatEek Jun 16, 2025
a61770f
style(pre-commit.ci): auto fixes [...]
pre-commit-ci[bot] Jun 16, 2025
bd1b1ec
Merge branch 'main' into hdn_config
jdeschamps Jun 17, 2025
58e46c1
Merge remote-tracking branch 'origin/main' into hdn_config
CatEek Jul 3, 2025
636e60e
loss types
CatEek Jul 3, 2025
7a93c91
dl params
CatEek Jul 9, 2025
7278cad
wip
CatEek Jul 16, 2025
3680f5f
fix hdn loss parameter types
veegalinova Aug 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion src/careamics/careamist.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,17 @@
from pytorch_lightning.callbacks import (
Callback,
EarlyStopping,
LearningRateMonitor,
ModelCheckpoint,
)
from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger, WandbLogger

from careamics.config import Configuration, UNetBasedAlgorithm, load_configuration
from careamics.config import (
Configuration,
UNetBasedAlgorithm,
VAEBasedAlgorithm,
load_configuration,
)
from careamics.config.support import (
SupportedAlgorithm,
SupportedArchitecture,
Expand All @@ -28,6 +34,7 @@
PredictDataModule,
ProgressBarCallback,
TrainDataModule,
VAEModule,
create_predict_datamodule,
)
from careamics.model_io import export_to_bmz, load_pretrained
Expand Down Expand Up @@ -141,6 +148,11 @@ def __init__(
self.model = FCNModule(
algorithm_config=self.cfg.algorithm_config,
)
elif isinstance(self.cfg.algorithm_config, VAEBasedAlgorithm):
self.model = VAEModule(
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Model should not even be instantiated.

algorithm_config=self.cfg.algorithm_config,
)
raise NotImplementedError("VAE based algorithms are not implemented.")
else:
raise NotImplementedError("Architecture not supported.")

Expand Down Expand Up @@ -252,6 +264,7 @@ def _define_callbacks(self, callbacks: Optional[list[Callback]] = None) -> None:
**self.cfg.training_config.checkpoint_callback.model_dump(),
),
ProgressBarCallback(),
LearningRateMonitor(),
]
)

Expand Down
4 changes: 4 additions & 0 deletions src/careamics/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
"Configuration",
"DataConfig",
"GaussianMixtureNMConfig",
"HDNAlgorithm",
"InferenceConfig",
"LVAELossConfig",
"MultiChannelNMConfig",
Expand All @@ -22,6 +23,7 @@
"VAEBasedAlgorithm",
"algorithm_factory",
"create_care_configuration",
"create_hdn_configuration",
"create_n2n_configuration",
"create_n2v_configuration",
"load_configuration",
Expand All @@ -30,6 +32,7 @@

from .algorithms import (
CAREAlgorithm,
HDNAlgorithm,
N2NAlgorithm,
N2VAlgorithm,
UNetBasedAlgorithm,
Expand All @@ -40,6 +43,7 @@
from .configuration_factories import (
algorithm_factory,
create_care_configuration,
create_hdn_configuration,
create_n2n_configuration,
create_n2v_configuration,
)
Expand Down
2 changes: 2 additions & 0 deletions src/careamics/config/algorithms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@

__all__ = [
"CAREAlgorithm",
"HDNAlgorithm",
"N2NAlgorithm",
"N2VAlgorithm",
"UNetBasedAlgorithm",
"VAEBasedAlgorithm",
]

from .care_algorithm_model import CAREAlgorithm
from .hdn_algorithm_model import HDNAlgorithm
from .n2n_algorithm_model import N2NAlgorithm
from .n2v_algorithm_model import N2VAlgorithm
from .unet_algorithm_model import UNetBasedAlgorithm
Expand Down
97 changes: 97 additions & 0 deletions src/careamics/config/algorithms/hdn_algorithm_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
"""HDN algorithm configuration."""

from typing import Literal

from bioimageio.spec.generic.v0_3 import CiteEntry
from pydantic import ConfigDict

from careamics.config.algorithms.vae_algorithm_model import VAEBasedAlgorithm
from careamics.config.architectures import LVAEModel
from careamics.config.loss_model import LVAELossConfig

HDN = "HDN"
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's actually "Hierarchical DivNoising"


HDN_DESCRIPTION = ""
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a TODO here?

HDN_REF = CiteEntry(
text='Prakash, M., Delbracio, M., Milanfar, P., Jug, F. 2022. "Interpretable '
'Unsupervised Diversity Denoising and Artefact Removal." The International '
"Conference on Learning Representations (ICLR).",
doi="10.1561/2200000056",
)


class HDNAlgorithm(VAEBasedAlgorithm):
"""HDN algorithm configuration."""

model_config = ConfigDict(validate_assignment=True)

algorithm: Literal["hdn"] = "hdn"

loss: LVAELossConfig

model: LVAEModel # TODO add validators
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you open an issue and state what these validators should be? Otherwise you will have to figure it out again.


def get_algorithm_friendly_name(self) -> str:
"""
Get the algorithm friendly name.

Returns
-------
str
Friendly name of the algorithm.
"""
return HDN

def get_algorithm_keywords(self) -> list[str]:
"""
Get algorithm keywords.

Returns
-------
list[str]
List of keywords.
"""
return [
"restoration",
"VAE",
"3D" if self.model.is_3D() else "2D",
"CAREamics",
"pytorch",
]

def get_algorithm_references(self) -> str:
"""
Get the algorithm references.

This is used to generate the README of the BioImage Model Zoo export.

Returns
-------
str
Algorithm references.
"""
return HDN_REF.text + " doi: " + HDN_REF.doi

def get_algorithm_citations(self) -> list[CiteEntry]:
"""
Return a list of citation entries of the current algorithm.

This is used to generate the model description for the BioImage Model Zoo.

Returns
-------
List[CiteEntry]
List of citation entries.
"""
return [HDN_REF]

def get_algorithm_description(self) -> str:
"""
Get the algorithm description.

Returns
-------
str
Algorithm description.
"""
return HDN_DESCRIPTION
25 changes: 24 additions & 1 deletion src/careamics/config/algorithms/vae_algorithm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class VAEBasedAlgorithm(BaseModel):
# defined in SupportedAlgorithm
# TODO: Use supported Enum classes for typing?
# - values can still be passed as strings and they will be cast to Enum
algorithm: Literal["musplit", "denoisplit"]
algorithm: Literal["hdn", "musplit", "denoisplit"]

# NOTE: these are all configs (pydantic models)
loss: LVAELossConfig
Expand All @@ -64,6 +64,14 @@ def algorithm_cross_validation(self: Self) -> Self:
Self
The validated model.
"""
# hdn
if self.algorithm == SupportedAlgorithm.HDN:
if self.loss.loss_type != SupportedLoss.HDN:
raise ValueError(
f"Algorithm {self.algorithm} only supports loss `hdn`."
)
if self.model.multiscale_count > 1:
raise ValueError("Algorithm `hdn` does not support multiscale models.")
Comment on lines +70 to +77
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like there should be a way to have these validations in the new child class HDNAlgorithm

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, will change

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will also add a relevant test

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added(replaced) tests. On a second thought, there're also microsplit related checks in this module, and moving some and leaving others makes little sense

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm I see there should probably be separate child classes for MuSplit, DenoiSplit, but I guess this can wait and we create an issue

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes they should be moved to child classes. Can you open an issue (maybe first a general MicroSplit issue, then this as a sub-issue)?

# musplit
if self.algorithm == SupportedAlgorithm.MUSPLIT:
if self.loss.loss_type != SupportedLoss.MUSPLIT:
Expand Down Expand Up @@ -108,6 +116,12 @@ def output_channels_validation(self: Self) -> Self:
f"Number of output channels ({self.model.output_channels}) must match "
f"the number of noise models ({len(self.noise_model.noise_models)})."
)

if self.algorithm == SupportedAlgorithm.HDN:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a fundamental limitation of HDN?

assert self.model.output_channels == 1, (
f"Number of output channels ({self.model.output_channels}) must be 1 "
"for algorithm `hdn`."
)
return self

@model_validator(mode="after")
Expand All @@ -127,6 +141,15 @@ def predict_logvar_validation(self: Self) -> Self:
"Gaussian likelihood model `predict_logvar` "
f"({self.gaussian_likelihood.predict_logvar}).",
)
# if self.algorithm == SupportedAlgorithm.HDN:
# assert (
# self.model.predict_logvar is None
# ), "Model `predict_logvar` must be `None` for algorithm `hdn`."
# if self.gaussian_likelihood is not None:
# assert self.gaussian_likelihood.predict_logvar is None, (
# "Gaussian likelihood model `predict_logvar` must be `None` "
# "for algorithm `hdn`."
# )
return self

def __str__(self) -> str:
Expand Down
15 changes: 10 additions & 5 deletions src/careamics/config/architectures/lvae_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,9 @@ class LVAEModel(ArchitectureModel):
model_config = ConfigDict(validate_assignment=True, validate_default=True)

architecture: Literal["LVAE"]
"""Name of the architecture."""

input_shape: list[int] = Field(default=[64, 64], validate_default=True)
"""Shape of the input patch (C, Z, Y, X) or (C, Y, X) if the data is 2D."""

input_shape: tuple[int, ...] = Field(default=(64, 64), validate_default=True)
"""Shape of the input patch (Z, Y, X) or (Y, X) if the data is 2D."""
Comment on lines +19 to +20
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changing this to tuple has some serialization issues with the current way we do model dump.

Note, it can actually be solved doing .model_dump(mode="json") which should automatically cast iterable python types to list.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting, I didn't know. And I guess when reading out the list, it gets casted into tuple without any issue?

Should we open an issue for refactoring the way we export the configuration? It would be nice to support tuples, that would also allow immutable defaults in functions signatures.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah default mode argument is "python", which will keep objects as python types, using "json" will convert python types to json serializable objects, and I guess there is overlap with yaml, model_dump API

encoder_conv_strides: list = Field(default=[2, 2], validate_default=True)

# TODO make this per hierarchy step ?
Expand Down Expand Up @@ -126,6 +124,13 @@ def validate_input_shape(cls, input_shape: list) -> list:
f"Input shape must be greater than 1 in all dimensions"
f"(got {input_shape})."
)

if any(s < 64 for s in input_shape[-2:]):
raise ValueError(
f"Input shape must be greater or equal to 64 in XY dimensions"
f"(got {input_shape})."
)

return input_shape

@field_validator("encoder_n_filters")
Expand Down Expand Up @@ -255,4 +260,4 @@ def is_3D(self) -> bool:
bool
Whether the model is 3D or not.
"""
return self.conv_dims == 3
return len(self.input_shape) == 3
2 changes: 2 additions & 0 deletions src/careamics/config/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from careamics.config.algorithms import (
CAREAlgorithm,
HDNAlgorithm,
N2NAlgorithm,
N2VAlgorithm,
)
Expand All @@ -22,6 +23,7 @@
CAREAlgorithm,
N2NAlgorithm,
N2VAlgorithm,
HDNAlgorithm,
]


Expand Down
Loading