Skip to content

Commit 5c00ebc

Browse files
committed
when using sparse attention, make sure text receives global attention
1 parent 164d176 commit 5c00ebc

File tree

3 files changed

+17
-7
lines changed

3 files changed

+17
-7
lines changed

dalle_pytorch/dalle_pytorch.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,8 @@ def __init__(
292292
reversible = reversible,
293293
attn_dropout = attn_dropout,
294294
ff_dropout = ff_dropout,
295-
sparse_attn = sparse_attn
295+
sparse_attn = sparse_attn,
296+
sparse_attn_global_indices = range(text_seq_len)
296297
)
297298

298299
self.to_logits = nn.Sequential(

dalle_pytorch/transformer.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1-
import torch
1+
from functools import partial
22
from inspect import isfunction
3+
4+
import torch
35
from torch import nn, einsum
46
import torch.nn.functional as F
57
from einops import rearrange, repeat
@@ -11,6 +13,9 @@
1113
def exists(val):
1214
return val is not None
1315

16+
def uniq(arr):
17+
return{el: True for el in arr}.keys()
18+
1419
def default(val, d):
1520
if exists(val):
1621
return val
@@ -89,15 +94,18 @@ def forward(self, x, mask = None):
8994
return out
9095

9196
class SparseAttention(Attention):
92-
def __init__(self, *args, **kwargs):
97+
def __init__(self, *args, sparse_attn_global_indices = [], block_size = 16, **kwargs):
9398
super().__init__(*args, **kwargs)
9499
from deepspeed.ops.sparse_attention import SparseSelfAttention, VariableSparsityConfig
95-
self.block_size = 16
100+
101+
self.block_size = block_size
102+
global_blocks = uniq(map(lambda t: t // self.block_size, sparse_attn_global_indices))
96103

97104
self.attn_fn = SparseSelfAttention(
98105
sparsity_config = VariableSparsityConfig(
99106
num_heads = self.heads,
100107
block = self.block_size,
108+
global_block_indices = global_blocks,
101109
attention = 'unidirectional' if self.causal else 'bidirectional'
102110
),
103111
max_seq_length = self.seq_len,
@@ -148,14 +156,15 @@ def __init__(
148156
ff_mult = 4,
149157
attn_dropout = 0.,
150158
ff_dropout = 0.,
151-
sparse_attn = True
159+
sparse_attn = True,
160+
sparse_attn_global_indices = []
152161
):
153162
super().__init__()
154163
layers = nn.ModuleList([])
155164
sparse_layer = cast_tuple(sparse_attn, depth)
156165

157166
for _, sparse_attn in zip(range(depth), sparse_layer):
158-
attn_class = Attention if not sparse_attn else SparseAttention
167+
attn_class = Attention if not sparse_attn else partial(SparseAttention, sparse_attn_global_indices = sparse_attn_global_indices)
159168

160169
layers.append(nn.ModuleList([
161170
PreNorm(dim, attn_class(dim, causal = causal, seq_len = seq_len, heads = heads, dim_head = dim_head, dropout = attn_dropout)),

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'dalle-pytorch',
55
packages = find_packages(),
6-
version = '0.0.36',
6+
version = '0.0.37',
77
license='MIT',
88
description = 'DALL-E - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)