Skip to content

Commit d6533da

Browse files
authored
ONNX Import: switch to rank inferencing, rename shape to static_shape, decouple tensor shape info (#3037)
1 parent 6d0db87 commit d6533da

25 files changed

+1213
-1024
lines changed

crates/burn-import/onnx-tests/tests/test_onnx.rs

+9-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#![no_std]
22

3+
extern crate alloc;
4+
35
/// Include generated models in the `model` directory in the target directory.
46
macro_rules! include_models {
57
($($model:ident),*) => {
@@ -2086,9 +2088,14 @@ mod tests {
20862088
let device = Default::default();
20872089
let model: unsqueeze::Model<Backend> = unsqueeze::Model::new(&device);
20882090
let input_shape = Shape::from([3, 4, 5]);
2089-
let expected_shape = Shape::from([1, 1, 3, 4, 5, 1]);
2091+
let expected_shape = Shape::from([1, 3, 1, 4, 5, 1]);
20902092
let input = Tensor::ones(input_shape, &device);
2091-
let output = model.forward(input);
2093+
2094+
// Note: The axes tensor must have rank 1 with a single element
2095+
// as the generated ONNX requires a 1D tensor for static shape operations
2096+
// see unsqueeze.onnx
2097+
let axes = Tensor::from_ints([2], &device);
2098+
let output = model.forward(input, axes);
20922099
assert_eq!(output.shape(), expected_shape);
20932100
}
20942101

crates/burn-import/src/burn/node/constant.rs

+19-39
Original file line numberDiff line numberDiff line change
@@ -117,33 +117,40 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for ConstantNode {
117117

118118
fn field_init(&self) -> Option<TokenStream> {
119119
match &self.value {
120-
ConstantValue::Tensor(tensor_type, _) => {
120+
ConstantValue::Tensor(tensor_type, data) => {
121121
let ty = tensor_type.ty();
122122
let name = Ident::new(self.name.as_ref(), Span::call_site());
123-
let shape = tensor_type.clone().shape.unwrap().to_tokens();
124-
let dim = tensor_type.rank.to_tokens();
123+
124+
assert_eq!(
125+
data.shape.len(),
126+
tensor_type.rank,
127+
"Tensor data shape does not match tensor type rank"
128+
);
129+
130+
let shape = data.shape.to_tokens();
131+
let rank = tensor_type.rank.to_tokens();
125132

126133
match tensor_type.kind {
127134
crate::burn::TensorKind::Int => Some(quote! {
128135
let #name: burn::module::Param<#ty> = burn::module::Param::uninitialized(
129136
burn::module::ParamId::new(),
130-
move |device, _require_grad| Tensor::<B, #dim, Int>::zeros(#shape, &device),
137+
move |device, _require_grad| Tensor::<B, #rank, Int>::zeros(#shape, &device),
131138
device.clone(),
132139
false
133140
);
134141
}),
135142
crate::burn::TensorKind::Float => Some(quote! {
136143
let #name: burn::module::Param<#ty> = burn::module::Param::uninitialized(
137144
burn::module::ParamId::new(),
138-
move |device, _require_grad| Tensor::<B, #dim>::zeros(#shape, &device),
145+
move |device, _require_grad| Tensor::<B, #rank>::zeros(#shape, &device),
139146
device.clone(),
140147
false,
141148
);
142149
}),
143150
crate::burn::TensorKind::Bool => Some(quote! {
144151
let #name: burn::module::Param<#ty> = burn::module::Param::uninitialized(
145152
burn::module::ParamId::new(),
146-
move |device, _require_grad| Tensor::<B, #dim, Bool>::empty(#shape, &device),
153+
move |device, _require_grad| Tensor::<B, #rank, Bool>::empty(#shape, &device),
147154
device.clone(),
148155
false,
149156
);
@@ -288,23 +295,14 @@ mod tests {
288295

289296
let const_tensor = Ident::new("const_tensor", Span::call_site());
290297
let dimensions = 1;
291-
let shape = vec![4];
292298
let data = TensorData::from([2f32, 2f32, 2f32, 2f32]);
293-
let tensor_type = TensorType::new_float_with_shape(
294-
const_tensor.to_string(),
295-
dimensions,
296-
Some(shape.clone()),
297-
);
299+
let tensor_type = TensorType::new_float(const_tensor.to_string(), dimensions);
298300
let constant = ConstantValue::Tensor(tensor_type.clone(), data);
299301

300302
graph.register(ConstantNode::new(
301303
const_tensor.to_string(),
302304
constant.clone(),
303-
Type::Tensor(TensorType::new_float_with_shape(
304-
"output",
305-
dimensions,
306-
Some(shape.clone()),
307-
)),
305+
Type::Tensor(TensorType::new_float("output", dimensions)),
308306
));
309307

310308
graph.register_input_output(vec![], vec!["output".to_string()]);
@@ -356,23 +354,14 @@ mod tests {
356354

357355
let const_tensor = Ident::new("const_tensor_int", Span::call_site());
358356
let dimensions = 1;
359-
let shape = vec![3];
360357
let data = TensorData::from([1i32, 2i32, 3i32]);
361-
let tensor_type = TensorType::new_int_with_shape(
362-
const_tensor.to_string(),
363-
dimensions,
364-
Some(shape.clone()),
365-
);
358+
let tensor_type = TensorType::new_int(const_tensor.to_string(), dimensions);
366359
let constant = ConstantValue::Tensor(tensor_type.clone(), data);
367360

368361
graph.register(ConstantNode::new(
369362
const_tensor.to_string(),
370363
constant.clone(),
371-
Type::Tensor(TensorType::new_int_with_shape(
372-
"output",
373-
dimensions,
374-
Some(shape.clone()),
375-
)),
364+
Type::Tensor(TensorType::new_int("output", dimensions)),
376365
));
377366

378367
graph.register_input_output(vec![], vec!["output".to_string()]);
@@ -425,23 +414,14 @@ mod tests {
425414

426415
let const_tensor = Ident::new("const_tensor_3d", Span::call_site());
427416
let dimensions = 3;
428-
let shape = vec![1, 3, 2];
429417
let data = TensorData::from([[[true, false], [true, false], [true, false]]]);
430-
let tensor_type = TensorType::new_bool_with_shape(
431-
const_tensor.to_string(),
432-
dimensions,
433-
Some(shape.clone()),
434-
);
418+
let tensor_type = TensorType::new_bool(const_tensor.to_string(), dimensions);
435419
let constant = ConstantValue::Tensor(tensor_type.clone(), data);
436420

437421
graph.register(ConstantNode::new(
438422
const_tensor.to_string(),
439423
constant.clone(),
440-
Type::Tensor(TensorType::new_bool_with_shape(
441-
"output",
442-
dimensions,
443-
Some(shape.clone()),
444-
)),
424+
Type::Tensor(TensorType::new_bool("output", dimensions)),
445425
));
446426

447427
graph.register_input_output(vec![], vec!["output".to_string()]);

crates/burn-import/src/burn/node/expand.rs

+8-9
Original file line numberDiff line numberDiff line change
@@ -35,22 +35,22 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for ExpandNode {
3535
fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream {
3636
let input = scope.tensor_use_owned(&self.input, node_position);
3737
let output = &self.output.name;
38+
let output_rank = &self.output.rank;
3839

3940
let shape = match &self.shape {
4041
ExpandShape::Static(static_shape) => static_shape.to_tokens(),
4142
ExpandShape::Runtime(Type::Tensor(shape_tensor)) => {
42-
// since we don't take ownership of the shape_tensor, we don't need `tensor_use_owned` here:
43+
// Since we don't take ownership of the shape_tensor, `tensor_use_owned` is not needed here.
4344
let tensor_name = &shape_tensor.name;
44-
let dim = shape_tensor.shape.as_ref().unwrap()[0];
45-
// the shape of the tensor is already validated statically to be rank one when parsing the input
46-
// we'll need to download the Tensor from device to cpu for expand operation.
47-
// Also, we'll need to convert it to an array for conversion into BroadcastArgs
45+
// The shape of the tensor is statically validated to be rank one during input parsing.
46+
// The tensor must be downloaded from device to CPU for the expand operation.
47+
// Additionally, it needs to be converted to an array for use in BroadcastArgs.
4848
quote! {
49-
TryInto::<[B::IntElem; #dim]>::try_into(#tensor_name.to_data().as_slice::<B::IntElem>().unwrap()).unwrap()
49+
TryInto::<[B::IntElem; #output_rank]>::try_into(#tensor_name.to_data().as_slice::<B::IntElem>().unwrap()).unwrap()
5050
}
5151
}
5252
ExpandShape::Runtime(Type::Shape(shape)) => {
53-
// Shape implements BroadcastArgs, so it can be passed to expand directly
53+
// Shape implements BroadcastArgs, allowing it to be passed directly to the expand method.
5454
let shape_name = &shape.name;
5555
quote! { #shape_name }
5656
}
@@ -177,8 +177,7 @@ mod tests {
177177
fn test_codegen_expand_tensor() {
178178
let mut graph = BurnGraph::<FullPrecisionSettings>::default();
179179

180-
let mut shape_tensor_type = TensorType::new_int("tensor3", 4);
181-
shape_tensor_type.shape = Some(vec![4]);
180+
let shape_tensor_type = TensorType::new_int("tensor3", 4);
182181

183182
graph.register(ExpandNode::new(
184183
TensorType::new_float("tensor1", 4),

crates/burn-import/src/burn/node/random_normal.rs

+6-8
Original file line numberDiff line numberDiff line change
@@ -9,24 +9,21 @@ pub struct RandomNormalNode {
99
pub mean: f64,
1010
pub scale: f64,
1111
pub output_ty: TensorType,
12+
pub shape: Vec<usize>,
1213
}
1314

1415
impl RandomNormalNode {
15-
pub fn new(output_ty: TensorType, mean: f64, scale: f64) -> Self {
16+
pub fn new(output_ty: TensorType, mean: f64, scale: f64, shape: Vec<usize>) -> Self {
1617
Self {
1718
mean,
1819
scale,
1920
output_ty,
21+
shape,
2022
}
2123
}
2224

2325
fn get_output_shape(&self) -> TokenStream {
24-
let shape_it = self
25-
.output_ty
26-
.shape
27-
.as_ref()
28-
.expect("RandomNormal output has no shape!")
29-
.iter();
26+
let shape_it = self.shape.iter();
3027
quote! { Shape::new([#(#shape_it),*]) }
3128
}
3229

@@ -81,9 +78,10 @@ mod tests {
8178
let mut graph = BurnGraph::<FullPrecisionSettings>::default();
8279

8380
graph.register(RandomNormalNode::new(
84-
TensorType::new("tensor1", 2, TensorKind::Float, Some(vec![2, 3])),
81+
TensorType::new("tensor1", 2, TensorKind::Float),
8582
0.0f64,
8683
1.0f64,
84+
vec![2, 3],
8785
));
8886

8987
graph.register_input_output(vec![], vec!["tensor1".to_string()]);

crates/burn-import/src/burn/node/random_normal_like.rs

+3-2
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for RandomNormalLikeNode {
5050

5151
#[cfg(test)]
5252
mod tests {
53+
5354
use super::*;
5455
use crate::burn::{TensorKind, TensorType, graph::BurnGraph, node::test::assert_tokens};
5556
use burn::record::FullPrecisionSettings;
@@ -61,8 +62,8 @@ mod tests {
6162
graph.register(RandomNormalLikeNode::new(
6263
0.0f64,
6364
1.0f64,
64-
TensorType::new("input", 2, TensorKind::Float, Some(vec![2, 3])),
65-
TensorType::new("output", 2, TensorKind::Float, Some(vec![2, 3])),
65+
TensorType::new("input", 2, TensorKind::Float),
66+
TensorType::new("output", 2, TensorKind::Float),
6667
));
6768

6869
graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]);

crates/burn-import/src/burn/node/random_uniform.rs

+6-8
Original file line numberDiff line numberDiff line change
@@ -9,24 +9,21 @@ pub struct RandomUniformNode {
99
pub low: f64,
1010
pub high: f64,
1111
pub output_ty: TensorType,
12+
pub shape: Vec<usize>,
1213
}
1314

1415
impl RandomUniformNode {
15-
pub fn new(output_ty: TensorType, low: f64, high: f64) -> Self {
16+
pub fn new(output_ty: TensorType, low: f64, high: f64, shape: Vec<usize>) -> Self {
1617
Self {
1718
low,
1819
high,
1920
output_ty,
21+
shape,
2022
}
2123
}
2224

2325
fn get_output_shape(&self) -> TokenStream {
24-
let shape_it = self
25-
.output_ty
26-
.shape
27-
.as_ref()
28-
.expect("RandomUniform output has no shape!")
29-
.iter();
26+
let shape_it = self.shape.iter();
3027
quote! { Shape::new([#(#shape_it),*]) }
3128
}
3229

@@ -81,9 +78,10 @@ mod tests {
8178
let mut graph = BurnGraph::<FullPrecisionSettings>::default();
8279

8380
graph.register(RandomUniformNode::new(
84-
TensorType::new("tensor1", 2, TensorKind::Float, Some(vec![2, 3])),
81+
TensorType::new("tensor1", 2, TensorKind::Float),
8582
0.0f64,
8683
1.0f64,
84+
vec![2, 3],
8785
));
8886

8987
graph.register_input_output(vec![], vec!["tensor1".to_string()]);

crates/burn-import/src/burn/node/random_uniform_like.rs

+3-2
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for RandomUniformLikeNode {
5050

5151
#[cfg(test)]
5252
mod tests {
53+
5354
use super::*;
5455
use crate::burn::{TensorKind, TensorType, graph::BurnGraph, node::test::assert_tokens};
5556
use burn::record::FullPrecisionSettings;
@@ -61,8 +62,8 @@ mod tests {
6162
graph.register(RandomUniformLikeNode::new(
6263
0.0f64,
6364
1.0f64,
64-
TensorType::new("input", 2, TensorKind::Float, Some(vec![2, 3])),
65-
TensorType::new("output", 2, TensorKind::Float, Some(vec![2, 3])),
65+
TensorType::new("input", 2, TensorKind::Float),
66+
TensorType::new("output", 2, TensorKind::Float),
6667
));
6768

6869
graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]);

crates/burn-import/src/burn/node/split.rs

+3-3
Original file line numberDiff line numberDiff line change
@@ -48,14 +48,14 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for SplitNode {
4848
if let Some(split_sizes) = &self.config.split_sizes {
4949
let split_sizes_tokens = split_sizes.to_tokens();
5050
quote! {
51-
let mut split_tensors = #input.split_with_sizes(#split_sizes_tokens, #axis);
51+
let split_tensors = #input.split_with_sizes(#split_sizes_tokens.to_vec(), #axis);
5252
#unpack_outputs
5353
}
5454
} else {
5555
let split_size = &self.config.split_size.unwrap();
5656
let split_size_tokens = split_size.to_tokens();
5757
quote! {
58-
let mut split_tensors = #input.split(#split_size_tokens, #axis);
58+
let split_tensors = #input.split(#split_size_tokens, #axis);
5959
#unpack_outputs
6060
}
6161
}
@@ -125,7 +125,7 @@ mod tests {
125125
&self,
126126
tensor1: Tensor<B, 2>,
127127
) -> (Tensor<B, 2>, Tensor<B, 2>) {
128-
let mut split_tensors = tensor1.split(2, 0);
128+
let split_tensors = tensor1.split(2, 0);
129129

130130
let [tensor2, tensor3] = split_tensors.try_into().unwrap();
131131
(tensor2, tensor3)

0 commit comments

Comments
 (0)