-
Notifications
You must be signed in to change notification settings - Fork 2
Setting up a UNet model #46
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
16 commits
Select commit
Hold shift + click to select a range
72ee8db
Initial commit for setting up a UNet model
IFenton 547c1f1
Linting with pre-commit
IFenton 7a8a11e
Remove redundant code lines
IFenton eafc34a
Remove redundant class
IFenton 661d668
Remove unnecessary import
IFenton 2042d6e
Remove print statement
IFenton fdc2a69
Converting channels from dict to list
IFenton eb6c2d9
Linting
IFenton 53d6d54
Removing redundant argument
IFenton e0d961d
:label: Add type hints and shift defaults to yaml
IFenton b9e4218
:truck: Moving common code out of the unet file
IFenton 56d5258
:rotating_light: Linting
IFenton 41aa26f
:recycle: Replace n_filters_factor / reduced channels with start_out_…
IFenton 2dd0ffb
Merge branch 'main' into 43-add-unet
IFenton c1010e5
:bulb: Adding comment on origin of the UNet model
IFenton 5e8bc13
:alien: Use `nn.MaxPool2d` instead of `functional.max_pool2d`
jemrobinson File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,19 @@ | ||
| _target_: ice_station_zebra.models.EncodeProcessDecode | ||
|
|
||
| name: encode-unet-decode | ||
|
|
||
| # Each dataset will be encoded into a latent space with these properties | ||
| latent_space: | ||
| channels: 20 | ||
| shape: [128, 128] | ||
|
|
||
| encoder: | ||
| _target_: ice_station_zebra.models.encoders.NaiveLatentSpaceEncoder | ||
|
|
||
| processor: | ||
| _target_: ice_station_zebra.models.processors.UNetProcessor | ||
| filter_size: 3 # Size of the kernel for convolutional layers | ||
| start_out_channels: 64 # Initial number of channels for the first convolutional layer | ||
|
|
||
| decoder: | ||
| _target_: ice_station_zebra.models.decoders.NaiveLatentSpaceDecoder |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,28 @@ | ||
| import torch.nn as nn | ||
| from torch import Tensor | ||
|
|
||
|
|
||
| class BottleneckBlock(nn.Module): | ||
| def __init__( | ||
| self, | ||
| in_channels: int, | ||
| out_channels: int, | ||
| *, | ||
| filter_size: int, | ||
| ) -> None: | ||
| super().__init__() | ||
|
|
||
| self.model = nn.Sequential( | ||
| nn.Conv2d( | ||
| in_channels, out_channels, kernel_size=filter_size, padding="same" | ||
| ), | ||
| nn.ReLU(inplace=True), | ||
| nn.Conv2d( | ||
| out_channels, out_channels, kernel_size=filter_size, padding="same" | ||
| ), | ||
| nn.ReLU(inplace=True), | ||
| nn.BatchNorm2d(num_features=out_channels), | ||
| ) | ||
|
|
||
| def forward(self, x: Tensor) -> Tensor: | ||
| return self.model(x) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,45 @@ | ||
| import torch.nn as nn | ||
| from torch import Tensor | ||
|
|
||
|
|
||
| class ConvBlock(nn.Module): | ||
| def __init__( | ||
| self, | ||
| in_channels: int, | ||
| out_channels: int, | ||
| *, | ||
| filter_size: int, | ||
| final: bool = False, | ||
| ) -> None: | ||
| super().__init__() | ||
|
|
||
| layers = [ | ||
| nn.Conv2d( | ||
| in_channels, out_channels, kernel_size=filter_size, padding="same" | ||
| ), | ||
| nn.ReLU(inplace=True), | ||
| nn.Conv2d( | ||
| out_channels, out_channels, kernel_size=filter_size, padding="same" | ||
| ), | ||
| nn.ReLU(inplace=True), | ||
| ] | ||
| if final: | ||
| layers += [ | ||
| nn.Conv2d( | ||
| out_channels, | ||
| out_channels, | ||
| kernel_size=filter_size, | ||
| padding="same", | ||
| ), | ||
| nn.ReLU(inplace=True), | ||
| ] | ||
|
|
||
| else: | ||
| layers.append( | ||
| nn.BatchNorm2d(num_features=out_channels), | ||
| ) | ||
|
|
||
| self.model = nn.Sequential(*layers) | ||
|
|
||
| def forward(self, x: Tensor) -> Tensor: | ||
| return self.model(x) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,16 @@ | ||
| import torch.nn as nn | ||
| from torch import Tensor | ||
|
|
||
|
|
||
| class UpconvBlock(nn.Module): | ||
| def __init__(self, in_channels: int, out_channels: int) -> None: | ||
| super().__init__() | ||
|
|
||
| self.model = nn.Sequential( | ||
| nn.Upsample(scale_factor=2, mode="nearest"), | ||
| nn.Conv2d(in_channels, out_channels, kernel_size=2, padding="same"), | ||
| nn.ReLU(inplace=True), | ||
| ) | ||
|
|
||
| def forward(self, x: Tensor) -> Tensor: | ||
| return self.model(x) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,5 +1,7 @@ | ||
| from .null import NullProcessor | ||
| from .unet import UNetProcessor | ||
|
|
||
| __all__ = [ | ||
| "NullProcessor", | ||
| "UNetProcessor", | ||
| ] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,82 @@ | ||
| import torch.nn as nn | ||
| from torch import Tensor | ||
| import torch | ||
|
|
||
| from ice_station_zebra.models.common.convblock import ConvBlock | ||
| from ice_station_zebra.models.common.bottleneckblock import BottleneckBlock | ||
| from ice_station_zebra.models.common.upconvblock import UpconvBlock | ||
|
|
||
|
|
||
| class UNetProcessor(nn.Module): | ||
| """UNet model that processes input through a UNet architecture | ||
|
|
||
| Structure based on Andersson et al. (2021) Nature Communications | ||
| https://doi.org/10.1038/s41467-021-25257-4""" | ||
|
|
||
| def __init__( | ||
| self, | ||
| n_latent_channels: int, | ||
| filter_size: int, | ||
| start_out_channels: int, | ||
| ) -> None: | ||
| super().__init__() | ||
|
|
||
| channels = [start_out_channels * 2**pow for pow in range(4)] | ||
|
|
||
| # Encoder | ||
| self.conv1 = ConvBlock(n_latent_channels, channels[0], filter_size=filter_size) | ||
| self.maxpool1 = nn.MaxPool2d(kernel_size=2) | ||
| self.conv2 = ConvBlock(channels[0], channels[1], filter_size=filter_size) | ||
| self.maxpool2 = nn.MaxPool2d(kernel_size=2) | ||
| self.conv3 = ConvBlock(channels[1], channels[2], filter_size=filter_size) | ||
| self.maxpool3 = nn.MaxPool2d(kernel_size=2) | ||
| self.conv4 = ConvBlock(channels[2], channels[2], filter_size=filter_size) | ||
| self.maxpool4 = nn.MaxPool2d(kernel_size=2) | ||
|
|
||
| # Bottleneck | ||
| self.conv5 = BottleneckBlock(channels[2], channels[3], filter_size=filter_size) | ||
|
|
||
| # Decoder | ||
| self.up6 = UpconvBlock(channels[3], channels[2]) | ||
| self.up7 = UpconvBlock(channels[2], channels[2]) | ||
| self.up8 = UpconvBlock(channels[2], channels[1]) | ||
| self.up9 = UpconvBlock(channels[1], channels[0]) | ||
|
|
||
| self.up6b = ConvBlock(channels[3], channels[2], filter_size=filter_size) | ||
| self.up7b = ConvBlock(channels[3], channels[2], filter_size=filter_size) | ||
| self.up8b = ConvBlock(channels[2], channels[1], filter_size=filter_size) | ||
| self.up9b = ConvBlock( | ||
| channels[1], channels[0], filter_size=filter_size, final=True | ||
| ) | ||
|
|
||
| # Final layer | ||
| self.final_layer = nn.Conv2d( | ||
| channels[0], n_latent_channels, kernel_size=1, padding="same" | ||
| ) | ||
|
|
||
| def forward(self, x: Tensor) -> Tensor: | ||
| # Process in latent space: tensor with (batch_size, all_variables, latent_height, latent_width) | ||
|
|
||
| # Encoder | ||
| bn1 = self.conv1(x) | ||
| conv1 = self.maxpool1(bn1) | ||
| bn2 = self.conv2(conv1) | ||
| conv2 = self.maxpool1(bn2) | ||
| bn3 = self.conv3(conv2) | ||
| conv3 = self.maxpool3(bn3) | ||
| bn4 = self.conv4(conv3) | ||
| conv4 = self.maxpool4(bn4) | ||
|
|
||
| # Bottleneck | ||
| bn5 = self.conv5(conv4) | ||
|
|
||
| # Decoder | ||
| up6 = self.up6b(torch.cat([bn4, self.up6(bn5)], dim=1)) | ||
| up7 = self.up7b(torch.cat([bn3, self.up7(up6)], dim=1)) | ||
| up8 = self.up8b(torch.cat([bn2, self.up8(up7)], dim=1)) | ||
| up9 = self.up9b(torch.cat([bn1, self.up9(up8)], dim=1)) | ||
|
|
||
| # Final layer | ||
| output = self.final_layer(up9) | ||
|
|
||
| return output | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.