Skip to content

Commit 3064403

Browse files
committed
switch to causal convolution for conv-like attention using unfold
1 parent 0cf55f2 commit 3064403

File tree

2 files changed

+9
-8
lines changed

2 files changed

+9
-8
lines changed

dalle_pytorch/attention.py

+8-7
Original file line numberDiff line numberDiff line change
@@ -164,10 +164,12 @@ def forward(self, x, mask = None, rotary_pos_emb = None):
164164
# image attention
165165

166166
effective_kernel_size = (kernel_size - 1) * dilation + 1
167-
padding = effective_kernel_size // 2
167+
same_padding = effective_kernel_size // 2
168+
causal_padding = (same_padding * 2, 0, same_padding * 2, 0)
168169

169170
k_img, v_img = map(lambda t: rearrange(t, 'b (h w) c -> b c h w', h = img_size), (k_img, v_img))
170-
k_img, v_img = map(lambda t: F.unfold(t, kernel_size, padding = padding, dilation = dilation), (k_img, v_img))
171+
k_img, v_img = map(lambda t: F.pad(t, causal_padding), (k_img, v_img))
172+
k_img, v_img = map(lambda t: F.unfold(t, kernel_size, dilation = dilation), (k_img, v_img))
171173
k_img, v_img = map(lambda t: rearrange(t, 'b (d j) i -> b i j d', j = kernel_size ** 2), (k_img, v_img))
172174

173175
# let image attend to all of text
@@ -180,20 +182,19 @@ def forward(self, x, mask = None, rotary_pos_emb = None):
180182
i, j = dots_image.shape[-2:]
181183
img_seq = torch.arange(img_seq_len, device = device)
182184
k_img_indices = rearrange(img_seq.float(), '(h w) -> () () h w', h = img_size)
183-
k_img_indices = F.pad(k_img_indices, (padding,) * 4, value = img_seq_len) # padding set to be max, so it is never attended to
185+
k_img_indices = F.pad(k_img_indices, causal_padding, value = img_seq_len) # padding set to be max, so it is never attended to
184186
k_img_indices = F.unfold(k_img_indices, kernel_size, dilation = dilation)
185187
k_img_indices = rearrange(k_img_indices, 'b j i -> b i j')
186188

187189
# mask image attention
188190

189-
q_img_indices = rearrange(img_seq, 'i -> () i ()')
190-
causal_mask = q_img_indices < k_img_indices
191+
padding_mask = k_img_indices == img_seq_len
191192

192193
# concat text mask with image causal mask
193194

194-
causal_mask = repeat(causal_mask, '() i j -> b i j', b = b * h)
195+
padding_mask = repeat(padding_mask, '() i j -> b i j', b = b * h)
195196
mask = repeat(mask, 'b j -> (b h) i j', i = i, h = h)
196-
mask = torch.cat((~mask, causal_mask), dim = -1)
197+
mask = torch.cat((~mask, padding_mask), dim = -1)
197198

198199
# image can attend to all of text
199200

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
name = 'dalle-pytorch',
55
packages = find_packages(),
66
include_package_data = True,
7-
version = '1.4.2',
7+
version = '1.5.0',
88
license='MIT',
99
description = 'DALL-E - Pytorch',
1010
author = 'Phil Wang',

0 commit comments

Comments
 (0)