Skip to content

Commit aaf9a0a

Browse files
committed
give knn attention layer one more way to tune out local if need be
1 parent fabbe14 commit aaf9a0a

File tree

2 files changed

+3
-1
lines changed

2 files changed

+3
-1
lines changed

memorizing_transformers_pytorch/memorizing_transformers_pytorch.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,7 @@ def __init__(
191191
super().__init__()
192192
self.heads = heads
193193
self.scale = nn.Parameter(torch.ones(heads, 1, 1) * math.log(attn_scale_init))
194+
self.local_attn_bias = nn.Parameter(torch.zeros(heads, 1, 1))
194195
self.knn_attn_bias = nn.Parameter(torch.zeros(heads, 1, 1))
195196

196197
inner_dim = heads * dim_head
@@ -241,6 +242,7 @@ def forward(
241242
if exists(rel_pos_bias):
242243
sim = rel_pos_bias[..., -i:, -j:] + sim
243244

245+
sim = sim + self.local_attn_bias
244246
mask_value = -torch.finfo(sim.dtype).max
245247

246248
causal_mask = torch.ones((i, j), dtype = torch.bool, device = device).triu(j - i + 1)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'memorizing-transformers-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.3.5',
6+
version = '0.3.6',
77
license='MIT',
88
description = 'Memorizing Transformer - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)