Skip to content

Commit f3e7616

Browse files
IFentonmarianovitasari20louisavzjemrobinson
authored
Setting up a UNet model (#46)
* Initial commit for setting up a UNet model Co-authored-by: Maria Novitasari <mnovitasari@turing.ac.uk> Co-authored-by: Louisa van Zeeland <louisa-ai2@protonmail.com> Co-authored-by: James Robinson <james.em.robinson@gmail.com> * Linting with pre-commit * Remove redundant code lines Co-authored-by: James Robinson <james.em.robinson@gmail.com> * Remove redundant class Co-authored-by: James Robinson <james.em.robinson@gmail.com> * Remove unnecessary import Co-authored-by: James Robinson <james.em.robinson@gmail.com> * Remove print statement Co-authored-by: James Robinson <james.em.robinson@gmail.com> * Converting channels from dict to list * Linting * Removing redundant argument Co-authored-by: James Robinson <james.em.robinson@gmail.com> * 🏷️ Add type hints and shift defaults to yaml * 🚚 Moving common code out of the unet file * 🚨 Linting * ♻️ Replace n_filters_factor / reduced channels with start_out_channels arg * 💡 Adding comment on origin of the UNet model * 👽 Use `nn.MaxPool2d` instead of `functional.max_pool2d` --------- Co-authored-by: Maria Novitasari <mnovitasari@turing.ac.uk> Co-authored-by: Louisa van Zeeland <louisa-ai2@protonmail.com> Co-authored-by: James Robinson <james.em.robinson@gmail.com>
1 parent e3c71fd commit f3e7616

6 files changed

Lines changed: 192 additions & 0 deletions

File tree

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
_target_: ice_station_zebra.models.EncodeProcessDecode
2+
3+
name: encode-unet-decode
4+
5+
# Each dataset will be encoded into a latent space with these properties
6+
latent_space:
7+
channels: 20
8+
shape: [128, 128]
9+
10+
encoder:
11+
_target_: ice_station_zebra.models.encoders.NaiveLatentSpaceEncoder
12+
13+
processor:
14+
_target_: ice_station_zebra.models.processors.UNetProcessor
15+
filter_size: 3 # Size of the kernel for convolutional layers
16+
start_out_channels: 64 # Initial number of channels for the first convolutional layer
17+
18+
decoder:
19+
_target_: ice_station_zebra.models.decoders.NaiveLatentSpaceDecoder
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import torch.nn as nn
2+
from torch import Tensor
3+
4+
5+
class BottleneckBlock(nn.Module):
6+
def __init__(
7+
self,
8+
in_channels: int,
9+
out_channels: int,
10+
*,
11+
filter_size: int,
12+
) -> None:
13+
super().__init__()
14+
15+
self.model = nn.Sequential(
16+
nn.Conv2d(
17+
in_channels, out_channels, kernel_size=filter_size, padding="same"
18+
),
19+
nn.ReLU(inplace=True),
20+
nn.Conv2d(
21+
out_channels, out_channels, kernel_size=filter_size, padding="same"
22+
),
23+
nn.ReLU(inplace=True),
24+
nn.BatchNorm2d(num_features=out_channels),
25+
)
26+
27+
def forward(self, x: Tensor) -> Tensor:
28+
return self.model(x)
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import torch.nn as nn
2+
from torch import Tensor
3+
4+
5+
class ConvBlock(nn.Module):
6+
def __init__(
7+
self,
8+
in_channels: int,
9+
out_channels: int,
10+
*,
11+
filter_size: int,
12+
final: bool = False,
13+
) -> None:
14+
super().__init__()
15+
16+
layers = [
17+
nn.Conv2d(
18+
in_channels, out_channels, kernel_size=filter_size, padding="same"
19+
),
20+
nn.ReLU(inplace=True),
21+
nn.Conv2d(
22+
out_channels, out_channels, kernel_size=filter_size, padding="same"
23+
),
24+
nn.ReLU(inplace=True),
25+
]
26+
if final:
27+
layers += [
28+
nn.Conv2d(
29+
out_channels,
30+
out_channels,
31+
kernel_size=filter_size,
32+
padding="same",
33+
),
34+
nn.ReLU(inplace=True),
35+
]
36+
37+
else:
38+
layers.append(
39+
nn.BatchNorm2d(num_features=out_channels),
40+
)
41+
42+
self.model = nn.Sequential(*layers)
43+
44+
def forward(self, x: Tensor) -> Tensor:
45+
return self.model(x)
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import torch.nn as nn
2+
from torch import Tensor
3+
4+
5+
class UpconvBlock(nn.Module):
6+
def __init__(self, in_channels: int, out_channels: int) -> None:
7+
super().__init__()
8+
9+
self.model = nn.Sequential(
10+
nn.Upsample(scale_factor=2, mode="nearest"),
11+
nn.Conv2d(in_channels, out_channels, kernel_size=2, padding="same"),
12+
nn.ReLU(inplace=True),
13+
)
14+
15+
def forward(self, x: Tensor) -> Tensor:
16+
return self.model(x)
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from .null import NullProcessor
2+
from .unet import UNetProcessor
23

34
__all__ = [
45
"NullProcessor",
6+
"UNetProcessor",
57
]
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
import torch.nn as nn
2+
from torch import Tensor
3+
import torch
4+
5+
from ice_station_zebra.models.common.convblock import ConvBlock
6+
from ice_station_zebra.models.common.bottleneckblock import BottleneckBlock
7+
from ice_station_zebra.models.common.upconvblock import UpconvBlock
8+
9+
10+
class UNetProcessor(nn.Module):
11+
"""UNet model that processes input through a UNet architecture
12+
13+
Structure based on Andersson et al. (2021) Nature Communications
14+
https://doi.org/10.1038/s41467-021-25257-4"""
15+
16+
def __init__(
17+
self,
18+
n_latent_channels: int,
19+
filter_size: int,
20+
start_out_channels: int,
21+
) -> None:
22+
super().__init__()
23+
24+
channels = [start_out_channels * 2**pow for pow in range(4)]
25+
26+
# Encoder
27+
self.conv1 = ConvBlock(n_latent_channels, channels[0], filter_size=filter_size)
28+
self.maxpool1 = nn.MaxPool2d(kernel_size=2)
29+
self.conv2 = ConvBlock(channels[0], channels[1], filter_size=filter_size)
30+
self.maxpool2 = nn.MaxPool2d(kernel_size=2)
31+
self.conv3 = ConvBlock(channels[1], channels[2], filter_size=filter_size)
32+
self.maxpool3 = nn.MaxPool2d(kernel_size=2)
33+
self.conv4 = ConvBlock(channels[2], channels[2], filter_size=filter_size)
34+
self.maxpool4 = nn.MaxPool2d(kernel_size=2)
35+
36+
# Bottleneck
37+
self.conv5 = BottleneckBlock(channels[2], channels[3], filter_size=filter_size)
38+
39+
# Decoder
40+
self.up6 = UpconvBlock(channels[3], channels[2])
41+
self.up7 = UpconvBlock(channels[2], channels[2])
42+
self.up8 = UpconvBlock(channels[2], channels[1])
43+
self.up9 = UpconvBlock(channels[1], channels[0])
44+
45+
self.up6b = ConvBlock(channels[3], channels[2], filter_size=filter_size)
46+
self.up7b = ConvBlock(channels[3], channels[2], filter_size=filter_size)
47+
self.up8b = ConvBlock(channels[2], channels[1], filter_size=filter_size)
48+
self.up9b = ConvBlock(
49+
channels[1], channels[0], filter_size=filter_size, final=True
50+
)
51+
52+
# Final layer
53+
self.final_layer = nn.Conv2d(
54+
channels[0], n_latent_channels, kernel_size=1, padding="same"
55+
)
56+
57+
def forward(self, x: Tensor) -> Tensor:
58+
# Process in latent space: tensor with (batch_size, all_variables, latent_height, latent_width)
59+
60+
# Encoder
61+
bn1 = self.conv1(x)
62+
conv1 = self.maxpool1(bn1)
63+
bn2 = self.conv2(conv1)
64+
conv2 = self.maxpool1(bn2)
65+
bn3 = self.conv3(conv2)
66+
conv3 = self.maxpool3(bn3)
67+
bn4 = self.conv4(conv3)
68+
conv4 = self.maxpool4(bn4)
69+
70+
# Bottleneck
71+
bn5 = self.conv5(conv4)
72+
73+
# Decoder
74+
up6 = self.up6b(torch.cat([bn4, self.up6(bn5)], dim=1))
75+
up7 = self.up7b(torch.cat([bn3, self.up7(up6)], dim=1))
76+
up8 = self.up8b(torch.cat([bn2, self.up8(up7)], dim=1))
77+
up9 = self.up9b(torch.cat([bn1, self.up9(up8)], dim=1))
78+
79+
# Final layer
80+
output = self.final_layer(up9)
81+
82+
return output

0 commit comments

Comments
 (0)