Skip to content

Support num_outputs attribute in Split operator #658

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion rten-convert/rten_convert/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,7 +653,7 @@ def op_node_from_onnx_operator(
case "Split":
attrs = sg.SplitAttrsT()
attrs.axis = attr_reader.get_attr("axis", "int", 0)
attr_reader.check_attr("num_outputs", "int", 0)
attrs.numOutputs = attr_reader.get_attr("num_outputs", "int", None)
attr_reader.generate_input_from_attr(1, "split", "ints")

case "Squeeze":
Expand Down
15 changes: 14 additions & 1 deletion rten-convert/rten_convert/schema_generated.py
Original file line number Diff line number Diff line change
Expand Up @@ -4911,12 +4911,22 @@ def Axis(self):
return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos)
return 0

# SplitAttrs
def NumOutputs(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
if o != 0:
return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos)
return None

def SplitAttrsStart(builder):
builder.StartObject(1)
builder.StartObject(2)

def SplitAttrsAddAxis(builder, axis):
builder.PrependInt32Slot(0, axis, 0)

def SplitAttrsAddNumOutputs(builder, numOutputs):
builder.PrependInt32Slot(1, numOutputs, None)

def SplitAttrsEnd(builder):
return builder.EndObject()

Expand All @@ -4927,6 +4937,7 @@ class SplitAttrsT(object):
# SplitAttrsT
def __init__(self):
self.axis = 0 # type: int
self.numOutputs = None # type: Optional[int]

@classmethod
def InitFromBuf(cls, buf, pos):
Expand All @@ -4950,11 +4961,13 @@ def _UnPack(self, splitAttrs):
if splitAttrs is None:
return
self.axis = splitAttrs.Axis()
self.numOutputs = splitAttrs.NumOutputs()

# SplitAttrsT
def Pack(self, builder):
SplitAttrsStart(builder)
SplitAttrsAddAxis(builder, self.axis)
SplitAttrsAddNumOutputs(builder, self.numOutputs)
splitAttrs = SplitAttrsEnd(builder)
return splitAttrs

Expand Down
5 changes: 4 additions & 1 deletion src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1646,7 +1646,10 @@ mod tests {
let split_out_2 = graph_builder.add_value("Split_out_2", None, None);
graph_builder.add_operator(
"Split",
OpType::Split(ops::Split { axis: 1 }),
OpType::Split(ops::Split {
axis: 1,
num_outputs: None,
}),
&[input_2d, split_splits].map(Some),
&[split_out_1, split_out_2],
);
Expand Down
1 change: 1 addition & 0 deletions src/model_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -858,6 +858,7 @@ impl<'mb, 'a> GraphBuilder<'mb, 'a> {
OpType::Split(args) => op_with_attrs!(Split, SplitAttrs, {
sg::SplitAttrsArgs {
axis: args.axis as i32,
num_outputs: args.num_outputs.map(|n| n as i32),
}
}),
OpType::Sqrt => op!(Sqrt),
Expand Down
6 changes: 5 additions & 1 deletion src/op_registry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -880,7 +880,11 @@ impl_read_op!(Size);
impl_read_op!(Slice);
impl_read_op!(Softmax, attrs_as_softmax_attrs, axis);
impl_read_op!(Softplus);
impl_read_op!(Split, attrs_as_split_attrs, axis);
impl_read_op!(Split, attrs_as_split_attrs, |attrs: sg::SplitAttrs| {
let axis = attrs.axis() as isize;
let num_outputs = attrs.num_outputs().map(|n| n as u32);
Ok(ops::Split { axis, num_outputs })
});
impl_read_op!(Sqrt);
impl_read_op!(Squeeze);
impl_read_op!(Sub);
Expand Down
4 changes: 2 additions & 2 deletions src/ops/rnn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -807,7 +807,7 @@ mod tests {
let splits = &[size as i32; 4];

// Split input into seperate tensor for each of the gates.
let ifco = split(&pool, x.view(), axis, &splits.into()).expect("split failed");
let ifco = split(&pool, x.view(), axis, splits.as_slice().into()).expect("split failed");

// Recombine in a new gate order.
concat(
Expand All @@ -831,7 +831,7 @@ mod tests {
let splits = &[size as i32; 3];

// Split input into seperate tensor for each of the gates.
let ruh = split(&pool, x.view(), axis, &splits.into()).expect("split failed");
let ruh = split(&pool, x.view(), axis, splits.as_slice().into()).expect("split failed");

// Recombine in a new gate order.
concat(&pool, &[ruh[1].view(), ruh[0].view(), ruh[2].view()], axis).expect("concat failed")
Expand Down
229 changes: 164 additions & 65 deletions src/ops/split.rs
Original file line number Diff line number Diff line change
@@ -1,46 +1,82 @@
use rten_tensor::prelude::*;
use rten_tensor::{NdTensorView, Tensor, TensorView};

use crate::iter_util::range_chunks;
use crate::ops::{
map_input, resolve_axis, static_dims, Input, InputList, OpError, Operator, OutputList,
};
use crate::tensor_pool::TensorPool;

#[derive(Clone, Debug)]
pub enum SplitSizes<'a> {
/// Split a tensor into pieces with sizes specified by a vector. The sum of
/// the piece sizes must match the size of the axis.
Sizes(NdTensorView<'a, i32, 1>),
/// Split a tensor into N equal-sized pieces. If the size of the axis being
/// split is not evenly divisible by N, the last chunk will be smaller.
NumSplits(u32),
}

impl<'a> From<&'a [i32]> for SplitSizes<'a> {
fn from(val: &'a [i32]) -> Self {
Self::Sizes(val.into())
}
}

pub fn split<T: Copy>(
pool: &TensorPool,
input: TensorView<T>,
axis: isize,
split: &NdTensorView<i32, 1>,
split: SplitSizes,
) -> Result<Vec<Tensor<T>>, OpError> {
let axis = resolve_axis(input.ndim(), axis)?;

if split.iter().any(|size| *size < 0) {
return Err(OpError::InvalidValue("Split sizes must be >= 0"));
}
let split_sum = split.iter().sum::<i32>() as usize;
if split_sum != input.size(axis) {
return Err(OpError::InvalidValue(
"Split sizes do not sum to dimension size",
));
}

let mut split_start = 0;
let outputs = split
.iter()
.map(|&split_size| {
let split_size = split_size as usize;
let split_range = split_start..split_start + split_size;
split_start += split_size;
input.slice_axis(axis, split_range).to_tensor_in(pool)
})
.collect();
let outputs = match split {
SplitSizes::Sizes(split) => {
if split.iter().any(|size| *size < 0) {
return Err(OpError::InvalidValue("Split sizes must be >= 0"));
}
let split_sum = split.iter().sum::<i32>() as usize;
if split_sum != input.size(axis) {
return Err(OpError::InvalidValue(
"Split sizes do not sum to dimension size",
));
}

let mut split_start = 0;
split
.iter()
.map(|&split_size| {
let split_size = split_size as usize;
let split_range = split_start..split_start + split_size;
split_start += split_size;
input.slice_axis(axis, split_range).to_tensor_in(pool)
})
.collect()
}
SplitSizes::NumSplits(n_splits) => {
let n_splits = n_splits as usize;
if n_splits == 0 {
return Err(OpError::InvalidValue("num_outputs must be > 0"));
}
let dim_size = input.size(axis);
if n_splits > dim_size {
return Err(OpError::InvalidValue("num_outputs exceeds dim size"));
}
let chunk_size = dim_size.div_ceil(n_splits);
range_chunks(0..dim_size, chunk_size)
.map(|chunk| input.slice_axis(axis, chunk).to_tensor_in(pool))
.collect()
}
};

Ok(outputs)
}

#[derive(Debug)]
pub struct Split {
pub axis: isize,
pub num_outputs: Option<u32>,
}

impl Operator for Split {
Expand All @@ -50,11 +86,21 @@ impl Operator for Split {

fn run(&self, pool: &TensorPool, inputs: InputList) -> Result<OutputList, OpError> {
let input = inputs.require(0)?;
let splits = inputs.require_as::<i32>(1)?;
let splits = static_dims!(splits, 1)?;
let splits = inputs.get_as::<i32>(1)?;

let split_sizes = if let Some(splits) = splits {
let splits = static_dims!(splits, 1)?;
SplitSizes::Sizes(splits)
} else if let Some(num_outputs) = self.num_outputs {
SplitSizes::NumSplits(num_outputs)
} else {
return Err(OpError::InvalidValue(
"Either `num_outputs` or `splits` must be set",
));
};

map_input!(input, x, {
split(pool, x, self.axis, &splits)
split(pool, x, self.axis, split_sizes)
.map(|tensors| tensors.into_iter().map(|t| t.into()).collect())
})
}
Expand All @@ -64,60 +110,113 @@ impl Operator for Split {
mod tests {
use rten_tensor::prelude::*;
use rten_tensor::Tensor;
use rten_testing::TestCases;

use crate::ops::tests::new_pool;
use crate::ops::{split, OpError};

use super::SplitSizes;

#[test]
fn test_split() {
let pool = new_pool();

let input = Tensor::from([[0., 1.], [2., 3.], [4., 5.], [6., 7.], [8., 9.]]);

// Split with positive axis
let splits = &[1, 1];
let results = split(&pool, input.view(), 1, &splits.into()).unwrap();

assert_eq!(results.len(), 2);
assert_eq!(results[0].data().unwrap(), &[0., 2., 4., 6., 8.]);
assert_eq!(results[1].data().unwrap(), &[1., 3., 5., 7., 9.]);

// Split with negative axis
let splits = &[1, 1];
let results = split(&pool, input.view(), -1, &splits.into()).unwrap();

assert_eq!(results.len(), 2);
assert_eq!(results[0].data().unwrap(), &[0., 2., 4., 6., 8.]);
assert_eq!(results[1].data().unwrap(), &[1., 3., 5., 7., 9.]);
#[derive(Debug)]
struct Case<'a> {
axis: isize,
splits: SplitSizes<'a>,
expected: Vec<Tensor>,
}

let cases = [
// Positive axis
Case {
axis: 1,
splits: [1, 1].as_slice().into(),
expected: [
Tensor::from([[0.], [2.], [4.], [6.], [8.]]),
Tensor::from([[1.], [3.], [5.], [7.], [9.]]),
]
.into(),
},
// Negative axis
Case {
axis: -1,
splits: [1, 1].as_slice().into(),
expected: [
Tensor::from([[0.], [2.], [4.], [6.], [8.]]),
Tensor::from([[1.], [3.], [5.], [7.], [9.]]),
]
.into(),
},
// Splits specified as count
Case {
axis: 0,
splits: SplitSizes::NumSplits(3),
expected: [
Tensor::from([[0., 1.], [2., 3.]]),
Tensor::from([[4., 5.], [6., 7.]]),
Tensor::from([[8., 9.]]),
]
.into(),
},
];

cases.test_each(|case| {
let pool = new_pool();
let results = split(&pool, input.view(), case.axis, case.splits.clone()).unwrap();
let expected_splits = match case.splits {
SplitSizes::NumSplits(n) => n as usize,
SplitSizes::Sizes(sizes) => sizes.len(),
};
assert_eq!(results.len(), expected_splits);
assert_eq!(results, case.expected);
})
}

#[test]
fn test_split_invalid_inputs() {
let pool = new_pool();

let input = Tensor::from([[0., 1.], [2., 3.], [4., 5.], [6., 7.], [8., 9.]]);

let splits = &[1, 1];
let result = split(&pool, input.view(), 2, &splits.into());
assert_eq!(result.err(), Some(OpError::InvalidValue("Axis is invalid")));

let result = split(&pool, input.view(), -3, &splits.into());
assert_eq!(result.err(), Some(OpError::InvalidValue("Axis is invalid")));

let splits = &[1, 2];
let result = split(&pool, input.view(), 1, &splits.into());
assert_eq!(
result.err(),
Some(OpError::InvalidValue(
"Split sizes do not sum to dimension size"
))
);

let splits = &[1, -2];
let result = split(&pool, input.view(), 1, &splits.into());
assert_eq!(
result.err(),
Some(OpError::InvalidValue("Split sizes must be >= 0"))
);
#[derive(Debug)]
struct Case<'a> {
axis: isize,
splits: SplitSizes<'a>,
expected: OpError,
}

let cases = [
Case {
axis: 2,
splits: [1, 1].as_slice().into(),
expected: OpError::InvalidValue("Axis is invalid"),
},
Case {
axis: 1,
splits: [1, 2].as_slice().into(),
expected: OpError::InvalidValue("Split sizes do not sum to dimension size"),
},
Case {
axis: 1,
splits: [1, -2].as_slice().into(),
expected: OpError::InvalidValue("Split sizes must be >= 0"),
},
Case {
axis: 1,
splits: SplitSizes::NumSplits(0),
expected: OpError::InvalidValue("num_outputs must be > 0"),
},
Case {
axis: 1,
splits: SplitSizes::NumSplits(3),
expected: OpError::InvalidValue("num_outputs exceeds dim size"),
},
];

cases.test_each(|case| {
let pool = new_pool();
let result = split(&pool, input.view(), case.axis, case.splits.clone());
assert_eq!(result.err().as_ref(), Some(&case.expected));
})
}
}
1 change: 1 addition & 0 deletions src/schema.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,7 @@ table SoftmaxAttrs {

table SplitAttrs {
axis:int;
num_outputs:int = null;
}

table TopKAttrs {
Expand Down
Loading