Skip to content

#7-create unet #16

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions model/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from model.blocks import Encoder, Decoder, Bottleneck
from model.embedding import TimeEmbedding
from model.scheduler import LinearNoiseScheduler
from model.unet import UNet, UNetConfig
2 changes: 1 addition & 1 deletion model/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def __init__(self, embedding_dim: int):
@torch.no_grad()
def forward(self, x: torch.Tensor) -> torch.Tensor:
exponents = (
torch.arange(0, self.half_embedding_dim, dtype=torch.float32) / self.halv_embedding_dim
torch.arange(0, self.half_embedding_dim, dtype=torch.float32) / self.half_embedding_dim
).to(x.device)
factors = 10_000 ** exponents

Expand Down
33 changes: 33 additions & 0 deletions model/test_unet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import unittest

import torch

from model import UNet, UNetConfig


class TestUNet(unittest.TestCase):

def test_shapes(self):
# Test the building and output shapes of the UNet model

config = UNetConfig(
in_channels=3,
embedding_dim=128,
encoder_channels=[32, 64, 128, 256],
encoder_down_sample=[True, True, False],
encoder_num_layers=2,
bottleneck_channels=[256, 256, 128],
bottleneck_num_layers=2,
decoder_num_layers=2
)
unet = UNet(config)

x = torch.randn(16, 3, 256, 256)
time_steps = torch.randint(0, 10, (16,))
y = unet(x, time_steps)

self.assertEquals(y.shape, (16, 3, 256, 256))


if __name__ == '__main__':
unittest.main()
86 changes: 86 additions & 0 deletions model/unet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from dataclasses import dataclass

import torch
from torch import nn

from model import TimeEmbedding, Encoder, Decoder, Bottleneck


@dataclass
class UNetConfig:
in_channels: int
embedding_dim: int

encoder_channels: list[int]
encoder_down_sample: list[bool]
encoder_num_layers: int

bottleneck_channels: list[int]
bottleneck_num_layers: int

decoder_num_layers: int


class UNet(nn.Module):

def __init__(self, config: UNetConfig):
super().__init__()
self.embedding_dim = config.embedding_dim

self.embedding = TimeEmbedding(embedding_dim=config.embedding_dim)
self.embedding_proj = nn.Sequential(
nn.Linear(in_features=config.embedding_dim, out_features=config.embedding_dim),
nn.SiLU(),
nn.Linear(in_features=config.embedding_dim, out_features=config.embedding_dim)
)

self.conv_input = nn.Conv2d(in_channels=config.in_channels, out_channels=config.encoder_channels[0],
kernel_size=3, padding=1)

self.up_sample = list(reversed(config.encoder_down_sample))
self.encoder = nn.ModuleList([])
for block_idx in range(len(config.encoder_channels) - 1):
self.encoder.append(Encoder(in_channels=config.encoder_channels[block_idx],
out_channels=config.encoder_channels[block_idx + 1],
embedding_dim=config.embedding_dim, num_layers=config.encoder_num_layers,
reduce_size=config.encoder_down_sample[block_idx], num_heads=4))

self.bottleneck = nn.ModuleList([])
for block_idx in range(len(config.bottleneck_channels) - 1):
self.bottleneck.append(Bottleneck(in_channels=config.bottleneck_channels[block_idx],
out_channels=config.bottleneck_channels[block_idx + 1],
embedding_dim=config.embedding_dim,
num_layers=config.bottleneck_num_layers,
num_heads=4))

self.decoder = nn.ModuleList([])
for block_idx in reversed(range(len(config.encoder_channels) - 1)):
self.decoder.append(Decoder(in_channels=config.encoder_channels[block_idx] * 2,
out_channels=config.encoder_channels[block_idx - 1] if block_idx != 0 else 16,
embedding_dim=config.embedding_dim, num_layers=config.decoder_num_layers,
increase_size=self.up_sample[block_idx], num_heads=4))

self.out_proj = nn.Sequential(
nn.GroupNorm(num_groups=8, num_channels=16),
nn.SiLU(),
nn.Conv2d(in_channels=16, out_channels=config.in_channels, kernel_size=3, padding=1)
)

def forward(self, x: torch.Tensor, time_embedding: torch.Tensor) -> torch.Tensor:
x = self.conv_input(x)

embedding = self.embedding(torch.as_tensor(time_embedding).long())
embedding = self.embedding_proj(embedding)

encoder_outs = []
for layer_idx, encoder_layer in enumerate(self.encoder):
encoder_outs.append(x)
x = encoder_layer(x, embedding)

for bottleneck_layer in self.bottleneck:
x = bottleneck_layer(x, embedding)

for decoder_layer in self.decoder:
x = decoder_layer(x, encoder_outs.pop(), embedding)

return self.out_proj(x)
Loading