Skip to content

Commit 9526cd6

Browse files
committed
remove g-mlps
1 parent 6aff102 commit 9526cd6

File tree

2 files changed

+2
-11
lines changed

2 files changed

+2
-11
lines changed

dalle_pytorch/transformer.py

+1-9
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from dalle_pytorch.attention import Attention, SparseAttention, SparseConvCausalAttention, SparseAxialCausalAttention
1313

1414
from rotary_embedding_torch import RotaryEmbedding, broadcat
15-
from g_mlp_pytorch import gMLPBlock
1615

1716
# helpers
1817

@@ -261,17 +260,12 @@ def __init__(
261260
attn_class = partial(SparseAxialCausalAttention, seq_len = seq_len, axis = 1, image_size = image_fmap_size, stable = stable)
262261
elif attn_type == 'conv_like':
263262
attn_class = partial(SparseConvCausalAttention, seq_len = seq_len, image_size = image_fmap_size, stable = stable)
264-
elif attn_type == 'mlp':
265-
attn_class = partial(gMLPBlock, seq_len = seq_len)
266263
else:
267264
raise ValueError(f'attention type "{attn_type}" is not valid')
268265

269266
attn, reused_attn_type = shared_attn_layers.get(attn_id, (None, None))
270267
if not exists(attn):
271-
if attn_type != 'mlp':
272-
attn = attn_class(dim, causal = causal, seq_len = seq_len, heads = heads, dim_head = dim_head, dropout = attn_dropout)
273-
else:
274-
attn = attn_class(dim = dim, causal = causal, dim_ff = dim * 4)
268+
attn = attn_class(dim, causal = causal, seq_len = seq_len, heads = heads, dim_head = dim_head, dropout = attn_dropout)
275269
shared_attn_layers[attn_id] = (attn, attn_type)
276270
elif attn_type != reused_attn_type:
277271
raise ValueError('attn_types do not match shared_attn_ids '
@@ -309,8 +303,6 @@ def __init__(
309303

310304
pos_emb = None
311305
if rotary_emb:
312-
assert 'mlp' not in attn_types, 'you cannot use gMLPs if rotary embedding is turned on'
313-
314306
rot_dim = dim_head // 3
315307
img_seq_len = (image_fmap_size ** 2)
316308
text_len = seq_len - img_seq_len + 1

setup.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
name = 'dalle-pytorch',
55
packages = find_packages(),
66
include_package_data = True,
7-
version = '1.4.0',
7+
version = '1.4.1',
88
license='MIT',
99
description = 'DALL-E - Pytorch',
1010
author = 'Phil Wang',
@@ -21,7 +21,6 @@
2121
'DALL-E',
2222
'einops>=0.3.2',
2323
'ftfy',
24-
'g-mlp-pytorch',
2524
'pillow',
2625
'regex',
2726
'rotary-embedding-torch',

0 commit comments

Comments
 (0)