@@ -14,7 +14,7 @@ use crate::ops::{
14
14
OutputList ,
15
15
} ;
16
16
use crate :: slice_reductions:: slice_sum;
17
- use crate :: tensor_pool:: { AutoReturn , TensorPool } ;
17
+ use crate :: tensor_pool:: TensorPool ;
18
18
19
19
/// Compute the indices of the max elements along an axis, according to a
20
20
/// comparison function `compare`.
@@ -226,6 +226,38 @@ impl Operator for NonZero {
226
226
}
227
227
}
228
228
229
+ /// Manages a scratch buffer allocated from a pool.
230
+ struct TempBuffer < ' a , T > {
231
+ pool : & ' a TensorPool ,
232
+ buf : Vec < T > ,
233
+ }
234
+
235
+ impl < ' a , T > TempBuffer < ' a , T > {
236
+ fn new ( pool : & ' a TensorPool ) -> Self {
237
+ TempBuffer {
238
+ pool,
239
+ buf : Vec :: new ( ) ,
240
+ }
241
+ }
242
+
243
+ /// Prepare the buffer by allocating it from the pool and clearing it.
244
+ fn reserve ( & mut self , capacity : usize ) -> & mut Vec < T > {
245
+ self . buf . clear ( ) ;
246
+ if self . buf . capacity ( ) < capacity {
247
+ self . buf = self . pool . alloc ( capacity) ;
248
+ }
249
+ & mut self . buf
250
+ }
251
+ }
252
+
253
+ impl < T > Drop for TempBuffer < ' _ , T > {
254
+ fn drop ( & mut self ) {
255
+ if self . buf . capacity ( ) > 0 {
256
+ self . pool . add ( std:: mem:: take ( & mut self . buf ) )
257
+ }
258
+ }
259
+ }
260
+
229
261
/// Kernel that handles reducing a single slice of the input.
230
262
trait ReduceKernel < T > {
231
263
/// Reduce a contiguous slice of values to a single value.
@@ -250,15 +282,9 @@ fn reduce<T: Copy>(
250
282
} ;
251
283
resolved_axes. sort ( ) ;
252
284
253
- // Allocate temporary buffer where slices of the input to be reduced are
254
- // packed first if non-contiguous.
255
- let mut tmp_buf = if !input. is_contiguous ( ) {
256
- let reduced_slice_len = resolved_axes. iter ( ) . map ( |& dim| input. size ( dim) ) . product ( ) ;
257
- pool. alloc ( reduced_slice_len)
258
- } else {
259
- Vec :: new ( )
260
- }
261
- . auto_return ( pool) ;
285
+ // Temporary buffer where slices of the input to be reduced are packed first
286
+ // if non-contiguous.
287
+ let mut tmp_buf = TempBuffer :: new ( pool) ;
262
288
263
289
if input. ndim ( ) == 0 {
264
290
let item = input. item ( ) . unwrap ( ) ;
@@ -316,9 +342,9 @@ fn reduce<T: Copy>(
316
342
if let Some ( lane_slice) = lane. as_slice ( ) {
317
343
kernel. reduce_slice ( lane_slice)
318
344
} else {
319
- tmp_buf. clear ( ) ;
320
- tmp_buf . extend ( lane. copied ( ) ) ;
321
- kernel. reduce_slice ( & tmp_buf )
345
+ let buf = tmp_buf. reserve ( lane . len ( ) ) ;
346
+ buf . extend ( lane. copied ( ) ) ;
347
+ kernel. reduce_slice ( buf )
322
348
}
323
349
} ) ) ;
324
350
} else {
@@ -334,8 +360,8 @@ fn reduce<T: Copy>(
334
360
let reduced = if let Some ( data) = slice. data ( ) {
335
361
kernel. reduce_slice ( data)
336
362
} else {
337
- tmp_buf. clear ( ) ;
338
- let tmp_uninit = & mut tmp_buf . spare_capacity_mut ( ) [ ..slice. len ( ) ] ;
363
+ let buf = tmp_buf. reserve ( slice . len ( ) ) ;
364
+ let tmp_uninit = & mut buf . spare_capacity_mut ( ) [ ..slice. len ( ) ] ;
339
365
let tmp = slice. copy_into_slice ( tmp_uninit) ;
340
366
kernel. reduce_slice ( tmp)
341
367
} ;
@@ -1093,6 +1119,17 @@ mod tests {
1093
1119
. unwrap ( ) ;
1094
1120
assert_eq ! ( result. to_vec( ) , & [ 0.5 , 4.5 ] ) ;
1095
1121
1122
+ // Reduce multiple non-contiguous (outer) dimensions in contiguous tensor
1123
+ let tensor = Tensor :: from ( [ [ [ 1. , 2. ] , [ 3. , 4. ] ] , [ [ 5. , 6. ] , [ 7. , 8. ] ] ] ) ;
1124
+ let result = reduce_mean (
1125
+ & pool,
1126
+ tensor. view ( ) ,
1127
+ Some ( & [ 0 , 1 ] ) ,
1128
+ false , /* keep_dims */
1129
+ )
1130
+ . unwrap ( ) ;
1131
+ assert_eq ! ( result. to_vec( ) , & [ 4. , 5. ] ) ;
1132
+
1096
1133
Ok ( ( ) )
1097
1134
}
1098
1135
0 commit comments