Skip to content

Commit 3173b77

Browse files
committed
allow for dilation in sparse convolutional causal attention
1 parent eedaea9 commit 3173b77

File tree

2 files changed

+10
-6
lines changed

2 files changed

+10
-6
lines changed

dalle_pytorch/attention.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def forward(self, x, mask = None):
7272
return out
7373

7474
class SparseConvCausalAttention(nn.Module):
75-
def __init__(self, dim, seq_len, image_size = 32, kernel_size = 5, heads = 8, dim_head = 64, dropout = 0., **kwargs):
75+
def __init__(self, dim, seq_len, image_size = 32, kernel_size = 5, dilation = 0, heads = 8, dim_head = 64, dropout = 0., **kwargs):
7676
super().__init__()
7777
assert kernel_size % 2 == 1, 'kernel size must be odd'
7878

@@ -81,6 +81,7 @@ def __init__(self, dim, seq_len, image_size = 32, kernel_size = 5, heads = 8, di
8181
self.scale = dim_head ** -0.5
8282
self.image_size = image_size
8383
self.kernel_size = kernel_size
84+
self.dilation = dilation
8485

8586
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
8687

@@ -90,7 +91,7 @@ def __init__(self, dim, seq_len, image_size = 32, kernel_size = 5, heads = 8, di
9091
)
9192

9293
def forward(self, x, mask = None):
93-
b, n, _, h, img_size, kernel_size, device = *x.shape, self.heads, self.image_size, self.kernel_size, x.device
94+
b, n, _, h, img_size, kernel_size, dilation, device = *x.shape, self.heads, self.image_size, self.kernel_size, self.dilation, x.device
9495
qkv = self.to_qkv(x).chunk(3, dim = -1)
9596
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), qkv)
9697

@@ -114,8 +115,11 @@ def forward(self, x, mask = None):
114115

115116
# image attention
116117

118+
effective_kernel_size = (kernel_size - 1) * dilation + 1
119+
padding = effective_kernel_size // 2
120+
117121
k_img, v_img = map(lambda t: rearrange(t, 'b (h w) c -> b c h w', h = img_size), (k_img, v_img))
118-
k_img, v_img = map(lambda t: F.unfold(t, kernel_size, padding = (kernel_size // 2)), (k_img, v_img))
122+
k_img, v_img = map(lambda t: F.unfold(t, kernel_size, padding = padding, dilation = dilation), (k_img, v_img))
119123
k_img, v_img = map(lambda t: rearrange(t, 'b (j d) i -> b i j d', j = kernel_size ** 2), (k_img, v_img))
120124

121125
k_text, v_text = map(lambda t: repeat(t, 'b j d -> b i j d', i = img_seq_len), (k_text, v_text))
@@ -132,8 +136,8 @@ def forward(self, x, mask = None):
132136
i, j = dots_image.shape[-2:]
133137
img_seq = torch.arange(img_seq_len, device = device)
134138
k_img_indices = rearrange(img_seq.float(), '(h w) -> () () h w', h = img_size)
135-
k_img_indices = F.pad(k_img_indices, (kernel_size // 2,) * 4, value = img_seq_len)
136-
k_img_indices = F.unfold(k_img_indices, kernel_size)
139+
k_img_indices = F.pad(k_img_indices, (padding,) * 4, value = img_seq_len) # padding set to be max, so it is never attended to
140+
k_img_indices = F.unfold(k_img_indices, kernel_size, dilation = dilation)
137141
k_img_indices = rearrange(k_img_indices, 'b j i -> b i j')
138142

139143
# mask image attention

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'dalle-pytorch',
55
packages = find_packages(),
6-
version = '0.0.51',
6+
version = '0.0.52',
77
license='MIT',
88
description = 'DALL-E - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)