diff --git a/model/__init__.py b/model/__init__.py
index e69de29..8a25b0d 100644
--- a/model/__init__.py
+++ b/model/__init__.py
@@ -0,0 +1,4 @@
+from model.blocks import Encoder, Decoder, Bottleneck
+from model.embedding import TimeEmbedding
+from model.scheduler import LinearNoiseScheduler
+from model.unet import UNet, UNetConfig
diff --git a/model/embedding.py b/model/embedding.py
index 81fb20c..499ba90 100644
--- a/model/embedding.py
+++ b/model/embedding.py
@@ -12,7 +12,7 @@ def __init__(self, embedding_dim: int):
@torch.no_grad()
def forward(self, x: torch.Tensor) -> torch.Tensor:
exponents = (
- torch.arange(0, self.half_embedding_dim, dtype=torch.float32) / self.halv_embedding_dim
+ torch.arange(0, self.half_embedding_dim, dtype=torch.float32) / self.half_embedding_dim
).to(x.device)
factors = 10_000 ** exponents
diff --git a/model/test_unet.py b/model/test_unet.py
new file mode 100644
index 0000000..eac3ca2
--- /dev/null
+++ b/model/test_unet.py
@@ -0,0 +1,33 @@
+import unittest
+
+import torch
+
+from model import UNet, UNetConfig
+
+
+class TestUNet(unittest.TestCase):
+
+ def test_shapes(self):
+ # Test the building and output shapes of the UNet model
+
+ config = UNetConfig(
+ in_channels=3,
+ embedding_dim=128,
+ encoder_channels=[32, 64, 128, 256],
+ encoder_down_sample=[True, True, False],
+ encoder_num_layers=2,
+ bottleneck_channels=[256, 256, 128],
+ bottleneck_num_layers=2,
+ decoder_num_layers=2
+ )
+ unet = UNet(config)
+
+ x = torch.randn(16, 3, 256, 256)
+ time_steps = torch.randint(0, 10, (16,))
+ y = unet(x, time_steps)
+
+ self.assertEquals(y.shape, (16, 3, 256, 256))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/model/unet.py b/model/unet.py
index e69de29..18a7341 100644
--- a/model/unet.py
+++ b/model/unet.py
@@ -0,0 +1,86 @@
+from dataclasses import dataclass
+
+import torch
+from torch import nn
+
+from model import TimeEmbedding, Encoder, Decoder, Bottleneck
+
+
+@dataclass
+class UNetConfig:
+ in_channels: int
+ embedding_dim: int
+
+ encoder_channels: list[int]
+ encoder_down_sample: list[bool]
+ encoder_num_layers: int
+
+ bottleneck_channels: list[int]
+ bottleneck_num_layers: int
+
+ decoder_num_layers: int
+
+
+class UNet(nn.Module):
+
+ def __init__(self, config: UNetConfig):
+ super().__init__()
+ self.embedding_dim = config.embedding_dim
+
+ self.embedding = TimeEmbedding(embedding_dim=config.embedding_dim)
+ self.embedding_proj = nn.Sequential(
+ nn.Linear(in_features=config.embedding_dim, out_features=config.embedding_dim),
+ nn.SiLU(),
+ nn.Linear(in_features=config.embedding_dim, out_features=config.embedding_dim)
+ )
+
+ self.conv_input = nn.Conv2d(in_channels=config.in_channels, out_channels=config.encoder_channels[0],
+ kernel_size=3, padding=1)
+
+ self.up_sample = list(reversed(config.encoder_down_sample))
+ self.encoder = nn.ModuleList([])
+ for block_idx in range(len(config.encoder_channels) - 1):
+ self.encoder.append(Encoder(in_channels=config.encoder_channels[block_idx],
+ out_channels=config.encoder_channels[block_idx + 1],
+ embedding_dim=config.embedding_dim, num_layers=config.encoder_num_layers,
+ reduce_size=config.encoder_down_sample[block_idx], num_heads=4))
+
+ self.bottleneck = nn.ModuleList([])
+ for block_idx in range(len(config.bottleneck_channels) - 1):
+ self.bottleneck.append(Bottleneck(in_channels=config.bottleneck_channels[block_idx],
+ out_channels=config.bottleneck_channels[block_idx + 1],
+ embedding_dim=config.embedding_dim,
+ num_layers=config.bottleneck_num_layers,
+ num_heads=4))
+
+ self.decoder = nn.ModuleList([])
+ for block_idx in reversed(range(len(config.encoder_channels) - 1)):
+ self.decoder.append(Decoder(in_channels=config.encoder_channels[block_idx] * 2,
+ out_channels=config.encoder_channels[block_idx - 1] if block_idx != 0 else 16,
+ embedding_dim=config.embedding_dim, num_layers=config.decoder_num_layers,
+ increase_size=self.up_sample[block_idx], num_heads=4))
+
+ self.out_proj = nn.Sequential(
+ nn.GroupNorm(num_groups=8, num_channels=16),
+ nn.SiLU(),
+ nn.Conv2d(in_channels=16, out_channels=config.in_channels, kernel_size=3, padding=1)
+ )
+
+ def forward(self, x: torch.Tensor, time_embedding: torch.Tensor) -> torch.Tensor:
+ x = self.conv_input(x)
+
+ embedding = self.embedding(torch.as_tensor(time_embedding).long())
+ embedding = self.embedding_proj(embedding)
+
+ encoder_outs = []
+ for layer_idx, encoder_layer in enumerate(self.encoder):
+ encoder_outs.append(x)
+ x = encoder_layer(x, embedding)
+
+ for bottleneck_layer in self.bottleneck:
+ x = bottleneck_layer(x, embedding)
+
+ for decoder_layer in self.decoder:
+ x = decoder_layer(x, encoder_outs.pop(), embedding)
+
+ return self.out_proj(x)