Skip to content

Commit 349aaca

Browse files
committed
add yet another transformer stability measure
1 parent 3ee3c56 commit 349aaca

File tree

3 files changed

+7
-1
lines changed

3 files changed

+7
-1
lines changed

dalle2_pytorch/dalle2_pytorch.py

+5
Original file line numberDiff line numberDiff line change
@@ -760,6 +760,7 @@ def __init__(
760760
dim_head = 64,
761761
heads = 8,
762762
ff_mult = 4,
763+
norm_in = False,
763764
norm_out = True,
764765
attn_dropout = 0.,
765766
ff_dropout = 0.,
@@ -768,6 +769,8 @@ def __init__(
768769
rotary_emb = True
769770
):
770771
super().__init__()
772+
self.init_norm = LayerNorm(dim) if norm_in else nn.Identity() # from latest BLOOM model and Yandex's YaLM
773+
771774
self.rel_pos_bias = RelPosBias(heads = heads)
772775

773776
rotary_emb = RotaryEmbedding(dim = min(32, dim_head)) if rotary_emb else None
@@ -785,6 +788,8 @@ def __init__(
785788
def forward(self, x):
786789
n, device = x.shape[1], x.device
787790

791+
x = self.init_norm(x)
792+
788793
attn_bias = self.rel_pos_bias(n, n + 1, device = device)
789794

790795
for attn, ff in self.layers:

dalle2_pytorch/train_configs.py

+1
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ class DiffusionPriorNetworkConfig(BaseModel):
137137
dim_head: int = 64
138138
heads: int = 8
139139
ff_mult: int = 4
140+
norm_in: bool = False
140141
norm_out: bool = True
141142
attn_dropout: float = 0.
142143
ff_dropout: float = 0.

dalle2_pytorch/version.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '0.23.1'
1+
__version__ = '0.23.2'

0 commit comments

Comments
 (0)