-
Notifications
You must be signed in to change notification settings - Fork 21
(Feature) HDN minimum example #396
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 60 commits
fa5c320
9a35653
7675ea7
36def88
ec7f0b2
78fa078
0eef8c7
cbb29a8
f514d7f
1f2b0e9
de7a939
4f7de87
5b71750
12b42e4
b9598b0
10e1919
147c280
2bc3525
9be50db
201eea5
9f91543
6fe54a0
ff50695
39bd601
5d9d534
0c5572f
131f9c4
c2559e0
0fcfbdc
4f7aa06
77b5b59
197cbfa
51f0e5f
bdc43b3
e69a656
4c8570a
814ba84
9e1d95f
fb570a6
e0ede24
1b852dd
4e10d1c
c93a131
491c7d3
88f2188
00d9eb2
3ce707a
2824cb7
ff2827b
4f3f4ac
3346b2e
5cb3260
8265f61
990abfd
ebc415b
4f5180b
92cc4a1
a6e7af5
97eca98
e9899fb
33f9f67
92e9057
3164ad3
9294cb3
aac24ae
2cb5f86
273ad43
55504fe
b716eb7
8e8c59e
e2f79ce
bda115c
9c7c304
3f566b3
55e2e77
db57ec2
cab3884
323e3e5
ca0fb77
4a89256
831467d
958dd77
830323d
704b3e5
8619b36
a61770f
bd1b1ec
58e46c1
636e60e
7a93c91
7278cad
3680f5f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,97 @@ | ||
| """HDN algorithm configuration.""" | ||
|
|
||
| from typing import Literal | ||
|
|
||
| from bioimageio.spec.generic.v0_3 import CiteEntry | ||
| from pydantic import ConfigDict | ||
|
|
||
| from careamics.config.algorithms.vae_algorithm_model import VAEBasedAlgorithm | ||
| from careamics.config.architectures import LVAEModel | ||
| from careamics.config.loss_model import LVAELossConfig | ||
|
|
||
| HDN = "HDN" | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's actually "Hierarchical DivNoising" |
||
|
|
||
| HDN_DESCRIPTION = "" | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you add a TODO here? |
||
| HDN_REF = CiteEntry( | ||
| text='Prakash, M., Delbracio, M., Milanfar, P., Jug, F. 2022. "Interpretable ' | ||
| 'Unsupervised Diversity Denoising and Artefact Removal." The International ' | ||
| "Conference on Learning Representations (ICLR).", | ||
| doi="10.1561/2200000056", | ||
| ) | ||
|
|
||
|
|
||
| class HDNAlgorithm(VAEBasedAlgorithm): | ||
| """HDN algorithm configuration.""" | ||
|
|
||
| model_config = ConfigDict(validate_assignment=True) | ||
|
|
||
| algorithm: Literal["hdn"] = "hdn" | ||
|
|
||
| loss: LVAELossConfig | ||
|
|
||
| model: LVAEModel # TODO add validators | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you open an issue and state what these validators should be? Otherwise you will have to figure it out again. |
||
|
|
||
| def get_algorithm_friendly_name(self) -> str: | ||
| """ | ||
| Get the algorithm friendly name. | ||
|
|
||
| Returns | ||
| ------- | ||
| str | ||
| Friendly name of the algorithm. | ||
| """ | ||
| return HDN | ||
|
|
||
| def get_algorithm_keywords(self) -> list[str]: | ||
| """ | ||
| Get algorithm keywords. | ||
|
|
||
| Returns | ||
| ------- | ||
| list[str] | ||
| List of keywords. | ||
| """ | ||
| return [ | ||
| "restoration", | ||
| "VAE", | ||
| "3D" if self.model.is_3D() else "2D", | ||
| "CAREamics", | ||
| "pytorch", | ||
| ] | ||
|
|
||
| def get_algorithm_references(self) -> str: | ||
| """ | ||
| Get the algorithm references. | ||
|
|
||
| This is used to generate the README of the BioImage Model Zoo export. | ||
|
|
||
| Returns | ||
| ------- | ||
| str | ||
| Algorithm references. | ||
| """ | ||
| return HDN_REF.text + " doi: " + HDN_REF.doi | ||
|
|
||
| def get_algorithm_citations(self) -> list[CiteEntry]: | ||
| """ | ||
| Return a list of citation entries of the current algorithm. | ||
|
|
||
| This is used to generate the model description for the BioImage Model Zoo. | ||
|
|
||
| Returns | ||
| ------- | ||
| List[CiteEntry] | ||
| List of citation entries. | ||
| """ | ||
| return [HDN_REF] | ||
|
|
||
| def get_algorithm_description(self) -> str: | ||
| """ | ||
| Get the algorithm description. | ||
|
|
||
| Returns | ||
| ------- | ||
| str | ||
| Algorithm description. | ||
| """ | ||
| return HDN_DESCRIPTION | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -40,7 +40,7 @@ class VAEBasedAlgorithm(BaseModel): | |
| # defined in SupportedAlgorithm | ||
| # TODO: Use supported Enum classes for typing? | ||
| # - values can still be passed as strings and they will be cast to Enum | ||
| algorithm: Literal["musplit", "denoisplit"] | ||
| algorithm: Literal["hdn", "musplit", "denoisplit"] | ||
|
|
||
| # NOTE: these are all configs (pydantic models) | ||
| loss: LVAELossConfig | ||
|
|
@@ -64,6 +64,14 @@ def algorithm_cross_validation(self: Self) -> Self: | |
| Self | ||
| The validated model. | ||
| """ | ||
| # hdn | ||
| if self.algorithm == SupportedAlgorithm.HDN: | ||
| if self.loss.loss_type != SupportedLoss.HDN: | ||
| raise ValueError( | ||
| f"Algorithm {self.algorithm} only supports loss `hdn`." | ||
| ) | ||
| if self.model.multiscale_count > 1: | ||
| raise ValueError("Algorithm `hdn` does not support multiscale models.") | ||
|
Comment on lines
+70
to
+77
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I feel like there should be a way to have these validations in the new child class
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point, will change
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I will also add a relevant test
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added(replaced) tests. On a second thought, there're also microsplit related checks in this module, and moving some and leaving others makes little sense
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm I see there should probably be separate child classes for MuSplit, DenoiSplit, but I guess this can wait and we create an issue
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes they should be moved to child classes. Can you open an issue (maybe first a general MicroSplit issue, then this as a sub-issue)? |
||
| # musplit | ||
| if self.algorithm == SupportedAlgorithm.MUSPLIT: | ||
| if self.loss.loss_type != SupportedLoss.MUSPLIT: | ||
|
|
@@ -108,6 +116,12 @@ def output_channels_validation(self: Self) -> Self: | |
| f"Number of output channels ({self.model.output_channels}) must match " | ||
| f"the number of noise models ({len(self.noise_model.noise_models)})." | ||
| ) | ||
|
|
||
| if self.algorithm == SupportedAlgorithm.HDN: | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this a fundamental limitation of HDN? |
||
| assert self.model.output_channels == 1, ( | ||
| f"Number of output channels ({self.model.output_channels}) must be 1 " | ||
| "for algorithm `hdn`." | ||
| ) | ||
| return self | ||
|
|
||
| @model_validator(mode="after") | ||
|
|
@@ -127,6 +141,15 @@ def predict_logvar_validation(self: Self) -> Self: | |
| "Gaussian likelihood model `predict_logvar` " | ||
| f"({self.gaussian_likelihood.predict_logvar}).", | ||
| ) | ||
| # if self.algorithm == SupportedAlgorithm.HDN: | ||
| # assert ( | ||
| # self.model.predict_logvar is None | ||
| # ), "Model `predict_logvar` must be `None` for algorithm `hdn`." | ||
| # if self.gaussian_likelihood is not None: | ||
| # assert self.gaussian_likelihood.predict_logvar is None, ( | ||
| # "Gaussian likelihood model `predict_logvar` must be `None` " | ||
| # "for algorithm `hdn`." | ||
| # ) | ||
| return self | ||
|
|
||
| def __str__(self) -> str: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,11 +15,9 @@ class LVAEModel(ArchitectureModel): | |
| model_config = ConfigDict(validate_assignment=True, validate_default=True) | ||
|
|
||
| architecture: Literal["LVAE"] | ||
| """Name of the architecture.""" | ||
|
|
||
| input_shape: list[int] = Field(default=[64, 64], validate_default=True) | ||
| """Shape of the input patch (C, Z, Y, X) or (C, Y, X) if the data is 2D.""" | ||
|
|
||
| input_shape: list[int] = Field(default=(64, 64), validate_default=True) | ||
| """Shape of the input patch (Z, Y, X) or (Y, X) if the data is 2D.""" | ||
|
CatEek marked this conversation as resolved.
Comment on lines
+19
to
+20
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. changing this to Note, it can actually be solved doing
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Interesting, I didn't know. And I guess when reading out the list, it gets casted into tuple without any issue? Should we open an issue for refactoring the way we export the configuration? It would be nice to support tuples, that would also allow immutable defaults in functions signatures.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah default |
||
| encoder_conv_strides: list = Field(default=[2, 2], validate_default=True) | ||
|
|
||
| # TODO make this per hierarchy step ? | ||
|
|
@@ -126,6 +124,13 @@ def validate_input_shape(cls, input_shape: list) -> list: | |
| f"Input shape must be greater than 1 in all dimensions" | ||
| f"(got {input_shape})." | ||
| ) | ||
|
|
||
| if any(s < 64 for s in input_shape[-2:]): | ||
| raise ValueError( | ||
| f"Input shape must be greater or equal to 64 in XY dimensions" | ||
| f"(got {input_shape})." | ||
| ) | ||
|
|
||
| return input_shape | ||
|
|
||
| @field_validator("encoder_n_filters") | ||
|
|
@@ -255,4 +260,4 @@ def is_3D(self) -> bool: | |
| bool | ||
| Whether the model is 3D or not. | ||
| """ | ||
| return self.conv_dims == 3 | ||
| return len(self.input_shape) == 3 | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Model should not even be instantiated.