@@ -122,20 +122,25 @@ def __init__(
122122 super ().__init__ ()
123123 assert isinstance (enformer , Enformer )
124124 self .enformer = enformer
125+ enformer_hidden_dim = enformer .dim * 2
126+
127+ self .query_norm = nn .LayerNorm (enformer_hidden_dim )
128+ self .key_values_norm = nn .LayerNorm (context_dim )
125129
126130 self .scale = dim_head ** - 0.5
127131 self .heads = heads
128132 inner_dim = heads * dim_head
129- self .to_queries = nn .Linear (enformer . dim * 2 , inner_dim )
133+ self .to_queries = nn .Linear (enformer_hidden_dim , inner_dim )
130134
131135 self .null_key = nn .Parameter (torch .randn (inner_dim ))
132136 self .null_value = nn .Parameter (torch .randn (inner_dim ))
133137
134138 self .to_key_values = nn .Linear (context_dim , inner_dim * 2 , bias = False )
139+ self .to_out = nn .Linear (inner_dim , enformer_hidden_dim )
135140
136- self .to_out = nn .Sequential (
137- nn .Linear (inner_dim , 1 ),
138- Rearrange ('c ... 1 -> ... c' ),
141+ self .to_pred = nn .Sequential (
142+ nn .Linear (enformer_hidden_dim , 1 ),
143+ Rearrange ('b c ... 1 -> b ... c' ),
139144 nn .Softplus ()
140145 )
141146
@@ -155,8 +160,8 @@ def forward(
155160 if context .ndim == 2 :
156161 context = rearrange (context , 'b d -> b 1 d' )
157162
158- q = self .to_queries (embeddings )
159- k , v = self .to_key_values (context ).chunk (2 , dim = - 1 )
163+ q = self .to_queries (self . query_norm ( embeddings ) )
164+ k , v = self .to_key_values (self . key_values_norm ( context ) ).chunk (2 , dim = - 1 )
160165
161166 null_k , null_v = map (lambda t : repeat (t , 'd -> b 1 d' , b = context .shape [0 ]), (self .null_key , self .null_value ))
162167
@@ -174,13 +179,21 @@ def forward(
174179
175180 # aggregate
176181
177- out = einsum ('b c h i j, c h j d -> c h i d' , attn , v )
182+ out = einsum ('b c h i j, c h j d -> b c h i d' , attn , v )
183+
184+ out = rearrange (out , 'b c h n d -> b c n (h d)' , h = h )
185+
186+ # combine heads
187+
188+ branch_out = self .to_out (out )
189+
190+ # residual
178191
179- out = rearrange ( out , 'c h n d -> c n (h d)' , h = h )
192+ embeddings = embeddings + branch_out
180193
181- # combine heads and project / softplus
194+ # to prediction
182195
183- pred = self .to_out ( out )
196+ pred = self .to_pred ( embeddings )
184197
185198 if not exists (target ):
186199 return pred
0 commit comments