Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions ice_station_zebra/config/model/encode_unet_decode.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
_target_: ice_station_zebra.models.EncodeProcessDecode

name: encode-unet-decode

# Each dataset will be encoded into a latent space with these properties
latent_space:
channels: 20
shape: [128, 128]

encoder:
_target_: ice_station_zebra.models.encoders.NaiveLatentSpaceEncoder

processor:
_target_: ice_station_zebra.models.processors.UNetProcessor
filter_size: 3 # Size of the kernel for convolutional layers
n_filters_factor: 1.0 # Factor to scale the size of the UNet

decoder:
_target_: ice_station_zebra.models.decoders.NaiveLatentSpaceDecoder
2 changes: 2 additions & 0 deletions ice_station_zebra/models/processors/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from .null import NullProcessor
from .unet import UNetProcessor

__all__ = [
"NullProcessor",
"UNetProcessor",
]
159 changes: 159 additions & 0 deletions ice_station_zebra/models/processors/unet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
import torch.nn as nn
from torch import Tensor
import torch
import torch.nn.functional as F
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we don't use this anymore

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which 'this'? I think we do need to import those four things

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I was trying to point to import torch.nn.functional as F

Comment thread
jemrobinson marked this conversation as resolved.
Outdated


class UNetProcessor(nn.Module):
"""UNet model that processes input through a UNet architecture"""

def __init__(
self,
n_latent_channels: int,
filter_size: int,
n_filters_factor: float,
) -> None:
super().__init__()

start_out_channels = 64

reduced_channels = int(start_out_channels * n_filters_factor)
Copy link
Copy Markdown
Member

@jemrobinson jemrobinson Aug 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't it be easier to make reduced_channels an argument rather than having n_filters_factor as an argument that we multiply by 64?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's actually neatest to just let people set the start_out_channels argument in the config file, then there is no need for the n_filters_factor or the reduced_channels


channels = [reduced_channels * 2**pow for pow in range(4)]

# Encoder
self.conv1 = ConvBlock(n_latent_channels, channels[0], filter_size=filter_size)
self.conv2 = ConvBlock(channels[0], channels[1], filter_size=filter_size)
self.conv3 = ConvBlock(channels[1], channels[2], filter_size=filter_size)
self.conv4 = ConvBlock(channels[2], channels[2], filter_size=filter_size)
Comment thread
IFenton marked this conversation as resolved.
Comment thread
jemrobinson marked this conversation as resolved.

# Bottleneck
self.conv5 = BottleneckBlock(channels[2], channels[3], filter_size=filter_size)

# Decoder
self.up6 = UpconvBlock(channels[3], channels[2])
self.up7 = UpconvBlock(channels[2], channels[2])
self.up8 = UpconvBlock(channels[2], channels[1])
self.up9 = UpconvBlock(channels[1], channels[0])

self.up6b = ConvBlock(channels[3], channels[2], filter_size=filter_size)
self.up7b = ConvBlock(channels[3], channels[2], filter_size=filter_size)
self.up8b = ConvBlock(channels[2], channels[1], filter_size=filter_size)
self.up9b = ConvBlock(
channels[1], channels[0], filter_size=filter_size, final=True
)

# Final layer
self.final_layer = nn.Conv2d(
channels[0], n_latent_channels, kernel_size=1, padding="same"
)

def forward(self, x: Tensor) -> Tensor:
# Process in latent space: tensor with (batch_size, all_variables, latent_height, latent_width)

# Encoder
bn1 = self.conv1(x)
conv1 = F.max_pool2d(bn1, kernel_size=2)
bn2 = self.conv2(conv1)
conv2 = F.max_pool2d(bn2, kernel_size=2)
bn3 = self.conv3(conv2)
conv3 = F.max_pool2d(bn3, kernel_size=2)
bn4 = self.conv4(conv3)
conv4 = F.max_pool2d(bn4, kernel_size=2)
Comment thread
jemrobinson marked this conversation as resolved.
Outdated

# Bottleneck
bn5 = self.conv5(conv4)

# Decoder
up6 = self.up6b(torch.cat([bn4, self.up6(bn5)], dim=1))
up7 = self.up7b(torch.cat([bn3, self.up7(up6)], dim=1))
up8 = self.up8b(torch.cat([bn2, self.up8(up7)], dim=1))
up9 = self.up9b(torch.cat([bn1, self.up9(up8)], dim=1))

# Final layer
output = self.final_layer(up9)

return output


class ConvBlock(nn.Module):
Comment thread
IFenton marked this conversation as resolved.
Outdated
def __init__(
self,
in_channels: int,
out_channels: int,
*,
filter_size: int,
final: bool = False,
) -> None:
super().__init__()

layers = [
nn.Conv2d(
in_channels, out_channels, kernel_size=filter_size, padding="same"
),
nn.ReLU(inplace=True),
nn.Conv2d(
out_channels, out_channels, kernel_size=filter_size, padding="same"
),
nn.ReLU(inplace=True),
]
if final:
layers += [
nn.Conv2d(
out_channels,
out_channels,
kernel_size=filter_size,
padding="same",
),
nn.ReLU(inplace=True),
]

else:
layers.append(
nn.BatchNorm2d(num_features=out_channels),
)

self.model = nn.Sequential(*layers)

def forward(self, x: Tensor) -> Tensor:
return self.model(x)


class BottleneckBlock(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
*,
filter_size: int,
) -> None:
super().__init__()

self.model = nn.Sequential(
nn.Conv2d(
in_channels, out_channels, kernel_size=filter_size, padding="same"
),
nn.ReLU(inplace=True),
nn.Conv2d(
out_channels, out_channels, kernel_size=filter_size, padding="same"
),
nn.ReLU(inplace=True),
nn.BatchNorm2d(num_features=out_channels),
)

def forward(self, x: Tensor) -> Tensor:
return self.model(x)


class UpconvBlock(nn.Module):
def __init__(self, in_channels: int, out_channels: int) -> None:
super().__init__()

self.model = nn.Sequential(
nn.Upsample(scale_factor=2, mode="nearest"),
nn.Conv2d(in_channels, out_channels, kernel_size=2, padding="same"),
nn.ReLU(inplace=True),
)

def forward(self, x: Tensor) -> Tensor:
return self.model(x)
Loading