Skip to content

Commit a3855fc

Browse files
committed
make sure sparse conv-like attention and sparse axial attention takes into account <bos> token
1 parent 9bf1f0f commit a3855fc

File tree

2 files changed

+46
-24
lines changed

2 files changed

+46
-24
lines changed

dalle_pytorch/attention.py

+45-23
Original file line numberDiff line numberDiff line change
@@ -89,19 +89,24 @@ def __init__(self, dim, seq_len, image_size = 32, kernel_size = 5, dilation = 1,
8989
def forward(self, x, mask = None):
9090
b, n, _, h, img_size, kernel_size, dilation, seq_len, device = *x.shape, self.heads, self.image_size, self.kernel_size, self.dilation, self.seq_len, x.device
9191

92-
if n < seq_len:
93-
padding = seq_len - n
94-
mask = default(mask, lambda: torch.ones(b, n, device = device).bool())
95-
x = F.pad(x, (0, 0, 0, padding), value = 0)
96-
mask = F.pad(x, (0, padding), value = False)
92+
img_seq_len = img_size ** 2
93+
text_len = seq_len + 1 - img_seq_len
94+
95+
# padding
96+
97+
padding = seq_len - n + 1
98+
mask = default(mask, lambda: torch.ones(b, text_len, device = device).bool())
99+
100+
x = F.pad(x, (0, 0, 0, padding), value = 0)
101+
mask = mask[:, :text_len]
102+
103+
# derive query / keys / values
97104

98105
qkv = self.to_qkv(x).chunk(3, dim = -1)
99106
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), qkv)
100107

101108
q *= self.scale
102109

103-
img_seq_len = img_size ** 2
104-
text_len = seq_len - img_seq_len
105110
((q_text, q_img), (k_text, k_img), (v_text, v_img)) = map(lambda t: (t[:, img_seq_len:], t[:, -img_seq_len:]), (q, k, v))
106111

107112
# text attention
@@ -110,8 +115,8 @@ def forward(self, x, mask = None):
110115
mask_value = max_neg_value(dots_text)
111116

112117
i, j = dots_text.shape[-2:]
113-
mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool()
114-
dots_text.masked_fill(mask, mask_value)
118+
text_causal_mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool()
119+
dots_text.masked_fill(text_causal_mask, mask_value)
115120

116121
attn_text = dots_text.softmax(dim = -1)
117122
out_text = einsum('b i j, b j d -> b i d', attn_text, v_text)
@@ -146,11 +151,16 @@ def forward(self, x, mask = None):
146151
# mask image attention
147152

148153
q_img_indices = rearrange(img_seq, 'i -> () i ()')
149-
mask = q_img_indices >= k_img_indices
154+
causal_mask = q_img_indices >= k_img_indices
155+
156+
# concat text mask with image causal mask
157+
158+
causal_mask = repeat(causal_mask, '() i j -> b i j', b = b * h)
159+
mask = repeat(mask, 'b j -> (b h) i j', i = i, h = h)
160+
mask = torch.cat((mask, causal_mask), dim = -1)
150161

151162
# image can attend to all of text
152163

153-
mask = F.pad(mask, (text_len, 0), value = True)
154164
dots_image.masked_fill_(~mask, mask_value)
155165

156166
attn_image = dots_image.softmax(dim = -1)
@@ -188,20 +198,24 @@ def __init__(self, dim, seq_len, image_size = 32, axis = 0, heads = 8, dim_head
188198
def forward(self, x, mask = None):
189199
b, n, _, h, img_size, axis, seq_len, device = *x.shape, self.heads, self.image_size, self.axis, self.seq_len, x.device
190200

191-
if n < seq_len:
192-
padding = seq_len - n
193-
mask = default(mask, lambda: torch.ones(b, n, device = device).bool())
194-
x = F.pad(x, (0, 0, 0, padding), value = 0)
195-
mask = F.pad(x, (0, padding), value = False)
201+
img_seq_len = img_size ** 2
202+
text_len = seq_len + 1 - img_seq_len
203+
204+
# padding
205+
206+
padding = seq_len - n + 1
207+
mask = default(mask, lambda: torch.ones(b, text_len, device = device).bool())
208+
209+
x = F.pad(x, (0, 0, 0, padding), value = 0)
210+
mask = mask[:, :text_len]
211+
212+
# derive queries / keys / values
196213

197214
qkv = self.to_qkv(x).chunk(3, dim = -1)
198215
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), qkv)
199216

200217
q *= self.scale
201218

202-
img_seq_len = img_size ** 2
203-
text_len = seq_len - img_seq_len
204-
205219
((q_text, q_img), (k_text, k_img), (v_text, v_img)) = map(lambda t: (t[:, img_seq_len:], t[:, -img_seq_len:]), (q, k, v))
206220

207221
# text attention
@@ -210,8 +224,8 @@ def forward(self, x, mask = None):
210224
mask_value = max_neg_value(dots_text)
211225

212226
i, j = dots_text.shape[-2:]
213-
mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool()
214-
dots_text.masked_fill(mask, mask_value)
227+
text_causal_mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool()
228+
dots_text.masked_fill(text_causal_mask, mask_value)
215229

216230
attn_text = dots_text.softmax(dim = -1)
217231
out_text = einsum('b i j, b j d -> b i d', attn_text, v_text)
@@ -237,13 +251,21 @@ def forward(self, x, mask = None):
237251

238252
# mask so image has full attention to text, but causal along axis
239253

240-
i, j = dots_image.shape[-2:]
241-
mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool()
254+
bh, i, j = dots_image.shape
255+
causal_mask = torch.ones(i, img_size, device = device).triu_(img_size - i + 1).bool()
256+
causal_mask = repeat(causal_mask, 'i j -> b i j', b = bh)
257+
258+
mask = repeat(mask, 'b j -> (b r) i j', r = (bh // b), i = i)
259+
mask = torch.cat((~mask, causal_mask), dim = -1)
260+
242261
dots_image.masked_fill_(mask, mask_value)
243262

244263
# attention.
245264

246265
attn_image = dots_image.softmax(dim = -1)
266+
267+
# aggregate
268+
247269
out_image = einsum('b i j, b j d -> b i d', attn_image, v_img)
248270

249271
# merge back axis

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'dalle-pytorch',
55
packages = find_packages(),
6-
version = '0.0.61',
6+
version = '0.0.62',
77
license='MIT',
88
description = 'DALL-E - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)