diff --git a/rten-convert/rten_convert/converter.py b/rten-convert/rten_convert/converter.py index b3ab6c9d..2ecbdb52 100644 --- a/rten-convert/rten_convert/converter.py +++ b/rten-convert/rten_convert/converter.py @@ -653,7 +653,7 @@ def op_node_from_onnx_operator( case "Split": attrs = sg.SplitAttrsT() attrs.axis = attr_reader.get_attr("axis", "int", 0) - attr_reader.check_attr("num_outputs", "int", 0) + attrs.numOutputs = attr_reader.get_attr("num_outputs", "int", None) attr_reader.generate_input_from_attr(1, "split", "ints") case "Squeeze": diff --git a/rten-convert/rten_convert/schema_generated.py b/rten-convert/rten_convert/schema_generated.py index f446bc71..e62a8b5f 100644 --- a/rten-convert/rten_convert/schema_generated.py +++ b/rten-convert/rten_convert/schema_generated.py @@ -4911,12 +4911,22 @@ def Axis(self): return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) return 0 + # SplitAttrs + def NumOutputs(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return None + def SplitAttrsStart(builder): - builder.StartObject(1) + builder.StartObject(2) def SplitAttrsAddAxis(builder, axis): builder.PrependInt32Slot(0, axis, 0) +def SplitAttrsAddNumOutputs(builder, numOutputs): + builder.PrependInt32Slot(1, numOutputs, None) + def SplitAttrsEnd(builder): return builder.EndObject() @@ -4927,6 +4937,7 @@ class SplitAttrsT(object): # SplitAttrsT def __init__(self): self.axis = 0 # type: int + self.numOutputs = None # type: Optional[int] @classmethod def InitFromBuf(cls, buf, pos): @@ -4950,11 +4961,13 @@ def _UnPack(self, splitAttrs): if splitAttrs is None: return self.axis = splitAttrs.Axis() + self.numOutputs = splitAttrs.NumOutputs() # SplitAttrsT def Pack(self, builder): SplitAttrsStart(builder) SplitAttrsAddAxis(builder, self.axis) + SplitAttrsAddNumOutputs(builder, self.numOutputs) splitAttrs = SplitAttrsEnd(builder) return splitAttrs diff --git a/src/model.rs b/src/model.rs index af6338b3..37111566 100644 --- a/src/model.rs +++ b/src/model.rs @@ -1646,7 +1646,10 @@ mod tests { let split_out_2 = graph_builder.add_value("Split_out_2", None, None); graph_builder.add_operator( "Split", - OpType::Split(ops::Split { axis: 1 }), + OpType::Split(ops::Split { + axis: 1, + num_outputs: None, + }), &[input_2d, split_splits].map(Some), &[split_out_1, split_out_2], ); diff --git a/src/model_builder.rs b/src/model_builder.rs index 7c865805..017014ac 100644 --- a/src/model_builder.rs +++ b/src/model_builder.rs @@ -858,6 +858,7 @@ impl<'mb, 'a> GraphBuilder<'mb, 'a> { OpType::Split(args) => op_with_attrs!(Split, SplitAttrs, { sg::SplitAttrsArgs { axis: args.axis as i32, + num_outputs: args.num_outputs.map(|n| n as i32), } }), OpType::Sqrt => op!(Sqrt), diff --git a/src/op_registry.rs b/src/op_registry.rs index 17262c2e..23375193 100644 --- a/src/op_registry.rs +++ b/src/op_registry.rs @@ -880,7 +880,11 @@ impl_read_op!(Size); impl_read_op!(Slice); impl_read_op!(Softmax, attrs_as_softmax_attrs, axis); impl_read_op!(Softplus); -impl_read_op!(Split, attrs_as_split_attrs, axis); +impl_read_op!(Split, attrs_as_split_attrs, |attrs: sg::SplitAttrs| { + let axis = attrs.axis() as isize; + let num_outputs = attrs.num_outputs().map(|n| n as u32); + Ok(ops::Split { axis, num_outputs }) +}); impl_read_op!(Sqrt); impl_read_op!(Squeeze); impl_read_op!(Sub); diff --git a/src/ops/rnn.rs b/src/ops/rnn.rs index fd0ca0cb..9c562330 100644 --- a/src/ops/rnn.rs +++ b/src/ops/rnn.rs @@ -807,7 +807,7 @@ mod tests { let splits = &[size as i32; 4]; // Split input into seperate tensor for each of the gates. - let ifco = split(&pool, x.view(), axis, &splits.into()).expect("split failed"); + let ifco = split(&pool, x.view(), axis, splits.as_slice().into()).expect("split failed"); // Recombine in a new gate order. concat( @@ -831,7 +831,7 @@ mod tests { let splits = &[size as i32; 3]; // Split input into seperate tensor for each of the gates. - let ruh = split(&pool, x.view(), axis, &splits.into()).expect("split failed"); + let ruh = split(&pool, x.view(), axis, splits.as_slice().into()).expect("split failed"); // Recombine in a new gate order. concat(&pool, &[ruh[1].view(), ruh[0].view(), ruh[2].view()], axis).expect("concat failed") diff --git a/src/ops/split.rs b/src/ops/split.rs index c070aa7f..4d8f7cca 100644 --- a/src/ops/split.rs +++ b/src/ops/split.rs @@ -1,39 +1,74 @@ use rten_tensor::prelude::*; use rten_tensor::{NdTensorView, Tensor, TensorView}; +use crate::iter_util::range_chunks; use crate::ops::{ map_input, resolve_axis, static_dims, Input, InputList, OpError, Operator, OutputList, }; use crate::tensor_pool::TensorPool; +#[derive(Clone, Debug)] +pub enum SplitSizes<'a> { + /// Split a tensor into pieces with sizes specified by a vector. The sum of + /// the piece sizes must match the size of the axis. + Sizes(NdTensorView<'a, i32, 1>), + /// Split a tensor into N equal-sized pieces. If the size of the axis being + /// split is not evenly divisible by N, the last chunk will be smaller. + NumSplits(u32), +} + +impl<'a> From<&'a [i32]> for SplitSizes<'a> { + fn from(val: &'a [i32]) -> Self { + Self::Sizes(val.into()) + } +} + pub fn split( pool: &TensorPool, input: TensorView, axis: isize, - split: &NdTensorView, + split: SplitSizes, ) -> Result>, OpError> { let axis = resolve_axis(input.ndim(), axis)?; - if split.iter().any(|size| *size < 0) { - return Err(OpError::InvalidValue("Split sizes must be >= 0")); - } - let split_sum = split.iter().sum::() as usize; - if split_sum != input.size(axis) { - return Err(OpError::InvalidValue( - "Split sizes do not sum to dimension size", - )); - } - - let mut split_start = 0; - let outputs = split - .iter() - .map(|&split_size| { - let split_size = split_size as usize; - let split_range = split_start..split_start + split_size; - split_start += split_size; - input.slice_axis(axis, split_range).to_tensor_in(pool) - }) - .collect(); + let outputs = match split { + SplitSizes::Sizes(split) => { + if split.iter().any(|size| *size < 0) { + return Err(OpError::InvalidValue("Split sizes must be >= 0")); + } + let split_sum = split.iter().sum::() as usize; + if split_sum != input.size(axis) { + return Err(OpError::InvalidValue( + "Split sizes do not sum to dimension size", + )); + } + + let mut split_start = 0; + split + .iter() + .map(|&split_size| { + let split_size = split_size as usize; + let split_range = split_start..split_start + split_size; + split_start += split_size; + input.slice_axis(axis, split_range).to_tensor_in(pool) + }) + .collect() + } + SplitSizes::NumSplits(n_splits) => { + let n_splits = n_splits as usize; + if n_splits == 0 { + return Err(OpError::InvalidValue("num_outputs must be > 0")); + } + let dim_size = input.size(axis); + if n_splits > dim_size { + return Err(OpError::InvalidValue("num_outputs exceeds dim size")); + } + let chunk_size = dim_size.div_ceil(n_splits); + range_chunks(0..dim_size, chunk_size) + .map(|chunk| input.slice_axis(axis, chunk).to_tensor_in(pool)) + .collect() + } + }; Ok(outputs) } @@ -41,6 +76,7 @@ pub fn split( #[derive(Debug)] pub struct Split { pub axis: isize, + pub num_outputs: Option, } impl Operator for Split { @@ -50,11 +86,21 @@ impl Operator for Split { fn run(&self, pool: &TensorPool, inputs: InputList) -> Result { let input = inputs.require(0)?; - let splits = inputs.require_as::(1)?; - let splits = static_dims!(splits, 1)?; + let splits = inputs.get_as::(1)?; + + let split_sizes = if let Some(splits) = splits { + let splits = static_dims!(splits, 1)?; + SplitSizes::Sizes(splits) + } else if let Some(num_outputs) = self.num_outputs { + SplitSizes::NumSplits(num_outputs) + } else { + return Err(OpError::InvalidValue( + "Either `num_outputs` or `splits` must be set", + )); + }; map_input!(input, x, { - split(pool, x, self.axis, &splits) + split(pool, x, self.axis, split_sizes) .map(|tensors| tensors.into_iter().map(|t| t.into()).collect()) }) } @@ -64,60 +110,113 @@ impl Operator for Split { mod tests { use rten_tensor::prelude::*; use rten_tensor::Tensor; + use rten_testing::TestCases; use crate::ops::tests::new_pool; use crate::ops::{split, OpError}; + use super::SplitSizes; + #[test] fn test_split() { - let pool = new_pool(); - let input = Tensor::from([[0., 1.], [2., 3.], [4., 5.], [6., 7.], [8., 9.]]); - // Split with positive axis - let splits = &[1, 1]; - let results = split(&pool, input.view(), 1, &splits.into()).unwrap(); - - assert_eq!(results.len(), 2); - assert_eq!(results[0].data().unwrap(), &[0., 2., 4., 6., 8.]); - assert_eq!(results[1].data().unwrap(), &[1., 3., 5., 7., 9.]); - - // Split with negative axis - let splits = &[1, 1]; - let results = split(&pool, input.view(), -1, &splits.into()).unwrap(); - - assert_eq!(results.len(), 2); - assert_eq!(results[0].data().unwrap(), &[0., 2., 4., 6., 8.]); - assert_eq!(results[1].data().unwrap(), &[1., 3., 5., 7., 9.]); + #[derive(Debug)] + struct Case<'a> { + axis: isize, + splits: SplitSizes<'a>, + expected: Vec, + } + + let cases = [ + // Positive axis + Case { + axis: 1, + splits: [1, 1].as_slice().into(), + expected: [ + Tensor::from([[0.], [2.], [4.], [6.], [8.]]), + Tensor::from([[1.], [3.], [5.], [7.], [9.]]), + ] + .into(), + }, + // Negative axis + Case { + axis: -1, + splits: [1, 1].as_slice().into(), + expected: [ + Tensor::from([[0.], [2.], [4.], [6.], [8.]]), + Tensor::from([[1.], [3.], [5.], [7.], [9.]]), + ] + .into(), + }, + // Splits specified as count + Case { + axis: 0, + splits: SplitSizes::NumSplits(3), + expected: [ + Tensor::from([[0., 1.], [2., 3.]]), + Tensor::from([[4., 5.], [6., 7.]]), + Tensor::from([[8., 9.]]), + ] + .into(), + }, + ]; + + cases.test_each(|case| { + let pool = new_pool(); + let results = split(&pool, input.view(), case.axis, case.splits.clone()).unwrap(); + let expected_splits = match case.splits { + SplitSizes::NumSplits(n) => n as usize, + SplitSizes::Sizes(sizes) => sizes.len(), + }; + assert_eq!(results.len(), expected_splits); + assert_eq!(results, case.expected); + }) } #[test] fn test_split_invalid_inputs() { - let pool = new_pool(); - let input = Tensor::from([[0., 1.], [2., 3.], [4., 5.], [6., 7.], [8., 9.]]); - let splits = &[1, 1]; - let result = split(&pool, input.view(), 2, &splits.into()); - assert_eq!(result.err(), Some(OpError::InvalidValue("Axis is invalid"))); - - let result = split(&pool, input.view(), -3, &splits.into()); - assert_eq!(result.err(), Some(OpError::InvalidValue("Axis is invalid"))); - - let splits = &[1, 2]; - let result = split(&pool, input.view(), 1, &splits.into()); - assert_eq!( - result.err(), - Some(OpError::InvalidValue( - "Split sizes do not sum to dimension size" - )) - ); - - let splits = &[1, -2]; - let result = split(&pool, input.view(), 1, &splits.into()); - assert_eq!( - result.err(), - Some(OpError::InvalidValue("Split sizes must be >= 0")) - ); + #[derive(Debug)] + struct Case<'a> { + axis: isize, + splits: SplitSizes<'a>, + expected: OpError, + } + + let cases = [ + Case { + axis: 2, + splits: [1, 1].as_slice().into(), + expected: OpError::InvalidValue("Axis is invalid"), + }, + Case { + axis: 1, + splits: [1, 2].as_slice().into(), + expected: OpError::InvalidValue("Split sizes do not sum to dimension size"), + }, + Case { + axis: 1, + splits: [1, -2].as_slice().into(), + expected: OpError::InvalidValue("Split sizes must be >= 0"), + }, + Case { + axis: 1, + splits: SplitSizes::NumSplits(0), + expected: OpError::InvalidValue("num_outputs must be > 0"), + }, + Case { + axis: 1, + splits: SplitSizes::NumSplits(3), + expected: OpError::InvalidValue("num_outputs exceeds dim size"), + }, + ]; + + cases.test_each(|case| { + let pool = new_pool(); + let result = split(&pool, input.view(), case.axis, case.splits.clone()); + assert_eq!(result.err().as_ref(), Some(&case.expected)); + }) } } diff --git a/src/schema.fbs b/src/schema.fbs index e0c80ccc..30a42769 100644 --- a/src/schema.fbs +++ b/src/schema.fbs @@ -485,6 +485,7 @@ table SoftmaxAttrs { table SplitAttrs { axis:int; + num_outputs:int = null; } table TopKAttrs { diff --git a/src/schema_generated.rs b/src/schema_generated.rs index 30c29812..5bfbe6a2 100644 --- a/src/schema_generated.rs +++ b/src/schema_generated.rs @@ -7865,6 +7865,7 @@ impl<'a> flatbuffers::Follow<'a> for SplitAttrs<'a> { impl<'a> SplitAttrs<'a> { pub const VT_AXIS: flatbuffers::VOffsetT = 4; + pub const VT_NUM_OUTPUTS: flatbuffers::VOffsetT = 6; #[inline] pub unsafe fn init_from_table(table: flatbuffers::Table<'a>) -> Self { @@ -7876,6 +7877,9 @@ impl<'a> SplitAttrs<'a> { args: &'args SplitAttrsArgs, ) -> flatbuffers::WIPOffset> { let mut builder = SplitAttrsBuilder::new(_fbb); + if let Some(x) = args.num_outputs { + builder.add_num_outputs(x); + } builder.add_axis(args.axis); builder.finish() } @@ -7887,6 +7891,13 @@ impl<'a> SplitAttrs<'a> { // which contains a valid value in this slot unsafe { self._tab.get::(SplitAttrs::VT_AXIS, Some(0)).unwrap() } } + #[inline] + pub fn num_outputs(&self) -> Option { + // Safety: + // Created from valid Table for this object + // which contains a valid value in this slot + unsafe { self._tab.get::(SplitAttrs::VT_NUM_OUTPUTS, None) } + } } impl flatbuffers::Verifiable for SplitAttrs<'_> { @@ -7898,17 +7909,22 @@ impl flatbuffers::Verifiable for SplitAttrs<'_> { use self::flatbuffers::Verifiable; v.visit_table(pos)? .visit_field::("axis", Self::VT_AXIS, false)? + .visit_field::("num_outputs", Self::VT_NUM_OUTPUTS, false)? .finish(); Ok(()) } } pub struct SplitAttrsArgs { pub axis: i32, + pub num_outputs: Option, } impl<'a> Default for SplitAttrsArgs { #[inline] fn default() -> Self { - SplitAttrsArgs { axis: 0 } + SplitAttrsArgs { + axis: 0, + num_outputs: None, + } } } @@ -7922,6 +7938,11 @@ impl<'a: 'b, 'b, A: flatbuffers::Allocator + 'a> SplitAttrsBuilder<'a, 'b, A> { self.fbb_.push_slot::(SplitAttrs::VT_AXIS, axis, 0); } #[inline] + pub fn add_num_outputs(&mut self, num_outputs: i32) { + self.fbb_ + .push_slot_always::(SplitAttrs::VT_NUM_OUTPUTS, num_outputs); + } + #[inline] pub fn new( _fbb: &'b mut flatbuffers::FlatBufferBuilder<'a, A>, ) -> SplitAttrsBuilder<'a, 'b, A> { @@ -7942,6 +7963,7 @@ impl core::fmt::Debug for SplitAttrs<'_> { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { let mut ds = f.debug_struct("SplitAttrs"); ds.field("axis", &self.axis()); + ds.field("num_outputs", &self.num_outputs()); ds.finish() } }