@@ -13,13 +13,12 @@ def __init__(self, dim, fn):
1313 def forward (self , x ):
1414 return self .fn (self .norm (x )) + x
1515
16- def FeedForward (dim , expansion_factor = 4 , dropout = 0. , dense = nn .Linear ):
17- inner_dim = int (dim * expansion_factor )
16+ def FeedForward (dim , dim_hidden , dropout = 0. , dense = nn .Linear ):
1817 return nn .Sequential (
19- dense (dim , inner_dim ),
18+ dense (dim , dim_hidden ),
2019 nn .GELU (),
2120 nn .Dropout (dropout ),
22- dense (inner_dim , dim ),
21+ dense (dim_hidden , dim ),
2322 nn .Dropout (dropout )
2423 )
2524
@@ -35,8 +34,8 @@ def MLPMixer3D(*, image_size, time_size, channels, patch_size, time_patch_size,
3534 Rearrange ('b c (t pt) (h p1) (w p2) -> b (h w t) (p1 p2 pt c)' , p1 = patch_size , p2 = patch_size , pt = time_patch_size ),
3635 nn .Linear ((time_patch_size * patch_size ** 2 ) * channels , dim ),
3736 * [nn .Sequential (
38- PreNormResidual (dim , FeedForward (num_patches , expansion_factor , dropout , chan_first )),
39- PreNormResidual (dim , FeedForward (dim , expansion_factor_token , dropout , chan_last ))
37+ PreNormResidual (dim , FeedForward (num_patches , int ( expansion_factor * dim ) , dropout , chan_first )),
38+ PreNormResidual (dim , FeedForward (dim , int ( expansion_factor_token * dim ) , dropout , chan_last ))
4039 ) for _ in range (depth )],
4140 nn .LayerNorm (dim ),
4241 Reduce ('b n c -> b c' , 'mean' ),
0 commit comments