Skip to content

Commit b1507d8

Browse files
authored
Merge pull request #665 from robertknight/pooling-split-pad-region
Split pooling inner loops into pad and non-pad regions
2 parents 9197108 + 764c210 commit b1507d8

File tree

1 file changed

+45
-10
lines changed

1 file changed

+45
-10
lines changed

Diff for: src/ops/pooling.rs

+45-10
Original file line numberDiff line numberDiff line change
@@ -145,19 +145,27 @@ where
145145
assert!(chans.into_iter().all(|c| c < out_chans && c < in_chans));
146146

147147
for out_y in 0..out_h {
148+
// Compute min/max input Y coordinates for this output position.
149+
let min_in_y = out_y * stride_h;
150+
let max_in_y = min_in_y + kernel_h.saturating_sub(1);
151+
let y_non_pad_region = min_in_y >= pad_top && max_in_y < in_h + pad_top;
152+
148153
for out_x in 0..out_w {
154+
// Compute min/max input X coordinates for this output position.
155+
let min_in_x = out_x * stride_w;
156+
let max_in_x = min_in_x + kernel_w.saturating_sub(1);
157+
let x_non_pad_region = min_in_x >= pad_left && max_in_x < in_w + pad_left;
158+
149159
let mut accumulator = [fold_init; N];
150160
let mut non_pad_elements = 0;
151161

152-
for k_y in 0..kernel_h {
153-
for k_x in 0..kernel_w {
154-
let in_y = out_y * stride_h + k_y;
155-
let in_x = out_x * stride_w + k_x;
156-
if in_y >= pad_top
157-
&& in_y < in_h + pad_top
158-
&& in_x >= pad_left
159-
&& in_x < in_w + pad_left
160-
{
162+
// Use faster path with fewer branches for non-padding region.
163+
if y_non_pad_region && x_non_pad_region {
164+
non_pad_elements = kernel_h * kernel_w;
165+
for k_y in 0..kernel_h {
166+
for k_x in 0..kernel_w {
167+
let in_y = out_y * stride_h + k_y;
168+
let in_x = out_x * stride_w + k_x;
161169
for (i, chan) in chans.into_iter().enumerate() {
162170
// Safety:
163171
// - We checked all `chans` are in-bounds
@@ -167,10 +175,37 @@ where
167175
};
168176
accumulator[i] = fold(accumulator[i], val);
169177
}
170-
non_pad_elements += 1;
178+
}
179+
}
180+
} else {
181+
for k_y in 0..kernel_h {
182+
for k_x in 0..kernel_w {
183+
let in_y = out_y * stride_h + k_y;
184+
let in_x = out_x * stride_w + k_x;
185+
if in_y >= pad_top
186+
&& in_y < in_h + pad_top
187+
&& in_x >= pad_left
188+
&& in_x < in_w + pad_left
189+
{
190+
for (i, chan) in chans.into_iter().enumerate() {
191+
// Safety:
192+
// - We checked all `chans` are in-bounds
193+
// - `in_y` and `in_x` are >= pad_top and pad_left
194+
let val = unsafe {
195+
*in_view.get_unchecked([
196+
chan,
197+
in_y - pad_top,
198+
in_x - pad_left,
199+
])
200+
};
201+
accumulator[i] = fold(accumulator[i], val);
202+
}
203+
non_pad_elements += 1;
204+
}
171205
}
172206
}
173207
}
208+
174209
for (i, chan) in chans.into_iter().enumerate() {
175210
// Safety:
176211
// - We checked all `chans` are in-bounds

0 commit comments

Comments
 (0)