@@ -191,8 +191,6 @@ def __init__(
191
191
super ().__init__ ()
192
192
self .heads = heads
193
193
self .scale = nn .Parameter (torch .ones (heads , 1 , 1 ) * math .log (attn_scale_init ))
194
- self .local_attn_bias = nn .Parameter (torch .zeros (heads , 1 , 1 ))
195
- self .knn_attn_bias = nn .Parameter (torch .zeros (heads , 1 , 1 ))
196
194
197
195
inner_dim = heads * dim_head
198
196
self .xl_max_memories = xl_max_memories
@@ -242,7 +240,6 @@ def forward(
242
240
if exists (rel_pos_bias ):
243
241
sim = rel_pos_bias [..., - i :, - j :] + sim
244
242
245
- sim = sim + self .local_attn_bias
246
243
mask_value = - torch .finfo (sim .dtype ).max
247
244
248
245
causal_mask = torch .ones ((i , j ), dtype = torch .bool , device = device ).triu (j - i + 1 )
@@ -254,7 +251,6 @@ def forward(
254
251
mem_k , mem_v = mem_kv .unbind (dim = - 2 )
255
252
256
253
sim_mem = einsum ('b h i d, b h i j d -> b h i j' , q , mem_k ) * scale
257
- sim_mem = sim_mem + self .knn_attn_bias
258
254
sim_mem = sim_mem .masked_fill (~ mem_mask , mask_value )
259
255
260
256
# calculate new XL memories, as well as memories to be discarded
@@ -356,6 +352,7 @@ def __init__(
356
352
# relative positional bias
357
353
358
354
self .rel_pos_bias = T5RelativePositionBias (scale = dim_head ** 0.5 , heads = heads )
355
+ self .knn_rel_pos_bias = T5RelativePositionBias (scale = dim_head ** 0.5 , heads = heads )
359
356
360
357
# layers
361
358
@@ -481,6 +478,7 @@ def forward(
481
478
max_context_len = max ([seq_len , * map (lambda t : (t .shape [- 3 ] if exists (t ) else 0 ) + seq_len , xl_memories )])
482
479
483
480
rel_pos_bias = self .rel_pos_bias (seq_len , max_context_len , device = device )
481
+ knn_rel_pos_bias = self .knn_rel_pos_bias (seq_len , max_context_len , device = device )
484
482
485
483
# keep track of new xl memories
486
484
@@ -494,7 +492,7 @@ def forward(
494
492
is_memorizing_layer = layer_num in self .memorizing_layers
495
493
is_xl_memory_layer = layer_num in self .xl_memory_layers
496
494
497
- attn_kwargs = dict (rel_pos_bias = rel_pos_bias )
495
+ attn_kwargs = dict (rel_pos_bias = rel_pos_bias if not is_memorizing_layer else knn_rel_pos_bias )
498
496
499
497
if is_memorizing_layer :
500
498
attn_kwargs = {** attn_kwargs , 'knn_memory' : next (knn_memories_iter ), 'add_knn_memory' : add_knn_memory }
0 commit comments