Skip to content

Commit 4d008eb

Browse files
committed
always use mask if sparse axial or conv attention needs to be padded
1 parent de732e8 commit 4d008eb

File tree

2 files changed

+5
-6
lines changed

2 files changed

+5
-6
lines changed

dalle_pytorch/attention.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -97,9 +97,9 @@ def forward(self, x, mask = None):
9797

9898
if n < seq_len:
9999
padding = seq_len - n
100+
mask = default(mask, lambda: torch.ones(b, n, device = device).bool())
100101
x = F.pad(x, (0, 0, 0, padding), value = 0)
101-
if exists(mask):
102-
mask = F.pad(x, (0, padding), value = False)
102+
mask = F.pad(x, (0, padding), value = False)
103103

104104
qkv = self.to_qkv(x).chunk(3, dim = -1)
105105
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), qkv)
@@ -196,10 +196,9 @@ def forward(self, x, mask = None):
196196

197197
if n < seq_len:
198198
padding = seq_len - n
199+
mask = default(mask, lambda: torch.ones(b, n, device = device).bool())
199200
x = F.pad(x, (0, 0, 0, padding), value = 0)
200-
201-
if exists(mask):
202-
mask = F.pad(x, (0, padding), value = False)
201+
mask = F.pad(x, (0, padding), value = False)
203202

204203
qkv = self.to_qkv(x).chunk(3, dim = -1)
205204
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), qkv)

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'dalle-pytorch',
55
packages = find_packages(),
6-
version = '0.0.55',
6+
version = '0.0.56',
77
license='MIT',
88
description = 'DALL-E - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)