Skip to content

Commit 79e2a3b

Browse files
committed
only use the stable layernorm for final output norm in transformer
1 parent 544cdd0 commit 79e2a3b

File tree

2 files changed

+13
-7
lines changed

2 files changed

+13
-7
lines changed

dalle2_pytorch/dalle2_pytorch.py

+12-6
Original file line numberDiff line numberDiff line change
@@ -527,25 +527,31 @@ def p2_reweigh_loss(self, loss, times):
527527
# diffusion prior
528528

529529
class LayerNorm(nn.Module):
530-
def __init__(self, dim, eps = 1e-5):
530+
def __init__(self, dim, eps = 1e-5, stable = False):
531531
super().__init__()
532532
self.eps = eps
533+
self.stable = stable
533534
self.g = nn.Parameter(torch.ones(dim))
534535

535536
def forward(self, x):
536-
x = x / x.amax(dim = -1, keepdim = True).detach()
537+
if self.stable:
538+
x = x / x.amax(dim = -1, keepdim = True).detach()
539+
537540
var = torch.var(x, dim = -1, unbiased = False, keepdim = True)
538541
mean = torch.mean(x, dim = -1, keepdim = True)
539542
return (x - mean) * (var + self.eps).rsqrt() * self.g
540543

541544
class ChanLayerNorm(nn.Module):
542-
def __init__(self, dim, eps = 1e-5):
545+
def __init__(self, dim, eps = 1e-5, stable = False):
543546
super().__init__()
544547
self.eps = eps
548+
self.stable = stable
545549
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
546550

547551
def forward(self, x):
548-
x = x / x.amax(dim = 1, keepdim = True).detach()
552+
if self.stable:
553+
x = x / x.amax(dim = 1, keepdim = True).detach()
554+
549555
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
550556
mean = torch.mean(x, dim = 1, keepdim = True)
551557
return (x - mean) * (var + self.eps).rsqrt() * self.g
@@ -669,7 +675,7 @@ def __init__(
669675
dropout = 0.,
670676
causal = False,
671677
rotary_emb = None,
672-
pb_relax_alpha = 32 ** 2
678+
pb_relax_alpha = 128
673679
):
674680
super().__init__()
675681
self.pb_relax_alpha = pb_relax_alpha
@@ -782,7 +788,7 @@ def __init__(
782788
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout, post_activation_norm = normformer)
783789
]))
784790

785-
self.norm = LayerNorm(dim) if norm_out else nn.Identity() # unclear in paper whether they projected after the classic layer norm for the final denoised image embedding, or just had the transformer output it directly: plan on offering both options
791+
self.norm = LayerNorm(dim, stable = True) if norm_out else nn.Identity() # unclear in paper whether they projected after the classic layer norm for the final denoised image embedding, or just had the transformer output it directly: plan on offering both options
786792
self.project_out = nn.Linear(dim, dim, bias = False) if final_proj else nn.Identity()
787793

788794
def forward(self, x):

dalle2_pytorch/version.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '0.23.2'
1+
__version__ = '0.23.3'

0 commit comments

Comments
 (0)