Skip to content

Commit 4b9caa6

Browse files
authored
#5-add down block (#9)
1 parent b19faf6 commit 4b9caa6

File tree

5 files changed

+75
-0
lines changed

5 files changed

+75
-0
lines changed

model/blocks.py

+70
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import torch
2+
from torch import nn
3+
from utils import repeat_layers
4+
5+
6+
class DownBlock(nn.Module):
7+
8+
def __init__(self, in_channels: int, out_channels: int, embedding_dim: int, num_layers: int, num_heads: int,
9+
reduce_size: bool):
10+
super().__init__()
11+
self.num_layers = num_layers
12+
13+
self.embedding_layers = repeat_layers(
14+
nn.Sequential(
15+
nn.SiLU(),
16+
nn.Linear(in_features=embedding_dim, out_features=out_channels)
17+
),
18+
num_layers
19+
)
20+
21+
self.conv1_layers = nn.ModuleList([
22+
nn.Sequential(
23+
nn.GroupNorm(num_groups=8, num_channels=in_channels if layer_idx == 0 else out_channels),
24+
nn.SiLU(),
25+
nn.Conv2d(in_channels=in_channels if layer_idx == 0 else out_channels, out_channels=out_channels,
26+
kernel_size=3, stride=1, padding=1)
27+
)
28+
for layer_idx in range(num_layers)
29+
])
30+
self.conv2_layers = repeat_layers(
31+
nn.Sequential(
32+
nn.GroupNorm(num_groups=8, num_channels=out_channels),
33+
nn.SiLU(),
34+
nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
35+
),
36+
num_layers
37+
)
38+
self.conv_residuals = nn.ModuleList([
39+
nn.Conv2d(in_channels=in_channels if layer_idx == 0 else out_channels, out_channels=out_channels,
40+
kernel_size=1)
41+
for layer_idx in range(num_layers)
42+
])
43+
self.conv_out_layers = repeat_layers(
44+
nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=4, stride=2,
45+
padding=1) if reduce_size else nn.Identity(),
46+
num_layers
47+
)
48+
49+
self.attention_norms = repeat_layers(nn.GroupNorm(num_groups=8, num_channels=out_channels), num_layers)
50+
self.attentions = repeat_layers(
51+
nn.MultiheadAttention(embed_dim=out_channels, num_heads=num_heads, batch_first=True),
52+
num_layers
53+
)
54+
55+
def forward(self, x: torch.Tensor, time_embedding: torch.Tensor) -> torch.Tensor:
56+
for layer_idx in range(self.num_layers):
57+
residual_input = x
58+
x = self.conv1_layers[layer_idx](x)
59+
x = x + self.embedding_layers[layer_idx](time_embedding)[:, :, None, None]
60+
x = self.conv2_layers[layer_idx](x)
61+
x = x + self.conv_residuals[layer_idx](residual_input)
62+
63+
batch_size, channels, h, w = x.shape
64+
x_att = x.reshape(batch_size, channels, h * w)
65+
x_att = self.attention_norms[layer_idx](x_att).transpose(1, 2)
66+
x_att, _ = self.attentions[layer_idx](x_att, x_att, x_att).transpose(1, 2)
67+
x_att = x_att.reshape(batch_size, channels, h, w)
68+
x = x + x_att
69+
70+
return self.conv_out_layers(x)
File renamed without changes.

model/unet.py

Whitespace-only changes.

utils/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from utils.model import repeat_layers

utils/model.py

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
import torch
2+
3+
def repeat_layers(module: torch.nn.Module, num_repeats: int) -> torch.nn.Module:
4+
return torch.nn.ModuleList([module for _ in range(num_repeats)])

0 commit comments

Comments
 (0)