@@ -1451,6 +1451,26 @@ def Downsample(dim, *, dim_out = None):
1451
1451
dim_out = default (dim_out , dim )
1452
1452
return nn .Conv2d (dim , dim_out , 4 , 2 , 1 )
1453
1453
1454
+ class WeightStandardizedConv2d (nn .Conv2d ):
1455
+ """
1456
+ https://arxiv.org/abs/1903.10520
1457
+ weight standardization purportedly works synergistically with group normalization
1458
+ """
1459
+ def forward (self , x ):
1460
+ eps = 1e-5 if x .dtype == torch .float32 else 1e-3
1461
+
1462
+ weight = self .weight
1463
+ flattened_weights = rearrange (weight , 'o ... -> o (...)' )
1464
+
1465
+ mean = reduce (weight , 'o ... -> o 1 1 1' , 'mean' )
1466
+
1467
+ var = torch .var (flattened_weights , dim = - 1 , unbiased = False )
1468
+ var = rearrange (var , 'o -> o 1 1 1' )
1469
+
1470
+ weight = (weight - mean ) * (var + eps ).rsqrt ()
1471
+
1472
+ return F .conv2d (x , weight , self .bias , self .stride , self .padding , self .dilation , self .groups )
1473
+
1454
1474
class SinusoidalPosEmb (nn .Module ):
1455
1475
def __init__ (self , dim ):
1456
1476
super ().__init__ ()
@@ -1469,10 +1489,13 @@ def __init__(
1469
1489
self ,
1470
1490
dim ,
1471
1491
dim_out ,
1472
- groups = 8
1492
+ groups = 8 ,
1493
+ weight_standardization = False
1473
1494
):
1474
1495
super ().__init__ ()
1475
- self .project = nn .Conv2d (dim , dim_out , 3 , padding = 1 )
1496
+ conv_klass = nn .Conv2d if not weight_standardization else WeightStandardizedConv2d
1497
+
1498
+ self .project = conv_klass (dim , dim_out , 3 , padding = 1 )
1476
1499
self .norm = nn .GroupNorm (groups , dim_out )
1477
1500
self .act = nn .SiLU ()
1478
1501
@@ -1496,6 +1519,7 @@ def __init__(
1496
1519
cond_dim = None ,
1497
1520
time_cond_dim = None ,
1498
1521
groups = 8 ,
1522
+ weight_standardization = False ,
1499
1523
cosine_sim_cross_attn = False
1500
1524
):
1501
1525
super ().__init__ ()
@@ -1521,8 +1545,8 @@ def __init__(
1521
1545
)
1522
1546
)
1523
1547
1524
- self .block1 = Block (dim , dim_out , groups = groups )
1525
- self .block2 = Block (dim_out , dim_out , groups = groups )
1548
+ self .block1 = Block (dim , dim_out , groups = groups , weight_standardization = weight_standardization )
1549
+ self .block2 = Block (dim_out , dim_out , groups = groups , weight_standardization = weight_standardization )
1526
1550
self .res_conv = nn .Conv2d (dim , dim_out , 1 ) if dim != dim_out else nn .Identity ()
1527
1551
1528
1552
def forward (self , x , time_emb = None , cond = None ):
@@ -1747,6 +1771,7 @@ def __init__(
1747
1771
init_dim = None ,
1748
1772
init_conv_kernel_size = 7 ,
1749
1773
resnet_groups = 8 ,
1774
+ resnet_weight_standardization = False ,
1750
1775
num_resnet_blocks = 2 ,
1751
1776
init_cross_embed = True ,
1752
1777
init_cross_embed_kernel_sizes = (3 , 7 , 15 ),
@@ -1894,7 +1919,7 @@ def __init__(
1894
1919
1895
1920
# prepare resnet klass
1896
1921
1897
- resnet_block = partial (ResnetBlock , cosine_sim_cross_attn = cosine_sim_cross_attn )
1922
+ resnet_block = partial (ResnetBlock , cosine_sim_cross_attn = cosine_sim_cross_attn , weight_standardization = resnet_weight_standardization )
1898
1923
1899
1924
# give memory efficient unet an initial resnet block
1900
1925
0 commit comments