Skip to content

Commit fabbe14

Browse files
committed
allow the network to pay more attention to memory later into training, if need be
1 parent c43995b commit fabbe14

File tree

3 files changed

+4
-2
lines changed

3 files changed

+4
-2
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
Implementation of <a href="https://arxiv.org/abs/2203.08913">Memorizing Transformers</a> (ICLR 2022), attention net augmented with indexing and retrieval of memories using approximate nearest neighbors, in Pytorch
66

7-
This repository deviates from the paper slightly, using a hybrid attention across attention logits local and distant (knn). It also uses cosine similarity attention (with learned temperature) for the KNN attention layer.
7+
This repository deviates from the paper slightly, using a hybrid attention across attention logits local and distant (rather than the sigmoid gate setup). It also uses cosine similarity attention (with learned temperature) for the KNN attention layer.
88

99
## Install
1010

memorizing_transformers_pytorch/memorizing_transformers_pytorch.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,7 @@ def __init__(
191191
super().__init__()
192192
self.heads = heads
193193
self.scale = nn.Parameter(torch.ones(heads, 1, 1) * math.log(attn_scale_init))
194+
self.knn_attn_bias = nn.Parameter(torch.zeros(heads, 1, 1))
194195

195196
inner_dim = heads * dim_head
196197
self.xl_max_memories = xl_max_memories
@@ -251,6 +252,7 @@ def forward(
251252
mem_k, mem_v = mem_kv.unbind(dim = -2)
252253

253254
sim_mem = einsum('b h i d, b h i j d -> b h i j', q, mem_k) * scale
255+
sim_mem = sim_mem + self.knn_attn_bias
254256
sim_mem = sim_mem.masked_fill(~mem_mask, mask_value)
255257

256258
# calculate new XL memories, as well as memories to be discarded

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'memorizing-transformers-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.3.4',
6+
version = '0.3.5',
77
license='MIT',
88
description = 'Memorizing Transformer - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)