Skip to content

Commit 16ae864

Browse files
authored
Merge pull request #657 from robertknight/split-slice-axis
Improve `Split` operator implementation, add `TensorBase::{slice_axis, slice_axis_mut}`
2 parents a926116 + 4fd720d commit 16ae864

File tree

4 files changed

+102
-21
lines changed

4 files changed

+102
-21
lines changed

Diff for: rten-tensor/src/errors.rs

+5
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,10 @@ pub enum SliceError {
5656
range_ndim: usize,
5757
},
5858

59+
/// The slice spec specified an axis that is equal to or greater than the
60+
/// dimension count.
61+
InvalidAxis { axis: usize },
62+
5963
/// An index in the slice spec is out of bounds for the corresponding tensor
6064
/// dimension.
6165
InvalidIndex {
@@ -105,6 +109,7 @@ impl Display for SliceError {
105109
range_ndim, ndim
106110
)
107111
}
112+
SliceError::InvalidAxis { axis } => write!(f, "slice axis {} is invalid", axis),
108113
SliceError::InvalidIndex { axis, index, size } => write!(
109114
f,
110115
"slice index {} is invalid for axis ({}) of size {}",

Diff for: rten-tensor/src/layout.rs

+53-4
Original file line numberDiff line numberDiff line change
@@ -970,8 +970,21 @@ pub trait MutLayout: Layout + Clone {
970970
/// Slice the layout along a given axis.
971971
///
972972
/// Returns a tuple of `(offset_range, sliced_layout)`.
973-
fn slice_axis(&self, axis: usize, range: Range<usize>) -> (Range<usize>, Self) {
974-
assert!(range.end >= range.start);
973+
fn slice_axis(
974+
&self,
975+
axis: usize,
976+
range: Range<usize>,
977+
) -> Result<(Range<usize>, Self), SliceError> {
978+
if axis >= self.ndim() {
979+
return Err(SliceError::InvalidAxis { axis });
980+
}
981+
if range.end < range.start || range.end > self.size(axis) {
982+
return Err(SliceError::InvalidRange {
983+
axis,
984+
range: range.into(),
985+
size: self.size(axis),
986+
});
987+
}
975988

976989
let mut sliced_layout = self.clone();
977990
sliced_layout.resize_dim(axis, range.len());
@@ -982,7 +995,7 @@ pub trait MutLayout: Layout + Clone {
982995
let end_offset = start_offset + sliced_layout.min_data_len();
983996
start_offset..end_offset
984997
};
985-
(range, sliced_layout)
998+
Ok((range, sliced_layout))
986999
}
9871000

9881001
/// Return a layout with all size-one dimensions removed.
@@ -1782,13 +1795,49 @@ mod tests {
17821795
} = case;
17831796

17841797
let layout = DynLayout::from_shape(shape);
1785-
let (offset_range, sliced_layout) = layout.slice_axis(axis, range);
1798+
let (offset_range, sliced_layout) = layout.slice_axis(axis, range).unwrap();
17861799
assert_eq!(sliced_layout.shape(), sliced_shape);
17871800
assert_eq!(sliced_layout.strides(), layout.strides());
17881801
assert_eq!(offset_range, offsets);
17891802
})
17901803
}
17911804

1805+
#[test]
1806+
fn test_slice_axis_invalid() {
1807+
#[derive(Debug)]
1808+
struct Case<'a> {
1809+
shape: &'a [usize],
1810+
axis: usize,
1811+
range: Range<usize>,
1812+
expected: SliceError,
1813+
}
1814+
1815+
let cases = [
1816+
Case {
1817+
shape: &[1, 2, 3],
1818+
axis: 4,
1819+
range: 0..1,
1820+
expected: SliceError::InvalidAxis { axis: 4 },
1821+
},
1822+
Case {
1823+
shape: &[1, 2, 3],
1824+
axis: 0,
1825+
range: 0..2,
1826+
expected: SliceError::InvalidRange {
1827+
axis: 0,
1828+
range: (0..2).into(),
1829+
size: 1,
1830+
},
1831+
},
1832+
];
1833+
1834+
cases.test_each(|case| {
1835+
let layout = DynLayout::from_shape(case.shape);
1836+
let result = layout.slice_axis(case.axis, case.range.clone());
1837+
assert_eq!(result, Err(case.expected.clone()));
1838+
})
1839+
}
1840+
17921841
#[test]
17931842
fn test_slice_invalid() {
17941843
#[derive(Debug)]

Diff for: rten-tensor/src/tensor.rs

+41-4
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,15 @@ pub trait AsView: Layout {
321321
self.view().slice(range)
322322
}
323323

324+
/// Slice this tensor along a given axis.
325+
fn slice_axis(
326+
&self,
327+
axis: usize,
328+
range: Range<usize>,
329+
) -> TensorBase<ViewData<Self::Elem>, Self::Layout> {
330+
self.view().slice_axis(axis, range)
331+
}
332+
324333
/// A variant of [`slice`](Self::slice) that returns a result
325334
/// instead of panicking.
326335
#[allow(clippy::type_complexity)]
@@ -791,14 +800,13 @@ impl<S: StorageMut, L: MutLayout> TensorBase<S, L> {
791800
}
792801

793802
/// Slice this tensor along a given axis.
794-
fn slice_axis_mut(
803+
pub fn slice_axis_mut(
795804
&mut self,
796805
axis: usize,
797806
range: Range<usize>,
798807
) -> TensorBase<ViewMutData<S::Elem>, L> {
799-
let (offset_range, sliced_layout) = self.layout.slice_axis(axis, range.clone());
808+
let (offset_range, sliced_layout) = self.layout.slice_axis(axis, range.clone()).unwrap();
800809
debug_assert_eq!(sliced_layout.size(axis), range.len());
801-
802810
TensorBase {
803811
data: self.data.slice_mut(offset_range),
804812
layout: sliced_layout,
@@ -1486,6 +1494,16 @@ impl<'a, T, L: Clone + MutLayout> TensorBase<ViewData<'a, T>, L> {
14861494
self.try_slice(range).expect("slice failed")
14871495
}
14881496

1497+
/// Slice this tensor along a given axis.
1498+
pub fn slice_axis(&self, axis: usize, range: Range<usize>) -> TensorBase<ViewData<'a, T>, L> {
1499+
let (offset_range, sliced_layout) = self.layout.slice_axis(axis, range.clone()).unwrap();
1500+
debug_assert_eq!(sliced_layout.size(axis), range.len());
1501+
TensorBase {
1502+
data: self.data.slice(offset_range),
1503+
layout: sliced_layout,
1504+
}
1505+
}
1506+
14891507
/// A variant of [`slice`](Self::slice) that returns a result
14901508
/// instead of panicking.
14911509
#[allow(clippy::type_complexity)]
@@ -1513,7 +1531,7 @@ impl<'a, T, L: Clone + MutLayout> TensorBase<ViewData<'a, T>, L> {
15131531
}
15141532
}
15151533

1516-
/// Divide this tensor into two mutable views along a given axis.
1534+
/// Divide this tensor into two views along a given axis.
15171535
///
15181536
/// Returns a `(left, right)` tuple of views, where the `left` view
15191537
/// contains the slice from `[0, mid)` along `axis` and the `right`
@@ -3458,6 +3476,25 @@ mod tests {
34583476
assert_eq!(row.data().unwrap(), &[1, 2, 3]);
34593477
}
34603478

3479+
#[test]
3480+
fn test_slice_axis() {
3481+
let data = NdTensor::from([[1, 2, 3], [4, 5, 6]]);
3482+
let row = data.slice_axis(0, 0..1);
3483+
let col = data.slice_axis(1, 1..2);
3484+
assert_eq!(row, data.slice((0..1, ..)));
3485+
assert_eq!(col, data.slice((.., 1..2)));
3486+
}
3487+
3488+
#[test]
3489+
fn test_slice_axis_mut() {
3490+
let mut data = NdTensor::from([[1, 2, 3], [4, 5, 6]]);
3491+
let mut row = data.slice_axis_mut(0, 0..1);
3492+
row.fill(8);
3493+
let mut col = data.slice_axis_mut(1, 1..2);
3494+
col.fill(9);
3495+
assert_eq!(data, NdTensor::from([[8, 9, 8], [4, 9, 6]]));
3496+
}
3497+
34613498
#[test]
34623499
fn test_slice_mut() {
34633500
// Slice static-rank array. The rank of the slice is inferred.

Diff for: src/ops/split.rs

+3-13
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use rten_tensor::prelude::*;
2-
use rten_tensor::{NdTensorView, SliceItem, Tensor, TensorView};
2+
use rten_tensor::{NdTensorView, Tensor, TensorView};
33

44
use crate::ops::{
55
map_input, resolve_axis, static_dims, Input, InputList, OpError, Operator, OutputList,
@@ -29,19 +29,9 @@ pub fn split<T: Copy>(
2929
.iter()
3030
.map(|&split_size| {
3131
let split_size = split_size as usize;
32-
let slice_range: Vec<SliceItem> = (0..input.ndim())
33-
.map(|dim| {
34-
if dim == axis {
35-
(split_start..split_start + split_size).into()
36-
} else {
37-
SliceItem::full_range()
38-
}
39-
})
40-
.collect();
41-
32+
let split_range = split_start..split_start + split_size;
4233
split_start += split_size;
43-
44-
input.slice(slice_range.as_slice()).to_tensor_in(pool)
34+
input.slice_axis(axis, split_range).to_tensor_in(pool)
4535
})
4636
.collect();
4737

0 commit comments

Comments
 (0)