Skip to content

Commit 40f4119

Browse files
committed
fix causal masking in sparse conv attention
1 parent 503139a commit 40f4119

File tree

2 files changed

+3
-2
lines changed

2 files changed

+3
-2
lines changed

dalle_pytorch/attention.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,8 @@ def forward(self, x, mask = None):
143143

144144
# mask image attention
145145

146-
mask = rearrange(img_seq, 'i -> () i ()') <= k_img_indices
146+
q_img_indices = rearrange(img_seq, 'i -> () i ()')
147+
mask = q_img_indices >= k_img_indices
147148

148149
# image can attend to all of text
149150

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'dalle-pytorch',
55
packages = find_packages(),
6-
version = '0.0.52',
6+
version = '0.0.53',
77
license='MIT',
88
description = 'DALL-E - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)