-
Notifications
You must be signed in to change notification settings - Fork 2
Setting up a UNet model #46
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 10 commits
72ee8db
547c1f1
7a8a11e
eafc34a
661d668
2042d6e
fdc2a69
eb6c2d9
53d6d54
e0d961d
b9e4218
56d5258
41aa26f
2dd0ffb
c1010e5
5e8bc13
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,19 @@ | ||
| _target_: ice_station_zebra.models.EncodeProcessDecode | ||
|
|
||
| name: encode-unet-decode | ||
|
|
||
| # Each dataset will be encoded into a latent space with these properties | ||
| latent_space: | ||
| channels: 20 | ||
| shape: [128, 128] | ||
|
|
||
| encoder: | ||
| _target_: ice_station_zebra.models.encoders.NaiveLatentSpaceEncoder | ||
|
|
||
| processor: | ||
| _target_: ice_station_zebra.models.processors.UNetProcessor | ||
| filter_size: 3 # Size of the kernel for convolutional layers | ||
| n_filters_factor: 1.0 # Factor to scale the size of the UNet | ||
|
|
||
| decoder: | ||
| _target_: ice_station_zebra.models.decoders.NaiveLatentSpaceDecoder |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,5 +1,7 @@ | ||
| from .null import NullProcessor | ||
| from .unet import UNetProcessor | ||
|
|
||
| __all__ = [ | ||
| "NullProcessor", | ||
| "UNetProcessor", | ||
| ] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,159 @@ | ||
| import torch.nn as nn | ||
| from torch import Tensor | ||
| import torch | ||
| import torch.nn.functional as F | ||
|
jemrobinson marked this conversation as resolved.
Outdated
|
||
|
|
||
|
|
||
| class UNetProcessor(nn.Module): | ||
| """UNet model that processes input through a UNet architecture""" | ||
|
|
||
| def __init__( | ||
| self, | ||
| n_latent_channels: int, | ||
| filter_size: int, | ||
| n_filters_factor: float, | ||
| ) -> None: | ||
| super().__init__() | ||
|
|
||
| start_out_channels = 64 | ||
|
|
||
| reduced_channels = int(start_out_channels * n_filters_factor) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wouldn't it be easier to make
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it's actually neatest to just let people set the start_out_channels argument in the config file, then there is no need for the n_filters_factor or the reduced_channels |
||
|
|
||
| channels = [reduced_channels * 2**pow for pow in range(4)] | ||
|
|
||
| # Encoder | ||
| self.conv1 = ConvBlock(n_latent_channels, channels[0], filter_size=filter_size) | ||
| self.conv2 = ConvBlock(channels[0], channels[1], filter_size=filter_size) | ||
| self.conv3 = ConvBlock(channels[1], channels[2], filter_size=filter_size) | ||
| self.conv4 = ConvBlock(channels[2], channels[2], filter_size=filter_size) | ||
|
IFenton marked this conversation as resolved.
jemrobinson marked this conversation as resolved.
|
||
|
|
||
| # Bottleneck | ||
| self.conv5 = BottleneckBlock(channels[2], channels[3], filter_size=filter_size) | ||
|
|
||
| # Decoder | ||
| self.up6 = UpconvBlock(channels[3], channels[2]) | ||
| self.up7 = UpconvBlock(channels[2], channels[2]) | ||
| self.up8 = UpconvBlock(channels[2], channels[1]) | ||
| self.up9 = UpconvBlock(channels[1], channels[0]) | ||
|
|
||
| self.up6b = ConvBlock(channels[3], channels[2], filter_size=filter_size) | ||
| self.up7b = ConvBlock(channels[3], channels[2], filter_size=filter_size) | ||
| self.up8b = ConvBlock(channels[2], channels[1], filter_size=filter_size) | ||
| self.up9b = ConvBlock( | ||
| channels[1], channels[0], filter_size=filter_size, final=True | ||
| ) | ||
|
|
||
| # Final layer | ||
| self.final_layer = nn.Conv2d( | ||
| channels[0], n_latent_channels, kernel_size=1, padding="same" | ||
| ) | ||
|
|
||
| def forward(self, x: Tensor) -> Tensor: | ||
| # Process in latent space: tensor with (batch_size, all_variables, latent_height, latent_width) | ||
|
|
||
| # Encoder | ||
| bn1 = self.conv1(x) | ||
| conv1 = F.max_pool2d(bn1, kernel_size=2) | ||
| bn2 = self.conv2(conv1) | ||
| conv2 = F.max_pool2d(bn2, kernel_size=2) | ||
| bn3 = self.conv3(conv2) | ||
| conv3 = F.max_pool2d(bn3, kernel_size=2) | ||
| bn4 = self.conv4(conv3) | ||
| conv4 = F.max_pool2d(bn4, kernel_size=2) | ||
|
jemrobinson marked this conversation as resolved.
Outdated
|
||
|
|
||
| # Bottleneck | ||
| bn5 = self.conv5(conv4) | ||
|
|
||
| # Decoder | ||
| up6 = self.up6b(torch.cat([bn4, self.up6(bn5)], dim=1)) | ||
| up7 = self.up7b(torch.cat([bn3, self.up7(up6)], dim=1)) | ||
| up8 = self.up8b(torch.cat([bn2, self.up8(up7)], dim=1)) | ||
| up9 = self.up9b(torch.cat([bn1, self.up9(up8)], dim=1)) | ||
|
|
||
| # Final layer | ||
| output = self.final_layer(up9) | ||
|
|
||
| return output | ||
|
|
||
|
|
||
| class ConvBlock(nn.Module): | ||
|
IFenton marked this conversation as resolved.
Outdated
|
||
| def __init__( | ||
| self, | ||
| in_channels: int, | ||
| out_channels: int, | ||
| *, | ||
| filter_size: int, | ||
| final: bool = False, | ||
| ) -> None: | ||
| super().__init__() | ||
|
|
||
| layers = [ | ||
| nn.Conv2d( | ||
| in_channels, out_channels, kernel_size=filter_size, padding="same" | ||
| ), | ||
| nn.ReLU(inplace=True), | ||
| nn.Conv2d( | ||
| out_channels, out_channels, kernel_size=filter_size, padding="same" | ||
| ), | ||
| nn.ReLU(inplace=True), | ||
| ] | ||
| if final: | ||
| layers += [ | ||
| nn.Conv2d( | ||
| out_channels, | ||
| out_channels, | ||
| kernel_size=filter_size, | ||
| padding="same", | ||
| ), | ||
| nn.ReLU(inplace=True), | ||
| ] | ||
|
|
||
| else: | ||
| layers.append( | ||
| nn.BatchNorm2d(num_features=out_channels), | ||
| ) | ||
|
|
||
| self.model = nn.Sequential(*layers) | ||
|
|
||
| def forward(self, x: Tensor) -> Tensor: | ||
| return self.model(x) | ||
|
|
||
|
|
||
| class BottleneckBlock(nn.Module): | ||
| def __init__( | ||
| self, | ||
| in_channels: int, | ||
| out_channels: int, | ||
| *, | ||
| filter_size: int, | ||
| ) -> None: | ||
| super().__init__() | ||
|
|
||
| self.model = nn.Sequential( | ||
| nn.Conv2d( | ||
| in_channels, out_channels, kernel_size=filter_size, padding="same" | ||
| ), | ||
| nn.ReLU(inplace=True), | ||
| nn.Conv2d( | ||
| out_channels, out_channels, kernel_size=filter_size, padding="same" | ||
| ), | ||
| nn.ReLU(inplace=True), | ||
| nn.BatchNorm2d(num_features=out_channels), | ||
| ) | ||
|
|
||
| def forward(self, x: Tensor) -> Tensor: | ||
| return self.model(x) | ||
|
|
||
|
|
||
| class UpconvBlock(nn.Module): | ||
| def __init__(self, in_channels: int, out_channels: int) -> None: | ||
| super().__init__() | ||
|
|
||
| self.model = nn.Sequential( | ||
| nn.Upsample(scale_factor=2, mode="nearest"), | ||
| nn.Conv2d(in_channels, out_channels, kernel_size=2, padding="same"), | ||
| nn.ReLU(inplace=True), | ||
| ) | ||
|
|
||
| def forward(self, x: Tensor) -> Tensor: | ||
| return self.model(x) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we don't use this anymore
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Which 'this'? I think we do need to import those four things
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, I was trying to point to
import torch.nn.functional as F