@@ -20,6 +20,22 @@ def _value_or_list(x, n, msg):
2020 return [x ] * n
2121
2222
23+ def _non_overlapping_sliding_windows (x , shape , window_shape ):
24+ # Compute the intermediate shape
25+ new_shape = [shape [0 ]]
26+ for s , w in zip (shape [1 :], window_shape ):
27+ new_shape .append (s // w )
28+ new_shape .append (w )
29+ new_shape .append (shape [- 1 ])
30+
31+ last_axis = len (new_shape ) - 1
32+ axis_order = [0 , * range (1 , last_axis , 2 ), * range (2 , last_axis , 2 ), last_axis ]
33+
34+ x = x .reshape (new_shape )
35+ x = x .transpose (axis_order )
36+ return x
37+
38+
2339def _sliding_windows (x , window_shape , window_strides ):
2440 if x .ndim < 3 :
2541 raise ValueError (
@@ -37,6 +53,12 @@ def _sliding_windows(x, window_shape, window_strides):
3753 )
3854
3955 shape = x .shape
56+ if all (
57+ window == stride and size % window == 0
58+ for size , window , stride in zip (spatial_dims , window_shape , window_strides )
59+ ):
60+ return _non_overlapping_sliding_windows (x , shape , window_shape )
61+
4062 strides = list (reversed (list (accumulate (reversed (shape + (1 ,)), operator .mul ))))[1 :]
4163
4264 # Compute the output shape
0 commit comments