diff --git a/model/__init__.py b/model/__init__.py index e69de29..8a25b0d 100644 --- a/model/__init__.py +++ b/model/__init__.py @@ -0,0 +1,4 @@ +from model.blocks import Encoder, Decoder, Bottleneck +from model.embedding import TimeEmbedding +from model.scheduler import LinearNoiseScheduler +from model.unet import UNet, UNetConfig diff --git a/model/embedding.py b/model/embedding.py index 81fb20c..499ba90 100644 --- a/model/embedding.py +++ b/model/embedding.py @@ -12,7 +12,7 @@ def __init__(self, embedding_dim: int): @torch.no_grad() def forward(self, x: torch.Tensor) -> torch.Tensor: exponents = ( - torch.arange(0, self.half_embedding_dim, dtype=torch.float32) / self.halv_embedding_dim + torch.arange(0, self.half_embedding_dim, dtype=torch.float32) / self.half_embedding_dim ).to(x.device) factors = 10_000 ** exponents diff --git a/model/test_unet.py b/model/test_unet.py new file mode 100644 index 0000000..eac3ca2 --- /dev/null +++ b/model/test_unet.py @@ -0,0 +1,33 @@ +import unittest + +import torch + +from model import UNet, UNetConfig + + +class TestUNet(unittest.TestCase): + + def test_shapes(self): + # Test the building and output shapes of the UNet model + + config = UNetConfig( + in_channels=3, + embedding_dim=128, + encoder_channels=[32, 64, 128, 256], + encoder_down_sample=[True, True, False], + encoder_num_layers=2, + bottleneck_channels=[256, 256, 128], + bottleneck_num_layers=2, + decoder_num_layers=2 + ) + unet = UNet(config) + + x = torch.randn(16, 3, 256, 256) + time_steps = torch.randint(0, 10, (16,)) + y = unet(x, time_steps) + + self.assertEquals(y.shape, (16, 3, 256, 256)) + + +if __name__ == '__main__': + unittest.main() diff --git a/model/unet.py b/model/unet.py index e69de29..18a7341 100644 --- a/model/unet.py +++ b/model/unet.py @@ -0,0 +1,86 @@ +from dataclasses import dataclass + +import torch +from torch import nn + +from model import TimeEmbedding, Encoder, Decoder, Bottleneck + + +@dataclass +class UNetConfig: + in_channels: int + embedding_dim: int + + encoder_channels: list[int] + encoder_down_sample: list[bool] + encoder_num_layers: int + + bottleneck_channels: list[int] + bottleneck_num_layers: int + + decoder_num_layers: int + + +class UNet(nn.Module): + + def __init__(self, config: UNetConfig): + super().__init__() + self.embedding_dim = config.embedding_dim + + self.embedding = TimeEmbedding(embedding_dim=config.embedding_dim) + self.embedding_proj = nn.Sequential( + nn.Linear(in_features=config.embedding_dim, out_features=config.embedding_dim), + nn.SiLU(), + nn.Linear(in_features=config.embedding_dim, out_features=config.embedding_dim) + ) + + self.conv_input = nn.Conv2d(in_channels=config.in_channels, out_channels=config.encoder_channels[0], + kernel_size=3, padding=1) + + self.up_sample = list(reversed(config.encoder_down_sample)) + self.encoder = nn.ModuleList([]) + for block_idx in range(len(config.encoder_channels) - 1): + self.encoder.append(Encoder(in_channels=config.encoder_channels[block_idx], + out_channels=config.encoder_channels[block_idx + 1], + embedding_dim=config.embedding_dim, num_layers=config.encoder_num_layers, + reduce_size=config.encoder_down_sample[block_idx], num_heads=4)) + + self.bottleneck = nn.ModuleList([]) + for block_idx in range(len(config.bottleneck_channels) - 1): + self.bottleneck.append(Bottleneck(in_channels=config.bottleneck_channels[block_idx], + out_channels=config.bottleneck_channels[block_idx + 1], + embedding_dim=config.embedding_dim, + num_layers=config.bottleneck_num_layers, + num_heads=4)) + + self.decoder = nn.ModuleList([]) + for block_idx in reversed(range(len(config.encoder_channels) - 1)): + self.decoder.append(Decoder(in_channels=config.encoder_channels[block_idx] * 2, + out_channels=config.encoder_channels[block_idx - 1] if block_idx != 0 else 16, + embedding_dim=config.embedding_dim, num_layers=config.decoder_num_layers, + increase_size=self.up_sample[block_idx], num_heads=4)) + + self.out_proj = nn.Sequential( + nn.GroupNorm(num_groups=8, num_channels=16), + nn.SiLU(), + nn.Conv2d(in_channels=16, out_channels=config.in_channels, kernel_size=3, padding=1) + ) + + def forward(self, x: torch.Tensor, time_embedding: torch.Tensor) -> torch.Tensor: + x = self.conv_input(x) + + embedding = self.embedding(torch.as_tensor(time_embedding).long()) + embedding = self.embedding_proj(embedding) + + encoder_outs = [] + for layer_idx, encoder_layer in enumerate(self.encoder): + encoder_outs.append(x) + x = encoder_layer(x, embedding) + + for bottleneck_layer in self.bottleneck: + x = bottleneck_layer(x, embedding) + + for decoder_layer in self.decoder: + x = decoder_layer(x, encoder_outs.pop(), embedding) + + return self.out_proj(x)