Skip to content

Commit 406e750

Browse files
committed
add upsample combiner feature for the unets
1 parent 9646dfc commit 406e750

File tree

3 files changed

+60
-8
lines changed

3 files changed

+60
-8
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -1112,7 +1112,7 @@ For detailed information on training the diffusion prior, please refer to the [d
11121112
- [x] allow for unet to be able to condition non-cross attention style as well
11131113
- [x] speed up inference, read up on papers (ddim)
11141114
- [x] add inpainting ability using resampler from repaint paper https://arxiv.org/abs/2201.09865
1115-
- [ ] try out the nested unet from https://arxiv.org/abs/2005.09007 after hearing several positive testimonies from researchers, for segmentation anyhow
1115+
- [x] add the final combination of upsample feature maps, used in unet squared, seems to have an effect in local experiments
11161116
- [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2
11171117

11181118
## Citations

dalle2_pytorch/dalle2_pytorch.py

+58-6
Original file line numberDiff line numberDiff line change
@@ -1538,6 +1538,38 @@ def forward(self, x):
15381538
fmaps = tuple(map(lambda conv: conv(x), self.convs))
15391539
return torch.cat(fmaps, dim = 1)
15401540

1541+
class UpsampleCombiner(nn.Module):
1542+
def __init__(
1543+
self,
1544+
dim,
1545+
*,
1546+
enabled = False,
1547+
dim_ins = tuple(),
1548+
dim_outs = tuple()
1549+
):
1550+
super().__init__()
1551+
assert len(dim_ins) == len(dim_outs)
1552+
self.enabled = enabled
1553+
1554+
if not self.enabled:
1555+
self.dim_out = dim
1556+
return
1557+
1558+
self.fmap_convs = nn.ModuleList([Block(dim_in, dim_out) for dim_in, dim_out in zip(dim_ins, dim_outs)])
1559+
self.dim_out = dim + (sum(dim_outs) if len(dim_outs) > 0 else 0)
1560+
1561+
def forward(self, x, fmaps = None):
1562+
target_size = x.shape[-1]
1563+
1564+
fmaps = default(fmaps, tuple())
1565+
1566+
if not self.enabled or len(fmaps) == 0 or len(self.fmap_convs) == 0:
1567+
return x
1568+
1569+
fmaps = [resize_image_to(fmap, target_size) for fmap in fmaps]
1570+
outs = [conv(fmap) for fmap, conv in zip(fmaps, self.fmap_convs)]
1571+
return torch.cat((x, *outs), dim = 1)
1572+
15411573
class Unet(nn.Module):
15421574
def __init__(
15431575
self,
@@ -1575,6 +1607,7 @@ def __init__(
15751607
scale_skip_connection = False,
15761608
pixel_shuffle_upsample = True,
15771609
final_conv_kernel_size = 1,
1610+
combine_upsample_fmaps = False, # whether to combine the outputs of all upsample blocks, as in unet squared paper
15781611
**kwargs
15791612
):
15801613
super().__init__()
@@ -1710,7 +1743,8 @@ def __init__(
17101743
self.ups = nn.ModuleList([])
17111744
num_resolutions = len(in_out)
17121745

1713-
skip_connect_dims = [] # keeping track of skip connection dimensions
1746+
skip_connect_dims = [] # keeping track of skip connection dimensions
1747+
upsample_combiner_dims = [] # keeping track of dimensions for final upsample feature map combiner
17141748

17151749
for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks, layer_self_attn) in enumerate(zip(in_out, resnet_groups, num_resnet_blocks, self_attn)):
17161750
is_first = ind == 0
@@ -1752,14 +1786,27 @@ def __init__(
17521786
elif sparse_attn:
17531787
attention = Residual(LinearAttention(dim_out, **attn_kwargs))
17541788

1789+
upsample_combiner_dims.append(dim_out)
1790+
17551791
self.ups.append(nn.ModuleList([
17561792
ResnetBlock(dim_out + skip_connect_dim, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups),
17571793
nn.ModuleList([ResnetBlock(dim_out + skip_connect_dim, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),
17581794
attention,
17591795
upsample_klass(dim_out, dim_in) if not is_last or memory_efficient else nn.Identity()
17601796
]))
17611797

1762-
self.final_resnet_block = ResnetBlock(dim * 2, dim, time_cond_dim = time_cond_dim, groups = top_level_resnet_group)
1798+
# whether to combine outputs from all upsample blocks for final resnet block
1799+
1800+
self.upsample_combiner = UpsampleCombiner(
1801+
dim = dim,
1802+
enabled = combine_upsample_fmaps,
1803+
dim_ins = upsample_combiner_dims,
1804+
dim_outs = (dim,) * len(upsample_combiner_dims)
1805+
)
1806+
1807+
# a final resnet block
1808+
1809+
self.final_resnet_block = ResnetBlock(self.upsample_combiner.dim_out + dim, dim, time_cond_dim = time_cond_dim, groups = top_level_resnet_group)
17631810

17641811
out_dim_in = dim + (channels if lowres_cond else 0)
17651812

@@ -1953,7 +2000,8 @@ def forward(
19532000

19542001
# go through the layers of the unet, down and up
19552002

1956-
hiddens = []
2003+
down_hiddens = []
2004+
up_hiddens = []
19572005

19582006
for pre_downsample, init_block, resnet_blocks, attn, post_downsample in self.downs:
19592007
if exists(pre_downsample):
@@ -1963,10 +2011,10 @@ def forward(
19632011

19642012
for resnet_block in resnet_blocks:
19652013
x = resnet_block(x, t, c)
1966-
hiddens.append(x)
2014+
down_hiddens.append(x.contiguous())
19672015

19682016
x = attn(x)
1969-
hiddens.append(x.contiguous())
2017+
down_hiddens.append(x.contiguous())
19702018

19712019
if exists(post_downsample):
19722020
x = post_downsample(x)
@@ -1978,7 +2026,7 @@ def forward(
19782026

19792027
x = self.mid_block2(x, t, mid_c)
19802028

1981-
connect_skip = lambda fmap: torch.cat((fmap, hiddens.pop() * self.skip_connect_scale), dim = 1)
2029+
connect_skip = lambda fmap: torch.cat((fmap, down_hiddens.pop() * self.skip_connect_scale), dim = 1)
19822030

19832031
for init_block, resnet_blocks, attn, upsample in self.ups:
19842032
x = connect_skip(x)
@@ -1989,8 +2037,12 @@ def forward(
19892037
x = resnet_block(x, t, c)
19902038

19912039
x = attn(x)
2040+
2041+
up_hiddens.append(x.contiguous())
19922042
x = upsample(x)
19932043

2044+
x = self.upsample_combiner(x, up_hiddens)
2045+
19942046
x = torch.cat((x, r), dim = 1)
19952047

19962048
x = self.final_resnet_block(x, t)

dalle2_pytorch/version.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '1.0.6'
1+
__version__ = '1.1.0'

0 commit comments

Comments
 (0)