Skip to content

Commit 39a55fa

Browse files
committed
address #16
1 parent bcd54af commit 39a55fa

File tree

3 files changed

+11
-13
lines changed

3 files changed

+11
-13
lines changed

mlp_mixer_pytorch/mlp_mixer_3d_pytorch.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,12 @@ def __init__(self, dim, fn):
1313
def forward(self, x):
1414
return self.fn(self.norm(x)) + x
1515

16-
def FeedForward(dim, expansion_factor = 4, dropout = 0., dense = nn.Linear):
17-
inner_dim = int(dim * expansion_factor)
16+
def FeedForward(dim, dim_hidden, dropout = 0., dense = nn.Linear):
1817
return nn.Sequential(
19-
dense(dim, inner_dim),
18+
dense(dim, dim_hidden),
2019
nn.GELU(),
2120
nn.Dropout(dropout),
22-
dense(inner_dim, dim),
21+
dense(dim_hidden, dim),
2322
nn.Dropout(dropout)
2423
)
2524

@@ -35,8 +34,8 @@ def MLPMixer3D(*, image_size, time_size, channels, patch_size, time_patch_size,
3534
Rearrange('b c (t pt) (h p1) (w p2) -> b (h w t) (p1 p2 pt c)', p1 = patch_size, p2 = patch_size, pt = time_patch_size),
3635
nn.Linear((time_patch_size * patch_size ** 2) * channels, dim),
3736
*[nn.Sequential(
38-
PreNormResidual(dim, FeedForward(num_patches, expansion_factor, dropout, chan_first)),
39-
PreNormResidual(dim, FeedForward(dim, expansion_factor_token, dropout, chan_last))
37+
PreNormResidual(dim, FeedForward(num_patches, int(expansion_factor * dim), dropout, chan_first)),
38+
PreNormResidual(dim, FeedForward(dim, int(expansion_factor_token * dim), dropout, chan_last))
4039
) for _ in range(depth)],
4140
nn.LayerNorm(dim),
4241
Reduce('b n c -> b c', 'mean'),

mlp_mixer_pytorch/mlp_mixer_pytorch.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,12 @@ def __init__(self, dim, fn):
1313
def forward(self, x):
1414
return self.fn(self.norm(x)) + x
1515

16-
def FeedForward(dim, expansion_factor = 4, dropout = 0., dense = nn.Linear):
17-
inner_dim = int(dim * expansion_factor)
16+
def FeedForward(dim, dim_hidden, dropout = 0., dense = nn.Linear):
1817
return nn.Sequential(
19-
dense(dim, inner_dim),
18+
dense(dim, dim_hidden),
2019
nn.GELU(),
2120
nn.Dropout(dropout),
22-
dense(inner_dim, dim),
21+
dense(dim_hidden, dim),
2322
nn.Dropout(dropout)
2423
)
2524

@@ -33,8 +32,8 @@ def MLPMixer(*, image_size, channels, patch_size, dim, depth, num_classes, expan
3332
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
3433
nn.Linear((patch_size ** 2) * channels, dim),
3534
*[nn.Sequential(
36-
PreNormResidual(dim, FeedForward(num_patches, expansion_factor, dropout, chan_first)),
37-
PreNormResidual(dim, FeedForward(dim, expansion_factor_token, dropout, chan_last))
35+
PreNormResidual(dim, FeedForward(num_patches, expansion_factor * dim, dropout, chan_first)),
36+
PreNormResidual(dim, FeedForward(dim, expansion_factor_token * dim, dropout, chan_last))
3837
) for _ in range(depth)],
3938
nn.LayerNorm(dim),
4039
Reduce('b n c -> b c', 'mean'),

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'mlp-mixer-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.2.0',
6+
version = '0.3.0',
77
license='MIT',
88
description = 'MLP Mixer - Pytorch',
99
long_description_content_type = 'text/markdown',

0 commit comments

Comments
 (0)