Skip to content

Commit f18cb7f

Browse files
authored
Refactor: Move op_configuration.rs from burn-import to onnx-ir (#3126)
* Move conv1d, avg_pool1d, max_pool1d and conv_transpose1d to onnx-ir * Fix format * Fix format * Move conv1d rank update * Move conv_transpose1d_update_outputs * Removed and replaced as same_as_input * Move op config from burn-import * Moved 2d and 3d config functions * Move some config functions to node * Moved op configs to individual node modules Still remaining: expand_config tile_config top_k_config trilu_config pad_config unsqueeze_config split_config * Break down reduce module into individual modules * Move rank inference functions * Move rank updates to node module * Add documentation * Repoint config function from onnx-ir * Remove burn types from onnx-ir * Fixed left over tests * No default init for config structs * Fix format * remove op_configuration.rs * Decouple burn-import types from op_configuration * Moved remaining configs from burn-import to onnx-ir * Update the documentation * Remove "Features" section * Converted links to resources * Remove deadcode * Shorten burn-import readme * Remove inline comments * Add NodeBuilder and refactor test code * Refactor tests to use NodeBuilder * Refactored tests to use NodeBuilder * Refactored tests to use NodeBuilder * Refactored tests to use NodeBuilder * Refactored tests to use NodeBuilder * Refactored tests to use NodeBuilder * Refactored tests to use NodeBuilder * Remove inline comments
1 parent 987bcc9 commit f18cb7f

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

89 files changed

+9311
-3743
lines changed

contributor-book/src/SUMMARY.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
- [Tensor](./project-architecture/tensor.md)
1111
- [Backend](./project-architecture/backend.md)
1212
- [Guides for Contributors](./guides/README.md)
13-
- [Onnx To Burn Conversion Tool: A Development Guide](./guides/onnx-to-burn-conversion-tool.md)
13+
- [ONNX to Burn: Development Guide](./guides/onnx-to-burn-conversion-tool.md)
1414
- [Adding a New Operation to Burn](./guides/adding-a-new-operation-to-burn.md)
1515
- [Submitting Examples to Burn](./guides/submitting-examples.md)
1616
- [Frequently Encountered Issues](./frequently-encountered-issues/README.md)

contributor-book/src/guides/onnx-to-burn-conversion-tool.md

Lines changed: 206 additions & 193 deletions
Large diffs are not rendered by default.

crates/burn-core/src/nn/padding.rs

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,11 @@ use crate::config::Config;
77
/// Padding configuration for 1D operators.
88
#[derive(Config, Debug, PartialEq)]
99
pub enum PaddingConfig1d {
10-
/// Dynamically calculate the amount of padding necessary to ensure that the output size will be
11-
/// the same as the input.
10+
/// Dynamically calculates padding to ensure output size matches input size.
1211
Same,
13-
/// Same as no padding.
12+
/// No padding applied.
1413
Valid,
15-
/// Applies the specified amount of padding to all inputs.
14+
/// Applies a specific amount of padding to all inputs.
1615
Explicit(usize),
1716
}
1817

@@ -35,12 +34,11 @@ impl PaddingConfig1d {
3534
/// Padding configuration for 2D operators.
3635
#[derive(Config, Debug, PartialEq)]
3736
pub enum PaddingConfig2d {
38-
/// Dynamically calculate the amount of padding necessary to ensure that the output size will be
39-
/// the same as the input.
37+
/// Dynamically calculates padding to preserve input dimensions in output.
4038
Same,
41-
/// Same as no padding.
39+
/// No padding applied.
4240
Valid,
43-
/// Applies the specified amount of padding to all inputs.
41+
/// Applies specified padding values to height and width dimensions.
4442
Explicit(usize, usize),
4543
}
4644

@@ -70,12 +68,11 @@ impl PaddingConfig2d {
7068
/// Padding configuration for 3D operators.
7169
#[derive(Config, Debug, PartialEq)]
7270
pub enum PaddingConfig3d {
73-
/// Dynamically calculate the amount of padding necessary to ensure that the output size will be
74-
/// the same as the input.
71+
/// Dynamically calculates padding to preserve input dimensions in output.
7572
Same,
76-
/// Same as no padding.
73+
/// No padding applied.
7774
Valid,
78-
/// Applies the specified amount of padding to all inputs.
75+
/// Applies specified padding values to depth, height, and width dimensions.
7976
Explicit(usize, usize, usize),
8077
}
8178

crates/burn-import/DEVELOPMENT.md

Lines changed: 0 additions & 80 deletions
This file was deleted.

crates/burn-import/README.md

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,24 @@
1-
# Importing Models
1+
# Burn Import
22

3-
The Burn project supports the import of models from various frameworks, emphasizing efficiency and
4-
compatibility. Currently, it handles two primary model formats:
3+
The `burn-import` crate enables seamless integration of pre-trained models from popular machine
4+
learning frameworks into the Burn ecosystem. This functionality allows you to leverage existing
5+
models while benefiting from Burn's performance optimizations and native Rust integration.
56

6-
1. [ONNX](https://burn.dev/burn-book/import/onnx-model.html): Facilitates direct import, ensuring the
7-
model's performance and structure are maintained.
7+
## Supported Import Formats
88

9-
2. [PyTorch](https://burn.dev/burn-book/import/pytorch-model.html): Enables the loading of PyTorch model
10-
weights into Burn’s native model architecture, ensuring seamless integration.
9+
Burn currently supports three primary model import formats, each serving different use cases:
1110

12-
## Contribution
11+
| Format | Description | Use Case |
12+
| ----------------------------------------------------------------------------------- | ----------------------------------------- | ------------------------------------------------------------------------------------------------------ |
13+
| [**ONNX** (Guide)](https://burn.dev/burn-book/import/onnx-model.html) | Open Neural Network Exchange format | Direct import of complete model architectures and weights from any framework that supports ONNX export |
14+
| [**PyTorch** (Guide)](https://burn.dev/burn-book/import/pytorch-model.html) | PyTorch weights (.pt, .pth) | Loading weights from PyTorch models into a matching Burn architecture |
15+
| [**Safetensors** (Guide)](https://burn.dev/burn-book/import/safetensors-model.html) | Hugging Face's model serialization format | Loading a model's tensor weights into a matching Burn architecture |
1316

14-
Interested in contributing to `burn-import`? Check out our [development guide](DEVELOPMENT.md) for
15-
more information.
17+
## ONNX Contributor Resources
18+
19+
- [ONNX to Burn conversion guide](https://burn.dev/contributor-book/guides/onnx-to-burn-conversion-tool.html) -
20+
Instructions for adding support for additional ONNX operators
21+
- [ONNX tests README](https://github.com/tracel-ai/burn/blob/main/crates/burn-import/onnx-tests/README.md) -
22+
Testing procedures for ONNX operators
23+
- [Supported ONNX Operators table](https://github.com/tracel-ai/burn/blob/main/crates/burn-import/SUPPORTED-ONNX-OPS.md) -
24+
Complete list of currently supported ONNX operators

crates/burn-import/src/burn/codegen.rs

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,7 @@
1+
use onnx_ir::node::padding::{PaddingConfig1d, PaddingConfig2d, PaddingConfig3d};
12
use proc_macro2::TokenStream;
23
use quote::quote;
34

4-
use burn::nn::PaddingConfig1d;
5-
use burn::nn::PaddingConfig2d;
6-
use burn::nn::PaddingConfig3d;
7-
85
fn convert_primitive<T: core::fmt::Debug>(primitive: T) -> TokenStream {
96
let value = format!("{:?}", primitive);
107

@@ -76,7 +73,6 @@ impl ToTokens for f32 {
7673
impl ToTokens for PaddingConfig1d {
7774
fn to_tokens(&self) -> TokenStream {
7875
match self {
79-
Self::Same => quote! { PaddingConfig1d::Same },
8076
Self::Valid => quote! { PaddingConfig1d::Valid },
8177
Self::Explicit(padding) => {
8278
let padding = padding.to_tokens();
@@ -90,7 +86,6 @@ impl ToTokens for PaddingConfig1d {
9086
impl ToTokens for PaddingConfig2d {
9187
fn to_tokens(&self) -> TokenStream {
9288
match self {
93-
Self::Same => quote! { PaddingConfig2d::Same },
9489
Self::Valid => quote! { PaddingConfig2d::Valid },
9590
Self::Explicit(padding1, padding2) => {
9691
let padding1 = padding1.to_tokens();
@@ -105,7 +100,6 @@ impl ToTokens for PaddingConfig2d {
105100
impl ToTokens for PaddingConfig3d {
106101
fn to_tokens(&self) -> TokenStream {
107102
match self {
108-
Self::Same => quote! { PaddingConfig3d::Same },
109103
Self::Valid => quote! { PaddingConfig3d::Valid },
110104
Self::Explicit(padding1, padding2, padding3) => {
111105
let padding1 = padding1.to_tokens();

crates/burn-import/src/burn/node/avg_pool1d.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1+
use onnx_ir::node::avg_pool1d::AvgPool1dConfig;
12
use proc_macro2::TokenStream;
23
use quote::quote;
34

4-
use burn::{nn::pool::AvgPool1dConfig, record::PrecisionSettings};
5+
use burn::record::PrecisionSettings;
56

67
use super::{Node, NodeCodegen};
78
use crate::burn::{BurnImports, OtherType, Scope, TensorType, ToTokens, Type};
@@ -93,7 +94,8 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for AvgPool1dNode {
9394
mod tests {
9495
use super::*;
9596
use crate::burn::{TensorType, graph::BurnGraph, node::test::assert_tokens};
96-
use burn::{nn::PaddingConfig1d, record::FullPrecisionSettings};
97+
use burn::record::FullPrecisionSettings;
98+
use onnx_ir::node::padding::PaddingConfig1d;
9799

98100
#[test]
99101
fn test_codegen() {
@@ -103,9 +105,7 @@ mod tests {
103105
"avg_pool1d",
104106
TensorType::new_float("input", 3),
105107
TensorType::new_float("output", 3),
106-
AvgPool1dConfig::new(3)
107-
.with_stride(1)
108-
.with_padding(PaddingConfig1d::Valid),
108+
AvgPool1dConfig::new(3, 1, PaddingConfig1d::Valid, true),
109109
));
110110

111111
graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]);

crates/burn-import/src/burn/node/avg_pool2d.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1+
use onnx_ir::node::avg_pool2d::AvgPool2dConfig;
12
use proc_macro2::TokenStream;
23
use quote::quote;
34

4-
use burn::{nn::pool::AvgPool2dConfig, record::PrecisionSettings};
5+
use burn::record::PrecisionSettings;
56

67
use super::{Node, NodeCodegen};
78
use crate::burn::{BurnImports, OtherType, Scope, TensorType, ToTokens, Type};
@@ -97,7 +98,8 @@ mod tests {
9798
graph::BurnGraph,
9899
node::{avg_pool2d::AvgPool2dNode, test::assert_tokens},
99100
};
100-
use burn::{nn::PaddingConfig2d, nn::pool::AvgPool2dConfig, record::FullPrecisionSettings};
101+
use burn::record::FullPrecisionSettings;
102+
use onnx_ir::node::padding::PaddingConfig2d;
101103

102104
#[test]
103105
fn test_codegen() {
@@ -107,9 +109,7 @@ mod tests {
107109
"avg_pool2d",
108110
TensorType::new_float("input", 4),
109111
TensorType::new_float("output", 4),
110-
AvgPool2dConfig::new([3, 3])
111-
.with_strides([1, 1])
112-
.with_padding(PaddingConfig2d::Valid),
112+
AvgPool2dConfig::new([3, 3], [1, 1], PaddingConfig2d::Valid, true),
113113
));
114114

115115
graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]);

crates/burn-import/src/burn/node/base.rs

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -308,10 +308,8 @@ pub(crate) mod tests {
308308
graph::BurnGraph,
309309
node::{NodeCodegen, conv2d::Conv2dNode, matmul::MatmulNode, test::assert_tokens},
310310
};
311-
use burn::{
312-
nn::PaddingConfig2d, nn::conv::Conv2dConfig, record::FullPrecisionSettings,
313-
tensor::TensorData,
314-
};
311+
use burn::{record::FullPrecisionSettings, tensor::TensorData};
312+
use onnx_ir::node::{conv2d::Conv2dConfig, padding::PaddingConfig2d};
315313
use proc_macro2::TokenStream;
316314
use quote::quote;
317315

@@ -373,7 +371,15 @@ pub(crate) mod tests {
373371
TensorType::new_float("tensor4", 4),
374372
TensorData::from([2f32]),
375373
None,
376-
Conv2dConfig::new([3, 3], [3, 3]).with_padding(PaddingConfig2d::Valid),
374+
Conv2dConfig::new(
375+
[3, 3],
376+
[3, 3],
377+
[1, 1],
378+
PaddingConfig2d::Valid,
379+
[1, 1],
380+
1,
381+
true,
382+
),
377383
));
378384

379385
graph.register_input_output(
@@ -446,7 +452,15 @@ pub(crate) mod tests {
446452
TensorType::new_float("tensor4", 4),
447453
TensorData::from([2f32]),
448454
None,
449-
Conv2dConfig::new([3, 3], [3, 3]).with_padding(PaddingConfig2d::Valid),
455+
Conv2dConfig::new(
456+
[3, 3],
457+
[3, 3],
458+
[1, 1],
459+
PaddingConfig2d::Valid,
460+
[1, 1],
461+
1,
462+
true,
463+
),
450464
));
451465
graph.register(MatmulNode::new(
452466
TensorType::new_float("tensor3", 4),

crates/burn-import/src/burn/node/batch_norm.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@ use super::{Node, NodeCodegen, SerializationBackend};
22
use crate::burn::{BurnImports, OtherType, Scope, TensorType, ToTokens, Type};
33
use burn::{
44
module::{ConstantRecord, Param, ParamId},
5-
nn::{BatchNormConfig, BatchNormRecord},
5+
nn::BatchNormRecord,
66
record::{PrecisionSettings, Record},
77
tensor::{Tensor, TensorData},
88
};
9+
use onnx_ir::node::batch_norm::BatchNormConfig;
910
use proc_macro2::TokenStream;
1011
use quote::quote;
1112
use serde::Serialize;
@@ -171,7 +172,7 @@ mod tests {
171172
TensorData::from([2f32]),
172173
TensorData::from([2f32]),
173174
TensorData::from([2f32]),
174-
BatchNormConfig::new(128),
175+
BatchNormConfig::new(128, 0.00001, 0.1),
175176
));
176177

177178
graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]);

0 commit comments

Comments
 (0)