Skip to content

Commit 54b0824

Browse files
committed
support rectangular images
1 parent 2151af6 commit 54b0824

File tree

3 files changed

+25
-3
lines changed

3 files changed

+25
-3
lines changed

README.md

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,25 @@ img = torch.randn(1, 3, 256, 256)
3333
pred = model(img) # (1, 1000)
3434
```
3535

36+
Rectangular image
37+
38+
```python
39+
import torch
40+
from mlp_mixer_pytorch import MLPMixer
41+
42+
model = MLPMixer(
43+
image_size = (256, 128),
44+
channels = 3,
45+
patch_size = 16,
46+
dim = 512,
47+
depth = 12,
48+
num_classes = 1000
49+
)
50+
51+
img = torch.randn(1, 3, 256, 128)
52+
pred = model(img) # (1, 1000)
53+
```
54+
3655
## Citations
3756

3857
```bibtex

mlp_mixer_pytorch/mlp_mixer_pytorch.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
from functools import partial
33
from einops.layers.torch import Rearrange, Reduce
44

5+
pair = lambda x: x if isinstance(x, tuple) else (x, x)
6+
57
class PreNormResidual(nn.Module):
68
def __init__(self, dim, fn):
79
super().__init__()
@@ -22,8 +24,9 @@ def FeedForward(dim, expansion_factor = 4, dropout = 0., dense = nn.Linear):
2224
)
2325

2426
def MLPMixer(*, image_size, channels, patch_size, dim, depth, num_classes, expansion_factor = 4, expansion_factor_token = 0.5, dropout = 0.):
25-
assert (image_size % patch_size) == 0, 'image must be divisible by patch size'
26-
num_patches = (image_size // patch_size) ** 2
27+
image_h, image_w = pair(image_size)
28+
assert (image_h % patch_size) == 0 and (image_w % patch_size) == 0, 'image must be divisible by patch size'
29+
num_patches = (image_h // patch_size) * (image_w // patch_size)
2730
chan_first, chan_last = partial(nn.Conv1d, kernel_size = 1), nn.Linear
2831

2932
return nn.Sequential(

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'mlp-mixer-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.1.0',
6+
version = '0.1.1',
77
license='MIT',
88
description = 'MLP Mixer - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)