File tree 3 files changed +7
-1
lines changed
3 files changed +7
-1
lines changed Original file line number Diff line number Diff line change @@ -760,6 +760,7 @@ def __init__(
760
760
dim_head = 64 ,
761
761
heads = 8 ,
762
762
ff_mult = 4 ,
763
+ norm_in = False ,
763
764
norm_out = True ,
764
765
attn_dropout = 0. ,
765
766
ff_dropout = 0. ,
@@ -768,6 +769,8 @@ def __init__(
768
769
rotary_emb = True
769
770
):
770
771
super ().__init__ ()
772
+ self .init_norm = LayerNorm (dim ) if norm_in else nn .Identity () # from latest BLOOM model and Yandex's YaLM
773
+
771
774
self .rel_pos_bias = RelPosBias (heads = heads )
772
775
773
776
rotary_emb = RotaryEmbedding (dim = min (32 , dim_head )) if rotary_emb else None
@@ -785,6 +788,8 @@ def __init__(
785
788
def forward (self , x ):
786
789
n , device = x .shape [1 ], x .device
787
790
791
+ x = self .init_norm (x )
792
+
788
793
attn_bias = self .rel_pos_bias (n , n + 1 , device = device )
789
794
790
795
for attn , ff in self .layers :
Original file line number Diff line number Diff line change @@ -137,6 +137,7 @@ class DiffusionPriorNetworkConfig(BaseModel):
137
137
dim_head : int = 64
138
138
heads : int = 8
139
139
ff_mult : int = 4
140
+ norm_in : bool = False
140
141
norm_out : bool = True
141
142
attn_dropout : float = 0.
142
143
ff_dropout : float = 0.
Original file line number Diff line number Diff line change 1
- __version__ = '0.23.1 '
1
+ __version__ = '0.23.2 '
You can’t perform that action at this time.
0 commit comments