Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions ice_station_zebra/config/model/encode_unet_decode.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
_target_: ice_station_zebra.models.EncodeProcessDecode

name: encode-unet-decode

# Each dataset will be encoded into a latent space with these properties
latent_space:
channels: 20
shape: [128, 128]

encoder:
_target_: ice_station_zebra.models.encoders.NaiveLatentSpaceEncoder

processor:
_target_: ice_station_zebra.models.processors.UNetProcessor
filter_size: 3 # Size of the kernel for convolutional layers
start_out_channels: 64 # Initial number of channels for the first convolutional layer

decoder:
_target_: ice_station_zebra.models.decoders.NaiveLatentSpaceDecoder
28 changes: 28 additions & 0 deletions ice_station_zebra/models/common/bottleneckblock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import torch.nn as nn
from torch import Tensor


class BottleneckBlock(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
*,
filter_size: int,
) -> None:
super().__init__()

self.model = nn.Sequential(
nn.Conv2d(
in_channels, out_channels, kernel_size=filter_size, padding="same"
),
nn.ReLU(inplace=True),
nn.Conv2d(
out_channels, out_channels, kernel_size=filter_size, padding="same"
),
nn.ReLU(inplace=True),
nn.BatchNorm2d(num_features=out_channels),
)

def forward(self, x: Tensor) -> Tensor:
return self.model(x)
45 changes: 45 additions & 0 deletions ice_station_zebra/models/common/convblock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import torch.nn as nn
from torch import Tensor


class ConvBlock(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
*,
filter_size: int,
final: bool = False,
) -> None:
super().__init__()

layers = [
nn.Conv2d(
in_channels, out_channels, kernel_size=filter_size, padding="same"
),
nn.ReLU(inplace=True),
nn.Conv2d(
out_channels, out_channels, kernel_size=filter_size, padding="same"
),
nn.ReLU(inplace=True),
]
if final:
layers += [
nn.Conv2d(
out_channels,
out_channels,
kernel_size=filter_size,
padding="same",
),
nn.ReLU(inplace=True),
]

else:
layers.append(
nn.BatchNorm2d(num_features=out_channels),
)

self.model = nn.Sequential(*layers)

def forward(self, x: Tensor) -> Tensor:
return self.model(x)
16 changes: 16 additions & 0 deletions ice_station_zebra/models/common/upconvblock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import torch.nn as nn
from torch import Tensor


class UpconvBlock(nn.Module):
def __init__(self, in_channels: int, out_channels: int) -> None:
super().__init__()

self.model = nn.Sequential(
nn.Upsample(scale_factor=2, mode="nearest"),
nn.Conv2d(in_channels, out_channels, kernel_size=2, padding="same"),
nn.ReLU(inplace=True),
)

def forward(self, x: Tensor) -> Tensor:
return self.model(x)
2 changes: 2 additions & 0 deletions ice_station_zebra/models/processors/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from .null import NullProcessor
from .unet import UNetProcessor

__all__ = [
"NullProcessor",
"UNetProcessor",
]
82 changes: 82 additions & 0 deletions ice_station_zebra/models/processors/unet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import torch.nn as nn
from torch import Tensor
import torch

from ice_station_zebra.models.common.convblock import ConvBlock
from ice_station_zebra.models.common.bottleneckblock import BottleneckBlock
from ice_station_zebra.models.common.upconvblock import UpconvBlock


class UNetProcessor(nn.Module):
"""UNet model that processes input through a UNet architecture

Structure based on Andersson et al. (2021) Nature Communications
https://doi.org/10.1038/s41467-021-25257-4"""

def __init__(
self,
n_latent_channels: int,
filter_size: int,
start_out_channels: int,
) -> None:
super().__init__()

channels = [start_out_channels * 2**pow for pow in range(4)]

# Encoder
self.conv1 = ConvBlock(n_latent_channels, channels[0], filter_size=filter_size)
self.maxpool1 = nn.MaxPool2d(kernel_size=2)
self.conv2 = ConvBlock(channels[0], channels[1], filter_size=filter_size)
self.maxpool2 = nn.MaxPool2d(kernel_size=2)
self.conv3 = ConvBlock(channels[1], channels[2], filter_size=filter_size)
self.maxpool3 = nn.MaxPool2d(kernel_size=2)
self.conv4 = ConvBlock(channels[2], channels[2], filter_size=filter_size)
Comment thread
IFenton marked this conversation as resolved.
self.maxpool4 = nn.MaxPool2d(kernel_size=2)

# Bottleneck
self.conv5 = BottleneckBlock(channels[2], channels[3], filter_size=filter_size)

# Decoder
self.up6 = UpconvBlock(channels[3], channels[2])
self.up7 = UpconvBlock(channels[2], channels[2])
self.up8 = UpconvBlock(channels[2], channels[1])
self.up9 = UpconvBlock(channels[1], channels[0])

self.up6b = ConvBlock(channels[3], channels[2], filter_size=filter_size)
self.up7b = ConvBlock(channels[3], channels[2], filter_size=filter_size)
self.up8b = ConvBlock(channels[2], channels[1], filter_size=filter_size)
self.up9b = ConvBlock(
channels[1], channels[0], filter_size=filter_size, final=True
)

# Final layer
self.final_layer = nn.Conv2d(
channels[0], n_latent_channels, kernel_size=1, padding="same"
)

def forward(self, x: Tensor) -> Tensor:
# Process in latent space: tensor with (batch_size, all_variables, latent_height, latent_width)

# Encoder
bn1 = self.conv1(x)
conv1 = self.maxpool1(bn1)
bn2 = self.conv2(conv1)
conv2 = self.maxpool1(bn2)
bn3 = self.conv3(conv2)
conv3 = self.maxpool3(bn3)
bn4 = self.conv4(conv3)
conv4 = self.maxpool4(bn4)

# Bottleneck
bn5 = self.conv5(conv4)

# Decoder
up6 = self.up6b(torch.cat([bn4, self.up6(bn5)], dim=1))
up7 = self.up7b(torch.cat([bn3, self.up7(up6)], dim=1))
up8 = self.up8b(torch.cat([bn2, self.up8(up7)], dim=1))
up9 = self.up9b(torch.cat([bn1, self.up9(up8)], dim=1))

# Final layer
output = self.final_layer(up9)

return output
Loading