Skip to content

Commit 3e20ff7

Browse files
committed
Fix error when reducing multiple axes if reduced chunks are non-contiguous
Fix a bounds check error when slicing `tmp_buf` if: - The input is contiguous - Multiple axes are reduced - The reduced slices are non-contiguous In this case the code expected capacity to have been reserved in `tmp_buf` but it was actually empty. Change how the buffer is allocated to ensure it is always allocated when needed.
1 parent 64fcbd2 commit 3e20ff7

File tree

1 file changed

+52
-15
lines changed

1 file changed

+52
-15
lines changed

Diff for: src/ops/reduce.rs

+52-15
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ use crate::ops::{
1414
OutputList,
1515
};
1616
use crate::slice_reductions::slice_sum;
17-
use crate::tensor_pool::{AutoReturn, TensorPool};
17+
use crate::tensor_pool::TensorPool;
1818

1919
/// Compute the indices of the max elements along an axis, according to a
2020
/// comparison function `compare`.
@@ -226,6 +226,38 @@ impl Operator for NonZero {
226226
}
227227
}
228228

229+
/// Manages a scratch buffer allocated from a pool.
230+
struct TempBuffer<'a, T> {
231+
pool: &'a TensorPool,
232+
buf: Vec<T>,
233+
}
234+
235+
impl<'a, T> TempBuffer<'a, T> {
236+
fn new(pool: &'a TensorPool) -> Self {
237+
TempBuffer {
238+
pool,
239+
buf: Vec::new(),
240+
}
241+
}
242+
243+
/// Prepare the buffer by allocating it from the pool and clearing it.
244+
fn reserve(&mut self, capacity: usize) -> &mut Vec<T> {
245+
self.buf.clear();
246+
if self.buf.capacity() < capacity {
247+
self.buf = self.pool.alloc(capacity);
248+
}
249+
&mut self.buf
250+
}
251+
}
252+
253+
impl<T> Drop for TempBuffer<'_, T> {
254+
fn drop(&mut self) {
255+
if self.buf.capacity() > 0 {
256+
self.pool.add(std::mem::take(&mut self.buf))
257+
}
258+
}
259+
}
260+
229261
/// Kernel that handles reducing a single slice of the input.
230262
trait ReduceKernel<T> {
231263
/// Reduce a contiguous slice of values to a single value.
@@ -250,15 +282,9 @@ fn reduce<T: Copy>(
250282
};
251283
resolved_axes.sort();
252284

253-
// Allocate temporary buffer where slices of the input to be reduced are
254-
// packed first if non-contiguous.
255-
let mut tmp_buf = if !input.is_contiguous() {
256-
let reduced_slice_len = resolved_axes.iter().map(|&dim| input.size(dim)).product();
257-
pool.alloc(reduced_slice_len)
258-
} else {
259-
Vec::new()
260-
}
261-
.auto_return(pool);
285+
// Temporary buffer where slices of the input to be reduced are packed first
286+
// if non-contiguous.
287+
let mut tmp_buf = TempBuffer::new(pool);
262288

263289
if input.ndim() == 0 {
264290
let item = input.item().unwrap();
@@ -316,9 +342,9 @@ fn reduce<T: Copy>(
316342
if let Some(lane_slice) = lane.as_slice() {
317343
kernel.reduce_slice(lane_slice)
318344
} else {
319-
tmp_buf.clear();
320-
tmp_buf.extend(lane.copied());
321-
kernel.reduce_slice(&tmp_buf)
345+
let buf = tmp_buf.reserve(lane.len());
346+
buf.extend(lane.copied());
347+
kernel.reduce_slice(buf)
322348
}
323349
}));
324350
} else {
@@ -334,8 +360,8 @@ fn reduce<T: Copy>(
334360
let reduced = if let Some(data) = slice.data() {
335361
kernel.reduce_slice(data)
336362
} else {
337-
tmp_buf.clear();
338-
let tmp_uninit = &mut tmp_buf.spare_capacity_mut()[..slice.len()];
363+
let buf = tmp_buf.reserve(slice.len());
364+
let tmp_uninit = &mut buf.spare_capacity_mut()[..slice.len()];
339365
let tmp = slice.copy_into_slice(tmp_uninit);
340366
kernel.reduce_slice(tmp)
341367
};
@@ -1093,6 +1119,17 @@ mod tests {
10931119
.unwrap();
10941120
assert_eq!(result.to_vec(), &[0.5, 4.5]);
10951121

1122+
// Reduce multiple non-contiguous (outer) dimensions in contiguous tensor
1123+
let tensor = Tensor::from([[[1., 2.], [3., 4.]], [[5., 6.], [7., 8.]]]);
1124+
let result = reduce_mean(
1125+
&pool,
1126+
tensor.view(),
1127+
Some(&[0, 1]),
1128+
false, /* keep_dims */
1129+
)
1130+
.unwrap();
1131+
assert_eq!(result.to_vec(), &[4., 5.]);
1132+
10961133
Ok(())
10971134
}
10981135

0 commit comments

Comments
 (0)