1
1
use rten_tensor:: prelude:: * ;
2
2
use rten_tensor:: { NdTensorView , Tensor , TensorView } ;
3
3
4
+ use crate :: iter_util:: range_chunks;
4
5
use crate :: ops:: {
5
6
map_input, resolve_axis, static_dims, Input , InputList , OpError , Operator , OutputList ,
6
7
} ;
7
8
use crate :: tensor_pool:: TensorPool ;
8
9
10
+ #[ derive( Clone , Debug ) ]
11
+ pub enum SplitSizes < ' a > {
12
+ /// Split a tensor into pieces with sizes specified by a vector. The sum of
13
+ /// the piece sizes must match the size of the axis.
14
+ Sizes ( NdTensorView < ' a , i32 , 1 > ) ,
15
+ /// Split a tensor into N equal-sized pieces. If the size of the axis being
16
+ /// split is not evenly divisible by N, the last chunk will be smaller.
17
+ NumSplits ( u32 ) ,
18
+ }
19
+
20
+ impl < ' a > From < & ' a [ i32 ] > for SplitSizes < ' a > {
21
+ fn from ( val : & ' a [ i32 ] ) -> Self {
22
+ Self :: Sizes ( val. into ( ) )
23
+ }
24
+ }
25
+
9
26
pub fn split < T : Copy > (
10
27
pool : & TensorPool ,
11
28
input : TensorView < T > ,
12
29
axis : isize ,
13
- split : & NdTensorView < i32 , 1 > ,
30
+ split : SplitSizes ,
14
31
) -> Result < Vec < Tensor < T > > , OpError > {
15
32
let axis = resolve_axis ( input. ndim ( ) , axis) ?;
16
33
17
- if split. iter ( ) . any ( |size| * size < 0 ) {
18
- return Err ( OpError :: InvalidValue ( "Split sizes must be >= 0" ) ) ;
19
- }
20
- let split_sum = split. iter ( ) . sum :: < i32 > ( ) as usize ;
21
- if split_sum != input. size ( axis) {
22
- return Err ( OpError :: InvalidValue (
23
- "Split sizes do not sum to dimension size" ,
24
- ) ) ;
25
- }
26
-
27
- let mut split_start = 0 ;
28
- let outputs = split
29
- . iter ( )
30
- . map ( |& split_size| {
31
- let split_size = split_size as usize ;
32
- let split_range = split_start..split_start + split_size;
33
- split_start += split_size;
34
- input. slice_axis ( axis, split_range) . to_tensor_in ( pool)
35
- } )
36
- . collect ( ) ;
34
+ let outputs = match split {
35
+ SplitSizes :: Sizes ( split) => {
36
+ if split. iter ( ) . any ( |size| * size < 0 ) {
37
+ return Err ( OpError :: InvalidValue ( "Split sizes must be >= 0" ) ) ;
38
+ }
39
+ let split_sum = split. iter ( ) . sum :: < i32 > ( ) as usize ;
40
+ if split_sum != input. size ( axis) {
41
+ return Err ( OpError :: InvalidValue (
42
+ "Split sizes do not sum to dimension size" ,
43
+ ) ) ;
44
+ }
45
+
46
+ let mut split_start = 0 ;
47
+ split
48
+ . iter ( )
49
+ . map ( |& split_size| {
50
+ let split_size = split_size as usize ;
51
+ let split_range = split_start..split_start + split_size;
52
+ split_start += split_size;
53
+ input. slice_axis ( axis, split_range) . to_tensor_in ( pool)
54
+ } )
55
+ . collect ( )
56
+ }
57
+ SplitSizes :: NumSplits ( n_splits) => {
58
+ let n_splits = n_splits as usize ;
59
+ if n_splits == 0 {
60
+ return Err ( OpError :: InvalidValue ( "num_outputs must be > 0" ) ) ;
61
+ }
62
+ let dim_size = input. size ( axis) ;
63
+ if n_splits > dim_size {
64
+ return Err ( OpError :: InvalidValue ( "num_outputs exceeds dim size" ) ) ;
65
+ }
66
+ let chunk_size = dim_size. div_ceil ( n_splits) ;
67
+ range_chunks ( 0 ..dim_size, chunk_size)
68
+ . map ( |chunk| input. slice_axis ( axis, chunk) . to_tensor_in ( pool) )
69
+ . collect ( )
70
+ }
71
+ } ;
37
72
38
73
Ok ( outputs)
39
74
}
40
75
41
76
#[ derive( Debug ) ]
42
77
pub struct Split {
43
78
pub axis : isize ,
79
+ pub num_outputs : Option < u32 > ,
44
80
}
45
81
46
82
impl Operator for Split {
@@ -50,11 +86,21 @@ impl Operator for Split {
50
86
51
87
fn run ( & self , pool : & TensorPool , inputs : InputList ) -> Result < OutputList , OpError > {
52
88
let input = inputs. require ( 0 ) ?;
53
- let splits = inputs. require_as :: < i32 > ( 1 ) ?;
54
- let splits = static_dims ! ( splits, 1 ) ?;
89
+ let splits = inputs. get_as :: < i32 > ( 1 ) ?;
90
+
91
+ let split_sizes = if let Some ( splits) = splits {
92
+ let splits = static_dims ! ( splits, 1 ) ?;
93
+ SplitSizes :: Sizes ( splits)
94
+ } else if let Some ( num_outputs) = self . num_outputs {
95
+ SplitSizes :: NumSplits ( num_outputs)
96
+ } else {
97
+ return Err ( OpError :: InvalidValue (
98
+ "Either `num_outputs` or `splits` must be set" ,
99
+ ) ) ;
100
+ } ;
55
101
56
102
map_input ! ( input, x, {
57
- split( pool, x, self . axis, & splits )
103
+ split( pool, x, self . axis, split_sizes )
58
104
. map( |tensors| tensors. into_iter( ) . map( |t| t. into( ) ) . collect( ) )
59
105
} )
60
106
}
@@ -64,60 +110,113 @@ impl Operator for Split {
64
110
mod tests {
65
111
use rten_tensor:: prelude:: * ;
66
112
use rten_tensor:: Tensor ;
113
+ use rten_testing:: TestCases ;
67
114
68
115
use crate :: ops:: tests:: new_pool;
69
116
use crate :: ops:: { split, OpError } ;
70
117
118
+ use super :: SplitSizes ;
119
+
71
120
#[ test]
72
121
fn test_split ( ) {
73
- let pool = new_pool ( ) ;
74
-
75
122
let input = Tensor :: from ( [ [ 0. , 1. ] , [ 2. , 3. ] , [ 4. , 5. ] , [ 6. , 7. ] , [ 8. , 9. ] ] ) ;
76
123
77
- // Split with positive axis
78
- let splits = & [ 1 , 1 ] ;
79
- let results = split ( & pool, input. view ( ) , 1 , & splits. into ( ) ) . unwrap ( ) ;
80
-
81
- assert_eq ! ( results. len( ) , 2 ) ;
82
- assert_eq ! ( results[ 0 ] . data( ) . unwrap( ) , & [ 0. , 2. , 4. , 6. , 8. ] ) ;
83
- assert_eq ! ( results[ 1 ] . data( ) . unwrap( ) , & [ 1. , 3. , 5. , 7. , 9. ] ) ;
84
-
85
- // Split with negative axis
86
- let splits = & [ 1 , 1 ] ;
87
- let results = split ( & pool, input. view ( ) , -1 , & splits. into ( ) ) . unwrap ( ) ;
88
-
89
- assert_eq ! ( results. len( ) , 2 ) ;
90
- assert_eq ! ( results[ 0 ] . data( ) . unwrap( ) , & [ 0. , 2. , 4. , 6. , 8. ] ) ;
91
- assert_eq ! ( results[ 1 ] . data( ) . unwrap( ) , & [ 1. , 3. , 5. , 7. , 9. ] ) ;
124
+ #[ derive( Debug ) ]
125
+ struct Case < ' a > {
126
+ axis : isize ,
127
+ splits : SplitSizes < ' a > ,
128
+ expected : Vec < Tensor > ,
129
+ }
130
+
131
+ let cases = [
132
+ // Positive axis
133
+ Case {
134
+ axis : 1 ,
135
+ splits : [ 1 , 1 ] . as_slice ( ) . into ( ) ,
136
+ expected : [
137
+ Tensor :: from ( [ [ 0. ] , [ 2. ] , [ 4. ] , [ 6. ] , [ 8. ] ] ) ,
138
+ Tensor :: from ( [ [ 1. ] , [ 3. ] , [ 5. ] , [ 7. ] , [ 9. ] ] ) ,
139
+ ]
140
+ . into ( ) ,
141
+ } ,
142
+ // Negative axis
143
+ Case {
144
+ axis : -1 ,
145
+ splits : [ 1 , 1 ] . as_slice ( ) . into ( ) ,
146
+ expected : [
147
+ Tensor :: from ( [ [ 0. ] , [ 2. ] , [ 4. ] , [ 6. ] , [ 8. ] ] ) ,
148
+ Tensor :: from ( [ [ 1. ] , [ 3. ] , [ 5. ] , [ 7. ] , [ 9. ] ] ) ,
149
+ ]
150
+ . into ( ) ,
151
+ } ,
152
+ // Splits specified as count
153
+ Case {
154
+ axis : 0 ,
155
+ splits : SplitSizes :: NumSplits ( 3 ) ,
156
+ expected : [
157
+ Tensor :: from ( [ [ 0. , 1. ] , [ 2. , 3. ] ] ) ,
158
+ Tensor :: from ( [ [ 4. , 5. ] , [ 6. , 7. ] ] ) ,
159
+ Tensor :: from ( [ [ 8. , 9. ] ] ) ,
160
+ ]
161
+ . into ( ) ,
162
+ } ,
163
+ ] ;
164
+
165
+ cases. test_each ( |case| {
166
+ let pool = new_pool ( ) ;
167
+ let results = split ( & pool, input. view ( ) , case. axis , case. splits . clone ( ) ) . unwrap ( ) ;
168
+ let expected_splits = match case. splits {
169
+ SplitSizes :: NumSplits ( n) => n as usize ,
170
+ SplitSizes :: Sizes ( sizes) => sizes. len ( ) ,
171
+ } ;
172
+ assert_eq ! ( results. len( ) , expected_splits) ;
173
+ assert_eq ! ( results, case. expected) ;
174
+ } )
92
175
}
93
176
94
177
#[ test]
95
178
fn test_split_invalid_inputs ( ) {
96
- let pool = new_pool ( ) ;
97
-
98
179
let input = Tensor :: from ( [ [ 0. , 1. ] , [ 2. , 3. ] , [ 4. , 5. ] , [ 6. , 7. ] , [ 8. , 9. ] ] ) ;
99
180
100
- let splits = & [ 1 , 1 ] ;
101
- let result = split ( & pool, input. view ( ) , 2 , & splits. into ( ) ) ;
102
- assert_eq ! ( result. err( ) , Some ( OpError :: InvalidValue ( "Axis is invalid" ) ) ) ;
103
-
104
- let result = split ( & pool, input. view ( ) , -3 , & splits. into ( ) ) ;
105
- assert_eq ! ( result. err( ) , Some ( OpError :: InvalidValue ( "Axis is invalid" ) ) ) ;
106
-
107
- let splits = & [ 1 , 2 ] ;
108
- let result = split ( & pool, input. view ( ) , 1 , & splits. into ( ) ) ;
109
- assert_eq ! (
110
- result. err( ) ,
111
- Some ( OpError :: InvalidValue (
112
- "Split sizes do not sum to dimension size"
113
- ) )
114
- ) ;
115
-
116
- let splits = & [ 1 , -2 ] ;
117
- let result = split ( & pool, input. view ( ) , 1 , & splits. into ( ) ) ;
118
- assert_eq ! (
119
- result. err( ) ,
120
- Some ( OpError :: InvalidValue ( "Split sizes must be >= 0" ) )
121
- ) ;
181
+ #[ derive( Debug ) ]
182
+ struct Case < ' a > {
183
+ axis : isize ,
184
+ splits : SplitSizes < ' a > ,
185
+ expected : OpError ,
186
+ }
187
+
188
+ let cases = [
189
+ Case {
190
+ axis : 2 ,
191
+ splits : [ 1 , 1 ] . as_slice ( ) . into ( ) ,
192
+ expected : OpError :: InvalidValue ( "Axis is invalid" ) ,
193
+ } ,
194
+ Case {
195
+ axis : 1 ,
196
+ splits : [ 1 , 2 ] . as_slice ( ) . into ( ) ,
197
+ expected : OpError :: InvalidValue ( "Split sizes do not sum to dimension size" ) ,
198
+ } ,
199
+ Case {
200
+ axis : 1 ,
201
+ splits : [ 1 , -2 ] . as_slice ( ) . into ( ) ,
202
+ expected : OpError :: InvalidValue ( "Split sizes must be >= 0" ) ,
203
+ } ,
204
+ Case {
205
+ axis : 1 ,
206
+ splits : SplitSizes :: NumSplits ( 0 ) ,
207
+ expected : OpError :: InvalidValue ( "num_outputs must be > 0" ) ,
208
+ } ,
209
+ Case {
210
+ axis : 1 ,
211
+ splits : SplitSizes :: NumSplits ( 3 ) ,
212
+ expected : OpError :: InvalidValue ( "num_outputs exceeds dim size" ) ,
213
+ } ,
214
+ ] ;
215
+
216
+ cases. test_each ( |case| {
217
+ let pool = new_pool ( ) ;
218
+ let result = split ( & pool, input. view ( ) , case. axis , case. splits . clone ( ) ) ;
219
+ assert_eq ! ( result. err( ) . as_ref( ) , Some ( & case. expected) ) ;
220
+ } )
122
221
}
123
222
}
0 commit comments