Skip to content

Commit de732e8

Browse files
committed
complete sparse attention integration with transformer, user can now define which types of sparse attention to cycle between down the layers
1 parent 30a13eb commit de732e8

File tree

5 files changed

+87
-15
lines changed

5 files changed

+87
-15
lines changed

README.md

+29-1
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,34 @@ dalle = DALLE(
152152

153153
## Sparse Attention
154154

155+
The blogpost alluded to a mixture of different types of sparse attention, used mainly on the image (while the text presumably had full causal attention). I have done my best to replicate these types of sparse attention, on the scant details released. Primarily, it seems as though they are doing causal axial row / column attention, combined with a causal convolution-like attention.
156+
157+
By default `DALLE` will use full attention for all layers, but you can specify the attention types as follows.
158+
159+
- `full` stands for full attention
160+
161+
- `axial_row` axial attention, along the rows of the image feature map
162+
163+
- `axial_col` axial attention, along the columns of the image feature map
164+
165+
- `conv_like` convolution-like attention, for the image feature map
166+
167+
168+
```python
169+
d = DALLE(
170+
dim = 1024,
171+
vae = vae,
172+
num_text_tokens = 10000,
173+
text_seq_len = 256,
174+
depth = 64,
175+
heads = 16,
176+
reversible = True,
177+
attn_types = ['full', 'axial_row', 'axial_col', 'conv_like'] # cycles between these four types of attention
178+
)
179+
```
180+
181+
## Deepspeed Sparse Attention
182+
155183
You can also train with Microsoft Deepspeed's <a href="https://www.deepspeed.ai/news/2020/09/08/sparse-attention.html">Sparse Attention</a>, with any combination of dense and sparse attention that you'd like. However, you will have to endure the installation process.
156184

157185
First, you need to install Deepspeed with Sparse Attention
@@ -176,7 +204,7 @@ dalle = DALLE(
176204
text_seq_len = 256,
177205
depth = 64,
178206
heads = 8,
179-
sparse_attn = (True, False) * 32 # interleave sparse and dense attention for 64 layers
207+
attn_types = ('full', 'sparse') # interleave sparse and dense attention for 64 layers
180208
)
181209
```
182210

dalle_pytorch/attention.py

+24-6
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ def __init__(self, dim, seq_len, image_size = 32, kernel_size = 5, dilation = 1,
7878
assert kernel_size % 2 == 1, 'kernel size must be odd'
7979

8080
inner_dim = dim_head * heads
81+
self.seq_len = seq_len
8182
self.heads = heads
8283
self.scale = dim_head ** -0.5
8384
self.image_size = image_size
@@ -92,14 +93,21 @@ def __init__(self, dim, seq_len, image_size = 32, kernel_size = 5, dilation = 1,
9293
)
9394

9495
def forward(self, x, mask = None):
95-
b, n, _, h, img_size, kernel_size, dilation, device = *x.shape, self.heads, self.image_size, self.kernel_size, self.dilation, x.device
96+
b, n, _, h, img_size, kernel_size, dilation, seq_len, device = *x.shape, self.heads, self.image_size, self.kernel_size, self.dilation, self.seq_len, x.device
97+
98+
if n < seq_len:
99+
padding = seq_len - n
100+
x = F.pad(x, (0, 0, 0, padding), value = 0)
101+
if exists(mask):
102+
mask = F.pad(x, (0, padding), value = False)
103+
96104
qkv = self.to_qkv(x).chunk(3, dim = -1)
97105
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), qkv)
98106

99107
q *= self.scale
100108

101109
img_seq_len = img_size ** 2
102-
text_len = n - img_seq_len
110+
text_len = seq_len - img_seq_len
103111
((q_text, q_img), (k_text, k_img), (v_text, v_img)) = map(lambda t: (t[:, img_seq_len:], t[:, -img_seq_len:]), (q, k, v))
104112

105113
# text attention
@@ -160,7 +168,7 @@ def forward(self, x, mask = None):
160168

161169
out = rearrange(out, '(b h) n d -> b n (h d)', h = h)
162170
out = self.to_out(out)
163-
return out
171+
return out[:, :n]
164172

165173
# sparse axial causal attention
166174

@@ -171,6 +179,7 @@ def __init__(self, dim, seq_len, image_size = 32, axis = 0, heads = 8, dim_head
171179
self.axis = axis
172180

173181
inner_dim = dim_head * heads
182+
self.seq_len = seq_len
174183
self.heads = heads
175184
self.scale = dim_head ** -0.5
176185
self.image_size = image_size
@@ -183,14 +192,23 @@ def __init__(self, dim, seq_len, image_size = 32, axis = 0, heads = 8, dim_head
183192
)
184193

185194
def forward(self, x, mask = None):
186-
b, n, _, h, img_size, axis, device = *x.shape, self.heads, self.image_size, self.axis, x.device
195+
b, n, _, h, img_size, axis, seq_len, device = *x.shape, self.heads, self.image_size, self.axis, self.seq_len, x.device
196+
197+
if n < seq_len:
198+
padding = seq_len - n
199+
x = F.pad(x, (0, 0, 0, padding), value = 0)
200+
201+
if exists(mask):
202+
mask = F.pad(x, (0, padding), value = False)
203+
187204
qkv = self.to_qkv(x).chunk(3, dim = -1)
188205
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), qkv)
189206

190207
q *= self.scale
191208

192209
img_seq_len = img_size ** 2
193-
text_len = n - img_seq_len
210+
text_len = seq_len - img_seq_len
211+
194212
((q_text, q_img), (k_text, k_img), (v_text, v_img)) = map(lambda t: (t[:, img_seq_len:], t[:, -img_seq_len:]), (q, k, v))
195213

196214
# text attention
@@ -245,7 +263,7 @@ def forward(self, x, mask = None):
245263

246264
out = rearrange(out, '(b h) n d -> b n (h d)', h = h)
247265
out = self.to_out(out)
248-
return out
266+
return out[:, :n]
249267

250268
# microsoft sparse attention CUDA kernel
251269

dalle_pytorch/dalle_pytorch.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -257,14 +257,16 @@ def __init__(
257257
sparse_attn = False,
258258
noncausal_attn_len = 0,
259259
ignore_index = -100,
260-
tie_codebook_image_emb = False
260+
attn_types = None,
261+
tie_codebook_image_emb = False,
261262
):
262263
super().__init__()
263264
assert isinstance(vae, DiscreteVAE), 'vae must be an instance of DiscreteVAE'
264265

265266
image_size = vae.image_size
266267
num_image_tokens = vae.num_tokens
267-
image_seq_len = (vae.image_size // (2 ** vae.num_layers)) ** 2
268+
image_fmap_size = (vae.image_size // (2 ** vae.num_layers))
269+
image_seq_len = image_fmap_size ** 2
268270

269271
self.text_emb = nn.Embedding(num_text_tokens, dim)
270272
self.image_emb = nn.Embedding(num_image_tokens, dim)
@@ -304,6 +306,8 @@ def __init__(
304306
attn_dropout = attn_dropout,
305307
ff_dropout = ff_dropout,
306308
noncausal_attn_len = (noncausal_attn_len + 1),
309+
attn_types = attn_types,
310+
image_fmap_size = image_fmap_size,
307311
sparse_attn = sparse_attn,
308312
sparse_attn_global_indices = range(text_seq_len)
309313
)

dalle_pytorch/transformer.py

+27-5
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,22 @@
11
from functools import partial
2+
from itertools import islice, cycle
23

34
import torch
45
from torch import nn, einsum
56
import torch.nn.functional as F
6-
from einops import rearrange, repeat
7+
from einops import rearrange
78

89
from dalle_pytorch.reversible import ReversibleSequence, SequentialSequence
910
from dalle_pytorch.attention import Attention, SparseAttention, SparseConvCausalAttention, SparseAxialCausalAttention
1011

1112
# helpers
1213

14+
def exists(val):
15+
return val is not None
16+
17+
def default(val, d):
18+
return val if exists(val) else d
19+
1320
def cast_tuple(val, depth):
1421
return val if isinstance(val, tuple) else (val,) * depth
1522

@@ -57,18 +64,33 @@ def __init__(
5764
attn_dropout = 0.,
5865
ff_dropout = 0.,
5966
noncausal_attn_len = 0,
67+
attn_types = None,
68+
image_fmap_size = None,
6069
sparse_attn = False,
6170
sparse_attn_global_indices = []
6271
):
6372
super().__init__()
6473
layers = nn.ModuleList([])
6574
sparse_layer = cast_tuple(sparse_attn, depth)
66-
67-
for _, sparse_attn in zip(range(depth), sparse_layer):
68-
attn_class = Attention if not sparse_attn else partial(SparseAttention, sparse_attn_global_indices = sparse_attn_global_indices)
75+
attn_types = default(attn_types, ('full',))
76+
attn_type_layer = islice(cycle(attn_types), depth)
77+
78+
for _, sparse_attn, attn_type in zip(range(depth), sparse_layer, attn_type_layer):
79+
if attn_type == 'full':
80+
attn_class = partial(Attention, noncausal_attn_len = noncausal_attn_len)
81+
elif attn_type == 'sparse':
82+
attn_class = partial(SparseAttention, sparse_attn_global_indices = sparse_attn_global_indices)
83+
elif attn_type == 'axial_row':
84+
attn_class = partial(SparseAxialCausalAttention, seq_len = seq_len, axis = 0, image_size = image_fmap_size)
85+
elif attn_type == 'axial_col':
86+
attn_class = partial(SparseAxialCausalAttention, seq_len = seq_len, axis = 1, image_size = image_fmap_size)
87+
elif attn_type == 'conv_like':
88+
attn_class = partial(SparseConvCausalAttention, seq_len = seq_len, image_size = image_fmap_size)
89+
else:
90+
raise ValueError(f'attention type "{attn_type}" is not valid')
6991

7092
layers.append(nn.ModuleList([
71-
PreNorm(dim, attn_class(dim, causal = causal, seq_len = seq_len, heads = heads, dim_head = dim_head, dropout = attn_dropout, noncausal_attn_len = noncausal_attn_len)),
93+
PreNorm(dim, attn_class(dim, causal = causal, seq_len = seq_len, heads = heads, dim_head = dim_head, dropout = attn_dropout)),
7294
PreNorm(dim, FeedForward(dim, mult = ff_mult, dropout = ff_dropout))
7395
]))
7496

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'dalle-pytorch',
55
packages = find_packages(),
6-
version = '0.0.54',
6+
version = '0.0.55',
77
license='MIT',
88
description = 'DALL-E - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)