@@ -1538,6 +1538,38 @@ def forward(self, x):
1538
1538
fmaps = tuple (map (lambda conv : conv (x ), self .convs ))
1539
1539
return torch .cat (fmaps , dim = 1 )
1540
1540
1541
+ class UpsampleCombiner (nn .Module ):
1542
+ def __init__ (
1543
+ self ,
1544
+ dim ,
1545
+ * ,
1546
+ enabled = False ,
1547
+ dim_ins = tuple (),
1548
+ dim_outs = tuple ()
1549
+ ):
1550
+ super ().__init__ ()
1551
+ assert len (dim_ins ) == len (dim_outs )
1552
+ self .enabled = enabled
1553
+
1554
+ if not self .enabled :
1555
+ self .dim_out = dim
1556
+ return
1557
+
1558
+ self .fmap_convs = nn .ModuleList ([Block (dim_in , dim_out ) for dim_in , dim_out in zip (dim_ins , dim_outs )])
1559
+ self .dim_out = dim + (sum (dim_outs ) if len (dim_outs ) > 0 else 0 )
1560
+
1561
+ def forward (self , x , fmaps = None ):
1562
+ target_size = x .shape [- 1 ]
1563
+
1564
+ fmaps = default (fmaps , tuple ())
1565
+
1566
+ if not self .enabled or len (fmaps ) == 0 or len (self .fmap_convs ) == 0 :
1567
+ return x
1568
+
1569
+ fmaps = [resize_image_to (fmap , target_size ) for fmap in fmaps ]
1570
+ outs = [conv (fmap ) for fmap , conv in zip (fmaps , self .fmap_convs )]
1571
+ return torch .cat ((x , * outs ), dim = 1 )
1572
+
1541
1573
class Unet (nn .Module ):
1542
1574
def __init__ (
1543
1575
self ,
@@ -1575,6 +1607,7 @@ def __init__(
1575
1607
scale_skip_connection = False ,
1576
1608
pixel_shuffle_upsample = True ,
1577
1609
final_conv_kernel_size = 1 ,
1610
+ combine_upsample_fmaps = False , # whether to combine the outputs of all upsample blocks, as in unet squared paper
1578
1611
** kwargs
1579
1612
):
1580
1613
super ().__init__ ()
@@ -1710,7 +1743,8 @@ def __init__(
1710
1743
self .ups = nn .ModuleList ([])
1711
1744
num_resolutions = len (in_out )
1712
1745
1713
- skip_connect_dims = [] # keeping track of skip connection dimensions
1746
+ skip_connect_dims = [] # keeping track of skip connection dimensions
1747
+ upsample_combiner_dims = [] # keeping track of dimensions for final upsample feature map combiner
1714
1748
1715
1749
for ind , ((dim_in , dim_out ), groups , layer_num_resnet_blocks , layer_self_attn ) in enumerate (zip (in_out , resnet_groups , num_resnet_blocks , self_attn )):
1716
1750
is_first = ind == 0
@@ -1752,14 +1786,27 @@ def __init__(
1752
1786
elif sparse_attn :
1753
1787
attention = Residual (LinearAttention (dim_out , ** attn_kwargs ))
1754
1788
1789
+ upsample_combiner_dims .append (dim_out )
1790
+
1755
1791
self .ups .append (nn .ModuleList ([
1756
1792
ResnetBlock (dim_out + skip_connect_dim , dim_out , cond_dim = layer_cond_dim , time_cond_dim = time_cond_dim , groups = groups ),
1757
1793
nn .ModuleList ([ResnetBlock (dim_out + skip_connect_dim , dim_out , cond_dim = layer_cond_dim , time_cond_dim = time_cond_dim , groups = groups ) for _ in range (layer_num_resnet_blocks )]),
1758
1794
attention ,
1759
1795
upsample_klass (dim_out , dim_in ) if not is_last or memory_efficient else nn .Identity ()
1760
1796
]))
1761
1797
1762
- self .final_resnet_block = ResnetBlock (dim * 2 , dim , time_cond_dim = time_cond_dim , groups = top_level_resnet_group )
1798
+ # whether to combine outputs from all upsample blocks for final resnet block
1799
+
1800
+ self .upsample_combiner = UpsampleCombiner (
1801
+ dim = dim ,
1802
+ enabled = combine_upsample_fmaps ,
1803
+ dim_ins = upsample_combiner_dims ,
1804
+ dim_outs = (dim ,) * len (upsample_combiner_dims )
1805
+ )
1806
+
1807
+ # a final resnet block
1808
+
1809
+ self .final_resnet_block = ResnetBlock (self .upsample_combiner .dim_out + dim , dim , time_cond_dim = time_cond_dim , groups = top_level_resnet_group )
1763
1810
1764
1811
out_dim_in = dim + (channels if lowres_cond else 0 )
1765
1812
@@ -1953,7 +2000,8 @@ def forward(
1953
2000
1954
2001
# go through the layers of the unet, down and up
1955
2002
1956
- hiddens = []
2003
+ down_hiddens = []
2004
+ up_hiddens = []
1957
2005
1958
2006
for pre_downsample , init_block , resnet_blocks , attn , post_downsample in self .downs :
1959
2007
if exists (pre_downsample ):
@@ -1963,10 +2011,10 @@ def forward(
1963
2011
1964
2012
for resnet_block in resnet_blocks :
1965
2013
x = resnet_block (x , t , c )
1966
- hiddens .append (x )
2014
+ down_hiddens .append (x . contiguous () )
1967
2015
1968
2016
x = attn (x )
1969
- hiddens .append (x .contiguous ())
2017
+ down_hiddens .append (x .contiguous ())
1970
2018
1971
2019
if exists (post_downsample ):
1972
2020
x = post_downsample (x )
@@ -1978,7 +2026,7 @@ def forward(
1978
2026
1979
2027
x = self .mid_block2 (x , t , mid_c )
1980
2028
1981
- connect_skip = lambda fmap : torch .cat ((fmap , hiddens .pop () * self .skip_connect_scale ), dim = 1 )
2029
+ connect_skip = lambda fmap : torch .cat ((fmap , down_hiddens .pop () * self .skip_connect_scale ), dim = 1 )
1982
2030
1983
2031
for init_block , resnet_blocks , attn , upsample in self .ups :
1984
2032
x = connect_skip (x )
@@ -1989,8 +2037,12 @@ def forward(
1989
2037
x = resnet_block (x , t , c )
1990
2038
1991
2039
x = attn (x )
2040
+
2041
+ up_hiddens .append (x .contiguous ())
1992
2042
x = upsample (x )
1993
2043
2044
+ x = self .upsample_combiner (x , up_hiddens )
2045
+
1994
2046
x = torch .cat ((x , r ), dim = 1 )
1995
2047
1996
2048
x = self .final_resnet_block (x , t )
0 commit comments