Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
1ede17c
Move conv1d, avg_pool1d, max_pool1d and conv_transpose1d to onnx-ir
antimora Apr 30, 2025
519ddd3
Merge remote-tracking branch 'upstream/main' into refactor-opconfig
antimora Apr 30, 2025
1b1287b
Fix format
antimora Apr 30, 2025
e523498
Fix format
antimora Apr 30, 2025
e0083c0
Move conv1d rank update
antimora Apr 30, 2025
9f39454
Move conv_transpose1d_update_outputs
antimora Apr 30, 2025
4fa7636
Removed and replaced as same_as_input
antimora Apr 30, 2025
4e41de9
Move op config from burn-import
antimora Apr 30, 2025
8e81541
Moved 2d and 3d config functions
antimora Apr 30, 2025
81b38d9
Move some config functions to node
antimora Apr 30, 2025
b19ce42
Moved op configs to individual node modules
antimora Apr 30, 2025
d38a7e9
Break down reduce module into individual modules
antimora Apr 30, 2025
b21369b
Move rank inference functions
antimora Apr 30, 2025
03ea300
Move rank updates to node module
antimora Apr 30, 2025
c60cde3
Add documentation
antimora Apr 30, 2025
a66a43b
Repoint config function from onnx-ir
antimora Apr 30, 2025
4834dff
Remove burn types from onnx-ir
antimora May 1, 2025
e1d9682
Fixed left over tests
antimora May 1, 2025
29a0896
No default init for config structs
antimora May 1, 2025
995b7eb
Fix format
antimora May 1, 2025
fa67a32
remove op_configuration.rs
antimora May 1, 2025
8e8c5a3
Decouple burn-import types from op_configuration
antimora May 2, 2025
5fbfbec
Moved remaining configs from burn-import to onnx-ir
antimora May 2, 2025
1c5f098
Merge remote-tracking branch 'upstream/main' into refactor-opconfig
antimora May 2, 2025
9efe0a0
Update the documentation
antimora May 2, 2025
781d75c
Merge remote-tracking branch 'upstream/main' into refactor-opconfig
antimora May 8, 2025
d571dd6
Remove "Features" section
antimora May 8, 2025
67d0067
Converted links to resources
antimora May 8, 2025
2ee9317
Remove deadcode
antimora May 8, 2025
f27645e
Shorten burn-import readme
antimora May 8, 2025
5ef9401
Remove inline comments
antimora May 8, 2025
92890e7
Add NodeBuilder and refactor test code
antimora May 8, 2025
83112de
Refactor tests to use NodeBuilder
antimora May 8, 2025
6dfe2c0
Refactored tests to use NodeBuilder
antimora May 8, 2025
a4592a3
Refactored tests to use NodeBuilder
antimora May 8, 2025
52ed266
Refactored tests to use NodeBuilder
antimora May 8, 2025
eca4a8a
Refactored tests to use NodeBuilder
antimora May 8, 2025
689ba26
Refactored tests to use NodeBuilder
antimora May 8, 2025
a898404
Refactored tests to use NodeBuilder
antimora May 8, 2025
86d8c72
Remove inline comments
antimora May 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion contributor-book/src/SUMMARY.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
- [Tensor](./project-architecture/tensor.md)
- [Backend](./project-architecture/backend.md)
- [Guides for Contributors](./guides/README.md)
- [Onnx To Burn Conversion Tool: A Development Guide](./guides/onnx-to-burn-conversion-tool.md)
- [ONNX to Burn: Development Guide](./guides/onnx-to-burn-conversion-tool.md)
- [Adding a New Operation to Burn](./guides/adding-a-new-operation-to-burn.md)
- [Submitting Examples to Burn](./guides/submitting-examples.md)
- [Frequently Encountered Issues](./frequently-encountered-issues/README.md)
Expand Down
399 changes: 206 additions & 193 deletions contributor-book/src/guides/onnx-to-burn-conversion-tool.md

Large diffs are not rendered by default.

21 changes: 9 additions & 12 deletions crates/burn-core/src/nn/padding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,11 @@ use crate::config::Config;
/// Padding configuration for 1D operators.
#[derive(Config, Debug, PartialEq)]
pub enum PaddingConfig1d {
/// Dynamically calculate the amount of padding necessary to ensure that the output size will be
/// the same as the input.
/// Dynamically calculates padding to ensure output size matches input size.
Same,
/// Same as no padding.
/// No padding applied.
Valid,
/// Applies the specified amount of padding to all inputs.
/// Applies a specific amount of padding to all inputs.
Explicit(usize),
}

Expand All @@ -35,12 +34,11 @@ impl PaddingConfig1d {
/// Padding configuration for 2D operators.
#[derive(Config, Debug, PartialEq)]
pub enum PaddingConfig2d {
/// Dynamically calculate the amount of padding necessary to ensure that the output size will be
/// the same as the input.
/// Dynamically calculates padding to preserve input dimensions in output.
Same,
/// Same as no padding.
/// No padding applied.
Valid,
/// Applies the specified amount of padding to all inputs.
/// Applies specified padding values to height and width dimensions.
Explicit(usize, usize),
}

Expand Down Expand Up @@ -70,12 +68,11 @@ impl PaddingConfig2d {
/// Padding configuration for 3D operators.
#[derive(Config, Debug, PartialEq)]
pub enum PaddingConfig3d {
/// Dynamically calculate the amount of padding necessary to ensure that the output size will be
/// the same as the input.
/// Dynamically calculates padding to preserve input dimensions in output.
Same,
/// Same as no padding.
/// No padding applied.
Valid,
/// Applies the specified amount of padding to all inputs.
/// Applies specified padding values to depth, height, and width dimensions.
Explicit(usize, usize, usize),
}

Expand Down
80 changes: 0 additions & 80 deletions crates/burn-import/DEVELOPMENT.md

This file was deleted.

29 changes: 19 additions & 10 deletions crates/burn-import/README.md
Original file line number Diff line number Diff line change
@@ -1,15 +1,24 @@
# Importing Models
# Burn Import

The Burn project supports the import of models from various frameworks, emphasizing efficiency and
compatibility. Currently, it handles two primary model formats:
The `burn-import` crate enables seamless integration of pre-trained models from popular machine
learning frameworks into the Burn ecosystem. This functionality allows you to leverage existing
models while benefiting from Burn's performance optimizations and native Rust integration.

1. [ONNX](https://burn.dev/burn-book/import/onnx-model.html): Facilitates direct import, ensuring the
model's performance and structure are maintained.
## Supported Import Formats

2. [PyTorch](https://burn.dev/burn-book/import/pytorch-model.html): Enables the loading of PyTorch model
weights into Burn’s native model architecture, ensuring seamless integration.
Burn currently supports three primary model import formats, each serving different use cases:

## Contribution
| Format | Description | Use Case |
| ----------------------------------------------------------------------------------- | ----------------------------------------- | ------------------------------------------------------------------------------------------------------ |
| [**ONNX** (Guide)](https://burn.dev/burn-book/import/onnx-model.html) | Open Neural Network Exchange format | Direct import of complete model architectures and weights from any framework that supports ONNX export |
| [**PyTorch** (Guide)](https://burn.dev/burn-book/import/pytorch-model.html) | PyTorch weights (.pt, .pth) | Loading weights from PyTorch models into a matching Burn architecture |
| [**Safetensors** (Guide)](https://burn.dev/burn-book/import/safetensors-model.html) | Hugging Face's model serialization format | Loading a model's tensor weights into a matching Burn architecture |

Interested in contributing to `burn-import`? Check out our [development guide](DEVELOPMENT.md) for
more information.
## ONNX Contributor Resources

- [ONNX to Burn conversion guide](https://burn.dev/contributor-book/guides/onnx-to-burn-conversion-tool.html) -
Instructions for adding support for additional ONNX operators
- [ONNX tests README](https://github.com/tracel-ai/burn/blob/main/crates/burn-import/onnx-tests/README.md) -
Testing procedures for ONNX operators
- [Supported ONNX Operators table](https://github.com/tracel-ai/burn/blob/main/crates/burn-import/SUPPORTED-ONNX-OPS.md) -
Complete list of currently supported ONNX operators
8 changes: 1 addition & 7 deletions crates/burn-import/src/burn/codegen.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
use onnx_ir::node::padding::{PaddingConfig1d, PaddingConfig2d, PaddingConfig3d};
use proc_macro2::TokenStream;
use quote::quote;

use burn::nn::PaddingConfig1d;
use burn::nn::PaddingConfig2d;
use burn::nn::PaddingConfig3d;

fn convert_primitive<T: core::fmt::Debug>(primitive: T) -> TokenStream {
let value = format!("{:?}", primitive);

Expand Down Expand Up @@ -76,7 +73,6 @@ impl ToTokens for f32 {
impl ToTokens for PaddingConfig1d {
fn to_tokens(&self) -> TokenStream {
match self {
Self::Same => quote! { PaddingConfig1d::Same },
Self::Valid => quote! { PaddingConfig1d::Valid },
Self::Explicit(padding) => {
let padding = padding.to_tokens();
Expand All @@ -90,7 +86,6 @@ impl ToTokens for PaddingConfig1d {
impl ToTokens for PaddingConfig2d {
fn to_tokens(&self) -> TokenStream {
match self {
Self::Same => quote! { PaddingConfig2d::Same },
Self::Valid => quote! { PaddingConfig2d::Valid },
Self::Explicit(padding1, padding2) => {
let padding1 = padding1.to_tokens();
Expand All @@ -105,7 +100,6 @@ impl ToTokens for PaddingConfig2d {
impl ToTokens for PaddingConfig3d {
fn to_tokens(&self) -> TokenStream {
match self {
Self::Same => quote! { PaddingConfig3d::Same },
Self::Valid => quote! { PaddingConfig3d::Valid },
Self::Explicit(padding1, padding2, padding3) => {
let padding1 = padding1.to_tokens();
Expand Down
10 changes: 5 additions & 5 deletions crates/burn-import/src/burn/node/avg_pool1d.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use onnx_ir::node::avg_pool1d::AvgPool1dConfig;
use proc_macro2::TokenStream;
use quote::quote;

use burn::{nn::pool::AvgPool1dConfig, record::PrecisionSettings};
use burn::record::PrecisionSettings;

use super::{Node, NodeCodegen};
use crate::burn::{BurnImports, OtherType, Scope, TensorType, ToTokens, Type};
Expand Down Expand Up @@ -93,7 +94,8 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for AvgPool1dNode {
mod tests {
use super::*;
use crate::burn::{TensorType, graph::BurnGraph, node::test::assert_tokens};
use burn::{nn::PaddingConfig1d, record::FullPrecisionSettings};
use burn::record::FullPrecisionSettings;
use onnx_ir::node::padding::PaddingConfig1d;

#[test]
fn test_codegen() {
Expand All @@ -103,9 +105,7 @@ mod tests {
"avg_pool1d",
TensorType::new_float("input", 3),
TensorType::new_float("output", 3),
AvgPool1dConfig::new(3)
.with_stride(1)
.with_padding(PaddingConfig1d::Valid),
AvgPool1dConfig::new(3, 1, PaddingConfig1d::Valid, true),
));

graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]);
Expand Down
10 changes: 5 additions & 5 deletions crates/burn-import/src/burn/node/avg_pool2d.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use onnx_ir::node::avg_pool2d::AvgPool2dConfig;
use proc_macro2::TokenStream;
use quote::quote;

use burn::{nn::pool::AvgPool2dConfig, record::PrecisionSettings};
use burn::record::PrecisionSettings;

use super::{Node, NodeCodegen};
use crate::burn::{BurnImports, OtherType, Scope, TensorType, ToTokens, Type};
Expand Down Expand Up @@ -97,7 +98,8 @@ mod tests {
graph::BurnGraph,
node::{avg_pool2d::AvgPool2dNode, test::assert_tokens},
};
use burn::{nn::PaddingConfig2d, nn::pool::AvgPool2dConfig, record::FullPrecisionSettings};
use burn::record::FullPrecisionSettings;
use onnx_ir::node::padding::PaddingConfig2d;

#[test]
fn test_codegen() {
Expand All @@ -107,9 +109,7 @@ mod tests {
"avg_pool2d",
TensorType::new_float("input", 4),
TensorType::new_float("output", 4),
AvgPool2dConfig::new([3, 3])
.with_strides([1, 1])
.with_padding(PaddingConfig2d::Valid),
AvgPool2dConfig::new([3, 3], [1, 1], PaddingConfig2d::Valid, true),
));

graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]);
Expand Down
26 changes: 20 additions & 6 deletions crates/burn-import/src/burn/node/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -308,10 +308,8 @@ pub(crate) mod tests {
graph::BurnGraph,
node::{NodeCodegen, conv2d::Conv2dNode, matmul::MatmulNode, test::assert_tokens},
};
use burn::{
nn::PaddingConfig2d, nn::conv::Conv2dConfig, record::FullPrecisionSettings,
tensor::TensorData,
};
use burn::{record::FullPrecisionSettings, tensor::TensorData};
use onnx_ir::node::{conv2d::Conv2dConfig, padding::PaddingConfig2d};
use proc_macro2::TokenStream;
use quote::quote;

Expand Down Expand Up @@ -373,7 +371,15 @@ pub(crate) mod tests {
TensorType::new_float("tensor4", 4),
TensorData::from([2f32]),
None,
Conv2dConfig::new([3, 3], [3, 3]).with_padding(PaddingConfig2d::Valid),
Conv2dConfig::new(
[3, 3],
[3, 3],
[1, 1],
PaddingConfig2d::Valid,
[1, 1],
1,
true,
),
));

graph.register_input_output(
Expand Down Expand Up @@ -446,7 +452,15 @@ pub(crate) mod tests {
TensorType::new_float("tensor4", 4),
TensorData::from([2f32]),
None,
Conv2dConfig::new([3, 3], [3, 3]).with_padding(PaddingConfig2d::Valid),
Conv2dConfig::new(
[3, 3],
[3, 3],
[1, 1],
PaddingConfig2d::Valid,
[1, 1],
1,
true,
),
));
graph.register(MatmulNode::new(
TensorType::new_float("tensor3", 4),
Expand Down
5 changes: 3 additions & 2 deletions crates/burn-import/src/burn/node/batch_norm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@ use super::{Node, NodeCodegen, SerializationBackend};
use crate::burn::{BurnImports, OtherType, Scope, TensorType, ToTokens, Type};
use burn::{
module::{ConstantRecord, Param, ParamId},
nn::{BatchNormConfig, BatchNormRecord},
nn::BatchNormRecord,
record::{PrecisionSettings, Record},
tensor::{Tensor, TensorData},
};
use onnx_ir::node::batch_norm::BatchNormConfig;
use proc_macro2::TokenStream;
use quote::quote;
use serde::Serialize;
Expand Down Expand Up @@ -171,7 +172,7 @@ mod tests {
TensorData::from([2f32]),
TensorData::from([2f32]),
TensorData::from([2f32]),
BatchNormConfig::new(128),
BatchNormConfig::new(128, 0.00001, 0.1),
));

graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]);
Expand Down
11 changes: 5 additions & 6 deletions crates/burn-import/src/burn/node/conv1d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@ use super::{Node, NodeCodegen, SerializationBackend};
use crate::burn::{BurnImports, OtherType, Scope, TensorType, ToTokens, Type};
use burn::{
module::{ConstantRecord, Param, ParamId},
nn::conv::{Conv1dConfig, Conv1dRecord},
nn::conv::Conv1dRecord,
record::{PrecisionSettings, Record},
tensor::{Tensor, TensorData},
};
use onnx_ir::node::conv1d::Conv1dConfig;
use proc_macro2::TokenStream;
use quote::quote;
use serde::Serialize;
Expand Down Expand Up @@ -135,10 +136,8 @@ mod tests {
graph::BurnGraph,
node::{conv1d::Conv1dNode, test::assert_tokens},
};
use burn::{
nn::{PaddingConfig1d, conv::Conv1dConfig},
record::FullPrecisionSettings,
};
use burn::record::FullPrecisionSettings;
use onnx_ir::node::padding::PaddingConfig1d;

#[test]
fn test_codegen() {
Expand All @@ -150,7 +149,7 @@ mod tests {
TensorType::new_float("output", 4),
TensorData::from([2f32]),
None,
Conv1dConfig::new(3, 3, 3).with_padding(PaddingConfig1d::Valid),
Conv1dConfig::new(3, 3, 3, 1, PaddingConfig1d::Valid, 1, 1, true),
));

graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]);
Expand Down
Loading