|
12 | 12 | from dalle_pytorch.attention import Attention, SparseAttention, SparseConvCausalAttention, SparseAxialCausalAttention
|
13 | 13 |
|
14 | 14 | from rotary_embedding_torch import RotaryEmbedding, broadcat
|
15 |
| -from g_mlp_pytorch import gMLPBlock |
16 | 15 |
|
17 | 16 | # helpers
|
18 | 17 |
|
@@ -261,17 +260,12 @@ def __init__(
|
261 | 260 | attn_class = partial(SparseAxialCausalAttention, seq_len = seq_len, axis = 1, image_size = image_fmap_size, stable = stable)
|
262 | 261 | elif attn_type == 'conv_like':
|
263 | 262 | attn_class = partial(SparseConvCausalAttention, seq_len = seq_len, image_size = image_fmap_size, stable = stable)
|
264 |
| - elif attn_type == 'mlp': |
265 |
| - attn_class = partial(gMLPBlock, seq_len = seq_len) |
266 | 263 | else:
|
267 | 264 | raise ValueError(f'attention type "{attn_type}" is not valid')
|
268 | 265 |
|
269 | 266 | attn, reused_attn_type = shared_attn_layers.get(attn_id, (None, None))
|
270 | 267 | if not exists(attn):
|
271 |
| - if attn_type != 'mlp': |
272 |
| - attn = attn_class(dim, causal = causal, seq_len = seq_len, heads = heads, dim_head = dim_head, dropout = attn_dropout) |
273 |
| - else: |
274 |
| - attn = attn_class(dim = dim, causal = causal, dim_ff = dim * 4) |
| 268 | + attn = attn_class(dim, causal = causal, seq_len = seq_len, heads = heads, dim_head = dim_head, dropout = attn_dropout) |
275 | 269 | shared_attn_layers[attn_id] = (attn, attn_type)
|
276 | 270 | elif attn_type != reused_attn_type:
|
277 | 271 | raise ValueError('attn_types do not match shared_attn_ids '
|
@@ -309,8 +303,6 @@ def __init__(
|
309 | 303 |
|
310 | 304 | pos_emb = None
|
311 | 305 | if rotary_emb:
|
312 |
| - assert 'mlp' not in attn_types, 'you cannot use gMLPs if rotary embedding is turned on' |
313 |
| - |
314 | 306 | rot_dim = dim_head // 3
|
315 | 307 | img_seq_len = (image_fmap_size ** 2)
|
316 | 308 | text_len = seq_len - img_seq_len + 1
|
|
0 commit comments