@@ -145,19 +145,27 @@ where
145
145
assert ! ( chans. into_iter( ) . all( |c| c < out_chans && c < in_chans) ) ;
146
146
147
147
for out_y in 0 ..out_h {
148
+ // Compute min/max input Y coordinates for this output position.
149
+ let min_in_y = out_y * stride_h;
150
+ let max_in_y = min_in_y + kernel_h. saturating_sub ( 1 ) ;
151
+ let y_non_pad_region = min_in_y >= pad_top && max_in_y < in_h + pad_top;
152
+
148
153
for out_x in 0 ..out_w {
154
+ // Compute min/max input X coordinates for this output position.
155
+ let min_in_x = out_x * stride_w;
156
+ let max_in_x = min_in_x + kernel_w. saturating_sub ( 1 ) ;
157
+ let x_non_pad_region = min_in_x >= pad_left && max_in_x < in_w + pad_left;
158
+
149
159
let mut accumulator = [ fold_init; N ] ;
150
160
let mut non_pad_elements = 0 ;
151
161
152
- for k_y in 0 ..kernel_h {
153
- for k_x in 0 ..kernel_w {
154
- let in_y = out_y * stride_h + k_y;
155
- let in_x = out_x * stride_w + k_x;
156
- if in_y >= pad_top
157
- && in_y < in_h + pad_top
158
- && in_x >= pad_left
159
- && in_x < in_w + pad_left
160
- {
162
+ // Use faster path with fewer branches for non-padding region.
163
+ if y_non_pad_region && x_non_pad_region {
164
+ non_pad_elements = kernel_h * kernel_w;
165
+ for k_y in 0 ..kernel_h {
166
+ for k_x in 0 ..kernel_w {
167
+ let in_y = out_y * stride_h + k_y;
168
+ let in_x = out_x * stride_w + k_x;
161
169
for ( i, chan) in chans. into_iter ( ) . enumerate ( ) {
162
170
// Safety:
163
171
// - We checked all `chans` are in-bounds
@@ -167,10 +175,37 @@ where
167
175
} ;
168
176
accumulator[ i] = fold ( accumulator[ i] , val) ;
169
177
}
170
- non_pad_elements += 1 ;
178
+ }
179
+ }
180
+ } else {
181
+ for k_y in 0 ..kernel_h {
182
+ for k_x in 0 ..kernel_w {
183
+ let in_y = out_y * stride_h + k_y;
184
+ let in_x = out_x * stride_w + k_x;
185
+ if in_y >= pad_top
186
+ && in_y < in_h + pad_top
187
+ && in_x >= pad_left
188
+ && in_x < in_w + pad_left
189
+ {
190
+ for ( i, chan) in chans. into_iter ( ) . enumerate ( ) {
191
+ // Safety:
192
+ // - We checked all `chans` are in-bounds
193
+ // - `in_y` and `in_x` are >= pad_top and pad_left
194
+ let val = unsafe {
195
+ * in_view. get_unchecked ( [
196
+ chan,
197
+ in_y - pad_top,
198
+ in_x - pad_left,
199
+ ] )
200
+ } ;
201
+ accumulator[ i] = fold ( accumulator[ i] , val) ;
202
+ }
203
+ non_pad_elements += 1 ;
204
+ }
171
205
}
172
206
}
173
207
}
208
+
174
209
for ( i, chan) in chans. into_iter ( ) . enumerate ( ) {
175
210
// Safety:
176
211
// - We checked all `chans` are in-bounds
0 commit comments