Skip to content

Commit 0b03514

Browse files
committed
feat: onnx ceil & round
1 parent c51045c commit 0b03514

File tree

16 files changed

+375
-9
lines changed

16 files changed

+375
-9
lines changed

crates/burn-autodiff/src/ops/tensor.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2046,9 +2046,9 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
20462046
{
20472047
OpsKind::Tracked(preps) => preps.finish(
20482048
(tensor.primitive.shape(), B::float_device(&tensor.primitive)),
2049-
B::float_floor(tensor.primitive),
2049+
B::float_ceil(tensor.primitive),
20502050
),
2051-
OpsKind::UnTracked(preps) => preps.finish(B::float_floor(tensor.primitive)),
2051+
OpsKind::UnTracked(preps) => preps.finish(B::float_ceil(tensor.primitive)),
20522052
}
20532053
}
20542054

crates/burn-import/SUPPORTED-ONNX-OPS.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ functionality.
3333
| [BlackmanWindow][21] |||
3434
| [Cast][22] |||
3535
| [CastLike][23] |||
36-
| [Ceil][24] | ||
36+
| [Ceil][24] | ||
3737
| [Celu][25] |||
3838
| [CenterCropPad][26] |||
3939
| [Clip][27] |||
@@ -158,7 +158,7 @@ functionality.
158158
| [ReverseSequence][144] |||
159159
| [RNN][145] |||
160160
| [RoiAlign][146] |||
161-
| [Round][147] | ||
161+
| [Round][147] | ||
162162
| [Scan][148] |||
163163
| [Scatter][149] |||
164164
| [ScatterElements][150] |||

crates/burn-import/onnx-tests/build.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ fn main() {
1414
.input("tests/avg_pool2d/avg_pool2d.onnx")
1515
.input("tests/batch_norm/batch_norm.onnx")
1616
.input("tests/cast/cast.onnx")
17+
.input("tests/ceil/ceil.onnx")
1718
.input("tests/clip/clip.onnx")
1819
.input("tests/concat/concat.onnx")
1920
.input("tests/constant/constant_f32.onnx")
@@ -109,6 +110,7 @@ fn main() {
109110
.input("tests/resize/resize_2d_bicubic_scale.onnx")
110111
.input("tests/resize/resize_2d_bilinear_scale.onnx")
111112
.input("tests/resize/resize_2d_nearest_scale.onnx")
113+
.input("tests/round/round.onnx")
112114
.input("tests/shape/shape.onnx")
113115
.input("tests/sigmoid/sigmoid.onnx")
114116
.input("tests/sign/sign.onnx")
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
pytorch2.3.1:Y
2+

3+
onnx::Ceil_01/Ceil"Ceil
4+
main_graphZ
5+
onnx::Ceil_0
6+
7+

8+
b
9+
1
10+
11+

12+
B
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
#!/usr/bin/env python3
2+
3+
# Used to generate model: onnx-tests/tests/ceil/ceil.onnx
4+
5+
import torch
6+
import torch.nn as nn
7+
import onnx
8+
9+
class CeilModel(nn.Module):
10+
def __init__(self):
11+
super().__init__()
12+
13+
def forward(self, x):
14+
return torch.ceil(x)
15+
16+
def main():
17+
model = CeilModel()
18+
model.eval()
19+
20+
test_input = torch.tensor([-0.5, 1.5, 2.1])
21+
22+
onnx_file = "ceil.onnx"
23+
24+
torch.onnx.export(
25+
model,
26+
test_input,
27+
onnx_file,
28+
opset_version=16,
29+
)
30+
31+
print(f"Finished exporting model to {onnx_file}")
32+
print(f"Test input data of ones: {test_input}")
33+
print(f"Test input data shape of ones: {test_input.shape}")
34+
output = model.forward(test_input)
35+
print(f"Test output data shape: {output.shape}")
36+
print(f"Test output: {output}")
37+
38+
39+
if __name__ == '__main__':
40+
main()
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
use crate::include_models;
2+
include_models!(ceil);
3+
4+
#[cfg(test)]
5+
mod tests {
6+
use super::*;
7+
use burn::tensor::{Tensor, ops::FloatElem};
8+
9+
type Backend = burn_ndarray::NdArray<f32>;
10+
type FT = FloatElem<Backend>;
11+
12+
#[test]
13+
fn ceil_test() {
14+
// Test for ceil
15+
let device = Default::default();
16+
let model = ceil::Model::<Backend>::new(&device);
17+
18+
let input = Tensor::<Backend, 1>::from_floats([-0.5, 1.5, 2.1], &device);
19+
let expected = Tensor::<Backend, 1>::from_floats([0., 2., 3.], &device);
20+
21+
let output = model.forward(input);
22+
23+
output
24+
.to_data()
25+
.assert_approx_eq::<FT>(&expected.to_data(), burn::tensor::Tolerance::default());
26+
}
27+
}
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
use crate::include_models;
2+
include_models!(round);
3+
4+
#[cfg(test)]
5+
mod tests {
6+
use super::*;
7+
use burn::tensor::{Tensor, ops::FloatElem};
8+
9+
type Backend = burn_ndarray::NdArray<f32>;
10+
type FT = FloatElem<Backend>;
11+
12+
#[test]
13+
fn round_test() {
14+
// Test for round
15+
let device = Default::default();
16+
let model = round::Model::<Backend>::new(&device);
17+
18+
let input = Tensor::<Backend, 1>::from_floats([-0.5, 1.5, 2.1], &device);
19+
let expected = Tensor::<Backend, 1>::from_floats([0., 2., 2.], &device);
20+
21+
let output = model.forward(input);
22+
23+
output
24+
.to_data()
25+
.assert_approx_eq::<FT>(&expected.to_data(), burn::tensor::Tolerance::default());
26+
}
27+
}

crates/burn-import/onnx-tests/tests/round/round.onnx

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
pytorch2.3.1:]
2+
!
3+
onnx::Round_01/Round"Round
4+
main_graphZ
5+
onnx::Round_0
6+
7+

8+
b
9+
1
10+
11+

12+
B
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
#!/usr/bin/env python3
2+
3+
# Used to generate model: onnx-tests/tests/round/round.onnx
4+
5+
import torch
6+
import torch.nn as nn
7+
import onnx
8+
9+
class RoundModel(nn.Module):
10+
def __init__(self):
11+
super().__init__()
12+
13+
def forward(self, x):
14+
return torch.round(x)
15+
16+
def main():
17+
model = RoundModel()
18+
model.eval()
19+
20+
test_input = torch.tensor([-0.5, 1.5, 2.1])
21+
22+
onnx_file = "round.onnx"
23+
24+
torch.onnx.export(
25+
model,
26+
test_input,
27+
onnx_file,
28+
opset_version=16,
29+
)
30+
31+
print(f"Finished exporting model to {onnx_file}")
32+
print(f"Test input data of ones: {test_input}")
33+
print(f"Test input data shape of ones: {test_input.shape}")
34+
output = model.forward(test_input)
35+
print(f"Test output data shape: {output.shape}")
36+
print(f"Test output: {output}")
37+
38+
39+
if __name__ == '__main__':
40+
main()

crates/burn-import/onnx-tests/tests/test_mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ pub mod argmax;
99
pub mod avg_pool;
1010
pub mod batch_norm;
1111
pub mod cast;
12+
pub mod ceil;
1213
pub mod clip;
1314
pub mod concat;
1415
pub mod constant;
@@ -69,6 +70,7 @@ pub mod reduce_sum;
6970
pub mod relu;
7071
pub mod reshape;
7172
pub mod resize;
73+
pub mod round;
7274
pub mod shape;
7375
pub mod sigmoid;
7476
pub mod sign;

crates/burn-import/src/burn/node/base.rs

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@ use std::marker::PhantomData;
22

33
use super::{
44
argmax::ArgMaxNode, avg_pool1d::AvgPool1dNode, avg_pool2d::AvgPool2dNode,
5-
batch_norm::BatchNormNode, binary::BinaryNode, clip::ClipNode, concat::ConcatNode,
6-
constant::ConstantNode, constant_of_shape::ConstantOfShapeNode,
5+
batch_norm::BatchNormNode, binary::BinaryNode, ceil::CeilNode, clip::ClipNode,
6+
concat::ConcatNode, constant::ConstantNode, constant_of_shape::ConstantOfShapeNode,
77
conv_transpose_1d::ConvTranspose1dNode, conv_transpose_2d::ConvTranspose2dNode,
88
conv_transpose_3d::ConvTranspose3dNode, conv1d::Conv1dNode, conv2d::Conv2dNode,
99
conv3d::Conv3dNode, dropout::DropoutNode, expand::ExpandNode, floor::FloorNode,
@@ -13,9 +13,9 @@ use super::{
1313
max_pool2d::MaxPool2dNode, mean::MeanNode, one_hot::OneHotNode, pad::PadNode, prelu::PReluNode,
1414
random_normal::RandomNormalNode, random_normal_like::RandomNormalLikeNode,
1515
random_uniform::RandomUniformNode, random_uniform_like::RandomUniformLikeNode,
16-
range::RangeNode, reshape::ReshapeNode, resize::ResizeNode, slice::SliceNode, split::SplitNode,
17-
squeeze::SqueezeNode, sum::SumNode, tile::TileNode, top_k::TopKNode, trilu::TriluNode,
18-
unary::UnaryNode, unsqueeze::UnsqueezeNode,
16+
range::RangeNode, reshape::ReshapeNode, resize::ResizeNode, round::RoundNode, slice::SliceNode,
17+
split::SplitNode, squeeze::SqueezeNode, sum::SumNode, tile::TileNode, top_k::TopKNode,
18+
trilu::TriluNode, unary::UnaryNode, unsqueeze::UnsqueezeNode,
1919
};
2020
use crate::burn::{BurnImports, Scope, Type};
2121
use burn::record::PrecisionSettings;
@@ -102,6 +102,7 @@ pub enum Node<PS: PrecisionSettings> {
102102
Dropout(DropoutNode),
103103
Expand(ExpandNode),
104104
Floor(FloorNode),
105+
Ceil(CeilNode),
105106
Gather(GatherNode),
106107
GatherElements(GatherElementsNode),
107108
Gemm(GemmNode),
@@ -118,6 +119,7 @@ pub enum Node<PS: PrecisionSettings> {
118119
Range(RangeNode),
119120
Reshape(ReshapeNode),
120121
Resize(ResizeNode),
122+
Round(RoundNode),
121123
Slice(SliceNode),
122124
Squeeze(SqueezeNode),
123125
Split(SplitNode),
@@ -160,6 +162,7 @@ macro_rules! match_all {
160162
Node::Dropout(node) => $func(node),
161163
Node::Expand(node) => $func(node),
162164
Node::Floor(node) => $func(node),
165+
Node::Ceil(node) => $func(node),
163166
Node::Gather(node) => $func(node),
164167
Node::GatherElements(node) => $func(node),
165168
Node::Gemm(node) => $func(node),
@@ -176,6 +179,7 @@ macro_rules! match_all {
176179
Node::Range(node) => $func(node),
177180
Node::Reshape(node) => $func(node),
178181
Node::Resize(node) => $func(node),
182+
Node::Round(node) => $func(node),
179183
Node::Slice(node) => $func(node),
180184
Node::Squeeze(node) => $func(node),
181185
Node::Sum(node) => $func(node),
@@ -226,6 +230,7 @@ impl<PS: PrecisionSettings> Node<PS> {
226230
Node::Dropout(_) => "dropout",
227231
Node::Expand(_) => "expand",
228232
Node::Floor(_) => "floor",
233+
Node::Ceil(_) => "ceil",
229234
Node::Gather(_) => "gather",
230235
Node::GatherElements(_) => "gather_elements",
231236
Node::Gemm(_) => "gemm",
@@ -242,6 +247,7 @@ impl<PS: PrecisionSettings> Node<PS> {
242247
Node::Range(_) => "range",
243248
Node::Reshape(_) => "reshape",
244249
Node::Resize(_) => "resize",
250+
Node::Round(_) => "round",
245251
Node::Slice(_) => "slice",
246252
Node::Squeeze(_) => "squeeze",
247253
Node::Sum(_) => "add",
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
use super::{Node, NodeCodegen};
2+
use crate::burn::{Scope, TensorType, Type};
3+
use burn::record::PrecisionSettings;
4+
use proc_macro2::TokenStream;
5+
use quote::quote;
6+
7+
#[derive(Debug, Clone, new)]
8+
pub struct CeilNode {
9+
pub input: TensorType,
10+
pub output: TensorType,
11+
}
12+
13+
impl<PS: PrecisionSettings> NodeCodegen<PS> for CeilNode {
14+
fn input_types(&self) -> Vec<Type> {
15+
vec![Type::Tensor(self.input.clone())]
16+
}
17+
18+
fn output_types(&self) -> Vec<Type> {
19+
vec![Type::Tensor(self.output.clone())]
20+
}
21+
22+
fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream {
23+
let input = scope.tensor_use_owned(&self.input, node_position);
24+
let output = &self.output.name;
25+
26+
quote! {
27+
let #output = #input.ceil();
28+
}
29+
}
30+
31+
fn into_node(self) -> Node<PS> {
32+
Node::Ceil(self)
33+
}
34+
}
35+
36+
#[cfg(test)]
37+
mod tests {
38+
use burn::record::FullPrecisionSettings;
39+
40+
use super::*;
41+
use crate::burn::{
42+
TensorType,
43+
graph::BurnGraph,
44+
node::{ceil::CeilNode, test::assert_tokens},
45+
};
46+
47+
#[test]
48+
fn test_codegen_nodes() {
49+
let mut graph = BurnGraph::<FullPrecisionSettings>::default();
50+
51+
graph.register(CeilNode::new(
52+
TensorType::new_float("tensor1", 1),
53+
TensorType::new_float("tensor2", 1),
54+
));
55+
56+
graph.register_input_output(vec!["tensor1".to_string()], vec!["tensor2".to_string()]);
57+
58+
let expected = quote! {
59+
use burn::{
60+
module::Module,
61+
tensor::{backend::Backend, Tensor},
62+
};
63+
64+
#[derive(Module, Debug)]
65+
pub struct Model<B: Backend> {
66+
phantom: core::marker::PhantomData<B>,
67+
device: burn::module::Ignored<B::Device>,
68+
}
69+
70+
impl<B: Backend> Model<B> {
71+
#[allow(unused_variables)]
72+
pub fn new(device: &B::Device) -> Self {
73+
Self {
74+
phantom: core::marker::PhantomData,
75+
device: burn::module::Ignored(device.clone()),
76+
}
77+
}
78+
#[allow(clippy::let_and_return, clippy::approx_constant)]
79+
pub fn forward(&self, tensor1: Tensor<B, 1>) -> Tensor<B, 1> {
80+
let tensor2 = tensor1.ceil();
81+
tensor2
82+
}
83+
}
84+
};
85+
86+
assert_tokens(graph.codegen(), expected);
87+
}
88+
}

0 commit comments

Comments
 (0)