Skip to content

Commit 272e39b

Browse files
committed
address #16
1 parent c44f755 commit 272e39b

File tree

2 files changed

+3
-7
lines changed

2 files changed

+3
-7
lines changed

memorizing_transformers_pytorch/memorizing_transformers_pytorch.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,6 @@ def cast_tuple(val, length = 1):
3434
def l2norm(t):
3535
return F.normalize(t, dim = -1)
3636

37-
def stable_softmax(t, dim = -1):
38-
t = t - t.amax(dim = dim, keepdim = True).detach()
39-
return F.softmax(t, dim = dim)
40-
4137
# helper classes
4238

4339
class PreNormResidual(nn.Module):
@@ -157,7 +153,7 @@ def forward(self, x, *, xl_memory = None, rel_pos_bias = None):
157153
causal_mask = torch.ones((i, j), dtype = torch.bool, device = device).triu(j - i + 1)
158154
sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)
159155

160-
attn = stable_softmax(sim)
156+
attn = sim.softmax(dim = -1)
161157
attn = self.dropout(attn)
162158

163159
out = einsum('b h i j, b j d -> b h i d', attn, v)
@@ -273,7 +269,7 @@ def forward(
273269
# attention (combining local and distant)
274270

275271
sim = torch.cat((sim_mem, sim), dim = -1)
276-
attn = stable_softmax(sim)
272+
attn = sim.softmax(dim = -1)
277273
attn = self.dropout(attn)
278274

279275
local_attn, mem_attn = attn[..., self.num_retrieved_memories:], attn[..., :self.num_retrieved_memories]

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'memorizing-transformers-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.4.0',
6+
version = '0.4.1',
77
license='MIT',
88
description = 'Memorizing Transformer - Pytorch',
99
long_description_content_type = 'text/markdown',

0 commit comments

Comments
 (0)