|
1 | | -import torch |
| 1 | +from functools import partial |
2 | 2 | from inspect import isfunction |
| 3 | + |
| 4 | +import torch |
3 | 5 | from torch import nn, einsum |
4 | 6 | import torch.nn.functional as F |
5 | 7 | from einops import rearrange, repeat |
|
11 | 13 | def exists(val): |
12 | 14 | return val is not None |
13 | 15 |
|
| 16 | +def uniq(arr): |
| 17 | + return{el: True for el in arr}.keys() |
| 18 | + |
14 | 19 | def default(val, d): |
15 | 20 | if exists(val): |
16 | 21 | return val |
@@ -89,15 +94,18 @@ def forward(self, x, mask = None): |
89 | 94 | return out |
90 | 95 |
|
91 | 96 | class SparseAttention(Attention): |
92 | | - def __init__(self, *args, **kwargs): |
| 97 | + def __init__(self, *args, sparse_attn_global_indices = [], block_size = 16, **kwargs): |
93 | 98 | super().__init__(*args, **kwargs) |
94 | 99 | from deepspeed.ops.sparse_attention import SparseSelfAttention, VariableSparsityConfig |
95 | | - self.block_size = 16 |
| 100 | + |
| 101 | + self.block_size = block_size |
| 102 | + global_blocks = uniq(map(lambda t: t // self.block_size, sparse_attn_global_indices)) |
96 | 103 |
|
97 | 104 | self.attn_fn = SparseSelfAttention( |
98 | 105 | sparsity_config = VariableSparsityConfig( |
99 | 106 | num_heads = self.heads, |
100 | 107 | block = self.block_size, |
| 108 | + global_block_indices = global_blocks, |
101 | 109 | attention = 'unidirectional' if self.causal else 'bidirectional' |
102 | 110 | ), |
103 | 111 | max_seq_length = self.seq_len, |
@@ -148,14 +156,15 @@ def __init__( |
148 | 156 | ff_mult = 4, |
149 | 157 | attn_dropout = 0., |
150 | 158 | ff_dropout = 0., |
151 | | - sparse_attn = True |
| 159 | + sparse_attn = True, |
| 160 | + sparse_attn_global_indices = [] |
152 | 161 | ): |
153 | 162 | super().__init__() |
154 | 163 | layers = nn.ModuleList([]) |
155 | 164 | sparse_layer = cast_tuple(sparse_attn, depth) |
156 | 165 |
|
157 | 166 | for _, sparse_attn in zip(range(depth), sparse_layer): |
158 | | - attn_class = Attention if not sparse_attn else SparseAttention |
| 167 | + attn_class = Attention if not sparse_attn else partial(SparseAttention, sparse_attn_global_indices = sparse_attn_global_indices) |
159 | 168 |
|
160 | 169 | layers.append(nn.ModuleList([ |
161 | 170 | PreNorm(dim, attn_class(dim, causal = causal, seq_len = seq_len, heads = heads, dim_head = dim_head, dropout = attn_dropout)), |
|
0 commit comments