Skip to content

Commit 83fa147

Browse files
committed
just give knn attention its own relative positional bias
1 parent aaf9a0a commit 83fa147

File tree

2 files changed

+4
-6
lines changed

2 files changed

+4
-6
lines changed

memorizing_transformers_pytorch/memorizing_transformers_pytorch.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -191,8 +191,6 @@ def __init__(
191191
super().__init__()
192192
self.heads = heads
193193
self.scale = nn.Parameter(torch.ones(heads, 1, 1) * math.log(attn_scale_init))
194-
self.local_attn_bias = nn.Parameter(torch.zeros(heads, 1, 1))
195-
self.knn_attn_bias = nn.Parameter(torch.zeros(heads, 1, 1))
196194

197195
inner_dim = heads * dim_head
198196
self.xl_max_memories = xl_max_memories
@@ -242,7 +240,6 @@ def forward(
242240
if exists(rel_pos_bias):
243241
sim = rel_pos_bias[..., -i:, -j:] + sim
244242

245-
sim = sim + self.local_attn_bias
246243
mask_value = -torch.finfo(sim.dtype).max
247244

248245
causal_mask = torch.ones((i, j), dtype = torch.bool, device = device).triu(j - i + 1)
@@ -254,7 +251,6 @@ def forward(
254251
mem_k, mem_v = mem_kv.unbind(dim = -2)
255252

256253
sim_mem = einsum('b h i d, b h i j d -> b h i j', q, mem_k) * scale
257-
sim_mem = sim_mem + self.knn_attn_bias
258254
sim_mem = sim_mem.masked_fill(~mem_mask, mask_value)
259255

260256
# calculate new XL memories, as well as memories to be discarded
@@ -356,6 +352,7 @@ def __init__(
356352
# relative positional bias
357353

358354
self.rel_pos_bias = T5RelativePositionBias(scale = dim_head ** 0.5, heads = heads)
355+
self.knn_rel_pos_bias = T5RelativePositionBias(scale = dim_head ** 0.5, heads = heads)
359356

360357
# layers
361358

@@ -481,6 +478,7 @@ def forward(
481478
max_context_len = max([seq_len, *map(lambda t: (t.shape[-3] if exists(t) else 0) + seq_len, xl_memories)])
482479

483480
rel_pos_bias = self.rel_pos_bias(seq_len, max_context_len, device = device)
481+
knn_rel_pos_bias = self.knn_rel_pos_bias(seq_len, max_context_len, device = device)
484482

485483
# keep track of new xl memories
486484

@@ -494,7 +492,7 @@ def forward(
494492
is_memorizing_layer = layer_num in self.memorizing_layers
495493
is_xl_memory_layer = layer_num in self.xl_memory_layers
496494

497-
attn_kwargs = dict(rel_pos_bias = rel_pos_bias)
495+
attn_kwargs = dict(rel_pos_bias = rel_pos_bias if not is_memorizing_layer else knn_rel_pos_bias)
498496

499497
if is_memorizing_layer:
500498
attn_kwargs = {**attn_kwargs, 'knn_memory': next(knn_memories_iter), 'add_knn_memory': add_knn_memory}

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'memorizing-transformers-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.3.6',
6+
version = '0.3.7',
77
license='MIT',
88
description = 'Memorizing Transformer - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)