Skip to content

Commit 53e6a93

Browse files
authored
Use reshape and transpose for non-overlapping pooling windows (#867)
1 parent f5a1582 commit 53e6a93

File tree

1 file changed

+22
-0
lines changed

1 file changed

+22
-0
lines changed

python/mlx/nn/layers/pooling.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,22 @@ def _value_or_list(x, n, msg):
2020
return [x] * n
2121

2222

23+
def _non_overlapping_sliding_windows(x, shape, window_shape):
24+
# Compute the intermediate shape
25+
new_shape = [shape[0]]
26+
for s, w in zip(shape[1:], window_shape):
27+
new_shape.append(s // w)
28+
new_shape.append(w)
29+
new_shape.append(shape[-1])
30+
31+
last_axis = len(new_shape) - 1
32+
axis_order = [0, *range(1, last_axis, 2), *range(2, last_axis, 2), last_axis]
33+
34+
x = x.reshape(new_shape)
35+
x = x.transpose(axis_order)
36+
return x
37+
38+
2339
def _sliding_windows(x, window_shape, window_strides):
2440
if x.ndim < 3:
2541
raise ValueError(
@@ -37,6 +53,12 @@ def _sliding_windows(x, window_shape, window_strides):
3753
)
3854

3955
shape = x.shape
56+
if all(
57+
window == stride and size % window == 0
58+
for size, window, stride in zip(spatial_dims, window_shape, window_strides)
59+
):
60+
return _non_overlapping_sliding_windows(x, shape, window_shape)
61+
4062
strides = list(reversed(list(accumulate(reversed(shape + (1,)), operator.mul))))[1:]
4163

4264
# Compute the output shape

0 commit comments

Comments
 (0)