Skip to content

Commit 46adff9

Browse files
committed
Support num_outputs attribute in Split operator
1 parent 16ae864 commit 46adff9

File tree

9 files changed

+215
-72
lines changed

9 files changed

+215
-72
lines changed

rten-convert/rten_convert/converter.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -653,7 +653,7 @@ def op_node_from_onnx_operator(
653653
case "Split":
654654
attrs = sg.SplitAttrsT()
655655
attrs.axis = attr_reader.get_attr("axis", "int", 0)
656-
attr_reader.check_attr("num_outputs", "int", 0)
656+
attrs.numOutputs = attr_reader.get_attr("num_outputs", "int", None)
657657
attr_reader.generate_input_from_attr(1, "split", "ints")
658658

659659
case "Squeeze":

rten-convert/rten_convert/schema_generated.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -4911,12 +4911,22 @@ def Axis(self):
49114911
return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos)
49124912
return 0
49134913

4914+
# SplitAttrs
4915+
def NumOutputs(self):
4916+
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
4917+
if o != 0:
4918+
return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos)
4919+
return None
4920+
49144921
def SplitAttrsStart(builder):
4915-
builder.StartObject(1)
4922+
builder.StartObject(2)
49164923

49174924
def SplitAttrsAddAxis(builder, axis):
49184925
builder.PrependInt32Slot(0, axis, 0)
49194926

4927+
def SplitAttrsAddNumOutputs(builder, numOutputs):
4928+
builder.PrependInt32Slot(1, numOutputs, None)
4929+
49204930
def SplitAttrsEnd(builder):
49214931
return builder.EndObject()
49224932

@@ -4927,6 +4937,7 @@ class SplitAttrsT(object):
49274937
# SplitAttrsT
49284938
def __init__(self):
49294939
self.axis = 0 # type: int
4940+
self.numOutputs = None # type: Optional[int]
49304941

49314942
@classmethod
49324943
def InitFromBuf(cls, buf, pos):
@@ -4950,11 +4961,13 @@ def _UnPack(self, splitAttrs):
49504961
if splitAttrs is None:
49514962
return
49524963
self.axis = splitAttrs.Axis()
4964+
self.numOutputs = splitAttrs.NumOutputs()
49534965

49544966
# SplitAttrsT
49554967
def Pack(self, builder):
49564968
SplitAttrsStart(builder)
49574969
SplitAttrsAddAxis(builder, self.axis)
4970+
SplitAttrsAddNumOutputs(builder, self.numOutputs)
49584971
splitAttrs = SplitAttrsEnd(builder)
49594972
return splitAttrs
49604973

src/model.rs

+4-1
Original file line numberDiff line numberDiff line change
@@ -1646,7 +1646,10 @@ mod tests {
16461646
let split_out_2 = graph_builder.add_value("Split_out_2", None, None);
16471647
graph_builder.add_operator(
16481648
"Split",
1649-
OpType::Split(ops::Split { axis: 1 }),
1649+
OpType::Split(ops::Split {
1650+
axis: 1,
1651+
num_outputs: None,
1652+
}),
16501653
&[input_2d, split_splits].map(Some),
16511654
&[split_out_1, split_out_2],
16521655
);

src/model_builder.rs

+1
Original file line numberDiff line numberDiff line change
@@ -858,6 +858,7 @@ impl<'mb, 'a> GraphBuilder<'mb, 'a> {
858858
OpType::Split(args) => op_with_attrs!(Split, SplitAttrs, {
859859
sg::SplitAttrsArgs {
860860
axis: args.axis as i32,
861+
num_outputs: args.num_outputs.map(|n| n as i32),
861862
}
862863
}),
863864
OpType::Sqrt => op!(Sqrt),

src/op_registry.rs

+5-1
Original file line numberDiff line numberDiff line change
@@ -880,7 +880,11 @@ impl_read_op!(Size);
880880
impl_read_op!(Slice);
881881
impl_read_op!(Softmax, attrs_as_softmax_attrs, axis);
882882
impl_read_op!(Softplus);
883-
impl_read_op!(Split, attrs_as_split_attrs, axis);
883+
impl_read_op!(Split, attrs_as_split_attrs, |attrs: sg::SplitAttrs| {
884+
let axis = attrs.axis() as isize;
885+
let num_outputs = attrs.num_outputs().map(|n| n as u32);
886+
Ok(ops::Split { axis, num_outputs })
887+
});
884888
impl_read_op!(Sqrt);
885889
impl_read_op!(Squeeze);
886890
impl_read_op!(Sub);

src/ops/rnn.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -807,7 +807,7 @@ mod tests {
807807
let splits = &[size as i32; 4];
808808

809809
// Split input into seperate tensor for each of the gates.
810-
let ifco = split(&pool, x.view(), axis, &splits.into()).expect("split failed");
810+
let ifco = split(&pool, x.view(), axis, splits.as_slice().into()).expect("split failed");
811811

812812
// Recombine in a new gate order.
813813
concat(
@@ -831,7 +831,7 @@ mod tests {
831831
let splits = &[size as i32; 3];
832832

833833
// Split input into seperate tensor for each of the gates.
834-
let ruh = split(&pool, x.view(), axis, &splits.into()).expect("split failed");
834+
let ruh = split(&pool, x.view(), axis, splits.as_slice().into()).expect("split failed");
835835

836836
// Recombine in a new gate order.
837837
concat(&pool, &[ruh[1].view(), ruh[0].view(), ruh[2].view()], axis).expect("concat failed")

src/ops/split.rs

+164-65
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,82 @@
11
use rten_tensor::prelude::*;
22
use rten_tensor::{NdTensorView, Tensor, TensorView};
33

4+
use crate::iter_util::range_chunks;
45
use crate::ops::{
56
map_input, resolve_axis, static_dims, Input, InputList, OpError, Operator, OutputList,
67
};
78
use crate::tensor_pool::TensorPool;
89

10+
#[derive(Clone, Debug)]
11+
pub enum SplitSizes<'a> {
12+
/// Split a tensor into pieces with sizes specified by a vector. The sum of
13+
/// the piece sizes must match the size of the axis.
14+
Sizes(NdTensorView<'a, i32, 1>),
15+
/// Split a tensor into N equal-sized pieces. If the size of the axis being
16+
/// split is not evenly divisible by N, the last chunk will be smaller.
17+
NumSplits(u32),
18+
}
19+
20+
impl<'a> From<&'a [i32]> for SplitSizes<'a> {
21+
fn from(val: &'a [i32]) -> Self {
22+
Self::Sizes(val.into())
23+
}
24+
}
25+
926
pub fn split<T: Copy>(
1027
pool: &TensorPool,
1128
input: TensorView<T>,
1229
axis: isize,
13-
split: &NdTensorView<i32, 1>,
30+
split: SplitSizes,
1431
) -> Result<Vec<Tensor<T>>, OpError> {
1532
let axis = resolve_axis(input.ndim(), axis)?;
1633

17-
if split.iter().any(|size| *size < 0) {
18-
return Err(OpError::InvalidValue("Split sizes must be >= 0"));
19-
}
20-
let split_sum = split.iter().sum::<i32>() as usize;
21-
if split_sum != input.size(axis) {
22-
return Err(OpError::InvalidValue(
23-
"Split sizes do not sum to dimension size",
24-
));
25-
}
26-
27-
let mut split_start = 0;
28-
let outputs = split
29-
.iter()
30-
.map(|&split_size| {
31-
let split_size = split_size as usize;
32-
let split_range = split_start..split_start + split_size;
33-
split_start += split_size;
34-
input.slice_axis(axis, split_range).to_tensor_in(pool)
35-
})
36-
.collect();
34+
let outputs = match split {
35+
SplitSizes::Sizes(split) => {
36+
if split.iter().any(|size| *size < 0) {
37+
return Err(OpError::InvalidValue("Split sizes must be >= 0"));
38+
}
39+
let split_sum = split.iter().sum::<i32>() as usize;
40+
if split_sum != input.size(axis) {
41+
return Err(OpError::InvalidValue(
42+
"Split sizes do not sum to dimension size",
43+
));
44+
}
45+
46+
let mut split_start = 0;
47+
split
48+
.iter()
49+
.map(|&split_size| {
50+
let split_size = split_size as usize;
51+
let split_range = split_start..split_start + split_size;
52+
split_start += split_size;
53+
input.slice_axis(axis, split_range).to_tensor_in(pool)
54+
})
55+
.collect()
56+
}
57+
SplitSizes::NumSplits(n_splits) => {
58+
let n_splits = n_splits as usize;
59+
if n_splits == 0 {
60+
return Err(OpError::InvalidValue("num_outputs must be > 0"));
61+
}
62+
let dim_size = input.size(axis);
63+
if n_splits > dim_size {
64+
return Err(OpError::InvalidValue("num_outputs exceeds dim size"));
65+
}
66+
let chunk_size = dim_size.div_ceil(n_splits);
67+
range_chunks(0..dim_size, chunk_size)
68+
.map(|chunk| input.slice_axis(axis, chunk).to_tensor_in(pool))
69+
.collect()
70+
}
71+
};
3772

3873
Ok(outputs)
3974
}
4075

4176
#[derive(Debug)]
4277
pub struct Split {
4378
pub axis: isize,
79+
pub num_outputs: Option<u32>,
4480
}
4581

4682
impl Operator for Split {
@@ -50,11 +86,21 @@ impl Operator for Split {
5086

5187
fn run(&self, pool: &TensorPool, inputs: InputList) -> Result<OutputList, OpError> {
5288
let input = inputs.require(0)?;
53-
let splits = inputs.require_as::<i32>(1)?;
54-
let splits = static_dims!(splits, 1)?;
89+
let splits = inputs.get_as::<i32>(1)?;
90+
91+
let split_sizes = if let Some(splits) = splits {
92+
let splits = static_dims!(splits, 1)?;
93+
SplitSizes::Sizes(splits)
94+
} else if let Some(num_outputs) = self.num_outputs {
95+
SplitSizes::NumSplits(num_outputs)
96+
} else {
97+
return Err(OpError::InvalidValue(
98+
"Either `num_outputs` or `splits` must be set",
99+
));
100+
};
55101

56102
map_input!(input, x, {
57-
split(pool, x, self.axis, &splits)
103+
split(pool, x, self.axis, split_sizes)
58104
.map(|tensors| tensors.into_iter().map(|t| t.into()).collect())
59105
})
60106
}
@@ -64,60 +110,113 @@ impl Operator for Split {
64110
mod tests {
65111
use rten_tensor::prelude::*;
66112
use rten_tensor::Tensor;
113+
use rten_testing::TestCases;
67114

68115
use crate::ops::tests::new_pool;
69116
use crate::ops::{split, OpError};
70117

118+
use super::SplitSizes;
119+
71120
#[test]
72121
fn test_split() {
73-
let pool = new_pool();
74-
75122
let input = Tensor::from([[0., 1.], [2., 3.], [4., 5.], [6., 7.], [8., 9.]]);
76123

77-
// Split with positive axis
78-
let splits = &[1, 1];
79-
let results = split(&pool, input.view(), 1, &splits.into()).unwrap();
80-
81-
assert_eq!(results.len(), 2);
82-
assert_eq!(results[0].data().unwrap(), &[0., 2., 4., 6., 8.]);
83-
assert_eq!(results[1].data().unwrap(), &[1., 3., 5., 7., 9.]);
84-
85-
// Split with negative axis
86-
let splits = &[1, 1];
87-
let results = split(&pool, input.view(), -1, &splits.into()).unwrap();
88-
89-
assert_eq!(results.len(), 2);
90-
assert_eq!(results[0].data().unwrap(), &[0., 2., 4., 6., 8.]);
91-
assert_eq!(results[1].data().unwrap(), &[1., 3., 5., 7., 9.]);
124+
#[derive(Debug)]
125+
struct Case<'a> {
126+
axis: isize,
127+
splits: SplitSizes<'a>,
128+
expected: Vec<Tensor>,
129+
}
130+
131+
let cases = [
132+
// Positive axis
133+
Case {
134+
axis: 1,
135+
splits: [1, 1].as_slice().into(),
136+
expected: [
137+
Tensor::from([[0.], [2.], [4.], [6.], [8.]]),
138+
Tensor::from([[1.], [3.], [5.], [7.], [9.]]),
139+
]
140+
.into(),
141+
},
142+
// Negative axis
143+
Case {
144+
axis: -1,
145+
splits: [1, 1].as_slice().into(),
146+
expected: [
147+
Tensor::from([[0.], [2.], [4.], [6.], [8.]]),
148+
Tensor::from([[1.], [3.], [5.], [7.], [9.]]),
149+
]
150+
.into(),
151+
},
152+
// Splits specified as count
153+
Case {
154+
axis: 0,
155+
splits: SplitSizes::NumSplits(3),
156+
expected: [
157+
Tensor::from([[0., 1.], [2., 3.]]),
158+
Tensor::from([[4., 5.], [6., 7.]]),
159+
Tensor::from([[8., 9.]]),
160+
]
161+
.into(),
162+
},
163+
];
164+
165+
cases.test_each(|case| {
166+
let pool = new_pool();
167+
let results = split(&pool, input.view(), case.axis, case.splits.clone()).unwrap();
168+
let expected_splits = match case.splits {
169+
SplitSizes::NumSplits(n) => n as usize,
170+
SplitSizes::Sizes(sizes) => sizes.len(),
171+
};
172+
assert_eq!(results.len(), expected_splits);
173+
assert_eq!(results, case.expected);
174+
})
92175
}
93176

94177
#[test]
95178
fn test_split_invalid_inputs() {
96-
let pool = new_pool();
97-
98179
let input = Tensor::from([[0., 1.], [2., 3.], [4., 5.], [6., 7.], [8., 9.]]);
99180

100-
let splits = &[1, 1];
101-
let result = split(&pool, input.view(), 2, &splits.into());
102-
assert_eq!(result.err(), Some(OpError::InvalidValue("Axis is invalid")));
103-
104-
let result = split(&pool, input.view(), -3, &splits.into());
105-
assert_eq!(result.err(), Some(OpError::InvalidValue("Axis is invalid")));
106-
107-
let splits = &[1, 2];
108-
let result = split(&pool, input.view(), 1, &splits.into());
109-
assert_eq!(
110-
result.err(),
111-
Some(OpError::InvalidValue(
112-
"Split sizes do not sum to dimension size"
113-
))
114-
);
115-
116-
let splits = &[1, -2];
117-
let result = split(&pool, input.view(), 1, &splits.into());
118-
assert_eq!(
119-
result.err(),
120-
Some(OpError::InvalidValue("Split sizes must be >= 0"))
121-
);
181+
#[derive(Debug)]
182+
struct Case<'a> {
183+
axis: isize,
184+
splits: SplitSizes<'a>,
185+
expected: OpError,
186+
}
187+
188+
let cases = [
189+
Case {
190+
axis: 2,
191+
splits: [1, 1].as_slice().into(),
192+
expected: OpError::InvalidValue("Axis is invalid"),
193+
},
194+
Case {
195+
axis: 1,
196+
splits: [1, 2].as_slice().into(),
197+
expected: OpError::InvalidValue("Split sizes do not sum to dimension size"),
198+
},
199+
Case {
200+
axis: 1,
201+
splits: [1, -2].as_slice().into(),
202+
expected: OpError::InvalidValue("Split sizes must be >= 0"),
203+
},
204+
Case {
205+
axis: 1,
206+
splits: SplitSizes::NumSplits(0),
207+
expected: OpError::InvalidValue("num_outputs must be > 0"),
208+
},
209+
Case {
210+
axis: 1,
211+
splits: SplitSizes::NumSplits(3),
212+
expected: OpError::InvalidValue("num_outputs exceeds dim size"),
213+
},
214+
];
215+
216+
cases.test_each(|case| {
217+
let pool = new_pool();
218+
let result = split(&pool, input.view(), case.axis, case.splits.clone());
219+
assert_eq!(result.err().as_ref(), Some(&case.expected));
220+
})
122221
}
123222
}

src/schema.fbs

+1
Original file line numberDiff line numberDiff line change
@@ -485,6 +485,7 @@ table SoftmaxAttrs {
485485

486486
table SplitAttrs {
487487
axis:int;
488+
num_outputs:int = null;
488489
}
489490

490491
table TopKAttrs {

0 commit comments

Comments
 (0)