From 1ede17c8cfddfc7ce8d514942aecb92954c4d497 Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Tue, 29 Apr 2025 20:12:17 -0500 Subject: [PATCH 01/37] Move conv1d, avg_pool1d, max_pool1d and conv_transpose1d to onnx-ir --- .../burn-import/src/onnx/op_configuration.rs | 222 +------------- crates/burn-import/src/onnx/to_burn.rs | 86 +++++- crates/onnx-ir/src/node/avg_pool1d.rs | 155 ++++++++++ crates/onnx-ir/src/node/conv1d.rs | 288 ++++++++++++++++++ crates/onnx-ir/src/node/conv_transpose1d.rs | 239 +++++++++++++++ crates/onnx-ir/src/node/max_pool1d.rs | 145 +++++++++ crates/onnx-ir/src/node/mod.rs | 4 + 7 files changed, 917 insertions(+), 222 deletions(-) create mode 100644 crates/onnx-ir/src/node/avg_pool1d.rs create mode 100644 crates/onnx-ir/src/node/conv1d.rs create mode 100644 crates/onnx-ir/src/node/conv_transpose1d.rs create mode 100644 crates/onnx-ir/src/node/max_pool1d.rs diff --git a/crates/burn-import/src/onnx/op_configuration.rs b/crates/burn-import/src/onnx/op_configuration.rs index e16c5f9772..5a99e3aacc 100644 --- a/crates/burn-import/src/onnx/op_configuration.rs +++ b/crates/burn-import/src/onnx/op_configuration.rs @@ -2,13 +2,13 @@ // See https://github.com/tracel-ai/burn/issues/3091 use burn::nn::{ - BatchNormConfig, DropoutConfig, LayerNormConfig, LinearConfig, PaddingConfig1d, - PaddingConfig2d, PaddingConfig3d, + BatchNormConfig, DropoutConfig, LayerNormConfig, LinearConfig, PaddingConfig2d, + PaddingConfig3d, conv::{ - Conv1dConfig, Conv2dConfig, Conv3dConfig, ConvTranspose1dConfig, ConvTranspose2dConfig, + Conv2dConfig, Conv3dConfig, ConvTranspose2dConfig, ConvTranspose3dConfig, }, - pool::{AvgPool1dConfig, AvgPool2dConfig, MaxPool1dConfig, MaxPool2dConfig}, + pool::{AvgPool2dConfig, MaxPool2dConfig}, }; use crate::burn::node::{ @@ -17,48 +17,7 @@ use crate::burn::node::{ }; use onnx_ir::ir::{ArgType, AttributeValue, Data, ElementType, Node, TensorData}; -/// Create a Conv1dConfig from the attributes of the node -pub fn conv1d_config(curr: &Node) -> Conv1dConfig { - let mut kernel_shape = Vec::new(); // TODO default inferred from weight tensor per spec - let mut strides = vec![1]; - let mut pads = vec![0, 0]; - let mut dilations = vec![1]; - let mut group: usize = 1; - - let weight_shape = curr.inputs[1] - .value - .as_ref() - .expect("Conv1d: weight tensor must be present") - .shape - .clone(); - - // check if the bias is present - let bias = curr.inputs.len() == 3; - - for (key, value) in curr.attrs.iter() { - match key.as_str() { - "kernel_shape" => kernel_shape = value.clone().into_i64s(), - "strides" => strides = value.clone().into_i64s(), - "pads" => pads = value.clone().into_i64s(), - "dilations" => dilations = value.clone().into_i64s(), - "group" => group = value.clone().into_i64() as usize, - _ => {} - } - } - - // the channels are inverted in the weight tensor - let channels_in = weight_shape[1] * group; - let channels_out = weight_shape[0]; - - let padding = padding_config_1d(&pads); - - Conv1dConfig::new(channels_in, channels_out, kernel_shape[0] as usize) - .with_stride(strides[0] as usize) - .with_dilation(dilations[0] as usize) - .with_groups(group) - .with_bias(bias) - .with_padding(padding) -} +// Conv1dConfig implementation moved to onnx-ir::node::conv1d /// Create a Conv2dConfig from the attributes of the node pub fn conv2d_config(curr: &Node) -> Conv2dConfig { @@ -164,32 +123,7 @@ pub fn conv3d_config(curr: &Node) -> Conv3dConfig { .with_padding(padding) } -/// Create a MaxPool2dConfig from the attributes of the node -pub fn max_pool1d_config(curr: &Node) -> MaxPool1dConfig { - let mut kernel_shape = Vec::new(); - let mut stride = vec![1]; - let mut pads = vec![0, 0]; - let mut dilation = vec![1]; - - for (key, value) in curr.attrs.iter() { - match key.as_str() { - "kernel_shape" => kernel_shape = value.clone().into_i64s(), - "strides" => stride = value.clone().into_i64s(), - "pads" => pads = value.clone().into_i64s(), - "dilations" => dilation = value.clone().into_i64s(), - _ => {} - } - } - assert_eq!(kernel_shape.len(), 1); - assert_eq!(dilation.len(), 1); - assert_eq!(stride.len(), 1); - let padding = padding_config_1d(&pads); - - MaxPool1dConfig::new(kernel_shape[0] as usize) - .with_stride(stride[0] as usize) - .with_padding(padding) - .with_dilation(dilation[0] as usize) -} +// MaxPool1dConfig implementation moved to onnx-ir::node::max_pool1d /// Create a MaxPool2dConfig from the attributes of the node pub fn max_pool2d_config(curr: &Node) -> MaxPool2dConfig { @@ -216,79 +150,7 @@ pub fn max_pool2d_config(curr: &Node) -> MaxPool2dConfig { .with_dilation([dilations[0] as usize, dilations[1] as usize]) } -pub fn conv_transpose1d_config(curr: &Node) -> ConvTranspose1dConfig { - let mut attrs = curr.attrs.clone(); - - // Extract kernel_shape, default to an empty vector if not present - let kernel_shape = attrs - .remove("kernel_shape") - .map(AttributeValue::into_i64s) - .unwrap_or_default(); - - // Extract strides, default to 1 if not present - let stride = attrs - .remove("strides") - .map(AttributeValue::into_i64s) - .unwrap_or_else(|| vec![1]); - - // Extract padding, default to 0 if not present - let pads = attrs - .remove("pads") - .map(AttributeValue::into_i64s) - .unwrap_or_else(|| vec![0, 0]); - - // Extract dilations, default to 1 if not present - let dilations = attrs - .remove("dilations") - .map(AttributeValue::into_i64s) - .unwrap_or_else(|| vec![1]); - - // Extract group attribute, default to 1 - let group = attrs - .remove("group") - .map(AttributeValue::into_i64) - .unwrap_or(1) as usize; - - // Extract output_padding, default to 0 if not present - let output_padding = attrs - .remove("output_padding") - .map(AttributeValue::into_i64s) - .unwrap_or_else(|| vec![0]); - - // Ensure no unused attributes remain - if !attrs.is_empty() { - panic!("Not all attributes are used: {attrs:?}"); - } - // Check the pads are symmetric. - if pads.len() != 2 || pads[0] != pads[1] { - panic!( - "Asymmetric padding is not supported for ConvTranspose1d: {:?}", - pads - ); - } - - let weight_shape = curr.inputs[1] - .value - .as_ref() - .expect("ConvTranspose1d: weight tensor must be present") - .shape - .clone(); - - // Check if bias is present (third input) - let bias = curr.inputs.len() == 3; - - // Extract channels from the weight tensor shape [out_channels, in_channels] - let channels: [usize; 2] = [weight_shape[1] * group, weight_shape[0]]; - - // Create the ConvTranspose1d configuration - ConvTranspose1dConfig::new(channels, kernel_shape[0] as usize) - .with_stride(stride[0] as usize) - .with_padding(pads[0] as usize) - .with_dilation(dilations[0] as usize) - .with_padding_out(output_padding[0] as usize) - .with_groups(group) - .with_bias(bias) -} +// ConvTranspose1dConfig implementation moved to onnx-ir::node::conv_transpose1d pub fn conv_transpose2d_config(curr: &Node) -> ConvTranspose2dConfig { let mut attrs = curr.attrs.clone(); @@ -432,37 +294,7 @@ pub fn conv_transpose3d_config(curr: &Node) -> ConvTranspose3dConfig { .with_bias(bias) } -pub fn avg_pool1d_config(curr: &Node) -> AvgPool1dConfig { - let mut kernel_shape = Vec::new(); - let mut strides = vec![1]; - let mut pads = vec![0, 0]; - let mut count_include_pad: i64 = 0; - let mut ceil_mode: i64 = 0; - - for (key, value) in curr.attrs.iter() { - match key.as_str() { - "kernel_shape" => kernel_shape = value.clone().into_i64s(), - "strides" => strides = value.clone().into_i64s(), - "pads" => pads = value.clone().into_i64s(), - "count_include_pad" => count_include_pad = value.clone().into_i64(), - "ceil_mode" => ceil_mode = value.clone().into_i64(), - _ => {} - } - } - assert_eq!(kernel_shape.len(), 1); - assert_eq!(strides.len(), 1); - - if ceil_mode == 1 { - panic!("ceil_mode is not supported"); - } - - let padding = padding_config_1d(&pads); - - AvgPool1dConfig::new(kernel_shape[0] as usize) - .with_stride(strides[0] as usize) - .with_padding(padding) - .with_count_include_pad(count_include_pad == 1) -} +// AvgPool1dConfig implementation moved to onnx-ir::node::avg_pool1d /// Create a AvgPool2dConfig from the attributes of the node pub fn avg_pool2d_config(curr: &Node) -> AvgPool2dConfig { let mut kernel_shape = Vec::new(); @@ -1083,43 +915,7 @@ pub fn pad_config(node: &Node) -> PadConfig { PadConfig::new(pads, constant_value) } -/// Calculate the padding configuration for a 1D operations such as Convolution and Pooling. -/// -/// # Arguments -/// -/// * `pads` - The padding values -/// -/// # Panics -/// -/// * If the padding is negative -/// * If the padding is not symmetric -/// -/// # Returns -/// -/// * The padding configuration -/// -/// # Remarks -/// -/// This function is used when the padding is specified as a list of integers, -/// and not used when the padding is specified as a string, e.g. "SAME_UPPER". -fn padding_config_1d(pads: &[i64]) -> PaddingConfig1d { - let [left, right] = [pads[0], pads[1]]; - - if left < 0 || right < 0 { - panic!("Negative pad values are not supported"); - } else if left != right { - panic!("Asymmetric padding is not supported"); - } else if left == 0 && right == 0 { - // i.e. [0, 0] - PaddingConfig1d::Valid - } else if left == right { - // i.e. [2, 2] - PaddingConfig1d::Explicit(left as usize) - } else { - // Unaccounted for padding configuration - panic!("Padding configuration ({:?}) not supported", pads); - } -} +// padding_config_1d moved to onnx-ir::node::conv1d /// Calculate the padding configuration for a 2D operations such as Convolution and Pooling. /// diff --git a/crates/burn-import/src/onnx/to_burn.rs b/crates/burn-import/src/onnx/to_burn.rs index 503cf9f733..e3aca7e852 100644 --- a/crates/burn-import/src/onnx/to_burn.rs +++ b/crates/burn-import/src/onnx/to_burn.rs @@ -70,11 +70,11 @@ use crate::{ }; use super::op_configuration::{ - argmax_config, avg_pool1d_config, avg_pool2d_config, batch_norm_config, clip_config, - concat_config, conv_transpose1d_config, conv_transpose2d_config, conv_transpose3d_config, - conv1d_config, conv2d_config, conv3d_config, dropout_config, expand_config, flatten_config, + argmax_config, avg_pool2d_config, batch_norm_config, clip_config, + concat_config, conv_transpose2d_config, conv_transpose3d_config, + conv2d_config, conv3d_config, dropout_config, expand_config, flatten_config, gather_config, gemm_config, hard_sigmoid_config, layer_norm_config, leaky_relu_config, - linear_config, log_softmax_config, max_pool1d_config, max_pool2d_config, one_hot_config, + linear_config, log_softmax_config, max_pool2d_config, one_hot_config, pad_config, reduce_max_config, reduce_mean_config, reduce_min_config, reduce_prod_config, reduce_sum_config, reshape_config, resize_config, shape_config, softmax_config, split_config, squeeze_config, tile_config, top_k_config, transpose_config, trilu_config, unsqueeze_config, @@ -85,7 +85,13 @@ use onnx_ir::{ ArgType, Argument as OnnxArgument, Data, ElementType, Node, NodeType, OnnxGraph, TensorType as OnnxTensorType, }, - node::slice::slice_config, + node::{ + slice::slice_config, + conv1d::conv1d_config, + avg_pool1d::avg_pool1d_config, + max_pool1d::max_pool1d_config, + conv_transpose1d::conv_transpose1d_config, + }, parse_onnx, }; @@ -1022,7 +1028,27 @@ impl ParsedOnnxGraph { fn conv1d_conversion(node: Node) -> Conv1dNode { let input = TensorType::from(node.inputs.first().unwrap()); let output = TensorType::from(node.outputs.first().unwrap()); - let config = conv1d_config(&node); + + // Get configuration from onnx-ir + let onnx_config = conv1d_config(&node); + + // Convert onnx-ir padding to burn padding + let burn_padding = match onnx_config.padding { + onnx_ir::node::conv1d::PaddingConfig1d::Valid => burn::nn::PaddingConfig1d::Valid, + onnx_ir::node::conv1d::PaddingConfig1d::Explicit(size) => burn::nn::PaddingConfig1d::Explicit(size), + }; + + // Convert to burn Conv1dConfig + let config = burn::nn::conv::Conv1dConfig::new( + onnx_config.channels_in, + onnx_config.channels_out, + onnx_config.kernel_size + ) + .with_stride(onnx_config.stride) + .with_dilation(onnx_config.dilation) + .with_groups(onnx_config.groups) + .with_bias(onnx_config.bias) + .with_padding(burn_padding); let bias = node.inputs.len() == 3; let weight = extract_data_serialize::(1, &node).unwrap(); @@ -1070,7 +1096,21 @@ impl ParsedOnnxGraph { fn max_pool1d_conversion(node: Node) -> MaxPool1dNode { let input = TensorType::from(node.inputs.first().unwrap()); let output = TensorType::from(node.outputs.first().unwrap()); - let config = max_pool1d_config(&node); + + // Get configuration from onnx-ir + let onnx_config = max_pool1d_config(&node); + + // Convert onnx-ir padding to burn padding + let burn_padding = match onnx_config.padding { + onnx_ir::node::conv1d::PaddingConfig1d::Valid => burn::nn::PaddingConfig1d::Valid, + onnx_ir::node::conv1d::PaddingConfig1d::Explicit(size) => burn::nn::PaddingConfig1d::Explicit(size), + }; + + // Convert to burn MaxPool1dConfig + let config = burn::nn::pool::MaxPool1dConfig::new(onnx_config.kernel_size) + .with_stride(onnx_config.stride) + .with_padding(burn_padding) + .with_dilation(onnx_config.dilation); let name = &node.name; MaxPool1dNode::new(name, input, output, config) @@ -1114,7 +1154,21 @@ impl ParsedOnnxGraph { fn conv_transpose1d_conversion(node: Node) -> ConvTranspose1dNode { let input = TensorType::from(node.inputs.first().unwrap()); let output = TensorType::from(node.outputs.first().unwrap()); - let config = conv_transpose1d_config(&node); + + // Get configuration from onnx-ir + let onnx_config = conv_transpose1d_config(&node); + + // Convert to burn ConvTranspose1dConfig + let config = burn::nn::conv::ConvTranspose1dConfig::new( + [onnx_config.channels_in, onnx_config.channels_out], + onnx_config.kernel_size + ) + .with_stride(onnx_config.stride) + .with_padding(onnx_config.padding) + .with_dilation(onnx_config.dilation) + .with_padding_out(onnx_config.padding_out) + .with_groups(onnx_config.groups) + .with_bias(onnx_config.bias); let bias = node.inputs.len() == 3; let weight = extract_data_serialize::(1, &node).unwrap(); @@ -1160,7 +1214,21 @@ impl ParsedOnnxGraph { fn avg_pool_1d_conversion(node: Node) -> AvgPool1dNode { let input = TensorType::from(node.inputs.first().unwrap()); let output = TensorType::from(node.outputs.first().unwrap()); - let config = avg_pool1d_config(&node); + + // Get configuration from onnx-ir + let onnx_config = avg_pool1d_config(&node); + + // Convert onnx-ir padding to burn padding + let burn_padding = match onnx_config.padding { + onnx_ir::node::conv1d::PaddingConfig1d::Valid => burn::nn::PaddingConfig1d::Valid, + onnx_ir::node::conv1d::PaddingConfig1d::Explicit(size) => burn::nn::PaddingConfig1d::Explicit(size), + }; + + // Convert to burn AvgPool1dConfig + let config = burn::nn::pool::AvgPool1dConfig::new(onnx_config.kernel_size) + .with_stride(onnx_config.stride) + .with_padding(burn_padding) + .with_count_include_pad(onnx_config.count_include_pad); let name = &node.name; AvgPool1dNode::new(name, input, output, config) diff --git a/crates/onnx-ir/src/node/avg_pool1d.rs b/crates/onnx-ir/src/node/avg_pool1d.rs new file mode 100644 index 0000000000..401566c978 --- /dev/null +++ b/crates/onnx-ir/src/node/avg_pool1d.rs @@ -0,0 +1,155 @@ +use crate::ir::Node; + +// Reuse PaddingConfig1d from conv1d module +pub use super::conv1d::PaddingConfig1d; + +/// Configuration for AvgPool1d operations extracted from ONNX nodes +#[derive(Debug, Clone)] +pub struct AvgPool1dConfig { + /// Kernel size + pub kernel_size: usize, + /// Stride + pub stride: usize, + /// Padding configuration + pub padding: PaddingConfig1d, + /// Whether to include padding in the average calculation + pub count_include_pad: bool, +} + +/// Create an AvgPool1dConfig from the attributes of the node +pub fn avg_pool1d_config(curr: &Node) -> AvgPool1dConfig { + let mut kernel_shape = Vec::new(); + let mut strides = vec![1]; + let mut pads = vec![0, 0]; + let mut count_include_pad: i64 = 0; + let mut ceil_mode: i64 = 0; + + for (key, value) in curr.attrs.iter() { + match key.as_str() { + "kernel_shape" => kernel_shape = value.clone().into_i64s(), + "strides" => strides = value.clone().into_i64s(), + "pads" => pads = value.clone().into_i64s(), + "count_include_pad" => count_include_pad = value.clone().into_i64(), + "ceil_mode" => ceil_mode = value.clone().into_i64(), + _ => {} + } + } + + assert_eq!( + kernel_shape.len(), + 1, + "AvgPool1d: kernel shape must have length 1" + ); + assert_eq!(strides.len(), 1, "AvgPool1d: stride must have length 1"); + + if ceil_mode == 1 { + panic!("ceil_mode is not supported"); + } + + let padding = super::conv1d::padding_config_1d(&pads); + + AvgPool1dConfig { + kernel_size: kernel_shape[0] as usize, + stride: strides[0] as usize, + padding, + count_include_pad: count_include_pad == 1, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::{ + ArgType, Argument, AttributeValue, ElementType, NodeType, TensorType, + }; + use std::collections::HashMap; + + fn create_test_node( + kernel_shape: Vec, + strides: Vec, + pads: Vec, + count_include_pad: i64, + ceil_mode: i64, + ) -> Node { + let inputs = vec![Argument { + name: "data".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 3, + static_shape: None, + }), + value: None, + passed: true, + }]; + + let mut attrs = HashMap::new(); + attrs.insert( + "kernel_shape".to_string(), + AttributeValue::Int64s(kernel_shape), + ); + attrs.insert("strides".to_string(), AttributeValue::Int64s(strides)); + attrs.insert("pads".to_string(), AttributeValue::Int64s(pads)); + attrs.insert( + "count_include_pad".to_string(), + AttributeValue::Int64(count_include_pad), + ); + attrs.insert("ceil_mode".to_string(), AttributeValue::Int64(ceil_mode)); + + Node { + node_type: NodeType::AveragePool1d, + name: "test_avgpool1d".to_string(), + inputs, + outputs: vec![Argument { + name: "output".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 3, + static_shape: None, + }), + value: None, + passed: true, + }], + attrs, + } + } + + #[test] + fn test_avg_pool1d_config_basic() { + let node = create_test_node(vec![4], vec![1], vec![0, 0], 0, 0); + let config = avg_pool1d_config(&node); + + assert_eq!(config.kernel_size, 4); + assert_eq!(config.stride, 1); + assert!(!config.count_include_pad); + assert!(matches!(config.padding, PaddingConfig1d::Valid)); + } + + #[test] + fn test_avg_pool1d_config_with_padding() { + let node = create_test_node(vec![4], vec![2], vec![2, 2], 0, 0); + let config = avg_pool1d_config(&node); + + assert_eq!(config.kernel_size, 4); + assert_eq!(config.stride, 2); + assert!(!config.count_include_pad); + assert!(matches!(config.padding, PaddingConfig1d::Explicit(2))); + } + + #[test] + fn test_avg_pool1d_config_with_count_include_pad() { + let node = create_test_node(vec![4], vec![1], vec![2, 2], 1, 0); + let config = avg_pool1d_config(&node); + + assert_eq!(config.kernel_size, 4); + assert_eq!(config.stride, 1); + assert!(config.count_include_pad); + assert!(matches!(config.padding, PaddingConfig1d::Explicit(2))); + } + + #[test] + #[should_panic(expected = "ceil_mode is not supported")] + fn test_avg_pool1d_config_with_ceil_mode() { + let node = create_test_node(vec![4], vec![1], vec![0, 0], 0, 1); + let _ = avg_pool1d_config(&node); + } +} diff --git a/crates/onnx-ir/src/node/conv1d.rs b/crates/onnx-ir/src/node/conv1d.rs new file mode 100644 index 0000000000..3ad98e212c --- /dev/null +++ b/crates/onnx-ir/src/node/conv1d.rs @@ -0,0 +1,288 @@ +use crate::ir::Node; +use std::fmt; + +/// Padding configuration for 1D operations such as convolution +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum PaddingConfig1d { + /// No padding (valid padding) + Valid, + /// Explicit padding with a specific size + Explicit(usize), +} + +impl fmt::Display for PaddingConfig1d { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + PaddingConfig1d::Valid => write!(f, "Valid"), + PaddingConfig1d::Explicit(size) => write!(f, "Explicit({})", size), + } + } +} + +/// Configuration for Conv1d operations extracted from ONNX nodes +#[derive(Debug, Clone)] +pub struct Conv1dConfig { + /// Input channels + pub channels_in: usize, + /// Output channels + pub channels_out: usize, + /// Kernel size + pub kernel_size: usize, + /// Stride + pub stride: usize, + /// Dilation + pub dilation: usize, + /// Number of groups + pub groups: usize, + /// Whether bias is used + pub bias: bool, + /// Padding configuration + pub padding: PaddingConfig1d, +} + +/// Create a Conv1dConfig from the attributes of the node +pub fn conv1d_config(curr: &Node) -> Conv1dConfig { + let mut kernel_shape = Vec::new(); // TODO default inferred from weight tensor per spec + let mut strides = vec![1]; + let mut pads = vec![0, 0]; + let mut dilations = vec![1]; + let mut group: usize = 1; + + let weight_shape = curr.inputs[1] + .value + .as_ref() + .expect("Conv1d: weight tensor must be present") + .shape + .clone(); + + // check if the bias is present + let bias = curr.inputs.len() == 3; + + for (key, value) in curr.attrs.iter() { + match key.as_str() { + "kernel_shape" => kernel_shape = value.clone().into_i64s(), + "strides" => strides = value.clone().into_i64s(), + "pads" => pads = value.clone().into_i64s(), + "dilations" => dilations = value.clone().into_i64s(), + "group" => group = value.clone().into_i64() as usize, + _ => {} + } + } + + // the channels are inverted in the weight tensor + let channels_in = weight_shape[1] * group; + let channels_out = weight_shape[0]; + + let padding = padding_config_1d(&pads); + + Conv1dConfig { + channels_in, + channels_out, + kernel_size: kernel_shape[0] as usize, + stride: strides[0] as usize, + dilation: dilations[0] as usize, + groups: group, + bias, + padding, + } +} + +/// Calculate the padding configuration for a 1D operations such as Convolution and Pooling. +/// +/// # Arguments +/// +/// * `pads` - The padding values +/// +/// # Panics +/// +/// * If the padding is negative +/// * If the padding is not symmetric +/// +/// # Returns +/// +/// * The padding configuration +/// +/// # Remarks +/// +/// This function is used when the padding is specified as a list of integers, +/// and not used when the padding is specified as a string, e.g. "SAME_UPPER". +pub fn padding_config_1d(pads: &[i64]) -> PaddingConfig1d { + let [left, right] = [pads[0], pads[1]]; + + if left < 0 || right < 0 { + panic!("Negative pad values are not supported"); + } else if left != right { + panic!("Asymmetric padding is not supported"); + } else if left == 0 && right == 0 { + // i.e. [0, 0] + PaddingConfig1d::Valid + } else if left == right { + // i.e. [2, 2] + PaddingConfig1d::Explicit(left as usize) + } else { + // Unaccounted for padding configuration + panic!("Padding configuration ({:?}) not supported", pads); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::{ + ArgType, Argument, AttributeValue, Data, ElementType, NodeType, TensorData, TensorType, + }; + use std::collections::HashMap; + + fn create_test_node( + kernel_shape: Vec, + strides: Vec, + pads: Vec, + dilations: Vec, + group: i64, + has_bias: bool, + ) -> Node { + let mut inputs = vec![Argument { + name: "data".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 3, + static_shape: None, + }), + value: None, + passed: true, + }]; + + // Add weight tensor + inputs.push(Argument { + name: "weight".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 3, + static_shape: None, + }), + value: Some(TensorData { + data: Data::Float32s(vec![0.1; 16]), + shape: vec![2, 2, 4], // [out_channels, in_channels, kernel_size] + }), + passed: true, + }); + + // Add bias if needed + if has_bias { + inputs.push(Argument { + name: "bias".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 1, + static_shape: None, + }), + value: Some(TensorData { + data: Data::Float32s(vec![0.1, 0.2]), + shape: vec![2], + }), + passed: true, + }); + } + + let mut attrs = HashMap::new(); + attrs.insert( + "kernel_shape".to_string(), + AttributeValue::Int64s(kernel_shape), + ); + attrs.insert("strides".to_string(), AttributeValue::Int64s(strides)); + attrs.insert("pads".to_string(), AttributeValue::Int64s(pads)); + attrs.insert("dilations".to_string(), AttributeValue::Int64s(dilations)); + attrs.insert("group".to_string(), AttributeValue::Int64(group)); + + Node { + node_type: NodeType::Conv1d, + name: "test_conv1d".to_string(), + inputs, + outputs: vec![Argument { + name: "output".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 3, + static_shape: None, + }), + value: None, + passed: true, + }], + attrs, + } + } + + #[test] + fn test_conv1d_config_basic() { + let node = create_test_node(vec![4], vec![1], vec![0, 0], vec![1], 1, false); + let config = conv1d_config(&node); + + assert_eq!(config.channels_in, 2); + assert_eq!(config.channels_out, 2); + assert_eq!(config.kernel_size, 4); + assert_eq!(config.stride, 1); + assert_eq!(config.dilation, 1); + assert_eq!(config.groups, 1); + assert!(!config.bias); + assert!(matches!(config.padding, PaddingConfig1d::Valid)); + } + + #[test] + fn test_conv1d_config_with_padding() { + let node = create_test_node(vec![4], vec![2], vec![2, 2], vec![1], 1, true); + let config = conv1d_config(&node); + + assert_eq!(config.channels_in, 2); + assert_eq!(config.channels_out, 2); + assert_eq!(config.kernel_size, 4); + assert_eq!(config.stride, 2); + assert_eq!(config.dilation, 1); + assert_eq!(config.groups, 1); + assert!(config.bias); + assert!(matches!(config.padding, PaddingConfig1d::Explicit(2))); + } + + #[test] + fn test_conv1d_config_with_dilation() { + let node = create_test_node(vec![4], vec![1], vec![0, 0], vec![2], 1, false); + let config = conv1d_config(&node); + + assert_eq!(config.channels_in, 2); + assert_eq!(config.channels_out, 2); + assert_eq!(config.kernel_size, 4); + assert_eq!(config.stride, 1); + assert_eq!(config.dilation, 2); + assert_eq!(config.groups, 1); + assert!(!config.bias); + assert!(matches!(config.padding, PaddingConfig1d::Valid)); + } + + #[test] + fn test_conv1d_config_with_groups() { + let node = create_test_node(vec![4], vec![1], vec![0, 0], vec![1], 2, false); + let config = conv1d_config(&node); + + assert_eq!(config.channels_in, 4); + assert_eq!(config.channels_out, 2); + assert_eq!(config.kernel_size, 4); + assert_eq!(config.stride, 1); + assert_eq!(config.dilation, 1); + assert_eq!(config.groups, 2); + assert!(!config.bias); + assert!(matches!(config.padding, PaddingConfig1d::Valid)); + } + + #[test] + #[should_panic(expected = "Asymmetric padding is not supported")] + fn test_conv1d_config_asymmetric_padding() { + let node = create_test_node(vec![4], vec![1], vec![1, 2], vec![1], 1, false); + let _ = conv1d_config(&node); + } + + #[test] + #[should_panic(expected = "Negative pad values are not supported")] + fn test_conv1d_config_negative_padding() { + let node = create_test_node(vec![4], vec![1], vec![-1, -1], vec![1], 1, false); + let _ = conv1d_config(&node); + } +} diff --git a/crates/onnx-ir/src/node/conv_transpose1d.rs b/crates/onnx-ir/src/node/conv_transpose1d.rs new file mode 100644 index 0000000000..e6d046b4e3 --- /dev/null +++ b/crates/onnx-ir/src/node/conv_transpose1d.rs @@ -0,0 +1,239 @@ +use crate::ir::{AttributeValue, Node}; + +// Reuse PaddingConfig1d from conv1d module +pub use super::conv1d::PaddingConfig1d; + +/// Configuration for ConvTranspose1d operations extracted from ONNX nodes +#[derive(Debug, Clone)] +pub struct ConvTranspose1dConfig { + /// Input channels + pub channels_in: usize, + /// Output channels + pub channels_out: usize, + /// Kernel size + pub kernel_size: usize, + /// Stride + pub stride: usize, + /// Dilation + pub dilation: usize, + /// Number of groups + pub groups: usize, + /// Whether bias is used + pub bias: bool, + /// Padding size + pub padding: usize, + /// Output padding size + pub padding_out: usize, +} + +/// Create a ConvTranspose1dConfig from the attributes of the node +pub fn conv_transpose1d_config(curr: &Node) -> ConvTranspose1dConfig { + let mut attrs = curr.attrs.clone(); + + // Extract kernel_shape, default to an empty vector if not present + let kernel_shape = attrs + .remove("kernel_shape") + .map(AttributeValue::into_i64s) + .unwrap_or_default(); + + // Extract strides, default to 1 if not present + let stride = attrs + .remove("strides") + .map(AttributeValue::into_i64s) + .unwrap_or_else(|| vec![1]); + + // Extract padding, default to 0 if not present + let pads = attrs + .remove("pads") + .map(AttributeValue::into_i64s) + .unwrap_or_else(|| vec![0, 0]); + + // Extract dilations, default to 1 if not present + let dilations = attrs + .remove("dilations") + .map(AttributeValue::into_i64s) + .unwrap_or_else(|| vec![1]); + + // Extract group attribute, default to 1 + let group = attrs + .remove("group") + .map(AttributeValue::into_i64) + .unwrap_or(1) as usize; + + // Extract output_padding, default to 0 if not present + let output_padding = attrs + .remove("output_padding") + .map(AttributeValue::into_i64s) + .unwrap_or_else(|| vec![0]); + + // Ensure no unused attributes remain + if !attrs.is_empty() { + panic!("Not all attributes are used: {attrs:?}"); + } + + // Check the pads are symmetric + if pads.len() != 2 || pads[0] != pads[1] { + panic!( + "Asymmetric padding is not supported for ConvTranspose1d: {:?}", + pads + ); + } + + let weight_shape = curr.inputs[1] + .value + .as_ref() + .expect("ConvTranspose1d: weight tensor must be present") + .shape + .clone(); + + // Check if bias is present (third input) + let bias = curr.inputs.len() == 3; + + // Extract channels from the weight tensor shape [out_channels, in_channels] + let channels_in = weight_shape[1] * group; + let channels_out = weight_shape[0]; + + ConvTranspose1dConfig { + channels_in, + channels_out, + kernel_size: kernel_shape[0] as usize, + stride: stride[0] as usize, + padding: pads[0] as usize, + dilation: dilations[0] as usize, + padding_out: output_padding[0] as usize, + groups: group, + bias, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::{ + ArgType, Argument, AttributeValue, Data, ElementType, NodeType, TensorData, TensorType, + }; + use std::collections::HashMap; + + fn create_test_node( + kernel_shape: Vec, + stride: Vec, + pads: Vec, + dilations: Vec, + group: i64, + output_padding: Vec, + has_bias: bool, + ) -> Node { + let mut inputs = vec![Argument { + name: "data".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 3, + static_shape: None, + }), + value: None, + passed: true, + }]; + + // Add weight tensor + inputs.push(Argument { + name: "weight".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 3, + static_shape: None, + }), + value: Some(TensorData { + data: Data::Float32s(vec![0.1; 16]), + shape: vec![2, 2, 4], // [out_channels, in_channels, kernel_size] + }), + passed: true, + }); + + // Add bias if needed + if has_bias { + inputs.push(Argument { + name: "bias".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 1, + static_shape: None, + }), + value: Some(TensorData { + data: Data::Float32s(vec![0.1, 0.2]), + shape: vec![2], + }), + passed: true, + }); + } + + let mut attrs = HashMap::new(); + attrs.insert( + "kernel_shape".to_string(), + AttributeValue::Int64s(kernel_shape), + ); + attrs.insert("strides".to_string(), AttributeValue::Int64s(stride)); + attrs.insert("pads".to_string(), AttributeValue::Int64s(pads)); + attrs.insert("dilations".to_string(), AttributeValue::Int64s(dilations)); + attrs.insert("group".to_string(), AttributeValue::Int64(group)); + attrs.insert( + "output_padding".to_string(), + AttributeValue::Int64s(output_padding), + ); + + Node { + node_type: NodeType::ConvTranspose1d, + name: "test_conv_transpose1d".to_string(), + inputs, + outputs: vec![Argument { + name: "output".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 3, + static_shape: None, + }), + value: None, + passed: true, + }], + attrs, + } + } + + #[test] + fn test_conv_transpose1d_config_basic() { + let node = create_test_node(vec![4], vec![1], vec![0, 0], vec![1], 1, vec![0], false); + let config = conv_transpose1d_config(&node); + + assert_eq!(config.channels_in, 2); + assert_eq!(config.channels_out, 2); + assert_eq!(config.kernel_size, 4); + assert_eq!(config.stride, 1); + assert_eq!(config.padding, 0); + assert_eq!(config.dilation, 1); + assert_eq!(config.padding_out, 0); + assert_eq!(config.groups, 1); + assert!(!config.bias); + } + + #[test] + fn test_conv_transpose1d_config_with_params() { + let node = create_test_node(vec![4], vec![2], vec![1, 1], vec![2], 2, vec![1], true); + let config = conv_transpose1d_config(&node); + + assert_eq!(config.channels_in, 4); // weight_shape[1] * group = 2 * 2 + assert_eq!(config.channels_out, 2); + assert_eq!(config.kernel_size, 4); + assert_eq!(config.stride, 2); + assert_eq!(config.padding, 1); + assert_eq!(config.dilation, 2); + assert_eq!(config.padding_out, 1); + assert_eq!(config.groups, 2); + assert!(config.bias); + } + + #[test] + #[should_panic(expected = "Asymmetric padding is not supported")] + fn test_conv_transpose1d_config_asymmetric_padding() { + let node = create_test_node(vec![4], vec![1], vec![1, 2], vec![1], 1, vec![0], false); + let _ = conv_transpose1d_config(&node); + } +} diff --git a/crates/onnx-ir/src/node/max_pool1d.rs b/crates/onnx-ir/src/node/max_pool1d.rs new file mode 100644 index 0000000000..dc1ec430c1 --- /dev/null +++ b/crates/onnx-ir/src/node/max_pool1d.rs @@ -0,0 +1,145 @@ +use crate::ir::Node; + +// Reuse PaddingConfig1d from conv1d module +pub use super::conv1d::PaddingConfig1d; + +/// Configuration for MaxPool1d operations extracted from ONNX nodes +#[derive(Debug, Clone)] +pub struct MaxPool1dConfig { + /// Kernel size + pub kernel_size: usize, + /// Stride + pub stride: usize, + /// Dilation + pub dilation: usize, + /// Padding configuration + pub padding: PaddingConfig1d, +} + +/// Create a MaxPool1dConfig from the attributes of the node +pub fn max_pool1d_config(curr: &Node) -> MaxPool1dConfig { + let mut kernel_shape = Vec::new(); + let mut stride = vec![1]; + let mut pads = vec![0, 0]; + let mut dilation = vec![1]; + + for (key, value) in curr.attrs.iter() { + match key.as_str() { + "kernel_shape" => kernel_shape = value.clone().into_i64s(), + "strides" => stride = value.clone().into_i64s(), + "pads" => pads = value.clone().into_i64s(), + "dilations" => dilation = value.clone().into_i64s(), + _ => {} + } + } + + assert_eq!( + kernel_shape.len(), + 1, + "MaxPool1d: kernel shape must have length 1" + ); + assert_eq!(dilation.len(), 1, "MaxPool1d: dilation must have length 1"); + assert_eq!(stride.len(), 1, "MaxPool1d: stride must have length 1"); + + let padding = super::conv1d::padding_config_1d(&pads); + + MaxPool1dConfig { + kernel_size: kernel_shape[0] as usize, + stride: stride[0] as usize, + dilation: dilation[0] as usize, + padding, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::{ + ArgType, Argument, AttributeValue, ElementType, NodeType, TensorType, + }; + use std::collections::HashMap; + + fn create_test_node( + kernel_shape: Vec, + stride: Vec, + pads: Vec, + dilation: Vec, + ) -> Node { + let inputs = vec![Argument { + name: "data".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 3, + static_shape: None, + }), + value: None, + passed: true, + }]; + + let mut attrs = HashMap::new(); + attrs.insert( + "kernel_shape".to_string(), + AttributeValue::Int64s(kernel_shape), + ); + attrs.insert("strides".to_string(), AttributeValue::Int64s(stride)); + attrs.insert("pads".to_string(), AttributeValue::Int64s(pads)); + attrs.insert("dilations".to_string(), AttributeValue::Int64s(dilation)); + + Node { + node_type: NodeType::MaxPool1d, + name: "test_maxpool1d".to_string(), + inputs, + outputs: vec![Argument { + name: "output".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 3, + static_shape: None, + }), + value: None, + passed: true, + }], + attrs, + } + } + + #[test] + fn test_max_pool1d_config_basic() { + let node = create_test_node(vec![4], vec![1], vec![0, 0], vec![1]); + let config = max_pool1d_config(&node); + + assert_eq!(config.kernel_size, 4); + assert_eq!(config.stride, 1); + assert_eq!(config.dilation, 1); + assert!(matches!(config.padding, PaddingConfig1d::Valid)); + } + + #[test] + fn test_max_pool1d_config_with_padding() { + let node = create_test_node(vec![4], vec![2], vec![2, 2], vec![1]); + let config = max_pool1d_config(&node); + + assert_eq!(config.kernel_size, 4); + assert_eq!(config.stride, 2); + assert_eq!(config.dilation, 1); + assert!(matches!(config.padding, PaddingConfig1d::Explicit(2))); + } + + #[test] + fn test_max_pool1d_config_with_dilation() { + let node = create_test_node(vec![4], vec![1], vec![0, 0], vec![2]); + let config = max_pool1d_config(&node); + + assert_eq!(config.kernel_size, 4); + assert_eq!(config.stride, 1); + assert_eq!(config.dilation, 2); + assert!(matches!(config.padding, PaddingConfig1d::Valid)); + } + + #[test] + #[should_panic(expected = "Asymmetric padding is not supported")] + fn test_max_pool1d_config_asymmetric_padding() { + let node = create_test_node(vec![4], vec![1], vec![1, 2], vec![1]); + let _ = max_pool1d_config(&node); + } +} diff --git a/crates/onnx-ir/src/node/mod.rs b/crates/onnx-ir/src/node/mod.rs index 913812d439..63fe32edb4 100644 --- a/crates/onnx-ir/src/node/mod.rs +++ b/crates/onnx-ir/src/node/mod.rs @@ -1 +1,5 @@ pub mod slice; +pub mod conv1d; +pub mod avg_pool1d; +pub mod max_pool1d; +pub mod conv_transpose1d; From 1b1287b6faec76586c4521924b19b681335527ba Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Wed, 30 Apr 2025 12:12:36 -0500 Subject: [PATCH 02/37] Fix format --- .../burn-import/src/onnx/op_configuration.rs | 5 +- crates/burn-import/src/onnx/to_burn.rs | 67 ++++++++++--------- 2 files changed, 36 insertions(+), 36 deletions(-) diff --git a/crates/burn-import/src/onnx/op_configuration.rs b/crates/burn-import/src/onnx/op_configuration.rs index 5a99e3aacc..84b5b2601b 100644 --- a/crates/burn-import/src/onnx/op_configuration.rs +++ b/crates/burn-import/src/onnx/op_configuration.rs @@ -4,10 +4,7 @@ use burn::nn::{ BatchNormConfig, DropoutConfig, LayerNormConfig, LinearConfig, PaddingConfig2d, PaddingConfig3d, - conv::{ - Conv2dConfig, Conv3dConfig, ConvTranspose2dConfig, - ConvTranspose3dConfig, - }, + conv::{Conv2dConfig, Conv3dConfig, ConvTranspose2dConfig, ConvTranspose3dConfig}, pool::{AvgPool2dConfig, MaxPool2dConfig}, }; diff --git a/crates/burn-import/src/onnx/to_burn.rs b/crates/burn-import/src/onnx/to_burn.rs index e3aca7e852..da924042ac 100644 --- a/crates/burn-import/src/onnx/to_burn.rs +++ b/crates/burn-import/src/onnx/to_burn.rs @@ -70,14 +70,14 @@ use crate::{ }; use super::op_configuration::{ - argmax_config, avg_pool2d_config, batch_norm_config, clip_config, - concat_config, conv_transpose2d_config, conv_transpose3d_config, - conv2d_config, conv3d_config, dropout_config, expand_config, flatten_config, - gather_config, gemm_config, hard_sigmoid_config, layer_norm_config, leaky_relu_config, - linear_config, log_softmax_config, max_pool2d_config, one_hot_config, - pad_config, reduce_max_config, reduce_mean_config, reduce_min_config, reduce_prod_config, - reduce_sum_config, reshape_config, resize_config, shape_config, softmax_config, split_config, - squeeze_config, tile_config, top_k_config, transpose_config, trilu_config, unsqueeze_config, + argmax_config, avg_pool2d_config, batch_norm_config, clip_config, concat_config, + conv_transpose2d_config, conv_transpose3d_config, conv2d_config, conv3d_config, dropout_config, + expand_config, flatten_config, gather_config, gemm_config, hard_sigmoid_config, + layer_norm_config, leaky_relu_config, linear_config, log_softmax_config, max_pool2d_config, + one_hot_config, pad_config, reduce_max_config, reduce_mean_config, reduce_min_config, + reduce_prod_config, reduce_sum_config, reshape_config, resize_config, shape_config, + softmax_config, split_config, squeeze_config, tile_config, top_k_config, transpose_config, + trilu_config, unsqueeze_config, }; use onnx_ir::{ convert_constant_value, @@ -86,11 +86,8 @@ use onnx_ir::{ TensorType as OnnxTensorType, }, node::{ - slice::slice_config, - conv1d::conv1d_config, - avg_pool1d::avg_pool1d_config, - max_pool1d::max_pool1d_config, - conv_transpose1d::conv_transpose1d_config, + avg_pool1d::avg_pool1d_config, conv_transpose1d::conv_transpose1d_config, + conv1d::conv1d_config, max_pool1d::max_pool1d_config, slice::slice_config, }, parse_onnx, }; @@ -1028,21 +1025,23 @@ impl ParsedOnnxGraph { fn conv1d_conversion(node: Node) -> Conv1dNode { let input = TensorType::from(node.inputs.first().unwrap()); let output = TensorType::from(node.outputs.first().unwrap()); - + // Get configuration from onnx-ir let onnx_config = conv1d_config(&node); - + // Convert onnx-ir padding to burn padding let burn_padding = match onnx_config.padding { onnx_ir::node::conv1d::PaddingConfig1d::Valid => burn::nn::PaddingConfig1d::Valid, - onnx_ir::node::conv1d::PaddingConfig1d::Explicit(size) => burn::nn::PaddingConfig1d::Explicit(size), + onnx_ir::node::conv1d::PaddingConfig1d::Explicit(size) => { + burn::nn::PaddingConfig1d::Explicit(size) + } }; - + // Convert to burn Conv1dConfig let config = burn::nn::conv::Conv1dConfig::new( - onnx_config.channels_in, - onnx_config.channels_out, - onnx_config.kernel_size + onnx_config.channels_in, + onnx_config.channels_out, + onnx_config.kernel_size, ) .with_stride(onnx_config.stride) .with_dilation(onnx_config.dilation) @@ -1096,16 +1095,18 @@ impl ParsedOnnxGraph { fn max_pool1d_conversion(node: Node) -> MaxPool1dNode { let input = TensorType::from(node.inputs.first().unwrap()); let output = TensorType::from(node.outputs.first().unwrap()); - + // Get configuration from onnx-ir let onnx_config = max_pool1d_config(&node); - + // Convert onnx-ir padding to burn padding let burn_padding = match onnx_config.padding { onnx_ir::node::conv1d::PaddingConfig1d::Valid => burn::nn::PaddingConfig1d::Valid, - onnx_ir::node::conv1d::PaddingConfig1d::Explicit(size) => burn::nn::PaddingConfig1d::Explicit(size), + onnx_ir::node::conv1d::PaddingConfig1d::Explicit(size) => { + burn::nn::PaddingConfig1d::Explicit(size) + } }; - + // Convert to burn MaxPool1dConfig let config = burn::nn::pool::MaxPool1dConfig::new(onnx_config.kernel_size) .with_stride(onnx_config.stride) @@ -1154,14 +1155,14 @@ impl ParsedOnnxGraph { fn conv_transpose1d_conversion(node: Node) -> ConvTranspose1dNode { let input = TensorType::from(node.inputs.first().unwrap()); let output = TensorType::from(node.outputs.first().unwrap()); - + // Get configuration from onnx-ir let onnx_config = conv_transpose1d_config(&node); - + // Convert to burn ConvTranspose1dConfig let config = burn::nn::conv::ConvTranspose1dConfig::new( - [onnx_config.channels_in, onnx_config.channels_out], - onnx_config.kernel_size + [onnx_config.channels_in, onnx_config.channels_out], + onnx_config.kernel_size, ) .with_stride(onnx_config.stride) .with_padding(onnx_config.padding) @@ -1214,16 +1215,18 @@ impl ParsedOnnxGraph { fn avg_pool_1d_conversion(node: Node) -> AvgPool1dNode { let input = TensorType::from(node.inputs.first().unwrap()); let output = TensorType::from(node.outputs.first().unwrap()); - + // Get configuration from onnx-ir let onnx_config = avg_pool1d_config(&node); - + // Convert onnx-ir padding to burn padding let burn_padding = match onnx_config.padding { onnx_ir::node::conv1d::PaddingConfig1d::Valid => burn::nn::PaddingConfig1d::Valid, - onnx_ir::node::conv1d::PaddingConfig1d::Explicit(size) => burn::nn::PaddingConfig1d::Explicit(size), + onnx_ir::node::conv1d::PaddingConfig1d::Explicit(size) => { + burn::nn::PaddingConfig1d::Explicit(size) + } }; - + // Convert to burn AvgPool1dConfig let config = burn::nn::pool::AvgPool1dConfig::new(onnx_config.kernel_size) .with_stride(onnx_config.stride) From e5234989149f592534404522086f0fd319d96395 Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Wed, 30 Apr 2025 12:13:13 -0500 Subject: [PATCH 03/37] Fix format --- crates/onnx-ir/src/node/avg_pool1d.rs | 4 +--- crates/onnx-ir/src/node/max_pool1d.rs | 4 +--- crates/onnx-ir/src/node/mod.rs | 6 +++--- 3 files changed, 5 insertions(+), 9 deletions(-) diff --git a/crates/onnx-ir/src/node/avg_pool1d.rs b/crates/onnx-ir/src/node/avg_pool1d.rs index 401566c978..8b45b473f8 100644 --- a/crates/onnx-ir/src/node/avg_pool1d.rs +++ b/crates/onnx-ir/src/node/avg_pool1d.rs @@ -59,9 +59,7 @@ pub fn avg_pool1d_config(curr: &Node) -> AvgPool1dConfig { #[cfg(test)] mod tests { use super::*; - use crate::ir::{ - ArgType, Argument, AttributeValue, ElementType, NodeType, TensorType, - }; + use crate::ir::{ArgType, Argument, AttributeValue, ElementType, NodeType, TensorType}; use std::collections::HashMap; fn create_test_node( diff --git a/crates/onnx-ir/src/node/max_pool1d.rs b/crates/onnx-ir/src/node/max_pool1d.rs index dc1ec430c1..b4e8ede8a4 100644 --- a/crates/onnx-ir/src/node/max_pool1d.rs +++ b/crates/onnx-ir/src/node/max_pool1d.rs @@ -54,9 +54,7 @@ pub fn max_pool1d_config(curr: &Node) -> MaxPool1dConfig { #[cfg(test)] mod tests { use super::*; - use crate::ir::{ - ArgType, Argument, AttributeValue, ElementType, NodeType, TensorType, - }; + use crate::ir::{ArgType, Argument, AttributeValue, ElementType, NodeType, TensorType}; use std::collections::HashMap; fn create_test_node( diff --git a/crates/onnx-ir/src/node/mod.rs b/crates/onnx-ir/src/node/mod.rs index 63fe32edb4..b44e62f217 100644 --- a/crates/onnx-ir/src/node/mod.rs +++ b/crates/onnx-ir/src/node/mod.rs @@ -1,5 +1,5 @@ -pub mod slice; -pub mod conv1d; pub mod avg_pool1d; -pub mod max_pool1d; +pub mod conv1d; pub mod conv_transpose1d; +pub mod max_pool1d; +pub mod slice; From e0083c0693d25d69fb113ee3127a81cf90964013 Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Wed, 30 Apr 2025 12:41:18 -0500 Subject: [PATCH 04/37] Move conv1d rank update --- crates/onnx-ir/src/node/conv1d.rs | 45 +++++++++++++++++++++++++++- crates/onnx-ir/src/rank_inference.rs | 24 +++------------ 2 files changed, 48 insertions(+), 21 deletions(-) diff --git a/crates/onnx-ir/src/node/conv1d.rs b/crates/onnx-ir/src/node/conv1d.rs index 3ad98e212c..876543c410 100644 --- a/crates/onnx-ir/src/node/conv1d.rs +++ b/crates/onnx-ir/src/node/conv1d.rs @@ -1,4 +1,4 @@ -use crate::ir::Node; +use crate::ir::{ArgType, Node, TensorType}; use std::fmt; /// Padding configuration for 1D operations such as convolution @@ -125,6 +125,25 @@ pub fn padding_config_1d(pads: &[i64]) -> PaddingConfig1d { } } +/// Update output rank for Conv1d (same as input). +pub fn conv1d_update_outputs(node: &mut Node) { + log::debug!("Conv1d rank inference for node {}", node.name); + + if let ArgType::Tensor(tensor) = &node.inputs[0].ty { + log::debug!("Conv1d input rank for {}: {}", node.name, tensor.rank); + + node.outputs[0].ty = ArgType::Tensor(TensorType { + elem_type: tensor.elem_type.clone(), + rank: tensor.rank, + static_shape: None, + }); + + log::debug!("Conv1d output rank for {}: {}", node.name, tensor.rank); + } else { + panic!("Only tensor input is valid"); + } +} + #[cfg(test)] mod tests { use super::*; @@ -285,4 +304,28 @@ mod tests { let node = create_test_node(vec![4], vec![1], vec![-1, -1], vec![1], 1, false); let _ = conv1d_config(&node); } + + #[test] + fn test_conv1d_update_outputs() { + let mut node = create_test_node(vec![4], vec![1], vec![0, 0], vec![1], 1, false); + + // Before calling, check that input and output ranks exist + if let ArgType::Tensor(tensor) = &node.inputs[0].ty { + assert_eq!(tensor.rank, 3); + } else { + panic!("Expected tensor input"); + } + + // Run the function + conv1d_update_outputs(&mut node); + + // After calling, output should have same rank as input + if let ArgType::Tensor(tensor) = &node.outputs[0].ty { + assert_eq!(tensor.rank, 3); + assert_eq!(tensor.elem_type, ElementType::Float32); + assert!(tensor.static_shape.is_none()); + } else { + panic!("Expected tensor output"); + } + } } diff --git a/crates/onnx-ir/src/rank_inference.rs b/crates/onnx-ir/src/rank_inference.rs index b3a0bd2bbf..ef448651a3 100644 --- a/crates/onnx-ir/src/rank_inference.rs +++ b/crates/onnx-ir/src/rank_inference.rs @@ -5,7 +5,10 @@ use protobuf::Enum; use crate::{ ir::{ArgType, AttributeValue, Data, ElementType, Node, NodeType, TensorType}, - node::slice::slice_update_output_rank, + node::{ + conv1d::conv1d_update_outputs, + slice::slice_update_output_rank, + }, protos::tensor_proto::DataType, util::shape_config, }; @@ -817,25 +820,6 @@ fn flatten_update_outputs(node: &mut Node) { }); } -/// Update output rank for Conv1d (same as input). -fn conv1d_update_outputs(node: &mut Node) { - log::debug!("Conv1d rank inference for node {}", node.name); - - if let ArgType::Tensor(tensor) = &node.inputs[0].ty { - log::debug!("Conv1d input rank for {}: {}", node.name, tensor.rank); - - node.outputs[0].ty = ArgType::Tensor(TensorType { - elem_type: tensor.elem_type.clone(), - rank: tensor.rank, - static_shape: None, - }); - - log::debug!("Conv1d output rank for {}: {}", node.name, tensor.rank); - } else { - panic!("Only tensor input is valid"); - } -} - /// Update output rank for Conv2d (same as input). fn conv2d_update_outputs(node: &mut Node) { log::debug!("Conv2d rank inference for node {}", node.name); From 9f39454718b295720637ff9a89fa32a55f897b9a Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Wed, 30 Apr 2025 12:45:30 -0500 Subject: [PATCH 05/37] Move conv_transpose1d_update_outputs --- crates/onnx-ir/src/node/conv1d.rs | 2 +- crates/onnx-ir/src/node/conv_transpose1d.rs | 53 ++++++++++++++++++++- crates/onnx-ir/src/rank_inference.rs | 29 +---------- 3 files changed, 54 insertions(+), 30 deletions(-) diff --git a/crates/onnx-ir/src/node/conv1d.rs b/crates/onnx-ir/src/node/conv1d.rs index 876543c410..a73e4e8224 100644 --- a/crates/onnx-ir/src/node/conv1d.rs +++ b/crates/onnx-ir/src/node/conv1d.rs @@ -308,7 +308,7 @@ mod tests { #[test] fn test_conv1d_update_outputs() { let mut node = create_test_node(vec![4], vec![1], vec![0, 0], vec![1], 1, false); - + // Before calling, check that input and output ranks exist if let ArgType::Tensor(tensor) = &node.inputs[0].ty { assert_eq!(tensor.rank, 3); diff --git a/crates/onnx-ir/src/node/conv_transpose1d.rs b/crates/onnx-ir/src/node/conv_transpose1d.rs index e6d046b4e3..eaf168b9d8 100644 --- a/crates/onnx-ir/src/node/conv_transpose1d.rs +++ b/crates/onnx-ir/src/node/conv_transpose1d.rs @@ -1,4 +1,4 @@ -use crate::ir::{AttributeValue, Node}; +use crate::ir::{ArgType, AttributeValue, Node, TensorType}; // Reuse PaddingConfig1d from conv1d module pub use super::conv1d::PaddingConfig1d; @@ -106,6 +106,33 @@ pub fn conv_transpose1d_config(curr: &Node) -> ConvTranspose1dConfig { } } +/// Update output rank for ConvTranspose1d (same as input). +pub fn conv_transpose1d_update_outputs(node: &mut Node) { + log::debug!("ConvTranspose1d rank inference for node {}", node.name); + + if let ArgType::Tensor(tensor) = &node.inputs[0].ty { + log::debug!( + "ConvTranspose1d input rank for {}: {}", + node.name, + tensor.rank + ); + + node.outputs[0].ty = ArgType::Tensor(TensorType { + elem_type: tensor.elem_type.clone(), + rank: tensor.rank, + static_shape: None, + }); + + log::debug!( + "ConvTranspose1d output rank for {}: {}", + node.name, + tensor.rank + ); + } else { + panic!("Only tensor input is valid"); + } +} + #[cfg(test)] mod tests { use super::*; @@ -236,4 +263,28 @@ mod tests { let node = create_test_node(vec![4], vec![1], vec![1, 2], vec![1], 1, vec![0], false); let _ = conv_transpose1d_config(&node); } + + #[test] + fn test_conv_transpose1d_update_outputs() { + let mut node = create_test_node(vec![4], vec![1], vec![0, 0], vec![1], 1, vec![0], false); + + // Before calling, check that input and output ranks exist + if let ArgType::Tensor(tensor) = &node.inputs[0].ty { + assert_eq!(tensor.rank, 3); + } else { + panic!("Expected tensor input"); + } + + // Run the function + conv_transpose1d_update_outputs(&mut node); + + // After calling, output should have same rank as input + if let ArgType::Tensor(tensor) = &node.outputs[0].ty { + assert_eq!(tensor.rank, 3); + assert_eq!(tensor.elem_type, ElementType::Float32); + assert!(tensor.static_shape.is_none()); + } else { + panic!("Expected tensor output"); + } + } } diff --git a/crates/onnx-ir/src/rank_inference.rs b/crates/onnx-ir/src/rank_inference.rs index ef448651a3..a3a83486fb 100644 --- a/crates/onnx-ir/src/rank_inference.rs +++ b/crates/onnx-ir/src/rank_inference.rs @@ -6,7 +6,7 @@ use protobuf::Enum; use crate::{ ir::{ArgType, AttributeValue, Data, ElementType, Node, NodeType, TensorType}, node::{ - conv1d::conv1d_update_outputs, + conv_transpose1d::conv_transpose1d_update_outputs, conv1d::conv1d_update_outputs, slice::slice_update_output_rank, }, protos::tensor_proto::DataType, @@ -839,33 +839,6 @@ fn conv2d_update_outputs(node: &mut Node) { } } -/// Update output rank for ConvTranspose1d (same as input). -fn conv_transpose1d_update_outputs(node: &mut Node) { - log::debug!("ConvTranspose1d rank inference for node {}", node.name); - - if let ArgType::Tensor(tensor) = &node.inputs[0].ty { - log::debug!( - "ConvTranspose1d input rank for {}: {}", - node.name, - tensor.rank - ); - - node.outputs[0].ty = ArgType::Tensor(TensorType { - elem_type: tensor.elem_type.clone(), - rank: tensor.rank, - static_shape: None, - }); - - log::debug!( - "ConvTranspose1d output rank for {}: {}", - node.name, - tensor.rank - ); - } else { - panic!("Only tensor input is valid"); - } -} - /// Update output rank for ConvTranspose2d (same as input). fn conv_transpose2d_update_outputs(node: &mut Node) { log::debug!("ConvTranspose2d rank inference for node {}", node.name); From 4fa7636456a65bf7fc0ca42af61cca05668427e6 Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Wed, 30 Apr 2025 12:55:05 -0500 Subject: [PATCH 06/37] Removed and replaced as same_as_input --- crates/onnx-ir/src/node/conv1d.rs | 45 +--------------- crates/onnx-ir/src/node/conv_transpose1d.rs | 53 +----------------- crates/onnx-ir/src/rank_inference.rs | 59 ++------------------- 3 files changed, 7 insertions(+), 150 deletions(-) diff --git a/crates/onnx-ir/src/node/conv1d.rs b/crates/onnx-ir/src/node/conv1d.rs index a73e4e8224..3ad98e212c 100644 --- a/crates/onnx-ir/src/node/conv1d.rs +++ b/crates/onnx-ir/src/node/conv1d.rs @@ -1,4 +1,4 @@ -use crate::ir::{ArgType, Node, TensorType}; +use crate::ir::Node; use std::fmt; /// Padding configuration for 1D operations such as convolution @@ -125,25 +125,6 @@ pub fn padding_config_1d(pads: &[i64]) -> PaddingConfig1d { } } -/// Update output rank for Conv1d (same as input). -pub fn conv1d_update_outputs(node: &mut Node) { - log::debug!("Conv1d rank inference for node {}", node.name); - - if let ArgType::Tensor(tensor) = &node.inputs[0].ty { - log::debug!("Conv1d input rank for {}: {}", node.name, tensor.rank); - - node.outputs[0].ty = ArgType::Tensor(TensorType { - elem_type: tensor.elem_type.clone(), - rank: tensor.rank, - static_shape: None, - }); - - log::debug!("Conv1d output rank for {}: {}", node.name, tensor.rank); - } else { - panic!("Only tensor input is valid"); - } -} - #[cfg(test)] mod tests { use super::*; @@ -304,28 +285,4 @@ mod tests { let node = create_test_node(vec![4], vec![1], vec![-1, -1], vec![1], 1, false); let _ = conv1d_config(&node); } - - #[test] - fn test_conv1d_update_outputs() { - let mut node = create_test_node(vec![4], vec![1], vec![0, 0], vec![1], 1, false); - - // Before calling, check that input and output ranks exist - if let ArgType::Tensor(tensor) = &node.inputs[0].ty { - assert_eq!(tensor.rank, 3); - } else { - panic!("Expected tensor input"); - } - - // Run the function - conv1d_update_outputs(&mut node); - - // After calling, output should have same rank as input - if let ArgType::Tensor(tensor) = &node.outputs[0].ty { - assert_eq!(tensor.rank, 3); - assert_eq!(tensor.elem_type, ElementType::Float32); - assert!(tensor.static_shape.is_none()); - } else { - panic!("Expected tensor output"); - } - } } diff --git a/crates/onnx-ir/src/node/conv_transpose1d.rs b/crates/onnx-ir/src/node/conv_transpose1d.rs index eaf168b9d8..e6d046b4e3 100644 --- a/crates/onnx-ir/src/node/conv_transpose1d.rs +++ b/crates/onnx-ir/src/node/conv_transpose1d.rs @@ -1,4 +1,4 @@ -use crate::ir::{ArgType, AttributeValue, Node, TensorType}; +use crate::ir::{AttributeValue, Node}; // Reuse PaddingConfig1d from conv1d module pub use super::conv1d::PaddingConfig1d; @@ -106,33 +106,6 @@ pub fn conv_transpose1d_config(curr: &Node) -> ConvTranspose1dConfig { } } -/// Update output rank for ConvTranspose1d (same as input). -pub fn conv_transpose1d_update_outputs(node: &mut Node) { - log::debug!("ConvTranspose1d rank inference for node {}", node.name); - - if let ArgType::Tensor(tensor) = &node.inputs[0].ty { - log::debug!( - "ConvTranspose1d input rank for {}: {}", - node.name, - tensor.rank - ); - - node.outputs[0].ty = ArgType::Tensor(TensorType { - elem_type: tensor.elem_type.clone(), - rank: tensor.rank, - static_shape: None, - }); - - log::debug!( - "ConvTranspose1d output rank for {}: {}", - node.name, - tensor.rank - ); - } else { - panic!("Only tensor input is valid"); - } -} - #[cfg(test)] mod tests { use super::*; @@ -263,28 +236,4 @@ mod tests { let node = create_test_node(vec![4], vec![1], vec![1, 2], vec![1], 1, vec![0], false); let _ = conv_transpose1d_config(&node); } - - #[test] - fn test_conv_transpose1d_update_outputs() { - let mut node = create_test_node(vec![4], vec![1], vec![0, 0], vec![1], 1, vec![0], false); - - // Before calling, check that input and output ranks exist - if let ArgType::Tensor(tensor) = &node.inputs[0].ty { - assert_eq!(tensor.rank, 3); - } else { - panic!("Expected tensor input"); - } - - // Run the function - conv_transpose1d_update_outputs(&mut node); - - // After calling, output should have same rank as input - if let ArgType::Tensor(tensor) = &node.outputs[0].ty { - assert_eq!(tensor.rank, 3); - assert_eq!(tensor.elem_type, ElementType::Float32); - assert!(tensor.static_shape.is_none()); - } else { - panic!("Expected tensor output"); - } - } } diff --git a/crates/onnx-ir/src/rank_inference.rs b/crates/onnx-ir/src/rank_inference.rs index a3a83486fb..7e4f8792bb 100644 --- a/crates/onnx-ir/src/rank_inference.rs +++ b/crates/onnx-ir/src/rank_inference.rs @@ -5,10 +5,7 @@ use protobuf::Enum; use crate::{ ir::{ArgType, AttributeValue, Data, ElementType, Node, NodeType, TensorType}, - node::{ - conv_transpose1d::conv_transpose1d_update_outputs, conv1d::conv1d_update_outputs, - slice::slice_update_output_rank, - }, + node::slice::slice_update_output_rank, protos::tensor_proto::DataType, util::shape_config, }; @@ -28,8 +25,8 @@ pub fn rank_inference(node: &mut Node) { NodeType::Concat => concat_update_outputs(node), NodeType::Constant => constant_update_outputs(node), NodeType::ConstantOfShape => constant_of_shape_update_output(node), - NodeType::Conv1d => conv1d_update_outputs(node), - NodeType::Conv2d => conv2d_update_outputs(node), + NodeType::Conv1d => same_as_input(node), + NodeType::Conv2d => same_as_input(node), NodeType::Cos => same_as_input(node), NodeType::Cosh => same_as_input(node), NodeType::Div => same_as_input_broadcast(node), @@ -48,8 +45,8 @@ pub fn rank_inference(node: &mut Node) { NodeType::GreaterOrEqual => elementwise_comparison_outputs(node), NodeType::HardSigmoid => same_as_input(node), NodeType::GlobalAveragePool => same_as_input(node), - NodeType::ConvTranspose1d => conv_transpose1d_update_outputs(node), - NodeType::ConvTranspose2d => conv_transpose2d_update_outputs(node), + NodeType::ConvTranspose1d => same_as_input(node), + NodeType::ConvTranspose2d => same_as_input(node), NodeType::LayerNormalization => same_as_input(node), NodeType::LeakyRelu => same_as_input(node), NodeType::Less => elementwise_comparison_outputs(node), @@ -820,52 +817,6 @@ fn flatten_update_outputs(node: &mut Node) { }); } -/// Update output rank for Conv2d (same as input). -fn conv2d_update_outputs(node: &mut Node) { - log::debug!("Conv2d rank inference for node {}", node.name); - - if let ArgType::Tensor(tensor) = &node.inputs[0].ty { - log::debug!("Conv2d input rank for {}: {}", node.name, tensor.rank); - - node.outputs[0].ty = ArgType::Tensor(TensorType { - elem_type: tensor.elem_type.clone(), - rank: tensor.rank, - static_shape: None, - }); - - log::debug!("Conv2d output rank for {}: {}", node.name, tensor.rank); - } else { - panic!("Only tensor input is valid"); - } -} - -/// Update output rank for ConvTranspose2d (same as input). -fn conv_transpose2d_update_outputs(node: &mut Node) { - log::debug!("ConvTranspose2d rank inference for node {}", node.name); - - if let ArgType::Tensor(tensor) = &node.inputs[0].ty { - log::debug!( - "ConvTranspose2d input rank for {}: {}", - node.name, - tensor.rank - ); - - node.outputs[0].ty = ArgType::Tensor(TensorType { - elem_type: tensor.elem_type.clone(), - rank: tensor.rank, - static_shape: None, - }); - - log::debug!( - "ConvTranspose2d output rank for {}: {}", - node.name, - tensor.rank - ); - } else { - panic!("Only tensor input is valid"); - } -} - /// Update output rank for MatMul based on input ranks. fn matmul_update_outputs(node: &mut Node) { log::debug!("MatMul rank inference for node {}", node.name); From 4e41de9d50912bac00888f878fe07b082db9be31 Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Wed, 30 Apr 2025 13:05:26 -0500 Subject: [PATCH 07/37] Move op config from burn-import --- Cargo.lock | 1 + crates/onnx-ir/Cargo.toml | 8 + crates/onnx-ir/src/lib.rs | 1 + crates/onnx-ir/src/op_configuration.rs | 1737 ++++++++++++++++++++++++ 4 files changed, 1747 insertions(+) create mode 100644 crates/onnx-ir/src/op_configuration.rs diff --git a/Cargo.lock b/Cargo.lock index cb3d74d4e0..1a890ff4fa 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4474,6 +4474,7 @@ dependencies = [ name = "onnx-ir" version = "0.18.0" dependencies = [ + "burn", "bytemuck", "half", "log", diff --git a/crates/onnx-ir/Cargo.toml b/crates/onnx-ir/Cargo.toml index bc720ebda0..931c5a3652 100644 --- a/crates/onnx-ir/Cargo.toml +++ b/crates/onnx-ir/Cargo.toml @@ -14,6 +14,14 @@ version.workspace = true [dependencies] + +# REMOVE burn specific crates +burn = { path = "../burn", version = "0.18.0", default-features = false, features = [ + "std", +] } +# burn-import = { path = "../burn-import", version = "0.18.0" } + + bytemuck = { workspace = true } half = { workspace = true } log = { workspace = true } diff --git a/crates/onnx-ir/src/lib.rs b/crates/onnx-ir/src/lib.rs index b195b88adb..0a4d0e847c 100644 --- a/crates/onnx-ir/src/lib.rs +++ b/crates/onnx-ir/src/lib.rs @@ -3,6 +3,7 @@ mod from_onnx; pub mod ir; pub mod node; mod node_remap; +pub mod op_configuration; mod proto_conversion; mod protos; mod rank_inference; diff --git a/crates/onnx-ir/src/op_configuration.rs b/crates/onnx-ir/src/op_configuration.rs new file mode 100644 index 0000000000..fee98f68d8 --- /dev/null +++ b/crates/onnx-ir/src/op_configuration.rs @@ -0,0 +1,1737 @@ +// TODO Move op_configuration.rs from burn-import to onnx-ir #3091 +// See https://github.com/tracel-ai/burn/issues/3091 + +use burn::nn::{ + BatchNormConfig, DropoutConfig, LayerNormConfig, LinearConfig, PaddingConfig2d, + PaddingConfig3d, + conv::{Conv2dConfig, Conv3dConfig, ConvTranspose2dConfig, ConvTranspose3dConfig}, + pool::{AvgPool2dConfig, MaxPool2dConfig}, +}; + +use crate::ir::{ArgType, AttributeValue, Data, ElementType, Node, TensorData}; + +// use burn_import::burn::node::{ +// expand::ExpandShape, pad::PadConfig, split::SplitConfig, tile::TileConfig, top_k::TopKConfig, +// trilu::TriluConfig, unsqueeze::UnsqueezeAxes, +// }; + +// Conv1dConfig implementation moved to onnx-ir::node::conv1d + +/// Create a Conv2dConfig from the attributes of the node +pub fn conv2d_config(curr: &Node) -> Conv2dConfig { + let mut kernel_shape = Vec::new(); // TODO default inferred from weight tensor per spec + let mut strides = vec![1, 1]; + let mut pads = vec![0, 0, 0, 0]; + let mut dilations = vec![1, 1]; + let mut group: usize = 1; + + let weight_shape = curr.inputs[1] + .value + .as_ref() + .expect("Conv2d: weight tensor must be present") + .shape + .clone(); + + // check if the bias is present + let bias = curr.inputs.len() == 3; + + for (key, value) in curr.attrs.iter() { + match key.as_str() { + "kernel_shape" => kernel_shape = value.clone().into_i64s(), + "strides" => strides = value.clone().into_i64s(), + "pads" => pads = value.clone().into_i64s(), + "dilations" => dilations = value.clone().into_i64s(), + "group" => group = value.clone().into_i64() as usize, + _ => {} + } + } + + // the channels are inverted in the weight tensor + let channels_in = weight_shape[1] * group; + let channels_out = weight_shape[0]; + + let padding = padding_config_2d(&pads); + + Conv2dConfig::new( + [channels_in, channels_out], + [kernel_shape[0] as usize, kernel_shape[1] as usize], + ) + .with_stride([strides[0] as usize, strides[1] as usize]) + .with_dilation([dilations[0] as usize, dilations[1] as usize]) + .with_groups(group) + .with_bias(bias) + .with_padding(padding) +} + +/// Create a Conv3dConfig from the attributes of the node +pub fn conv3d_config(curr: &Node) -> Conv3dConfig { + let mut kernel_shape = Vec::new(); // TODO default inferred from weight tensor per spec + let mut strides = vec![1, 1, 1]; + let mut pads = vec![0, 0, 0, 0, 0, 0]; + let mut dilations = vec![1, 1, 1]; + let mut group: usize = 1; + + let weight_shape = curr.inputs[1] + .value + .as_ref() + .expect("Conv3d: weight tensor must be present") + .shape + .clone(); + + // check if the bias is present + let bias = curr.inputs.len() == 3; + + for (key, value) in curr.attrs.iter() { + match key.as_str() { + "kernel_shape" => kernel_shape = value.clone().into_i64s(), + "strides" => strides = value.clone().into_i64s(), + "pads" => pads = value.clone().into_i64s(), + "dilations" => dilations = value.clone().into_i64s(), + "group" => group = value.clone().into_i64() as usize, + _ => {} + } + } + + // the channels are inverted in the weight tensor + let channels_in = weight_shape[1] * group; + let channels_out = weight_shape[0]; + + let padding = padding_config_3d(&pads); + + Conv3dConfig::new( + [channels_in, channels_out], + [ + kernel_shape[0] as usize, + kernel_shape[1] as usize, + kernel_shape[2] as usize, + ], + ) + .with_stride([ + strides[0] as usize, + strides[1] as usize, + strides[2] as usize, + ]) + .with_dilation([ + dilations[0] as usize, + dilations[1] as usize, + dilations[2] as usize, + ]) + .with_groups(group) + .with_bias(bias) + .with_padding(padding) +} + +// MaxPool1dConfig implementation moved to onnx-ir::node::max_pool1d + +/// Create a MaxPool2dConfig from the attributes of the node +pub fn max_pool2d_config(curr: &Node) -> MaxPool2dConfig { + let mut kernel_shape = Vec::new(); + let mut strides = vec![1, 1]; + let mut pads = vec![0, 0, 0, 0]; + let mut dilations = vec![1, 1]; + + for (key, value) in curr.attrs.iter() { + match key.as_str() { + "kernel_shape" => kernel_shape = value.clone().into_i64s(), + "strides" => strides = value.clone().into_i64s(), + "pads" => pads = value.clone().into_i64s(), + "dilations" => dilations = value.clone().into_i64s(), + _ => {} + } + } + + let padding = padding_config_2d(&pads); + + MaxPool2dConfig::new([kernel_shape[0] as usize, kernel_shape[1] as usize]) + .with_strides([strides[0] as usize, strides[1] as usize]) + .with_padding(padding) + .with_dilation([dilations[0] as usize, dilations[1] as usize]) +} + +// ConvTranspose1dConfig implementation moved to onnx-ir::node::conv_transpose1d + +pub fn conv_transpose2d_config(curr: &Node) -> ConvTranspose2dConfig { + let mut attrs = curr.attrs.clone(); + let kernel_shape = attrs + .remove("kernel_shape") + .map(AttributeValue::into_i64s) + .unwrap_or_default(); + let stride = attrs + .remove("strides") + .map(AttributeValue::into_i64s) + .unwrap_or_else(|| vec![1, 1]); + let pads = attrs + .remove("pads") + .map(AttributeValue::into_i64s) + .unwrap_or_else(|| vec![0, 0, 0, 0]); + let dilations = attrs + .remove("dilations") + .map(AttributeValue::into_i64s) + .unwrap_or_else(|| vec![1, 1]); + let group = attrs + .remove("group") + .map(AttributeValue::into_i64) + .unwrap_or(1) as usize; + let output_padding = attrs + .remove("output_padding") + .map(AttributeValue::into_i64s) + .unwrap_or_else(|| vec![0, 0]); + + // Trick with remove + empty check is simplest way to not forget some attribute for runtime: + if !attrs.is_empty() { + panic!("Not all attributes are used: {attrs:?}"); + } + // Check the pads are symmetric. + let [left, top, right, bottom] = [pads[0], pads[1], pads[2], pads[3]]; + if left < 0 || top < 0 || right < 0 || bottom < 0 { + panic!("Negative pad values are not supported"); + } else if (left != right) || (top != bottom) { + panic!("Asymmetric padding is not supported"); + } + + let weight_shape = curr.inputs[1] + .value + .as_ref() + .expect("ConvTranspose2d: weight tensor must be present") + .shape + .clone(); + + // check if the bias is present + let bias = curr.inputs.len() == 3; + + // the channels are inverted in the weight tensor + let channels: [usize; 2] = [weight_shape[1] * group, weight_shape[0]]; + + ConvTranspose2dConfig::new( + channels, + [kernel_shape[0] as usize, kernel_shape[1] as usize], + ) + .with_stride([stride[0] as usize, stride[1] as usize]) + .with_padding([pads[0] as usize, pads[1] as usize]) + .with_dilation([dilations[0] as usize, dilations[1] as usize]) + .with_padding_out([output_padding[0] as usize, output_padding[1] as usize]) + .with_groups(group) + .with_bias(bias) +} + +pub fn conv_transpose3d_config(curr: &Node) -> ConvTranspose3dConfig { + let mut attrs = curr.attrs.clone(); + let kernel_shape = attrs + .remove("kernel_shape") + .map(AttributeValue::into_i64s) + .unwrap_or_default(); + let stride = attrs + .remove("strides") + .map(AttributeValue::into_i64s) + .unwrap_or_else(|| vec![1, 1, 1]); + let pads = attrs + .remove("pads") + .map(AttributeValue::into_i64s) + .unwrap_or_else(|| vec![0, 0, 0, 0, 0, 0]); + let dilations = attrs + .remove("dilations") + .map(AttributeValue::into_i64s) + .unwrap_or_else(|| vec![1, 1, 1]); + let group = attrs + .remove("group") + .map(AttributeValue::into_i64) + .unwrap_or(1) as usize; + let output_padding = attrs + .remove("output_padding") + .map(AttributeValue::into_i64s) + .unwrap_or_else(|| vec![0, 0, 0]); + + // Trick with remove + empty check is simplest way to not forget some attribute for runtime: + if !attrs.is_empty() { + panic!("Not all attributes are used: {attrs:?}"); + } + // Check the pads are symmetric. + let [left, top, front, right, bottom, back] = + [pads[0], pads[1], pads[2], pads[3], pads[4], pads[5]]; + + if left < 0 || top < 0 || front < 0 || right < 0 || bottom < 0 || back < 0 { + panic!("Negative pad values are not supported"); + } else if (left != right) || (top != bottom) || (front != back) { + panic!("Asymmetric padding is not supported"); + } + + let weight_shape = curr.inputs[1] + .value + .as_ref() + .expect("ConvTranspose3d: weight tensor must be present") + .shape + .clone(); + + // check if the bias is present + let bias = curr.inputs.len() == 3; + + // the channels are inverted in the weight tensor + let channels: [usize; 2] = [weight_shape[1] * group, weight_shape[0]]; + + ConvTranspose3dConfig::new( + channels, + [ + kernel_shape[0] as usize, + kernel_shape[1] as usize, + kernel_shape[2] as usize, + ], + ) + .with_stride([stride[0] as usize, stride[1] as usize, stride[2] as usize]) + .with_padding([pads[0] as usize, pads[1] as usize, pads[2] as usize]) + .with_dilation([ + dilations[0] as usize, + dilations[1] as usize, + dilations[2] as usize, + ]) + .with_padding_out([ + output_padding[0] as usize, + output_padding[1] as usize, + output_padding[2] as usize, + ]) + .with_groups(group) + .with_bias(bias) +} + +// AvgPool1dConfig implementation moved to onnx-ir::node::avg_pool1d +/// Create a AvgPool2dConfig from the attributes of the node +pub fn avg_pool2d_config(curr: &Node) -> AvgPool2dConfig { + let mut kernel_shape = Vec::new(); + let mut strides = vec![1, 1]; + let mut pads = vec![0, 0, 0, 0]; + let mut count_include_pad: i64 = 0; + let mut ceil_mode: i64 = 0; + + for (key, value) in curr.attrs.iter() { + match key.as_str() { + "kernel_shape" => kernel_shape = value.clone().into_i64s(), + "strides" => strides = value.clone().into_i64s(), + "pads" => pads = value.clone().into_i64s(), + "count_include_pad" => count_include_pad = value.clone().into_i64(), + "ceil_mode" => ceil_mode = value.clone().into_i64(), + _ => {} + } + } + + if ceil_mode == 1 { + panic!("ceil_mode is not supported"); + } + + let padding = padding_config_2d(&pads); + + AvgPool2dConfig::new([kernel_shape[0] as usize, kernel_shape[1] as usize]) + .with_strides([strides[0] as usize, strides[1] as usize]) + .with_padding(padding) + .with_count_include_pad(count_include_pad == 1) +} + +// pub fn expand_config(node: &Node) -> ExpandShape { +// match &node.inputs[1].ty { +// ArgType::Tensor(tensor) => { +// assert_eq!(tensor.rank, 1, "Expand: shape tensor must be 1D"); +// assert!( +// matches!(tensor.elem_type, ElementType::Int64), +// "Expand: shape tensor must have element type int64" +// ); +// } +// ArgType::Shape(_) => { +// // Shapes are always 1-D int64 data, so nothing to assert here +// } +// _ => panic!("Only tensor input is valid for shape"), +// } + +// match &node.inputs[1].value { +// Some(TensorData { +// data: Data::Int64s(shape), +// .. +// }) => ExpandShape::Static(shape.clone()), +// None => { +// // we were unable to statically determine the input value, so we'll need to fetch it at runtime +// ExpandShape::Runtime(crate::burn::Type::from(&node.inputs[1])) +// } +// _ => panic!( +// "Shape data type must be int64, is {:?}", +// &node.inputs[1].value +// ), +// } +// } + +/// Create a FlattenConfig from the attributes of the node +pub fn flatten_config(curr: &Node) -> usize { + // the begin dimension is the first dimension (Default: 1 per ONNX spec) + let mut axis: i64 = 1; + + // check if the node has only one input + if curr.inputs.len() != 1 { + panic!( + "Flatten: multiple inputs are not supported (got {:?})", + curr.inputs.len() + ); + } + + // extract the shape of the input tensor + let tensor = match curr.inputs.first().unwrap().clone().ty { + ArgType::Tensor(tensor) => tensor, + _ => panic!("Only tensor input is valid"), + }; + + // check if the input tensor has at least 2 dimensions + if tensor.rank < 2 { + panic!( + "Flatten: input tensor must have at least 2 dimensions (got {:?})", + tensor.rank + ); + } + + // extract the attributes + for (key, value) in curr.attrs.iter() { + match key.as_str() { + "axis" => axis = value.clone().into_i64(), + _ => {} + } + } + + // if beg_dim is negative, it is counted from the end + if axis < 0 { + axis += tensor.rank as i64; + } + + axis as usize +} + +/// Create a GatherConfig from the attributes of the node +pub fn gather_config(curr: &Node) -> usize { + // Default: 0 per ONNX spec + let mut dim: i64 = 0; + + // check if the node has only one input + if curr.inputs.len() != 2 { + panic!("Gather: index tensor must be present"); + } + + // extract the shape of the input tensor + let input_dim = match curr.inputs.first().unwrap().clone().ty { + ArgType::Tensor(tensor) => tensor.rank as i64, + ArgType::Shape(_shape) => 1, //Shape is always 1-D + other => panic!("Only tensor or shape input is valid, got {:?}", other), + }; + + // extract the attributes + for (key, value) in curr.attrs.iter() { + match key.as_str() { + "axis" => dim = value.clone().into_i64(), + _ => {} + } + } + + // if dim is negative, it is counted from the end + if dim < 0 { + dim += input_dim; + } + + dim as usize +} + +/// Create a LinearConfig from the attributes of the node +pub fn linear_config(node: &Node) -> LinearConfig { + if node.inputs.len() < 2 { + panic!("Linear: missing weight tensor"); + } + + let weight_shape = node.inputs[1] + .value + .as_ref() + .expect("Linear: weight tensor must be present") + .shape + .clone(); + + // check if the weight tensor has at least 2 dimensions + if weight_shape.len() < 2 { + panic!( + "Linear: weight tensor must have at least 2 dimensions (got {:?})", + weight_shape.len() + ); + } + + let (in_size, out_size) = (weight_shape[0], weight_shape[1]); + + // check if the bias is present + let bias = node.inputs.len() == 3 && node.inputs[2].value.is_some(); + + LinearConfig::new(in_size, out_size).with_bias(bias) +} + +/// Create a DropoutConfig from an attribute and state of the node +pub fn dropout_config(node: &Node) -> DropoutConfig { + // Opset 7 and older store probability as an attribute + if node.attrs.contains_key("ratio") { + let prob = node.attrs.get("ratio").unwrap().clone().into_f32(); + return DropoutConfig::new(prob as f64); + } + + if node.inputs.len() < 2 { + panic!("Dropout configuration must have at least 2 inputs"); + } + + let ratio = node.inputs[1] + .value + .clone() + .expect("Dropout ratio must be passed in the second input") + .data + .into_scalar(); + + let prob = match ratio { + Data::Float16(ratio) => f64::from(f32::from(ratio)), + Data::Float32(ratio) => ratio as f64, + Data::Float64(ratio) => ratio, + _ => panic!("Dropout ratio must be a float"), + }; + + DropoutConfig::new(prob) +} + +/// Create log_softmax config from the attributes of the node +pub fn log_softmax_config(node: &Node) -> usize { + // the axis is the last dimension (Default: 1 per ONNX spec) + let mut axis: i64 = -1; + + // check if the node has only one input + if node.inputs.len() != 1 { + panic!( + "LogSoftmax: multiple inputs are not supported (got {:?})", + node.inputs.len() + ); + } + + // extract the shape of the input tensor + let tensor = match node.inputs.first().unwrap().clone().ty { + ArgType::Tensor(tensor) => tensor, + _ => panic!("Only tensor input is valid"), + }; + + // extract the attributes + for (key, value) in node.attrs.iter() { + match key.as_str() { + "axis" => axis = value.clone().into_i64(), + _ => {} + } + } + + // if axis is negative, it is counted from the end + if axis < 0 { + axis += tensor.rank as i64; + } + + axis as usize +} + +/// Create softmax config from the attributes of the node +pub fn softmax_config(node: &Node) -> usize { + // the axis is the last dimension (Default: 1 per ONNX spec) + let mut axis: i64 = -1; + + // check if the node has only one input + if node.inputs.len() != 1 { + panic!( + "Softmax: multiple inputs are not supported (got {:?})", + node.inputs.len() + ); + } + + // extract the shape of the input tensor + let tensor = match node.inputs.first().unwrap().clone().ty { + ArgType::Tensor(tensor) => tensor, + _ => panic!("Only tensor input is valid"), + }; + + // extract the attributes + for (key, value) in node.attrs.iter() { + match key.as_str() { + "axis" => axis = value.clone().into_i64(), + _ => {} + } + } + + // if axis is negative, it is counted from the end + if axis < 0 { + axis += tensor.rank as i64; + } + + axis as usize +} + +/// Create argmax config from the attributes of the node +pub fn argmax_config(node: &Node) -> usize { + let mut axis: i64 = 0; + + // check if the node has only one input + if node.inputs.len() != 1 { + panic!( + "Argmax: multiple inputs are not supported (got {:?})", + node.inputs.len() + ); + } + + // extract the shape of the input tensor + let tensor = match node.inputs.first().unwrap().clone().ty { + ArgType::Tensor(tensor) => tensor, + _ => panic!("Only tensor input is valid"), + }; + + // extract the attributes + for (key, value) in node.attrs.iter() { + match key.as_str() { + "axis" => axis = value.clone().into_i64(), + "select_last_index" => { + // not all params are supported in burn + if value.clone().into_i64() != 0 { + log::warn!( + "only select_last_index=0 is supported for argmax in burn. Ignoring supplied value (got {:?})", + value + ); + } + } + "keepdims" => { + // not all params are supported in burn + if value.clone().into_i64() != 1 { + panic!( + "Only keepdims=1 is supported for argmax in burn (got {:?})", + value + ); + } + } + _ => {} + } + } + + // if axis is negative, it is counted from the end + if axis < 0 { + axis += tensor.rank as i64; + } + + axis as usize +} + +/// Create concat config from the attributes of the node +pub fn concat_config(node: &Node) -> usize { + // the axis is the last dimension (Default: 1 per ONNX spec) + let mut axis: i64 = 1; + + // extract the shape of the input tensor + let tensor = match node.inputs.first().unwrap().clone().ty { + ArgType::Tensor(tensor) => tensor, + _ => panic!("Only tensor input is valid"), + }; + + // extract the attributes + for (key, value) in node.attrs.iter() { + match key.as_str() { + "axis" => axis = value.clone().into_i64(), + _ => {} + } + } + + // if axis is negative, it is counted from the end + if axis < 0 { + axis += tensor.rank as i64; + } + + axis as usize +} + +/// Create a BatchNormConfig from the attributes of the node +pub fn batch_norm_config(node: &Node) -> BatchNormConfig { + let weight_shape = node.inputs[1] + .value + .as_ref() + .expect("BatchNorm: weight tensor must be present") + .shape + .clone(); + + let num_features = weight_shape[0]; + + let mut epsilon = 0f32; + let mut momentum = 0f32; + + for (key, value) in node.attrs.iter() { + match key.as_str() { + "momentum" => momentum = value.clone().into_f32(), + "epsilon" => epsilon = value.clone().into_f32(), + _ => {} + } + } + + BatchNormConfig::new(num_features) + .with_epsilon(epsilon as f64) + .with_momentum(momentum as f64) +} + +/// Create a LayerNormConfig from the attributes of the node +pub fn layer_norm_config(node: &Node) -> (LayerNormConfig, bool) { + let weight_shape = node.inputs[1] + .value + .as_ref() + .expect("LayerNorm: weight tensor must be present") + .shape + .clone(); + + let num_features = weight_shape[0]; + + // When `stash_type` is `1` (default), perform operations in 32-bit float and + // cast the results back to original dtype + let mut stash_type = 1; + let mut axis = -1; + let mut epsilon = 1e-5; + + for (key, value) in node.attrs.iter() { + match key.as_str() { + "axis" => axis = value.clone().into_i64(), + "epsilon" => epsilon = value.clone().into_f32(), + "stash_type" => stash_type = value.clone().into_i64(), + _ => {} + } + } + + if axis != -1 && axis != weight_shape.len() as i64 - 1 { + panic!("LayerNorm: normalization is only supported on the last axis right now") + } + + ( + LayerNormConfig::new(num_features).with_epsilon(epsilon as f64), + stash_type == 1, + ) +} + +// /// Create a TileConfig from the attributes of the node +// pub fn tile_config(node: &Node) -> TileConfig { +// let repeat = node +// .inputs +// .get(1) +// .map(|input| { +// if let Some(TensorData { data, .. }) = &input.value { +// data.clone() +// .into_i64s() +// .iter() +// .map(|&x| x as usize) +// .collect() +// } else { +// vec![] +// } +// }) +// .unwrap_or_default(); +// TileConfig::new(repeat) +// } + +// /// Create a TopKConfig from the attributes of the node. +// pub fn top_k_config(node: &Node) -> TopKConfig { +// // extract the shape of the input data tensor +// let data_tensor = match node.inputs.first().unwrap().clone().ty { +// ArgType::Tensor(tensor) => tensor, +// _ => panic!("Only tensor input is valid"), +// }; + +// let k = match node.inputs.get(1) { +// Some(k_tensor) => k_tensor +// .clone() +// .value +// .expect("TopK: only constant 'k' tensor is currently supported") +// .data +// .into_i64s()[0], +// _ => node +// .attrs +// .get("k") +// .expect("TopK: number of top elements 'k' is missing") +// .clone() +// .into_i64(), +// }; + +// let mut axis = match node.attrs.get("axis") { +// Some(axis) => axis.clone().into_i64(), +// None => -1, +// }; + +// // if axis is negative, it is counted from the end +// if axis < 0 { +// axis += data_tensor.rank as i64; +// } + +// if let Some(largest) = node.attrs.get("largest") { +// if largest.clone().into_i64() != 1 { +// unimplemented!("TopK: only largest elements is supported") +// } +// }; + +// if let Some(sorted) = node.attrs.get("sorted") { +// if sorted.clone().into_i64() != 1 { +// unimplemented!("TopK: only sorted elements is supported") +// } +// }; + +// TopKConfig::new(axis as usize, k as usize) +// } + +// /// Create a TriluConfig from the attributes of the node +// pub fn trilu_config(node: &Node) -> TriluConfig { +// let mut upper = true; +// let mut diagonal = 0; +// for (key, value) in node.attrs.iter() { +// match key.as_str() { +// "upper" => upper = value.clone().into_i64() != 0, +// _ => {} +// } +// } +// // The second input of the Trilu node is the diagonal value, coming from a constant node +// if let Some(diagonal_arg) = node.inputs.get(1) { +// if let Some(TensorData { +// data: Data::Int64(diagonal_val), +// .. +// }) = &diagonal_arg.value +// { +// diagonal = *diagonal_val; +// } +// } +// TriluConfig::new(upper, diagonal) +// } + +// /// Create a PadConfig from the attributes of the node +// pub fn pad_config(node: &Node) -> PadConfig { +// fn get_pads_input(node: &Node) -> Vec { +// match &node.inputs[1].value { +// Some(TensorData { data, .. }) => data.clone().into_i64s(), +// _ => Vec::new(), +// } +// } +// fn get_pads(node: &Node) -> Vec { +// if node.inputs.is_empty() { +// panic!("Pad: must provide data as input") +// } +// if node.inputs.len() >= 4 { +// panic!("Pad: axes input is not supported") +// } + +// let input_dim = match &node.inputs.first().unwrap().ty { +// ArgType::Tensor(tensor) => tensor.rank, +// _ => panic!("Pad: Only tensor input is valid"), +// }; + +// //TODO : handle more possible attributes +// let mut pads: Vec = get_pads_input(node) +// .into_iter() +// .map(|x| x as usize) +// .collect(); + +// for (key, value) in node.attrs.iter() { +// match key.as_str() { +// "pads" => { +// pads = value +// .clone() +// .into_i64s() +// .iter() +// .map(|&x| { +// if x < 0 { +// panic!("Pad: Negative pad is not supported"); +// } +// x as usize +// }) +// .collect() +// } +// "mode" => { +// let mode = value.clone().into_string(); +// if mode != "constant" { +// panic!("only constant mode is supported, given mode is {}", mode); +// } +// } + +// _ => {} +// } +// } + +// if pads.is_empty() { +// panic!("Pad: pads should be given as attribute or as input"); +// } + +// if pads.len() != input_dim * 2 { +// panic!("Pad: pads should be a 1D tensor of shape [2 * num_axes]"); +// } +// // TODO: Burn's pad should support 1D tensor +// if input_dim < 2 { +// panic!("Pad: input tensor should be rank 2 or higher"); +// } + +// let left_index = input_dim - 1; +// let top_index = input_dim - 2; +// let right_index = pads.len() - 1; +// let bottom_index = pads.len() - 2; +// let index_list = [left_index, top_index, right_index, bottom_index]; + +// for (index, &item) in pads.iter().enumerate() { +// if !index_list.contains(&index) && item != 0 { +// panic!( +// "Pad: padding will only be applied to the last two dimensions but found non zero padding for other dimensions" +// ); +// } +// } + +// let left = pads[left_index]; +// let top = pads[top_index]; +// let right = pads[right_index]; +// let bottom = pads[bottom_index]; +// vec![left, right, top, bottom] +// } +// fn get_constant_value(node: &Node) -> f32 { +// // TODO: support int, boolean +// let mut constant_value = node.inputs +// .get(2) +// .and_then(|input| match &input.value.as_ref().expect("Value input must be present").data { +// Data::Float16s(constant_value) => { +// constant_value.first().map(|&f| f32::from(f)) +// } +// Data::Float32s(constant_value) => { +// constant_value.first().copied() +// } +// Data::Float64s(constant_value) => { +// constant_value.first().map(|&f| f as f32) +// } +// Data::Float16(constant_value) => Some(f32::from(*constant_value)), +// Data::Float32(constant_value) => Some(*constant_value), +// Data::Float64(constant_value) => Some(*constant_value as f32), +// _ => panic!("Pad: only float values are currently supported for constant value, submit an issue on github"), +// }) +// .unwrap_or(0.0); + +// if node.attrs.contains_key("value") { +// constant_value = node.attrs.get("value").map(|value| match value { +// AttributeValue::Float32(value) => *value, +// _ => panic!("Pad: only float32 values are currently supported for constant value as attribute, submit an issue on github"), +// }).expect("constant_value should have had a value now"); +// } +// constant_value +// } + +// let pads = get_pads(node); +// let constant_value = get_constant_value(node); + +// PadConfig::new(pads, constant_value) +// } + +// padding_config_1d moved to onnx-ir::node::conv1d + +/// Calculate the padding configuration for a 2D operations such as Convolution and Pooling. +/// +/// # Arguments +/// +/// * `pads` - The padding values +/// +/// # Panics +/// +/// * If the padding is negative +/// * If the padding is not symmetric +/// +/// # Returns +/// +/// * The padding configuration +/// +/// # Remarks +/// +/// This function is used when the padding is specified as a list of integers, +/// and not used when the padding is specified as a string, e.g. "SAME_UPPER". +fn padding_config_2d(pads: &[i64]) -> PaddingConfig2d { + let [left, top, right, bottom] = [pads[0], pads[1], pads[2], pads[3]]; + + if left < 0 || top < 0 || right < 0 || bottom < 0 { + panic!("Negative pad values are not supported"); + } else if (left != right) || (top != bottom) { + panic!("Asymmetric padding is not supported"); + } else if left == 0 && top == 0 && right == 0 && bottom == 0 { + // i.e [0, 0, 0, 0] + PaddingConfig2d::Valid + } else if left == right && top == bottom { + // i.e [2, 3, 2, 3] + PaddingConfig2d::Explicit(left as usize, top as usize) + } else { + // Unaccounted for padding configuration + panic!("Padding configuration ({:?}) not supported", pads); + } +} + +/// Calculate the padding configuration for a 3D operations such as Convolution and Pooling. +/// +/// # Arguments +/// +/// * `pads` - The padding values +/// +/// # Panics +/// +/// * If the padding is negative +/// * If the padding is not symmetric +/// +/// # Returns +/// +/// * The padding configuration +/// +/// # Remarks +/// +/// This function is used when the padding is specified as a list of integers, +/// and not used when the padding is specified as a string, e.g. "SAME_UPPER". +fn padding_config_3d(pads: &[i64]) -> PaddingConfig3d { + let [left, top, front, right, bottom, back] = + [pads[0], pads[1], pads[2], pads[3], pads[4], pads[5]]; + + if left < 0 || top < 0 || front < 0 || right < 0 || bottom < 0 || back < 0 { + panic!("Negative pad values are not supported"); + } else if (left != right) || (top != bottom) || (front != back) { + panic!("Asymmetric padding is not supported"); + } else if left == 0 && top == 0 && front == 0 && right == 0 && bottom == 0 && back == 0 { + // i.e [0, 0, 0, 0] + PaddingConfig3d::Valid + } else if left == right && top == bottom && front == back { + // i.e [2, 3, 2, 3] + PaddingConfig3d::Explicit(left as usize, top as usize, front as usize) + } else { + // Unaccounted for padding configuration + panic!("Padding configuration ({:?}) not supported", pads); + } +} + +// Create a LeakyReluConfig from the alpha attribute of the node +pub fn leaky_relu_config(node: &Node) -> f64 { + let mut alpha = 0.01; + + for (key, value) in node.attrs.iter() { + match key.as_str() { + "alpha" => alpha = value.clone().into_f32() as f64, + _ => {} + } + } + + alpha +} + +// Create a HardSigmoidConfig from the alpha and beta attributes of the node +pub fn hard_sigmoid_config(node: &Node) -> (f64, f64) { + let mut alpha = 0.2; + let mut beta = 0.5; + + for (key, value) in node.attrs.iter() { + match key.as_str() { + "alpha" => alpha = value.clone().into_f32() as f64, + "beta" => beta = value.clone().into_f32() as f64, + _ => {} + } + } + + (alpha, beta) +} + +pub fn reshape_config(node: &Node) -> Vec { + let mut allowzero = 0; + + for (key, value) in node.attrs.iter() { + match key.as_str() { + "allowzero" => allowzero = value.clone().into_i64(), + _ => {} + } + } + + // Burn does not support zero size shape (0 means false in ONNX) + // (see https://onnx.ai/onnx/operators/onnx__Reshape.html#attributes) + if allowzero != 0 { + panic!("Zero shape size is not supported"); + } + + // TODO: check "shape" attribute + if node.inputs.len() != 2 || node.inputs[1].value.is_none() { + panic!("Reshape: shape tensor must be present for {:?}", node); + } + + match &node.inputs[1].value { + Some(TensorData { data, shape, .. }) => { + assert_eq!(shape.len(), 1, "Reshape: shape tensor must be 1D"); + data.clone().into_i64s() + } + _ => panic!("Only tensor input is valid for shape"), + } +} + +pub fn resize_config(node: &Node) -> (String, Vec, Vec) { + let mut mode: String = "".to_string(); + + let mut scales: Vec; + let mut sizes: Vec; + + let input = if let ArgType::Tensor(tensor) = &node + .inputs + .first() + .expect("Resize: Input tensor must be present") + .ty + { + tensor + } else { + panic!("Resize: input must be a tensor") + }; + + // Note: we are ignoring some attributes because results are approximately the same + // and we are not supporting all the attributes of the Resize operator. + // However, some attributes are important to be checked and we are checking + // against the default values of the attributes. + // TODO revisit this when we have more Resize operators in the model + for (key, value) in node.attrs.iter() { + match key.as_str() { + "antialias" => assert_eq!( + value.clone().into_i32(), + 0, + "Resize: antialias other than 0 is not supported" + ), + "axes" => panic!("Resize: custom axes attribute is not supported"), + "coordinate_transformation_mode" => { + log::warn!("Resize: coordinate_transformation_mode is ignored") + } + + "cubic_coeff_a" => log::warn!("Resize: cubic_coeff_a is ignored"), + "exclude_outside" => assert_eq!( + value.clone().into_i32(), + 0, + "Resize: exclude_outside other than 0 is not supported" + ), + "extrapolation_value" => assert_eq!( + value.clone().into_f32(), + 0.0, + "Resize: extrapolation_value other than 0.0 is not supported" + ), + "keep_aspect_ratio_policy" => { + assert_eq!( + value.clone().into_string().to_lowercase(), + "stretch", + "Resize: keep_aspect_ratio_policy other than 'stretch' is not supported" + ) + } + "mode" => mode = value.clone().into_string().to_lowercase(), + "nearest_mode" => log::warn!("Resize: nearest_mode is ignored"), + + _ => {} + } + } + + let roi: Vec = node + .inputs + .get(1) + .map(|input| { + if let Some(TensorData { data, .. }) = &input.value { + data.clone().into_f32s() + } else { + vec![] + } + }) + .unwrap_or_default(); + + scales = node + .inputs + .get(2) + .map(|input| { + if let Some(TensorData { data, .. }) = &input.value { + data.clone().into_f32s() + } else { + vec![] + } + }) + .unwrap_or_default(); + + sizes = node + .inputs + .get(3) + .map(|input| { + if let Some(TensorData { data, .. }) = &input.value { + data.clone() + .into_i64s() + .iter() + .map(|&x| x as usize) + .collect() + } else { + vec![] + } + }) + .unwrap_or_default(); + + if mode.is_empty() { + panic!("Resize: mode attribute is required") + } + + if !roi.is_empty() { + panic!("Resize: roi input is not supported") + } + + if scales.is_empty() && sizes.is_empty() { + panic!("Resize: either scales or sizes input is required") + } + + if !scales.is_empty() { + assert!(scales.len() == input.rank); + // ignore the fist two items from scales + // because they are the batch and channel dimensions + scales = scales.iter().skip(2).cloned().collect(); + } + + if !sizes.is_empty() { + assert!(sizes.len() == input.rank); + // ignore the fist two items from sizes + // because they are the batch and channel dimensions + sizes = sizes.iter().skip(2).cloned().collect(); + } + + (mode, scales, sizes) +} + +// //Note this function should only execute if the second input is a constant +// //if it wasn't and the output shape was known, unsqueeze has been remapped to reshape +// pub fn unsqueeze_config(node: &Node) -> UnsqueezeAxes { +// // Check if axes attribute exists +// for (key, value) in node.attrs.iter() { +// match key.as_str() { +// "axes" => return UnsqueezeAxes::Static(value.clone().into_i64s()), +// _ => {} +// } +// } + +// assert!( +// !node.inputs.is_empty(), +// "Unsqueeze: axes tensor must be present" +// ); + +// let input_value = &node.inputs[1]; + +// match &node.inputs[1].ty { +// ArgType::Tensor(tensor) => { +// assert_eq!(tensor.rank, 1, "Unsqueeze: axes tensor must be 1D"); +// if let Some(TensorData { +// data: Data::Int64s(shape), +// .. +// }) = input_value.value.as_ref() +// { +// UnsqueezeAxes::Static(shape.clone()) +// } else { +// UnsqueezeAxes::Runtime(crate::burn::Type::from(&node.inputs[1])) +// } +// } +// _ => panic!("Arg for unsqueeze must be tensor or scalar"), +// } +// } + +pub fn clip_config(node: &Node) -> (Option, Option) { + let mut min_result: Option = None; + let mut max_result: Option = None; + + // For Clip Opset 6+ , the min and max values are attributes + for (key, value) in node.attrs.iter() { + match key.as_str() { + "min" => { + let min = value.clone().into_f32() as f64; + min_result = Some(min); + } + "max" => { + let max = value.clone().into_f32(); + max_result = Some(max as f64); + } + _ => {} + } + } + + // For Clip Opset 11+ , the min and max values are inputs + // Get the min and max values from the input values + if min_result.is_none() && max_result.is_none() { + let min = &node.inputs[1].value; + let max = &node.inputs[2].value; + + if min_result.is_none() && min.is_some() { + let min = min.clone().unwrap().data.into_scalar(); + min_result = match min { + Data::Float16(min) => Some(f32::from(min) as f64), + Data::Float32(min) => Some(min as f64), + Data::Float64(min) => Some(min), + _ => panic!("Clip: only float min is supported"), + }; + } + + if max_result.is_none() && max.is_some() { + let max = max.clone().unwrap().data.into_scalar(); + max_result = match max { + Data::Float16(max) => Some(f32::from(max) as f64), + Data::Float32(max) => Some(max as f64), + Data::Float64(max) => Some(max), + _ => panic!("Clip: only float max is supported"), + }; + } + } + + if min_result.is_none() && max_result.is_none() { + panic!("Clip: min and max values must be either attributes or inputs"); + } + + (min_result, max_result) +} + +pub fn reduce_max_config(node: &Node) -> Option { + let mut axes = Vec::new(); + let mut keepdims = 1; + + let tensor = match node.inputs.first().unwrap().clone().ty { + ArgType::Tensor(tensor) => tensor, + _ => panic!("Only tensor input is valid"), + }; + + // Extract the attributes + for (key, value) in node.attrs.iter() { + match key.as_str() { + "axes" => axes = value.clone().into_i64s(), + "keepdims" => keepdims = value.clone().into_i64(), + _ => {} + } + } + + if axes.len() > 1 { + panic!("ReduceMax: reducing on multiple dimensions is not supported") + } + + if axes.is_empty() && keepdims == 1 { + panic!("ReduceMax: axes must be provided with keepdims") + } + + if !axes.is_empty() && keepdims == 0 { + // Not supported in Burn + panic!("ReduceMax: the reduce operation must preserve the reduced dimension") + } + + if axes.is_empty() { + None + } else { + let mut dim = axes[0]; + + if dim < 0 { + // Accepted range is [-r, r-1] where r = rank(data) but Burn only supports positive dim + dim += tensor.rank as i64; + } + Some(dim as usize) + } +} + +pub fn reduce_min_config(node: &Node) -> Option { + let mut axes = Vec::new(); + let mut keepdims = 1; + + let tensor = match node.inputs.first().unwrap().clone().ty { + ArgType::Tensor(tensor) => tensor, + _ => panic!("Only tensor input is valid"), + }; + + // Extract the attributes + for (key, value) in node.attrs.iter() { + match key.as_str() { + "axes" => axes = value.clone().into_i64s(), + "keepdims" => keepdims = value.clone().into_i64(), + _ => {} + } + } + + if axes.len() > 1 { + panic!("ReduceMin: reducing on multiple dimensions is not supported") + } + + if axes.is_empty() && keepdims == 1 { + panic!("ReduceMin: axes must be provided with keepdims") + } + + if !axes.is_empty() && keepdims == 0 { + panic!("ReduceMin: the reduce operation must preserve the reduced dimension") + } + + if axes.is_empty() { + None + } else { + let mut dim = axes[0]; + + if dim < 0 { + dim += tensor.rank as i64; + } + Some(dim as usize) + } +} + +pub fn reduce_mean_config(node: &Node) -> Option { + let mut axes = Vec::new(); + let mut keepdims = 1; + + let tensor = match node.inputs.first().unwrap().clone().ty { + ArgType::Tensor(tensor) => tensor, + _ => panic!("Only tensor input is valid"), + }; + + // Extract the attributes + for (key, value) in node.attrs.iter() { + match key.as_str() { + "axes" => axes = value.clone().into_i64s(), + "keepdims" => keepdims = value.clone().into_i64(), + _ => {} + } + } + + if axes.len() > 1 { + panic!("ReduceMean: reducing on multiple dimensions is not supported") + } + + if axes.is_empty() && keepdims == 1 { + panic!("ReduceMean: axes must be provided with keepdims") + } + + if !axes.is_empty() && keepdims == 0 { + // Not supported in Burn + panic!("ReduceMean: the reduce operation must preserve the reduced dimension") + } + + if axes.is_empty() { + None + } else { + let mut dim = axes[0]; + + if dim < 0 { + // Accepted range is [-r, r-1] where r = rank(data) but Burn only supports positive dim + dim += tensor.rank as i64; + } + Some(dim as usize) + } +} + +pub fn reduce_prod_config(node: &Node) -> Option { + let mut axes = Vec::new(); + let mut keepdims = 1; + + let tensor = match node.inputs.first().unwrap().clone().ty { + ArgType::Tensor(tensor) => tensor, + _ => panic!("Only tensor input is valid"), + }; + + // Extract the attributes + for (key, value) in node.attrs.iter() { + match key.as_str() { + "axes" => axes = value.clone().into_i64s(), + "keepdims" => keepdims = value.clone().into_i64(), + // TODO: handle noop_with_empty_axes (opset 18) + _ => {} + } + } + + if axes.len() > 1 { + panic!("ReduceProd: reducing on multiple dimensions is not supported") + } + + if axes.is_empty() && keepdims == 1 { + panic!("ReduceProd: axes must be provided with keepdims") + } + + if !axes.is_empty() && keepdims == 0 { + // Not supported in Burn + panic!("ReduceProd: the reduce operation must preserve the reduced dimension") + } + + if axes.is_empty() { + None + } else { + let mut dim = axes[0]; + + if dim < 0 { + // Accepted range is [-r, r-1] where r = rank(data) but Burn only supports positive dim + dim += tensor.rank as i64; + } + Some(dim as usize) + } +} + +pub fn reduce_sum_config(node: &Node) -> Option { + let mut axes = Vec::new(); + let mut keepdims = 1; + + let tensor = match node.inputs.first().unwrap().clone().ty { + ArgType::Tensor(tensor) => tensor, + _ => panic!("Only tensor input is valid"), + }; + + // Extract the attributes + for (key, value) in node.attrs.iter() { + match key.as_str() { + "keepdims" => keepdims = value.clone().into_i64(), + "axes" => axes = value.clone().into_i64s(), + // TODO: handle noop_with_empty_axes + _ => {} + } + } + + // TODO: Handle case where axes are passed in. Will require its own ReduceSumNode instead of a UnaryNode. + if let Some(value) = node + .inputs + .get(1) + .and_then(|argument| argument.value.as_ref()) + { + axes = value.clone().data.into_i64s(); + } + + if axes.len() > 1 { + panic!("ReduceMean: reducing on multiple dimensions is not supported") + } + + if axes.is_empty() && keepdims == 1 { + panic!("ReduceMean: axes must be provided with keepdims") + } + + if !axes.is_empty() && keepdims == 0 { + // Not supported in Burn + panic!("ReduceMean: the reduce operation must preserve the reduced dimension") + } + + if axes.is_empty() { + None + } else { + let mut dim = axes[0]; + + if dim < 0 { + // Accepted range is [-r, r-1] where r = rank(data) but Burn only supports positive dim + dim += tensor.rank as i64; + } + Some(dim as usize) + } +} + +pub fn shape_config(curr: &Node) -> (usize, usize) { + if curr.inputs.len() != 1 { + panic!( + "Shape: multiple inputs are not supported (got {:?})", + curr.inputs.len() + ); + } + + // Extract the shape of the input tensor + let tensor = match curr.inputs.first().unwrap().clone().ty { + ArgType::Tensor(tensor) => tensor, + _ => panic!("Only tensor input is valid"), + }; + + // Default: all axes up to the last one (included) + let mut start_dim: i64 = 0; + let mut end_dim: i64 = tensor.rank as i64; + + // Extract the attributes + for (key, value) in curr.attrs.iter() { + match key.as_str() { + "start" => start_dim = value.clone().into_i64(), + "end" => end_dim = value.clone().into_i64(), + _ => {} + } + } + + // If dim is negative, it is counted from the end + if start_dim < 0 { + start_dim += tensor.rank as i64; + } + if end_dim < 0 { + end_dim += tensor.rank as i64; + } + + (start_dim as usize, end_dim as usize) +} + +pub fn transpose_config(curr: &Node) -> Vec { + if curr.inputs.len() != 1 { + panic!( + "Transpose: multiple inputs are not supported (got {:?})", + curr.inputs.len() + ); + } + + // Extract the shape of the input tensor + let tensor = match curr.inputs.first().unwrap().clone().ty { + ArgType::Tensor(tensor) => tensor, + _ => panic!("Only tensor input is valid"), + }; + + // Default: reverse the dimensions + let mut perm = (0..tensor.rank as i64).rev().collect::>(); + + if let Some(axes) = curr.attrs.get("perm") { + perm = axes.clone().into_i64s(); + } + + perm +} + +pub fn squeeze_config(curr: &Node) -> Vec { + let axes = curr + .attrs + .iter() + .filter_map(|(key, value)| { + if key == "axes" { + Some(value.clone().into_i64s()) + } else { + None + } + }) + .next() + .unwrap_or_else(Vec::new); + + match curr.inputs.first().unwrap().clone().ty { + ArgType::Tensor(tensor) => tensor, + _ => panic!("Only tensor input is valid"), + }; + + axes +} +// pub fn split_config(node: &Node) -> SplitConfig { +// // Initialize the axis to split along (default is 0 as per ONNX specification) +// let mut axis: i64 = 0; +// // Holds the uniform split size if calculated or provided +// let mut split_size: Option = None; +// // Holds the custom split sizes if provided as input +// let mut split_sizes: Option> = None; + +// // Extract the input tensor type to determine rank and shape +// let tensor = match node.inputs.first().unwrap().ty { +// ArgType::Tensor(ref tensor) => tensor, +// _ => panic!("Split: Input must be a valid tensor"), +// }; + +// // Optionally store the number of outputs if provided as an attribute +// let mut num_outputs: Option = None; + +// // Iterate through node attributes to extract relevant values +// for (key, value) in node.attrs.iter() { +// match key.as_str() { +// "axis" => axis = value.clone().into_i64(), +// "num_outputs" => num_outputs = Some(value.clone().into_i64() as usize), +// _ => {} +// } +// } + +// // Handle the case when num_outputs is provided to calculate uniform split size +// if let Some(num_outputs) = num_outputs { +// if num_outputs == 0 { +// panic!("Split: 'num_outputs' must be a positive value greater than zero"); +// } + +// let dim_size = tensor +// .static_shape +// .as_ref() +// .expect("Split: Static shape must be known to calculate split size")[axis as usize]; + +// // Calculate the split size considering any remainder for non-evenly divisible dimensions +// let calculated_split_size = +// dim_size / (num_outputs - (dim_size % num_outputs != 0) as usize); + +// if calculated_split_size == 0 { +// panic!( +// "Split: Calculated split size is zero. Please ensure 'num_outputs' is valid for the dimension size" +// ); +// } + +// // Assign the calculated split size +// split_size = Some(calculated_split_size); +// } + +// // Adjust axis if negative to count from the end as per ONNX spec +// if axis < 0 { +// axis += tensor.rank as i64; +// } + +// // Check for custom split sizes provided as a second input +// if node.inputs.len() > 1 && node.inputs[1].value.is_some() { +// let sizes = node.inputs[1] +// .value +// .as_ref() +// .unwrap() +// .data +// .clone() +// .into_usizes(); + +// if !sizes.is_empty() { +// split_sizes = Some(sizes); +// } +// } + +// // Ensure that only one of 'split_sizes' or 'num_outputs' is specified +// if split_sizes.is_some() && split_size.is_some() { +// panic!( +// "Split: Cannot specify both 'split' input and 'num_outputs' attribute simultaneously" +// ); +// } + +// // Infer split_size if neither custom split_sizes nor split_size is provided +// if split_sizes.is_none() && split_size.is_none() { +// let num_outputs = node.outputs.len(); +// let dim_size = tensor +// .static_shape +// .as_ref() +// .expect("Split: Static shape must be known to infer split size")[axis as usize]; + +// // Calculate inferred split size based on number of outputs +// let calculated_split_size = +// dim_size / (num_outputs - (dim_size % num_outputs != 0) as usize); + +// if calculated_split_size == 0 { +// panic!( +// "Split: Inferred split size is zero. Please ensure the number of outputs is valid for the dimension size" +// ); +// } + +// split_size = Some(calculated_split_size); +// } + +// // Return the configuration for splitting operation +// SplitConfig { +// axis: axis as usize, +// split_size, +// split_sizes, +// } +// } + +pub fn one_hot_config(curr: &Node) -> (usize, [f32; 2], i64) { + let depth = curr.inputs[1] + .value + .clone() + .expect("OneHot: Only constant depth is currently supported") + .data + .into_i64(); + + let values = curr.inputs[2] + .value + .clone() + .expect("OneHot: Only constant on/off values is currently supported") + .data + .into_f32s(); + + let axis = curr + .attrs + .get("axis") + .map(|val| val.clone().into_i64()) + .unwrap_or(-1); + + (depth as usize, values.try_into().unwrap(), axis) +} + +pub fn gemm_config(curr: &Node) -> (f32, f32, i64, i64) { + let alpha = curr + .attrs + .get("alpha") + .map(|val| val.clone().into_f32()) + .unwrap_or(1.0); + let beta = curr + .attrs + .get("beta") + .map(|val| val.clone().into_f32()) + .unwrap_or(1.0); + let trans_a = curr + .attrs + .get("transA") + .map(|val| val.clone().into_i64()) + .unwrap_or(0); + let trans_b = curr + .attrs + .get("transB") + .map(|val| val.clone().into_i64()) + .unwrap_or(0); + + (alpha, beta, trans_a, trans_b) +} From 8e815410a4821f17a180b2a0c3b634e0c75f75e3 Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Wed, 30 Apr 2025 13:30:56 -0500 Subject: [PATCH 08/37] Moved 2d and 3d config functions --- crates/onnx-ir/src/node/avg_pool2d.rs | 168 ++++++++ crates/onnx-ir/src/node/conv2d.rs | 241 +++++++++++ crates/onnx-ir/src/node/conv3d.rs | 254 ++++++++++++ crates/onnx-ir/src/node/conv_transpose2d.rs | 262 ++++++++++++ crates/onnx-ir/src/node/conv_transpose3d.rs | 276 +++++++++++++ crates/onnx-ir/src/node/max_pool2d.rs | 150 +++++++ crates/onnx-ir/src/node/mod.rs | 6 + crates/onnx-ir/src/op_configuration.rs | 437 +------------------- 8 files changed, 1373 insertions(+), 421 deletions(-) create mode 100644 crates/onnx-ir/src/node/avg_pool2d.rs create mode 100644 crates/onnx-ir/src/node/conv2d.rs create mode 100644 crates/onnx-ir/src/node/conv3d.rs create mode 100644 crates/onnx-ir/src/node/conv_transpose2d.rs create mode 100644 crates/onnx-ir/src/node/conv_transpose3d.rs create mode 100644 crates/onnx-ir/src/node/max_pool2d.rs diff --git a/crates/onnx-ir/src/node/avg_pool2d.rs b/crates/onnx-ir/src/node/avg_pool2d.rs new file mode 100644 index 0000000000..7603c2489a --- /dev/null +++ b/crates/onnx-ir/src/node/avg_pool2d.rs @@ -0,0 +1,168 @@ +use crate::ir::Node; +use burn::nn::PaddingConfig2d; +use burn::nn::pool::AvgPool2dConfig; + +/// Create a AvgPool2dConfig from the attributes of the node +pub fn avg_pool2d_config(curr: &Node) -> AvgPool2dConfig { + let mut kernel_shape = Vec::new(); + let mut strides = vec![1, 1]; + let mut pads = vec![0, 0, 0, 0]; + let mut count_include_pad: i64 = 0; + let mut ceil_mode: i64 = 0; + + for (key, value) in curr.attrs.iter() { + match key.as_str() { + "kernel_shape" => kernel_shape = value.clone().into_i64s(), + "strides" => strides = value.clone().into_i64s(), + "pads" => pads = value.clone().into_i64s(), + "count_include_pad" => count_include_pad = value.clone().into_i64(), + "ceil_mode" => ceil_mode = value.clone().into_i64(), + _ => {} + } + } + + if ceil_mode == 1 { + panic!("ceil_mode is not supported"); + } + + let padding = padding_config_2d(&pads); + + AvgPool2dConfig::new([kernel_shape[0] as usize, kernel_shape[1] as usize]) + .with_strides([strides[0] as usize, strides[1] as usize]) + .with_padding(padding) + .with_count_include_pad(count_include_pad == 1) +} + +/// Calculate the padding configuration for a 2D operations such as Convolution and Pooling. +/// +/// # Arguments +/// +/// * `pads` - The padding values +/// +/// # Panics +/// +/// * If the padding is negative +/// * If the padding is not symmetric +/// +/// # Returns +/// +/// * The padding configuration +/// +/// # Remarks +/// +/// This function is used when the padding is specified as a list of integers, +/// and not used when the padding is specified as a string, e.g. "SAME_UPPER". +fn padding_config_2d(pads: &[i64]) -> PaddingConfig2d { + let [left, top, right, bottom] = [pads[0], pads[1], pads[2], pads[3]]; + + if left < 0 || top < 0 || right < 0 || bottom < 0 { + panic!("Negative pad values are not supported"); + } else if (left != right) || (top != bottom) { + panic!("Asymmetric padding is not supported"); + } else if left == 0 && top == 0 && right == 0 && bottom == 0 { + // i.e [0, 0, 0, 0] + PaddingConfig2d::Valid + } else if left == right && top == bottom { + // i.e [2, 3, 2, 3] + PaddingConfig2d::Explicit(left as usize, top as usize) + } else { + // Unaccounted for padding configuration + panic!("Padding configuration ({:?}) not supported", pads); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::{ArgType, Argument, AttributeValue, ElementType, NodeType, TensorType}; + use std::collections::HashMap; + + fn create_test_node( + kernel_shape: Vec, + strides: Vec, + pads: Vec, + count_include_pad: i64, + ceil_mode: i64, + ) -> Node { + let inputs = vec![Argument { + name: "data".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 4, + static_shape: None, + }), + value: None, + passed: true, + }]; + + let mut attrs = HashMap::new(); + attrs.insert( + "kernel_shape".to_string(), + AttributeValue::Int64s(kernel_shape), + ); + attrs.insert("strides".to_string(), AttributeValue::Int64s(strides)); + attrs.insert("pads".to_string(), AttributeValue::Int64s(pads)); + attrs.insert( + "count_include_pad".to_string(), + AttributeValue::Int64(count_include_pad), + ); + attrs.insert("ceil_mode".to_string(), AttributeValue::Int64(ceil_mode)); + + Node { + node_type: NodeType::AveragePool2d, + name: "test_avgpool2d".to_string(), + inputs, + outputs: vec![Argument { + name: "output".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 4, + static_shape: None, + }), + value: None, + passed: true, + }], + attrs, + } + } + + #[test] + fn test_avg_pool2d_config_basic() { + let node = create_test_node(vec![3, 3], vec![1, 1], vec![0, 0, 0, 0], 0, 0); + let config = avg_pool2d_config(&node); + + assert_eq!(config.kernel_size, [3, 3]); + assert_eq!(config.strides, [1, 1]); + assert!(!config.count_include_pad); + assert!(matches!(config.padding, PaddingConfig2d::Valid)); + } + + #[test] + fn test_avg_pool2d_config_with_padding() { + let node = create_test_node(vec![2, 2], vec![2, 2], vec![1, 1, 1, 1], 0, 0); + let config = avg_pool2d_config(&node); + + assert_eq!(config.kernel_size, [2, 2]); + assert_eq!(config.strides, [2, 2]); + assert!(!config.count_include_pad); + assert!(matches!(config.padding, PaddingConfig2d::Explicit(1, 1))); + } + + #[test] + fn test_avg_pool2d_config_with_count_include_pad() { + let node = create_test_node(vec![3, 3], vec![1, 1], vec![1, 1, 1, 1], 1, 0); + let config = avg_pool2d_config(&node); + + assert_eq!(config.kernel_size, [3, 3]); + assert_eq!(config.strides, [1, 1]); + assert!(config.count_include_pad); + assert!(matches!(config.padding, PaddingConfig2d::Explicit(1, 1))); + } + + #[test] + #[should_panic(expected = "ceil_mode is not supported")] + fn test_avg_pool2d_config_with_ceil_mode() { + let node = create_test_node(vec![3, 3], vec![1, 1], vec![0, 0, 0, 0], 0, 1); + let _ = avg_pool2d_config(&node); + } +} diff --git a/crates/onnx-ir/src/node/conv2d.rs b/crates/onnx-ir/src/node/conv2d.rs new file mode 100644 index 0000000000..283d885157 --- /dev/null +++ b/crates/onnx-ir/src/node/conv2d.rs @@ -0,0 +1,241 @@ +use crate::ir::Node; +use burn::nn::PaddingConfig2d; +use burn::nn::conv::Conv2dConfig; + +/// Create a Conv2dConfig from the attributes of the node +pub fn conv2d_config(curr: &Node) -> Conv2dConfig { + let mut kernel_shape = Vec::new(); // TODO default inferred from weight tensor per spec + let mut strides = vec![1, 1]; + let mut pads = vec![0, 0, 0, 0]; + let mut dilations = vec![1, 1]; + let mut group: usize = 1; + + let weight_shape = curr.inputs[1] + .value + .as_ref() + .expect("Conv2d: weight tensor must be present") + .shape + .clone(); + + // check if the bias is present + let bias = curr.inputs.len() == 3; + + for (key, value) in curr.attrs.iter() { + match key.as_str() { + "kernel_shape" => kernel_shape = value.clone().into_i64s(), + "strides" => strides = value.clone().into_i64s(), + "pads" => pads = value.clone().into_i64s(), + "dilations" => dilations = value.clone().into_i64s(), + "group" => group = value.clone().into_i64() as usize, + _ => {} + } + } + + // the channels are inverted in the weight tensor + let channels_in = weight_shape[1] * group; + let channels_out = weight_shape[0]; + + let padding = padding_config_2d(&pads); + + Conv2dConfig::new( + [channels_in, channels_out], + [kernel_shape[0] as usize, kernel_shape[1] as usize], + ) + .with_stride([strides[0] as usize, strides[1] as usize]) + .with_dilation([dilations[0] as usize, dilations[1] as usize]) + .with_groups(group) + .with_bias(bias) + .with_padding(padding) +} + +/// Calculate the padding configuration for a 2D operations such as Convolution and Pooling. +/// +/// # Arguments +/// +/// * `pads` - The padding values +/// +/// # Panics +/// +/// * If the padding is negative +/// * If the padding is not symmetric +/// +/// # Returns +/// +/// * The padding configuration +/// +/// # Remarks +/// +/// This function is used when the padding is specified as a list of integers, +/// and not used when the padding is specified as a string, e.g. "SAME_UPPER". +fn padding_config_2d(pads: &[i64]) -> PaddingConfig2d { + let [left, top, right, bottom] = [pads[0], pads[1], pads[2], pads[3]]; + + if left < 0 || top < 0 || right < 0 || bottom < 0 { + panic!("Negative pad values are not supported"); + } else if (left != right) || (top != bottom) { + panic!("Asymmetric padding is not supported"); + } else if left == 0 && top == 0 && right == 0 && bottom == 0 { + // i.e [0, 0, 0, 0] + PaddingConfig2d::Valid + } else if left == right && top == bottom { + // i.e [2, 3, 2, 3] + PaddingConfig2d::Explicit(left as usize, top as usize) + } else { + // Unaccounted for padding configuration + panic!("Padding configuration ({:?}) not supported", pads); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::{ + ArgType, Argument, AttributeValue, Data, ElementType, NodeType, TensorData, TensorType, + }; + use std::collections::HashMap; + + fn create_test_node( + kernel_shape: Vec, + strides: Vec, + pads: Vec, + dilations: Vec, + group: i64, + has_bias: bool, + ) -> Node { + let weight_tensor = TensorData { + data: Data::Float32s(vec![0.0; 16]), // Not important for the test + shape: vec![4, 2, 2, 2], // [output_channels, input_channels/groups, k_h, k_w] + }; + + let mut inputs = vec![ + Argument { + name: "data".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 4, + static_shape: None, + }), + value: None, + passed: true, + }, + Argument { + name: "weight".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 4, + static_shape: None, + }), + value: Some(weight_tensor), + passed: true, + }, + ]; + + if has_bias { + inputs.push(Argument { + name: "bias".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 1, + static_shape: None, + }), + value: None, + passed: true, + }); + } + + let mut attrs = HashMap::new(); + attrs.insert( + "kernel_shape".to_string(), + AttributeValue::Int64s(kernel_shape), + ); + attrs.insert("strides".to_string(), AttributeValue::Int64s(strides)); + attrs.insert("pads".to_string(), AttributeValue::Int64s(pads)); + attrs.insert("dilations".to_string(), AttributeValue::Int64s(dilations)); + attrs.insert("group".to_string(), AttributeValue::Int64(group)); + + Node { + node_type: NodeType::Conv2d, + name: "test_conv2d".to_string(), + inputs, + outputs: vec![Argument { + name: "output".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 4, + static_shape: None, + }), + value: None, + passed: true, + }], + attrs, + } + } + + #[test] + fn test_conv2d_config_basic() { + let node = create_test_node( + vec![2, 2], + vec![1, 1], + vec![0, 0, 0, 0], + vec![1, 1], + 1, + false, + ); + let config = conv2d_config(&node); + + assert_eq!(config.channels, [2, 4]); + assert_eq!(config.kernel_size, [2, 2]); + assert_eq!(config.stride, [1, 1]); + assert_eq!(config.dilation, [1, 1]); + assert_eq!(config.groups, 1); + assert!(!config.bias); + assert!(matches!(config.padding, PaddingConfig2d::Valid)); + } + + #[test] + fn test_conv2d_config_with_padding() { + let node = create_test_node( + vec![3, 3], + vec![1, 1], + vec![1, 1, 1, 1], + vec![1, 1], + 1, + false, + ); + let config = conv2d_config(&node); + + assert_eq!(config.kernel_size, [3, 3]); + assert!(matches!(config.padding, PaddingConfig2d::Explicit(1, 1))); + } + + #[test] + fn test_conv2d_config_with_groups() { + let node = create_test_node( + vec![2, 2], + vec![1, 1], + vec![0, 0, 0, 0], + vec![1, 1], + 2, + false, + ); + let config = conv2d_config(&node); + + assert_eq!(config.groups, 2); + assert_eq!(config.channels, [4, 4]); // channels_in is adjusted by groups + } + + #[test] + fn test_conv2d_config_with_bias() { + let node = create_test_node( + vec![2, 2], + vec![1, 1], + vec![0, 0, 0, 0], + vec![1, 1], + 1, + true, + ); + let config = conv2d_config(&node); + + assert!(config.bias); + } +} diff --git a/crates/onnx-ir/src/node/conv3d.rs b/crates/onnx-ir/src/node/conv3d.rs new file mode 100644 index 0000000000..c3ed5b5ce3 --- /dev/null +++ b/crates/onnx-ir/src/node/conv3d.rs @@ -0,0 +1,254 @@ +use crate::ir::Node; +use burn::nn::PaddingConfig3d; +use burn::nn::conv::Conv3dConfig; + +/// Create a Conv3dConfig from the attributes of the node +pub fn conv3d_config(curr: &Node) -> Conv3dConfig { + let mut kernel_shape = Vec::new(); // TODO default inferred from weight tensor per spec + let mut strides = vec![1, 1, 1]; + let mut pads = vec![0, 0, 0, 0, 0, 0]; + let mut dilations = vec![1, 1, 1]; + let mut group: usize = 1; + + let weight_shape = curr.inputs[1] + .value + .as_ref() + .expect("Conv3d: weight tensor must be present") + .shape + .clone(); + + // check if the bias is present + let bias = curr.inputs.len() == 3; + + for (key, value) in curr.attrs.iter() { + match key.as_str() { + "kernel_shape" => kernel_shape = value.clone().into_i64s(), + "strides" => strides = value.clone().into_i64s(), + "pads" => pads = value.clone().into_i64s(), + "dilations" => dilations = value.clone().into_i64s(), + "group" => group = value.clone().into_i64() as usize, + _ => {} + } + } + + // the channels are inverted in the weight tensor + let channels_in = weight_shape[1] * group; + let channels_out = weight_shape[0]; + + let padding = padding_config_3d(&pads); + + Conv3dConfig::new( + [channels_in, channels_out], + [ + kernel_shape[0] as usize, + kernel_shape[1] as usize, + kernel_shape[2] as usize, + ], + ) + .with_stride([ + strides[0] as usize, + strides[1] as usize, + strides[2] as usize, + ]) + .with_dilation([ + dilations[0] as usize, + dilations[1] as usize, + dilations[2] as usize, + ]) + .with_groups(group) + .with_bias(bias) + .with_padding(padding) +} + +/// Calculate the padding configuration for a 3D operations such as Convolution and Pooling. +/// +/// # Arguments +/// +/// * `pads` - The padding values +/// +/// # Panics +/// +/// * If the padding is negative +/// * If the padding is not symmetric +/// +/// # Returns +/// +/// * The padding configuration +/// +/// # Remarks +/// +/// This function is used when the padding is specified as a list of integers, +/// and not used when the padding is specified as a string, e.g. "SAME_UPPER". +fn padding_config_3d(pads: &[i64]) -> PaddingConfig3d { + let [left, top, front, right, bottom, back] = + [pads[0], pads[1], pads[2], pads[3], pads[4], pads[5]]; + + if left < 0 || top < 0 || front < 0 || right < 0 || bottom < 0 || back < 0 { + panic!("Negative pad values are not supported"); + } else if (left != right) || (top != bottom) || (front != back) { + panic!("Asymmetric padding is not supported"); + } else if left == 0 && top == 0 && front == 0 && right == 0 && bottom == 0 && back == 0 { + // i.e [0, 0, 0, 0] + PaddingConfig3d::Valid + } else if left == right && top == bottom && front == back { + // i.e [2, 3, 2, 3] + PaddingConfig3d::Explicit(left as usize, top as usize, front as usize) + } else { + // Unaccounted for padding configuration + panic!("Padding configuration ({:?}) not supported", pads); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::{ + ArgType, Argument, AttributeValue, Data, ElementType, NodeType, TensorData, TensorType, + }; + use std::collections::HashMap; + + fn create_test_node( + kernel_shape: Vec, + strides: Vec, + pads: Vec, + dilations: Vec, + group: i64, + has_bias: bool, + ) -> Node { + let weight_tensor = TensorData { + data: Data::Float32s(vec![0.0; 32]), // Not important for the test + shape: vec![4, 2, 2, 2, 2], // [output_channels, input_channels/groups, k_d, k_h, k_w] + }; + + let mut inputs = vec![ + Argument { + name: "data".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 5, + static_shape: None, + }), + value: None, + passed: true, + }, + Argument { + name: "weight".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 5, + static_shape: None, + }), + value: Some(weight_tensor), + passed: true, + }, + ]; + + if has_bias { + inputs.push(Argument { + name: "bias".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 1, + static_shape: None, + }), + value: None, + passed: true, + }); + } + + let mut attrs = HashMap::new(); + attrs.insert( + "kernel_shape".to_string(), + AttributeValue::Int64s(kernel_shape), + ); + attrs.insert("strides".to_string(), AttributeValue::Int64s(strides)); + attrs.insert("pads".to_string(), AttributeValue::Int64s(pads)); + attrs.insert("dilations".to_string(), AttributeValue::Int64s(dilations)); + attrs.insert("group".to_string(), AttributeValue::Int64(group)); + + Node { + node_type: NodeType::Conv3d, + name: "test_conv3d".to_string(), + inputs, + outputs: vec![Argument { + name: "output".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 5, + static_shape: None, + }), + value: None, + passed: true, + }], + attrs, + } + } + + #[test] + fn test_conv3d_config_basic() { + let node = create_test_node( + vec![2, 2, 2], + vec![1, 1, 1], + vec![0, 0, 0, 0, 0, 0], + vec![1, 1, 1], + 1, + false, + ); + let config = conv3d_config(&node); + + assert_eq!(config.channels, [2, 4]); + assert_eq!(config.kernel_size, [2, 2, 2]); + assert_eq!(config.stride, [1, 1, 1]); + assert_eq!(config.dilation, [1, 1, 1]); + assert_eq!(config.groups, 1); + assert!(!config.bias); + assert!(matches!(config.padding, PaddingConfig3d::Valid)); + } + + #[test] + fn test_conv3d_config_with_padding() { + let node = create_test_node( + vec![3, 3, 3], + vec![1, 1, 1], + vec![1, 1, 1, 1, 1, 1], + vec![1, 1, 1], + 1, + false, + ); + let config = conv3d_config(&node); + + assert_eq!(config.kernel_size, [3, 3, 3]); + assert!(matches!(config.padding, PaddingConfig3d::Explicit(1, 1, 1))); + } + + #[test] + fn test_conv3d_config_with_groups() { + let node = create_test_node( + vec![2, 2, 2], + vec![1, 1, 1], + vec![0, 0, 0, 0, 0, 0], + vec![1, 1, 1], + 2, + false, + ); + let config = conv3d_config(&node); + + assert_eq!(config.groups, 2); + assert_eq!(config.channels, [4, 4]); // channels_in is adjusted by groups + } + + #[test] + fn test_conv3d_config_with_bias() { + let node = create_test_node( + vec![2, 2, 2], + vec![1, 1, 1], + vec![0, 0, 0, 0, 0, 0], + vec![1, 1, 1], + 1, + true, + ); + let config = conv3d_config(&node); + + assert!(config.bias); + } +} diff --git a/crates/onnx-ir/src/node/conv_transpose2d.rs b/crates/onnx-ir/src/node/conv_transpose2d.rs new file mode 100644 index 0000000000..a0dd9e975a --- /dev/null +++ b/crates/onnx-ir/src/node/conv_transpose2d.rs @@ -0,0 +1,262 @@ +use crate::ir::{AttributeValue, Node}; +use burn::nn::conv::ConvTranspose2dConfig; + +/// Create a ConvTranspose2dConfig from the attributes of the node +pub fn conv_transpose2d_config(curr: &Node) -> ConvTranspose2dConfig { + let mut attrs = curr.attrs.clone(); + let kernel_shape = attrs + .remove("kernel_shape") + .map(AttributeValue::into_i64s) + .unwrap_or_default(); + let stride = attrs + .remove("strides") + .map(AttributeValue::into_i64s) + .unwrap_or_else(|| vec![1, 1]); + let pads = attrs + .remove("pads") + .map(AttributeValue::into_i64s) + .unwrap_or_else(|| vec![0, 0, 0, 0]); + let dilations = attrs + .remove("dilations") + .map(AttributeValue::into_i64s) + .unwrap_or_else(|| vec![1, 1]); + let group = attrs + .remove("group") + .map(AttributeValue::into_i64) + .unwrap_or(1) as usize; + let output_padding = attrs + .remove("output_padding") + .map(AttributeValue::into_i64s) + .unwrap_or_else(|| vec![0, 0]); + + // Trick with remove + empty check is simplest way to not forget some attribute for runtime: + if !attrs.is_empty() { + panic!("Not all attributes are used: {attrs:?}"); + } + // Check the pads are symmetric. + let [left, top, right, bottom] = [pads[0], pads[1], pads[2], pads[3]]; + if left < 0 || top < 0 || right < 0 || bottom < 0 { + panic!("Negative pad values are not supported"); + } else if (left != right) || (top != bottom) { + panic!("Asymmetric padding is not supported"); + } + + let weight_shape = curr.inputs[1] + .value + .as_ref() + .expect("ConvTranspose2d: weight tensor must be present") + .shape + .clone(); + + // check if the bias is present + let bias = curr.inputs.len() == 3; + + // the channels are inverted in the weight tensor + let channels: [usize; 2] = [weight_shape[1] * group, weight_shape[0]]; + + ConvTranspose2dConfig::new( + channels, + [kernel_shape[0] as usize, kernel_shape[1] as usize], + ) + .with_stride([stride[0] as usize, stride[1] as usize]) + .with_padding([pads[0] as usize, pads[1] as usize]) + .with_dilation([dilations[0] as usize, dilations[1] as usize]) + .with_padding_out([output_padding[0] as usize, output_padding[1] as usize]) + .with_groups(group) + .with_bias(bias) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::{ + ArgType, Argument, AttributeValue, Data, ElementType, NodeType, TensorData, TensorType, + }; + use std::collections::HashMap; + + fn create_test_node( + kernel_shape: Vec, + strides: Vec, + pads: Vec, + dilations: Vec, + output_padding: Vec, + group: i64, + has_bias: bool, + ) -> Node { + let weight_tensor = TensorData { + data: Data::Float32s(vec![0.0; 16]), // Not important for the test + shape: vec![2, 4, 2, 2], // [input_channels, output_channels/groups, k_h, k_w] + }; + + let mut inputs = vec![ + Argument { + name: "data".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 4, + static_shape: None, + }), + value: None, + passed: true, + }, + Argument { + name: "weight".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 4, + static_shape: None, + }), + value: Some(weight_tensor), + passed: true, + }, + ]; + + if has_bias { + inputs.push(Argument { + name: "bias".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 1, + static_shape: None, + }), + value: None, + passed: true, + }); + } + + let mut attrs = HashMap::new(); + attrs.insert( + "kernel_shape".to_string(), + AttributeValue::Int64s(kernel_shape), + ); + attrs.insert("strides".to_string(), AttributeValue::Int64s(strides)); + attrs.insert("pads".to_string(), AttributeValue::Int64s(pads)); + attrs.insert("dilations".to_string(), AttributeValue::Int64s(dilations)); + attrs.insert( + "output_padding".to_string(), + AttributeValue::Int64s(output_padding), + ); + attrs.insert("group".to_string(), AttributeValue::Int64(group)); + + Node { + node_type: NodeType::ConvTranspose2d, + name: "test_convtranspose2d".to_string(), + inputs, + outputs: vec![Argument { + name: "output".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 4, + static_shape: None, + }), + value: None, + passed: true, + }], + attrs, + } + } + + #[test] + fn test_conv_transpose2d_config_basic() { + let node = create_test_node( + vec![2, 2], + vec![1, 1], + vec![0, 0, 0, 0], + vec![1, 1], + vec![0, 0], + 1, + false, + ); + let config = conv_transpose2d_config(&node); + + assert_eq!(config.channels, [4, 2]); + assert_eq!(config.kernel_size, [2, 2]); + assert_eq!(config.stride, [1, 1]); + assert_eq!(config.dilation, [1, 1]); + assert_eq!(config.padding, [0, 0]); + assert_eq!(config.padding_out, [0, 0]); + assert_eq!(config.groups, 1); + assert!(!config.bias); + } + + #[test] + fn test_conv_transpose2d_config_with_padding() { + let node = create_test_node( + vec![3, 3], + vec![2, 2], + vec![1, 1, 1, 1], + vec![1, 1], + vec![0, 0], + 1, + false, + ); + let config = conv_transpose2d_config(&node); + + assert_eq!(config.padding, [1, 1]); + assert_eq!(config.stride, [2, 2]); + } + + #[test] + fn test_conv_transpose2d_config_with_output_padding() { + let node = create_test_node( + vec![2, 2], + vec![2, 2], + vec![0, 0, 0, 0], + vec![1, 1], + vec![1, 1], + 1, + false, + ); + let config = conv_transpose2d_config(&node); + + assert_eq!(config.padding_out, [1, 1]); + } + + #[test] + fn test_conv_transpose2d_config_with_groups() { + let node = create_test_node( + vec![2, 2], + vec![1, 1], + vec![0, 0, 0, 0], + vec![1, 1], + vec![0, 0], + 2, + false, + ); + let config = conv_transpose2d_config(&node); + + assert_eq!(config.groups, 2); + assert_eq!(config.channels, [8, 2]); // channels_in is adjusted by groups + } + + #[test] + fn test_conv_transpose2d_config_with_bias() { + let node = create_test_node( + vec![2, 2], + vec![1, 1], + vec![0, 0, 0, 0], + vec![1, 1], + vec![0, 0], + 1, + true, + ); + let config = conv_transpose2d_config(&node); + + assert!(config.bias); + } + + #[test] + #[should_panic(expected = "Asymmetric padding is not supported")] + fn test_conv_transpose2d_config_with_asymmetric_padding() { + let node = create_test_node( + vec![2, 2], + vec![1, 1], + vec![1, 1, 2, 2], + vec![1, 1], + vec![0, 0], + 1, + false, + ); + let _ = conv_transpose2d_config(&node); + } +} diff --git a/crates/onnx-ir/src/node/conv_transpose3d.rs b/crates/onnx-ir/src/node/conv_transpose3d.rs new file mode 100644 index 0000000000..d4326f6ddb --- /dev/null +++ b/crates/onnx-ir/src/node/conv_transpose3d.rs @@ -0,0 +1,276 @@ +use crate::ir::{AttributeValue, Node}; +use burn::nn::conv::ConvTranspose3dConfig; + +/// Create a ConvTranspose3dConfig from the attributes of the node +pub fn conv_transpose3d_config(curr: &Node) -> ConvTranspose3dConfig { + let mut attrs = curr.attrs.clone(); + let kernel_shape = attrs + .remove("kernel_shape") + .map(AttributeValue::into_i64s) + .unwrap_or_default(); + let stride = attrs + .remove("strides") + .map(AttributeValue::into_i64s) + .unwrap_or_else(|| vec![1, 1, 1]); + let pads = attrs + .remove("pads") + .map(AttributeValue::into_i64s) + .unwrap_or_else(|| vec![0, 0, 0, 0, 0, 0]); + let dilations = attrs + .remove("dilations") + .map(AttributeValue::into_i64s) + .unwrap_or_else(|| vec![1, 1, 1]); + let group = attrs + .remove("group") + .map(AttributeValue::into_i64) + .unwrap_or(1) as usize; + let output_padding = attrs + .remove("output_padding") + .map(AttributeValue::into_i64s) + .unwrap_or_else(|| vec![0, 0, 0]); + + // Trick with remove + empty check is simplest way to not forget some attribute for runtime: + if !attrs.is_empty() { + panic!("Not all attributes are used: {attrs:?}"); + } + // Check the pads are symmetric. + let [left, top, front, right, bottom, back] = + [pads[0], pads[1], pads[2], pads[3], pads[4], pads[5]]; + + if left < 0 || top < 0 || front < 0 || right < 0 || bottom < 0 || back < 0 { + panic!("Negative pad values are not supported"); + } else if (left != right) || (top != bottom) || (front != back) { + panic!("Asymmetric padding is not supported"); + } + + let weight_shape = curr.inputs[1] + .value + .as_ref() + .expect("ConvTranspose3d: weight tensor must be present") + .shape + .clone(); + + // check if the bias is present + let bias = curr.inputs.len() == 3; + + // the channels are inverted in the weight tensor + let channels: [usize; 2] = [weight_shape[1] * group, weight_shape[0]]; + + ConvTranspose3dConfig::new( + channels, + [ + kernel_shape[0] as usize, + kernel_shape[1] as usize, + kernel_shape[2] as usize, + ], + ) + .with_stride([stride[0] as usize, stride[1] as usize, stride[2] as usize]) + .with_padding([pads[0] as usize, pads[1] as usize, pads[2] as usize]) + .with_dilation([ + dilations[0] as usize, + dilations[1] as usize, + dilations[2] as usize, + ]) + .with_padding_out([ + output_padding[0] as usize, + output_padding[1] as usize, + output_padding[2] as usize, + ]) + .with_groups(group) + .with_bias(bias) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::{ + ArgType, Argument, AttributeValue, Data, ElementType, NodeType, TensorData, TensorType, + }; + use std::collections::HashMap; + + fn create_test_node( + kernel_shape: Vec, + strides: Vec, + pads: Vec, + dilations: Vec, + output_padding: Vec, + group: i64, + has_bias: bool, + ) -> Node { + let weight_tensor = TensorData { + data: Data::Float32s(vec![0.0; 32]), // Not important for the test + shape: vec![2, 4, 2, 2, 2], // [input_channels, output_channels/groups, k_d, k_h, k_w] + }; + + let mut inputs = vec![ + Argument { + name: "data".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 5, + static_shape: None, + }), + value: None, + passed: true, + }, + Argument { + name: "weight".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 5, + static_shape: None, + }), + value: Some(weight_tensor), + passed: true, + }, + ]; + + if has_bias { + inputs.push(Argument { + name: "bias".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 1, + static_shape: None, + }), + value: None, + passed: true, + }); + } + + let mut attrs = HashMap::new(); + attrs.insert( + "kernel_shape".to_string(), + AttributeValue::Int64s(kernel_shape), + ); + attrs.insert("strides".to_string(), AttributeValue::Int64s(strides)); + attrs.insert("pads".to_string(), AttributeValue::Int64s(pads)); + attrs.insert("dilations".to_string(), AttributeValue::Int64s(dilations)); + attrs.insert( + "output_padding".to_string(), + AttributeValue::Int64s(output_padding), + ); + attrs.insert("group".to_string(), AttributeValue::Int64(group)); + + Node { + node_type: NodeType::ConvTranspose3d, + name: "test_convtranspose3d".to_string(), + inputs, + outputs: vec![Argument { + name: "output".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 5, + static_shape: None, + }), + value: None, + passed: true, + }], + attrs, + } + } + + #[test] + fn test_conv_transpose3d_config_basic() { + let node = create_test_node( + vec![2, 2, 2], + vec![1, 1, 1], + vec![0, 0, 0, 0, 0, 0], + vec![1, 1, 1], + vec![0, 0, 0], + 1, + false, + ); + let config = conv_transpose3d_config(&node); + + assert_eq!(config.channels, [4, 2]); + assert_eq!(config.kernel_size, [2, 2, 2]); + assert_eq!(config.stride, [1, 1, 1]); + assert_eq!(config.dilation, [1, 1, 1]); + assert_eq!(config.padding, [0, 0, 0]); + assert_eq!(config.padding_out, [0, 0, 0]); + assert_eq!(config.groups, 1); + assert!(!config.bias); + } + + #[test] + fn test_conv_transpose3d_config_with_padding() { + let node = create_test_node( + vec![3, 3, 3], + vec![2, 2, 2], + vec![1, 1, 1, 1, 1, 1], + vec![1, 1, 1], + vec![0, 0, 0], + 1, + false, + ); + let config = conv_transpose3d_config(&node); + + assert_eq!(config.padding, [1, 1, 1]); + assert_eq!(config.stride, [2, 2, 2]); + } + + #[test] + fn test_conv_transpose3d_config_with_output_padding() { + let node = create_test_node( + vec![2, 2, 2], + vec![2, 2, 2], + vec![0, 0, 0, 0, 0, 0], + vec![1, 1, 1], + vec![1, 1, 1], + 1, + false, + ); + let config = conv_transpose3d_config(&node); + + assert_eq!(config.padding_out, [1, 1, 1]); + } + + #[test] + fn test_conv_transpose3d_config_with_groups() { + let node = create_test_node( + vec![2, 2, 2], + vec![1, 1, 1], + vec![0, 0, 0, 0, 0, 0], + vec![1, 1, 1], + vec![0, 0, 0], + 2, + false, + ); + let config = conv_transpose3d_config(&node); + + assert_eq!(config.groups, 2); + assert_eq!(config.channels, [8, 2]); // channels_in is adjusted by groups + } + + #[test] + fn test_conv_transpose3d_config_with_bias() { + let node = create_test_node( + vec![2, 2, 2], + vec![1, 1, 1], + vec![0, 0, 0, 0, 0, 0], + vec![1, 1, 1], + vec![0, 0, 0], + 1, + true, + ); + let config = conv_transpose3d_config(&node); + + assert!(config.bias); + } + + #[test] + #[should_panic(expected = "Asymmetric padding is not supported")] + fn test_conv_transpose3d_config_with_asymmetric_padding() { + let node = create_test_node( + vec![2, 2, 2], + vec![1, 1, 1], + vec![1, 1, 1, 2, 2, 2], + vec![1, 1, 1], + vec![0, 0, 0], + 1, + false, + ); + let _ = conv_transpose3d_config(&node); + } +} diff --git a/crates/onnx-ir/src/node/max_pool2d.rs b/crates/onnx-ir/src/node/max_pool2d.rs new file mode 100644 index 0000000000..a5f4e3a9e1 --- /dev/null +++ b/crates/onnx-ir/src/node/max_pool2d.rs @@ -0,0 +1,150 @@ +use crate::ir::Node; +use burn::nn::PaddingConfig2d; +use burn::nn::pool::MaxPool2dConfig; + +/// Create a MaxPool2dConfig from the attributes of the node +pub fn max_pool2d_config(curr: &Node) -> MaxPool2dConfig { + let mut kernel_shape = Vec::new(); + let mut strides = vec![1, 1]; + let mut pads = vec![0, 0, 0, 0]; + let mut dilations = vec![1, 1]; + + for (key, value) in curr.attrs.iter() { + match key.as_str() { + "kernel_shape" => kernel_shape = value.clone().into_i64s(), + "strides" => strides = value.clone().into_i64s(), + "pads" => pads = value.clone().into_i64s(), + "dilations" => dilations = value.clone().into_i64s(), + _ => {} + } + } + + let padding = padding_config_2d(&pads); + + MaxPool2dConfig::new([kernel_shape[0] as usize, kernel_shape[1] as usize]) + .with_strides([strides[0] as usize, strides[1] as usize]) + .with_padding(padding) + .with_dilation([dilations[0] as usize, dilations[1] as usize]) +} + +/// Calculate the padding configuration for a 2D operations such as Convolution and Pooling. +/// +/// # Arguments +/// +/// * `pads` - The padding values +/// +/// # Panics +/// +/// * If the padding is negative +/// * If the padding is not symmetric +/// +/// # Returns +/// +/// * The padding configuration +/// +/// # Remarks +/// +/// This function is used when the padding is specified as a list of integers, +/// and not used when the padding is specified as a string, e.g. "SAME_UPPER". +fn padding_config_2d(pads: &[i64]) -> PaddingConfig2d { + let [left, top, right, bottom] = [pads[0], pads[1], pads[2], pads[3]]; + + if left < 0 || top < 0 || right < 0 || bottom < 0 { + panic!("Negative pad values are not supported"); + } else if (left != right) || (top != bottom) { + panic!("Asymmetric padding is not supported"); + } else if left == 0 && top == 0 && right == 0 && bottom == 0 { + // i.e [0, 0, 0, 0] + PaddingConfig2d::Valid + } else if left == right && top == bottom { + // i.e [2, 3, 2, 3] + PaddingConfig2d::Explicit(left as usize, top as usize) + } else { + // Unaccounted for padding configuration + panic!("Padding configuration ({:?}) not supported", pads); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::{ArgType, Argument, AttributeValue, ElementType, NodeType, TensorType}; + use std::collections::HashMap; + + fn create_test_node( + kernel_shape: Vec, + strides: Vec, + pads: Vec, + dilations: Vec, + ) -> Node { + let inputs = vec![Argument { + name: "data".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 4, + static_shape: None, + }), + value: None, + passed: true, + }]; + + let mut attrs = HashMap::new(); + attrs.insert( + "kernel_shape".to_string(), + AttributeValue::Int64s(kernel_shape), + ); + attrs.insert("strides".to_string(), AttributeValue::Int64s(strides)); + attrs.insert("pads".to_string(), AttributeValue::Int64s(pads)); + attrs.insert("dilations".to_string(), AttributeValue::Int64s(dilations)); + + Node { + node_type: NodeType::MaxPool2d, + name: "test_maxpool2d".to_string(), + inputs, + outputs: vec![Argument { + name: "output".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 4, + static_shape: None, + }), + value: None, + passed: true, + }], + attrs, + } + } + + #[test] + fn test_max_pool2d_config_basic() { + let node = create_test_node(vec![3, 3], vec![1, 1], vec![0, 0, 0, 0], vec![1, 1]); + let config = max_pool2d_config(&node); + + assert_eq!(config.kernel_size, [3, 3]); + assert_eq!(config.strides, [1, 1]); + assert_eq!(config.dilation, [1, 1]); + assert!(matches!(config.padding, PaddingConfig2d::Valid)); + } + + #[test] + fn test_max_pool2d_config_with_padding() { + let node = create_test_node(vec![2, 2], vec![2, 2], vec![1, 1, 1, 1], vec![1, 1]); + let config = max_pool2d_config(&node); + + assert_eq!(config.kernel_size, [2, 2]); + assert_eq!(config.strides, [2, 2]); + assert_eq!(config.dilation, [1, 1]); + assert!(matches!(config.padding, PaddingConfig2d::Explicit(1, 1))); + } + + #[test] + fn test_max_pool2d_config_with_dilation() { + let node = create_test_node(vec![3, 3], vec![1, 1], vec![0, 0, 0, 0], vec![2, 2]); + let config = max_pool2d_config(&node); + + assert_eq!(config.kernel_size, [3, 3]); + assert_eq!(config.strides, [1, 1]); + assert_eq!(config.dilation, [2, 2]); + assert!(matches!(config.padding, PaddingConfig2d::Valid)); + } +} diff --git a/crates/onnx-ir/src/node/mod.rs b/crates/onnx-ir/src/node/mod.rs index b44e62f217..065a29fcb6 100644 --- a/crates/onnx-ir/src/node/mod.rs +++ b/crates/onnx-ir/src/node/mod.rs @@ -1,5 +1,11 @@ pub mod avg_pool1d; +pub mod avg_pool2d; pub mod conv1d; +pub mod conv2d; +pub mod conv3d; pub mod conv_transpose1d; +pub mod conv_transpose2d; +pub mod conv_transpose3d; pub mod max_pool1d; +pub mod max_pool2d; pub mod slice; diff --git a/crates/onnx-ir/src/op_configuration.rs b/crates/onnx-ir/src/op_configuration.rs index fee98f68d8..cff13a86ce 100644 --- a/crates/onnx-ir/src/op_configuration.rs +++ b/crates/onnx-ir/src/op_configuration.rs @@ -1,328 +1,9 @@ // TODO Move op_configuration.rs from burn-import to onnx-ir #3091 // See https://github.com/tracel-ai/burn/issues/3091 -use burn::nn::{ - BatchNormConfig, DropoutConfig, LayerNormConfig, LinearConfig, PaddingConfig2d, - PaddingConfig3d, - conv::{Conv2dConfig, Conv3dConfig, ConvTranspose2dConfig, ConvTranspose3dConfig}, - pool::{AvgPool2dConfig, MaxPool2dConfig}, -}; - -use crate::ir::{ArgType, AttributeValue, Data, ElementType, Node, TensorData}; - -// use burn_import::burn::node::{ -// expand::ExpandShape, pad::PadConfig, split::SplitConfig, tile::TileConfig, top_k::TopKConfig, -// trilu::TriluConfig, unsqueeze::UnsqueezeAxes, -// }; - -// Conv1dConfig implementation moved to onnx-ir::node::conv1d - -/// Create a Conv2dConfig from the attributes of the node -pub fn conv2d_config(curr: &Node) -> Conv2dConfig { - let mut kernel_shape = Vec::new(); // TODO default inferred from weight tensor per spec - let mut strides = vec![1, 1]; - let mut pads = vec![0, 0, 0, 0]; - let mut dilations = vec![1, 1]; - let mut group: usize = 1; - - let weight_shape = curr.inputs[1] - .value - .as_ref() - .expect("Conv2d: weight tensor must be present") - .shape - .clone(); - - // check if the bias is present - let bias = curr.inputs.len() == 3; - - for (key, value) in curr.attrs.iter() { - match key.as_str() { - "kernel_shape" => kernel_shape = value.clone().into_i64s(), - "strides" => strides = value.clone().into_i64s(), - "pads" => pads = value.clone().into_i64s(), - "dilations" => dilations = value.clone().into_i64s(), - "group" => group = value.clone().into_i64() as usize, - _ => {} - } - } - - // the channels are inverted in the weight tensor - let channels_in = weight_shape[1] * group; - let channels_out = weight_shape[0]; - - let padding = padding_config_2d(&pads); - - Conv2dConfig::new( - [channels_in, channels_out], - [kernel_shape[0] as usize, kernel_shape[1] as usize], - ) - .with_stride([strides[0] as usize, strides[1] as usize]) - .with_dilation([dilations[0] as usize, dilations[1] as usize]) - .with_groups(group) - .with_bias(bias) - .with_padding(padding) -} - -/// Create a Conv3dConfig from the attributes of the node -pub fn conv3d_config(curr: &Node) -> Conv3dConfig { - let mut kernel_shape = Vec::new(); // TODO default inferred from weight tensor per spec - let mut strides = vec![1, 1, 1]; - let mut pads = vec![0, 0, 0, 0, 0, 0]; - let mut dilations = vec![1, 1, 1]; - let mut group: usize = 1; - - let weight_shape = curr.inputs[1] - .value - .as_ref() - .expect("Conv3d: weight tensor must be present") - .shape - .clone(); - - // check if the bias is present - let bias = curr.inputs.len() == 3; - - for (key, value) in curr.attrs.iter() { - match key.as_str() { - "kernel_shape" => kernel_shape = value.clone().into_i64s(), - "strides" => strides = value.clone().into_i64s(), - "pads" => pads = value.clone().into_i64s(), - "dilations" => dilations = value.clone().into_i64s(), - "group" => group = value.clone().into_i64() as usize, - _ => {} - } - } - - // the channels are inverted in the weight tensor - let channels_in = weight_shape[1] * group; - let channels_out = weight_shape[0]; - - let padding = padding_config_3d(&pads); - - Conv3dConfig::new( - [channels_in, channels_out], - [ - kernel_shape[0] as usize, - kernel_shape[1] as usize, - kernel_shape[2] as usize, - ], - ) - .with_stride([ - strides[0] as usize, - strides[1] as usize, - strides[2] as usize, - ]) - .with_dilation([ - dilations[0] as usize, - dilations[1] as usize, - dilations[2] as usize, - ]) - .with_groups(group) - .with_bias(bias) - .with_padding(padding) -} - -// MaxPool1dConfig implementation moved to onnx-ir::node::max_pool1d - -/// Create a MaxPool2dConfig from the attributes of the node -pub fn max_pool2d_config(curr: &Node) -> MaxPool2dConfig { - let mut kernel_shape = Vec::new(); - let mut strides = vec![1, 1]; - let mut pads = vec![0, 0, 0, 0]; - let mut dilations = vec![1, 1]; - - for (key, value) in curr.attrs.iter() { - match key.as_str() { - "kernel_shape" => kernel_shape = value.clone().into_i64s(), - "strides" => strides = value.clone().into_i64s(), - "pads" => pads = value.clone().into_i64s(), - "dilations" => dilations = value.clone().into_i64s(), - _ => {} - } - } - - let padding = padding_config_2d(&pads); - - MaxPool2dConfig::new([kernel_shape[0] as usize, kernel_shape[1] as usize]) - .with_strides([strides[0] as usize, strides[1] as usize]) - .with_padding(padding) - .with_dilation([dilations[0] as usize, dilations[1] as usize]) -} - -// ConvTranspose1dConfig implementation moved to onnx-ir::node::conv_transpose1d - -pub fn conv_transpose2d_config(curr: &Node) -> ConvTranspose2dConfig { - let mut attrs = curr.attrs.clone(); - let kernel_shape = attrs - .remove("kernel_shape") - .map(AttributeValue::into_i64s) - .unwrap_or_default(); - let stride = attrs - .remove("strides") - .map(AttributeValue::into_i64s) - .unwrap_or_else(|| vec![1, 1]); - let pads = attrs - .remove("pads") - .map(AttributeValue::into_i64s) - .unwrap_or_else(|| vec![0, 0, 0, 0]); - let dilations = attrs - .remove("dilations") - .map(AttributeValue::into_i64s) - .unwrap_or_else(|| vec![1, 1]); - let group = attrs - .remove("group") - .map(AttributeValue::into_i64) - .unwrap_or(1) as usize; - let output_padding = attrs - .remove("output_padding") - .map(AttributeValue::into_i64s) - .unwrap_or_else(|| vec![0, 0]); - - // Trick with remove + empty check is simplest way to not forget some attribute for runtime: - if !attrs.is_empty() { - panic!("Not all attributes are used: {attrs:?}"); - } - // Check the pads are symmetric. - let [left, top, right, bottom] = [pads[0], pads[1], pads[2], pads[3]]; - if left < 0 || top < 0 || right < 0 || bottom < 0 { - panic!("Negative pad values are not supported"); - } else if (left != right) || (top != bottom) { - panic!("Asymmetric padding is not supported"); - } - - let weight_shape = curr.inputs[1] - .value - .as_ref() - .expect("ConvTranspose2d: weight tensor must be present") - .shape - .clone(); +use burn::nn::{BatchNormConfig, DropoutConfig, LayerNormConfig, LinearConfig}; - // check if the bias is present - let bias = curr.inputs.len() == 3; - - // the channels are inverted in the weight tensor - let channels: [usize; 2] = [weight_shape[1] * group, weight_shape[0]]; - - ConvTranspose2dConfig::new( - channels, - [kernel_shape[0] as usize, kernel_shape[1] as usize], - ) - .with_stride([stride[0] as usize, stride[1] as usize]) - .with_padding([pads[0] as usize, pads[1] as usize]) - .with_dilation([dilations[0] as usize, dilations[1] as usize]) - .with_padding_out([output_padding[0] as usize, output_padding[1] as usize]) - .with_groups(group) - .with_bias(bias) -} - -pub fn conv_transpose3d_config(curr: &Node) -> ConvTranspose3dConfig { - let mut attrs = curr.attrs.clone(); - let kernel_shape = attrs - .remove("kernel_shape") - .map(AttributeValue::into_i64s) - .unwrap_or_default(); - let stride = attrs - .remove("strides") - .map(AttributeValue::into_i64s) - .unwrap_or_else(|| vec![1, 1, 1]); - let pads = attrs - .remove("pads") - .map(AttributeValue::into_i64s) - .unwrap_or_else(|| vec![0, 0, 0, 0, 0, 0]); - let dilations = attrs - .remove("dilations") - .map(AttributeValue::into_i64s) - .unwrap_or_else(|| vec![1, 1, 1]); - let group = attrs - .remove("group") - .map(AttributeValue::into_i64) - .unwrap_or(1) as usize; - let output_padding = attrs - .remove("output_padding") - .map(AttributeValue::into_i64s) - .unwrap_or_else(|| vec![0, 0, 0]); - - // Trick with remove + empty check is simplest way to not forget some attribute for runtime: - if !attrs.is_empty() { - panic!("Not all attributes are used: {attrs:?}"); - } - // Check the pads are symmetric. - let [left, top, front, right, bottom, back] = - [pads[0], pads[1], pads[2], pads[3], pads[4], pads[5]]; - - if left < 0 || top < 0 || front < 0 || right < 0 || bottom < 0 || back < 0 { - panic!("Negative pad values are not supported"); - } else if (left != right) || (top != bottom) || (front != back) { - panic!("Asymmetric padding is not supported"); - } - - let weight_shape = curr.inputs[1] - .value - .as_ref() - .expect("ConvTranspose3d: weight tensor must be present") - .shape - .clone(); - - // check if the bias is present - let bias = curr.inputs.len() == 3; - - // the channels are inverted in the weight tensor - let channels: [usize; 2] = [weight_shape[1] * group, weight_shape[0]]; - - ConvTranspose3dConfig::new( - channels, - [ - kernel_shape[0] as usize, - kernel_shape[1] as usize, - kernel_shape[2] as usize, - ], - ) - .with_stride([stride[0] as usize, stride[1] as usize, stride[2] as usize]) - .with_padding([pads[0] as usize, pads[1] as usize, pads[2] as usize]) - .with_dilation([ - dilations[0] as usize, - dilations[1] as usize, - dilations[2] as usize, - ]) - .with_padding_out([ - output_padding[0] as usize, - output_padding[1] as usize, - output_padding[2] as usize, - ]) - .with_groups(group) - .with_bias(bias) -} - -// AvgPool1dConfig implementation moved to onnx-ir::node::avg_pool1d -/// Create a AvgPool2dConfig from the attributes of the node -pub fn avg_pool2d_config(curr: &Node) -> AvgPool2dConfig { - let mut kernel_shape = Vec::new(); - let mut strides = vec![1, 1]; - let mut pads = vec![0, 0, 0, 0]; - let mut count_include_pad: i64 = 0; - let mut ceil_mode: i64 = 0; - - for (key, value) in curr.attrs.iter() { - match key.as_str() { - "kernel_shape" => kernel_shape = value.clone().into_i64s(), - "strides" => strides = value.clone().into_i64s(), - "pads" => pads = value.clone().into_i64s(), - "count_include_pad" => count_include_pad = value.clone().into_i64(), - "ceil_mode" => ceil_mode = value.clone().into_i64(), - _ => {} - } - } - - if ceil_mode == 1 { - panic!("ceil_mode is not supported"); - } - - let padding = padding_config_2d(&pads); - - AvgPool2dConfig::new([kernel_shape[0] as usize, kernel_shape[1] as usize]) - .with_strides([strides[0] as usize, strides[1] as usize]) - .with_padding(padding) - .with_count_include_pad(count_include_pad == 1) -} +use crate::ir::{ArgType, Data, Node, TensorData}; // pub fn expand_config(node: &Node) -> ExpandShape { // match &node.inputs[1].ty { @@ -384,9 +65,8 @@ pub fn flatten_config(curr: &Node) -> usize { // extract the attributes for (key, value) in curr.attrs.iter() { - match key.as_str() { - "axis" => axis = value.clone().into_i64(), - _ => {} + if key.as_str() == "axis" { + axis = value.clone().into_i64() } } @@ -417,9 +97,8 @@ pub fn gather_config(curr: &Node) -> usize { // extract the attributes for (key, value) in curr.attrs.iter() { - match key.as_str() { - "axis" => dim = value.clone().into_i64(), - _ => {} + if key.as_str() == "axis" { + dim = value.clone().into_i64() } } @@ -510,9 +189,8 @@ pub fn log_softmax_config(node: &Node) -> usize { // extract the attributes for (key, value) in node.attrs.iter() { - match key.as_str() { - "axis" => axis = value.clone().into_i64(), - _ => {} + if key.as_str() == "axis" { + axis = value.clone().into_i64() } } @@ -545,9 +223,8 @@ pub fn softmax_config(node: &Node) -> usize { // extract the attributes for (key, value) in node.attrs.iter() { - match key.as_str() { - "axis" => axis = value.clone().into_i64(), - _ => {} + if key.as_str() == "axis" { + axis = value.clone().into_i64() } } @@ -624,9 +301,8 @@ pub fn concat_config(node: &Node) -> usize { // extract the attributes for (key, value) in node.attrs.iter() { - match key.as_str() { - "axis" => axis = value.clone().into_i64(), - _ => {} + if key.as_str() == "axis" { + axis = value.clone().into_i64() } } @@ -913,93 +589,13 @@ pub fn layer_norm_config(node: &Node) -> (LayerNormConfig, bool) { // PadConfig::new(pads, constant_value) // } -// padding_config_1d moved to onnx-ir::node::conv1d - -/// Calculate the padding configuration for a 2D operations such as Convolution and Pooling. -/// -/// # Arguments -/// -/// * `pads` - The padding values -/// -/// # Panics -/// -/// * If the padding is negative -/// * If the padding is not symmetric -/// -/// # Returns -/// -/// * The padding configuration -/// -/// # Remarks -/// -/// This function is used when the padding is specified as a list of integers, -/// and not used when the padding is specified as a string, e.g. "SAME_UPPER". -fn padding_config_2d(pads: &[i64]) -> PaddingConfig2d { - let [left, top, right, bottom] = [pads[0], pads[1], pads[2], pads[3]]; - - if left < 0 || top < 0 || right < 0 || bottom < 0 { - panic!("Negative pad values are not supported"); - } else if (left != right) || (top != bottom) { - panic!("Asymmetric padding is not supported"); - } else if left == 0 && top == 0 && right == 0 && bottom == 0 { - // i.e [0, 0, 0, 0] - PaddingConfig2d::Valid - } else if left == right && top == bottom { - // i.e [2, 3, 2, 3] - PaddingConfig2d::Explicit(left as usize, top as usize) - } else { - // Unaccounted for padding configuration - panic!("Padding configuration ({:?}) not supported", pads); - } -} - -/// Calculate the padding configuration for a 3D operations such as Convolution and Pooling. -/// -/// # Arguments -/// -/// * `pads` - The padding values -/// -/// # Panics -/// -/// * If the padding is negative -/// * If the padding is not symmetric -/// -/// # Returns -/// -/// * The padding configuration -/// -/// # Remarks -/// -/// This function is used when the padding is specified as a list of integers, -/// and not used when the padding is specified as a string, e.g. "SAME_UPPER". -fn padding_config_3d(pads: &[i64]) -> PaddingConfig3d { - let [left, top, front, right, bottom, back] = - [pads[0], pads[1], pads[2], pads[3], pads[4], pads[5]]; - - if left < 0 || top < 0 || front < 0 || right < 0 || bottom < 0 || back < 0 { - panic!("Negative pad values are not supported"); - } else if (left != right) || (top != bottom) || (front != back) { - panic!("Asymmetric padding is not supported"); - } else if left == 0 && top == 0 && front == 0 && right == 0 && bottom == 0 && back == 0 { - // i.e [0, 0, 0, 0] - PaddingConfig3d::Valid - } else if left == right && top == bottom && front == back { - // i.e [2, 3, 2, 3] - PaddingConfig3d::Explicit(left as usize, top as usize, front as usize) - } else { - // Unaccounted for padding configuration - panic!("Padding configuration ({:?}) not supported", pads); - } -} - // Create a LeakyReluConfig from the alpha attribute of the node pub fn leaky_relu_config(node: &Node) -> f64 { let mut alpha = 0.01; for (key, value) in node.attrs.iter() { - match key.as_str() { - "alpha" => alpha = value.clone().into_f32() as f64, - _ => {} + if key.as_str() == "alpha" { + alpha = value.clone().into_f32() as f64 } } @@ -1026,9 +622,8 @@ pub fn reshape_config(node: &Node) -> Vec { let mut allowzero = 0; for (key, value) in node.attrs.iter() { - match key.as_str() { - "allowzero" => allowzero = value.clone().into_i64(), - _ => {} + if key.as_str() == "allowzero" { + allowzero = value.clone().into_i64() } } From 81b38d9fbb2cd26368429fdb186010f4cb02bdbf Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Wed, 30 Apr 2025 14:05:10 -0500 Subject: [PATCH 09/37] Move some config functions to node --- crates/onnx-ir/src/node/argmax.rs | 136 ++++++++++ crates/onnx-ir/src/node/batch_norm.rs | 154 +++++++++++ crates/onnx-ir/src/node/concat.rs | 94 +++++++ crates/onnx-ir/src/node/dropout.rs | 145 +++++++++++ crates/onnx-ir/src/node/flatten.rs | 126 +++++++++ crates/onnx-ir/src/node/gather.rs | 120 +++++++++ crates/onnx-ir/src/node/layer_norm.rs | 152 +++++++++++ crates/onnx-ir/src/node/linear.rs | 140 ++++++++++ crates/onnx-ir/src/node/log_softmax.rs | 106 ++++++++ crates/onnx-ir/src/node/mod.rs | 10 + crates/onnx-ir/src/node/softmax.rs | 106 ++++++++ crates/onnx-ir/src/op_configuration.rs | 343 ------------------------- 12 files changed, 1289 insertions(+), 343 deletions(-) create mode 100644 crates/onnx-ir/src/node/argmax.rs create mode 100644 crates/onnx-ir/src/node/batch_norm.rs create mode 100644 crates/onnx-ir/src/node/concat.rs create mode 100644 crates/onnx-ir/src/node/dropout.rs create mode 100644 crates/onnx-ir/src/node/flatten.rs create mode 100644 crates/onnx-ir/src/node/gather.rs create mode 100644 crates/onnx-ir/src/node/layer_norm.rs create mode 100644 crates/onnx-ir/src/node/linear.rs create mode 100644 crates/onnx-ir/src/node/log_softmax.rs create mode 100644 crates/onnx-ir/src/node/softmax.rs diff --git a/crates/onnx-ir/src/node/argmax.rs b/crates/onnx-ir/src/node/argmax.rs new file mode 100644 index 0000000000..ae4826233f --- /dev/null +++ b/crates/onnx-ir/src/node/argmax.rs @@ -0,0 +1,136 @@ +use crate::ir::{ArgType, Node}; + +/// Create argmax config from the attributes of the node +pub fn argmax_config(node: &Node) -> usize { + let mut axis: i64 = 0; + + // check if the node has only one input + if node.inputs.len() != 1 { + panic!( + "Argmax: multiple inputs are not supported (got {:?})", + node.inputs.len() + ); + } + + // extract the shape of the input tensor + let tensor = match node.inputs.first().unwrap().clone().ty { + ArgType::Tensor(tensor) => tensor, + _ => panic!("Only tensor input is valid"), + }; + + // extract the attributes + for (key, value) in node.attrs.iter() { + match key.as_str() { + "axis" => axis = value.clone().into_i64(), + "select_last_index" => { + // not all params are supported in burn + if value.clone().into_i64() != 0 { + log::warn!( + "only select_last_index=0 is supported for argmax in burn. Ignoring supplied value (got {:?})", + value + ); + } + } + "keepdims" => { + // not all params are supported in burn + if value.clone().into_i64() != 1 { + panic!( + "Only keepdims=1 is supported for argmax in burn (got {:?})", + value + ); + } + } + _ => {} + } + } + + // if axis is negative, it is counted from the end + if axis < 0 { + axis += tensor.rank as i64; + } + + axis as usize +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::{Argument, AttributeValue, ElementType, NodeType, TensorType}; + use std::collections::HashMap; + + fn create_test_node(axis: i64, select_last_index: i64, keepdims: i64) -> Node { + let inputs = vec![Argument { + name: "data".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 3, + static_shape: None, + }), + value: None, + passed: true, + }]; + + let mut attrs = HashMap::new(); + attrs.insert("axis".to_string(), AttributeValue::Int64(axis)); + attrs.insert( + "select_last_index".to_string(), + AttributeValue::Int64(select_last_index), + ); + attrs.insert("keepdims".to_string(), AttributeValue::Int64(keepdims)); + + Node { + node_type: NodeType::ArgMax, + name: "test_argmax".to_string(), + inputs, + outputs: vec![Argument { + name: "output".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Int64, + rank: 3, + static_shape: None, + }), + value: None, + passed: true, + }], + attrs, + } + } + + #[test] + fn test_argmax_config_basic() { + let node = create_test_node(0, 0, 1); + let config = argmax_config(&node); + assert_eq!(config, 0); + } + + #[test] + fn test_argmax_config_negative_axis() { + let node = create_test_node(-2, 0, 1); + let config = argmax_config(&node); + assert_eq!(config, 1); // -2 + 3 = 1 + } + + #[test] + #[should_panic(expected = "Argmax: multiple inputs are not supported")] + fn test_argmax_config_multiple_inputs() { + let mut node = create_test_node(0, 0, 1); + node.inputs.push(Argument { + name: "extra".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 1, + static_shape: None, + }), + value: None, + passed: true, + }); + let _ = argmax_config(&node); + } + + #[test] + #[should_panic(expected = "Only keepdims=1 is supported for argmax in burn")] + fn test_argmax_config_keepdims_not_supported() { + let node = create_test_node(0, 0, 0); + let _ = argmax_config(&node); + } +} diff --git a/crates/onnx-ir/src/node/batch_norm.rs b/crates/onnx-ir/src/node/batch_norm.rs new file mode 100644 index 0000000000..0e23cf021d --- /dev/null +++ b/crates/onnx-ir/src/node/batch_norm.rs @@ -0,0 +1,154 @@ +use crate::ir::Node; +use burn::nn::BatchNormConfig; + +/// Create a BatchNormConfig from the attributes of the node +pub fn batch_norm_config(node: &Node) -> BatchNormConfig { + let weight_shape = node.inputs[1] + .value + .as_ref() + .expect("BatchNorm: weight tensor must be present") + .shape + .clone(); + + let num_features = weight_shape[0]; + + let mut epsilon = 0f32; + let mut momentum = 0f32; + + for (key, value) in node.attrs.iter() { + match key.as_str() { + "momentum" => momentum = value.clone().into_f32(), + "epsilon" => epsilon = value.clone().into_f32(), + _ => {} + } + } + + BatchNormConfig::new(num_features) + .with_epsilon(epsilon as f64) + .with_momentum(momentum as f64) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::{ + ArgType, Argument, AttributeValue, Data, ElementType, NodeType, TensorData, TensorType, + }; + use std::collections::HashMap; + + fn create_test_node(epsilon: f32, momentum: f32, num_features: usize) -> Node { + let weight_tensor = TensorData { + data: Data::Float32s(vec![1.0; num_features]), // Not important for the test + shape: vec![num_features], + }; + + let bias_tensor = TensorData { + data: Data::Float32s(vec![0.0; num_features]), // Not important for the test + shape: vec![num_features], + }; + + let mean_tensor = TensorData { + data: Data::Float32s(vec![0.0; num_features]), // Not important for the test + shape: vec![num_features], + }; + + let var_tensor = TensorData { + data: Data::Float32s(vec![1.0; num_features]), // Not important for the test + shape: vec![num_features], + }; + + let inputs = vec![ + Argument { + name: "X".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 4, // NCHW format + static_shape: None, + }), + value: None, + passed: true, + }, + Argument { + name: "scale".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 1, + static_shape: None, + }), + value: Some(weight_tensor), + passed: true, + }, + Argument { + name: "bias".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 1, + static_shape: None, + }), + value: Some(bias_tensor), + passed: true, + }, + Argument { + name: "mean".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 1, + static_shape: None, + }), + value: Some(mean_tensor), + passed: true, + }, + Argument { + name: "var".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 1, + static_shape: None, + }), + value: Some(var_tensor), + passed: true, + }, + ]; + + let mut attrs = HashMap::new(); + attrs.insert("epsilon".to_string(), AttributeValue::Float32(epsilon)); + attrs.insert("momentum".to_string(), AttributeValue::Float32(momentum)); + + Node { + node_type: NodeType::BatchNormalization, + name: "test_batchnorm".to_string(), + inputs, + outputs: vec![Argument { + name: "output".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 4, + static_shape: None, + }), + value: None, + passed: true, + }], + attrs, + } + } + + #[test] + fn test_batch_norm_config_basic() { + let node = create_test_node(1e-5, 0.9, 64); + let config = batch_norm_config(&node); + + assert_eq!(config.num_features, 64); + assert!(f64::abs(config.epsilon - 1e-5) < 1e-6); + assert!(f64::abs(config.momentum - 0.9) < 1e-6); + } + + #[test] + fn test_batch_norm_config_default_values() { + let node = create_test_node(0.0, 0.0, 32); + let config = batch_norm_config(&node); + + assert_eq!(config.num_features, 32); + assert!(f64::abs(config.epsilon - 0.0) < 1e-6); + assert!(f64::abs(config.momentum - 0.0) < 1e-6); + } +} diff --git a/crates/onnx-ir/src/node/concat.rs b/crates/onnx-ir/src/node/concat.rs new file mode 100644 index 0000000000..147b552869 --- /dev/null +++ b/crates/onnx-ir/src/node/concat.rs @@ -0,0 +1,94 @@ +use crate::ir::{ArgType, Node}; + +/// Create concat config from the attributes of the node +pub fn concat_config(node: &Node) -> usize { + // the axis is the last dimension (Default: 1 per ONNX spec) + let mut axis: i64 = 1; + + // extract the shape of the input tensor + let tensor = match node.inputs.first().unwrap().clone().ty { + ArgType::Tensor(tensor) => tensor, + _ => panic!("Only tensor input is valid"), + }; + + // extract the attributes + for (key, value) in node.attrs.iter() { + if key.as_str() == "axis" { + axis = value.clone().into_i64() + } + } + + // if axis is negative, it is counted from the end + if axis < 0 { + axis += tensor.rank as i64; + } + + axis as usize +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::{Argument, AttributeValue, ElementType, NodeType, TensorType}; + use std::collections::HashMap; + + fn create_test_node(axis: i64, input_rank: usize, num_inputs: usize) -> Node { + let mut inputs = Vec::new(); + + // Create multiple inputs for concat + for i in 0..num_inputs { + inputs.push(Argument { + name: format!("data_{}", i), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: input_rank, + static_shape: None, + }), + value: None, + passed: true, + }); + } + + let mut attrs = HashMap::new(); + attrs.insert("axis".to_string(), AttributeValue::Int64(axis)); + + Node { + node_type: NodeType::Concat, + name: "test_concat".to_string(), + inputs, + outputs: vec![Argument { + name: "output".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: input_rank, + static_shape: None, + }), + value: None, + passed: true, + }], + attrs, + } + } + + #[test] + fn test_concat_config_basic() { + let node = create_test_node(1, 3, 2); + let config = concat_config(&node); + assert_eq!(config, 1); + } + + #[test] + fn test_concat_config_negative_axis() { + let node = create_test_node(-2, 3, 2); + let config = concat_config(&node); + assert_eq!(config, 1); // -2 + 3 = 1 + } + + #[test] + #[should_panic(expected = "Only tensor input is valid")] + fn test_concat_config_invalid_input() { + let mut node = create_test_node(1, 3, 1); + node.inputs[0].ty = ArgType::Shape(1); + let _ = concat_config(&node); + } +} diff --git a/crates/onnx-ir/src/node/dropout.rs b/crates/onnx-ir/src/node/dropout.rs new file mode 100644 index 0000000000..40fc1f03ef --- /dev/null +++ b/crates/onnx-ir/src/node/dropout.rs @@ -0,0 +1,145 @@ +use crate::ir::{Data, Node}; +use burn::nn::DropoutConfig; + +/// Create a DropoutConfig from an attribute and state of the node +pub fn dropout_config(node: &Node) -> DropoutConfig { + // Opset 7 and older store probability as an attribute + if node.attrs.contains_key("ratio") { + let prob = node.attrs.get("ratio").unwrap().clone().into_f32(); + return DropoutConfig::new(prob as f64); + } + + if node.inputs.len() < 2 { + panic!("Dropout configuration must have at least 2 inputs"); + } + + let ratio = node.inputs[1] + .value + .clone() + .expect("Dropout ratio must be passed in the second input") + .data + .into_scalar(); + + let prob = match ratio { + Data::Float16(ratio) => f64::from(f32::from(ratio)), + Data::Float32(ratio) => ratio as f64, + Data::Float64(ratio) => ratio, + _ => panic!("Dropout ratio must be a float"), + }; + + DropoutConfig::new(prob) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::{ + ArgType, Argument, AttributeValue, Data, ElementType, NodeType, TensorData, TensorType, + }; + use std::collections::HashMap; + + fn create_test_node_with_attr(ratio: f32) -> Node { + let inputs = vec![Argument { + name: "data".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 3, + static_shape: None, + }), + value: None, + passed: true, + }]; + + let mut attrs = HashMap::new(); + attrs.insert("ratio".to_string(), AttributeValue::Float32(ratio)); + + Node { + node_type: NodeType::Dropout, + name: "test_dropout".to_string(), + inputs, + outputs: vec![Argument { + name: "output".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 3, + static_shape: None, + }), + value: None, + passed: true, + }], + attrs, + } + } + + fn create_test_node_with_input(ratio: f32) -> Node { + let ratio_tensor = TensorData { + data: Data::Float32(ratio), + shape: vec![], + }; + + let inputs = vec![ + Argument { + name: "data".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 3, + static_shape: None, + }), + value: None, + passed: true, + }, + Argument { + name: "ratio".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 0, + static_shape: None, + }), + value: Some(ratio_tensor), + passed: true, + }, + ]; + + let attrs = HashMap::new(); + + Node { + node_type: NodeType::Dropout, + name: "test_dropout".to_string(), + inputs, + outputs: vec![Argument { + name: "output".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 3, + static_shape: None, + }), + value: None, + passed: true, + }], + attrs, + } + } + + #[test] + fn test_dropout_config_with_attr() { + let node = create_test_node_with_attr(0.3); + let config = dropout_config(&node); + assert!(f64::abs(config.prob - 0.3) < 1e-6); + } + + #[test] + fn test_dropout_config_with_input() { + let node = create_test_node_with_input(0.5); + let config = dropout_config(&node); + assert!(f64::abs(config.prob - 0.5) < 1e-6); + } + + #[test] + #[should_panic(expected = "Dropout configuration must have at least 2 inputs")] + fn test_dropout_config_missing_input() { + let mut node = create_test_node_with_input(0.5); + node.attrs = HashMap::new(); // Remove attributes + node.inputs.remove(1); // Remove ratio input + let _ = dropout_config(&node); + } +} diff --git a/crates/onnx-ir/src/node/flatten.rs b/crates/onnx-ir/src/node/flatten.rs new file mode 100644 index 0000000000..373c2a560f --- /dev/null +++ b/crates/onnx-ir/src/node/flatten.rs @@ -0,0 +1,126 @@ +use crate::ir::{ArgType, Node}; + +/// Create a FlattenConfig from the attributes of the node +pub fn flatten_config(curr: &Node) -> usize { + // the begin dimension is the first dimension (Default: 1 per ONNX spec) + let mut axis: i64 = 1; + + // check if the node has only one input + if curr.inputs.len() != 1 { + panic!( + "Flatten: multiple inputs are not supported (got {:?})", + curr.inputs.len() + ); + } + + // extract the shape of the input tensor + let tensor = match curr.inputs.first().unwrap().clone().ty { + ArgType::Tensor(tensor) => tensor, + _ => panic!("Only tensor input is valid"), + }; + + // check if the input tensor has at least 2 dimensions + if tensor.rank < 2 { + panic!( + "Flatten: input tensor must have at least 2 dimensions (got {:?})", + tensor.rank + ); + } + + // extract the attributes + for (key, value) in curr.attrs.iter() { + if key.as_str() == "axis" { + axis = value.clone().into_i64() + } + } + + // if beg_dim is negative, it is counted from the end + if axis < 0 { + axis += tensor.rank as i64; + } + + axis as usize +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::{Argument, AttributeValue, ElementType, NodeType, TensorType}; + use std::collections::HashMap; + + fn create_test_node(axis: i64) -> Node { + let inputs = vec![Argument { + name: "data".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 4, + static_shape: None, + }), + value: None, + passed: true, + }]; + + let mut attrs = HashMap::new(); + attrs.insert("axis".to_string(), AttributeValue::Int64(axis)); + + Node { + node_type: NodeType::Flatten, + name: "test_flatten".to_string(), + inputs, + outputs: vec![Argument { + name: "output".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 2, + static_shape: None, + }), + value: None, + passed: true, + }], + attrs, + } + } + + #[test] + fn test_flatten_config_basic() { + let node = create_test_node(1); + let config = flatten_config(&node); + assert_eq!(config, 1); + } + + #[test] + fn test_flatten_config_with_negative_axis() { + let node = create_test_node(-2); + let config = flatten_config(&node); + assert_eq!(config, 2); // -2 + 4 = 2 + } + + #[test] + #[should_panic(expected = "Flatten: input tensor must have at least 2 dimensions")] + fn test_flatten_config_with_low_rank() { + let mut node = create_test_node(1); + node.inputs[0].ty = ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 1, + static_shape: None, + }); + let _ = flatten_config(&node); + } + + #[test] + #[should_panic(expected = "Flatten: multiple inputs are not supported")] + fn test_flatten_config_with_multiple_inputs() { + let mut node = create_test_node(1); + node.inputs.push(Argument { + name: "extra".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 1, + static_shape: None, + }), + value: None, + passed: true, + }); + let _ = flatten_config(&node); + } +} diff --git a/crates/onnx-ir/src/node/gather.rs b/crates/onnx-ir/src/node/gather.rs new file mode 100644 index 0000000000..f4a63756cf --- /dev/null +++ b/crates/onnx-ir/src/node/gather.rs @@ -0,0 +1,120 @@ +use crate::ir::{ArgType, Node}; + +/// Create a GatherConfig from the attributes of the node +pub fn gather_config(curr: &Node) -> usize { + // Default: 0 per ONNX spec + let mut dim: i64 = 0; + + // check if the node has only one input + if curr.inputs.len() != 2 { + panic!("Gather: index tensor must be present"); + } + + // extract the shape of the input tensor + let input_dim = match curr.inputs.first().unwrap().clone().ty { + ArgType::Tensor(tensor) => tensor.rank as i64, + ArgType::Shape(_shape) => 1, // Shape is always 1-D + other => panic!("Only tensor or shape input is valid, got {:?}", other), + }; + + // extract the attributes + for (key, value) in curr.attrs.iter() { + if key.as_str() == "axis" { + dim = value.clone().into_i64() + } + } + + // if dim is negative, it is counted from the end + if dim < 0 { + dim += input_dim; + } + + dim as usize +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::{Argument, AttributeValue, ElementType, NodeType, TensorType}; + use std::collections::HashMap; + + fn create_test_node(axis: i64, input_rank: usize, is_shape: bool) -> Node { + let input_ty = if is_shape { + ArgType::Shape(1) + } else { + ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: input_rank, + static_shape: None, + }) + }; + + let inputs = vec![ + Argument { + name: "data".to_string(), + ty: input_ty, + value: None, + passed: true, + }, + Argument { + name: "indices".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Int64, + rank: 1, + static_shape: None, + }), + value: None, + passed: true, + }, + ]; + + let mut attrs = HashMap::new(); + attrs.insert("axis".to_string(), AttributeValue::Int64(axis)); + + Node { + node_type: NodeType::Gather, + name: "test_gather".to_string(), + inputs, + outputs: vec![Argument { + name: "output".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: input_rank, + static_shape: None, + }), + value: None, + passed: true, + }], + attrs, + } + } + + #[test] + fn test_gather_config_basic() { + let node = create_test_node(0, 3, false); + let config = gather_config(&node); + assert_eq!(config, 0); + } + + #[test] + fn test_gather_config_negative_axis() { + let node = create_test_node(-2, 3, false); + let config = gather_config(&node); + assert_eq!(config, 1); // -2 + 3 = 1 + } + + #[test] + fn test_gather_config_shape_input() { + let node = create_test_node(0, 0, true); + let config = gather_config(&node); + assert_eq!(config, 0); + } + + #[test] + #[should_panic(expected = "Gather: index tensor must be present")] + fn test_gather_config_missing_index() { + let mut node = create_test_node(0, 3, false); + node.inputs.pop(); // Remove the indices input + let _ = gather_config(&node); + } +} diff --git a/crates/onnx-ir/src/node/layer_norm.rs b/crates/onnx-ir/src/node/layer_norm.rs new file mode 100644 index 0000000000..90d8538985 --- /dev/null +++ b/crates/onnx-ir/src/node/layer_norm.rs @@ -0,0 +1,152 @@ +use crate::ir::Node; +use burn::nn::LayerNormConfig; + +/// Create a LayerNormConfig from the attributes of the node +pub fn layer_norm_config(node: &Node) -> (LayerNormConfig, bool) { + let weight_shape = node.inputs[1] + .value + .as_ref() + .expect("LayerNorm: weight tensor must be present") + .shape + .clone(); + + let num_features = weight_shape[0]; + + // When `stash_type` is `1` (default), perform operations in 32-bit float and + // cast the results back to original dtype + let mut stash_type = 1; + let mut axis = -1; + let mut epsilon = 1e-5; + + for (key, value) in node.attrs.iter() { + match key.as_str() { + "axis" => axis = value.clone().into_i64(), + "epsilon" => epsilon = value.clone().into_f32(), + "stash_type" => stash_type = value.clone().into_i64(), + _ => {} + } + } + + if axis != -1 && axis != weight_shape.len() as i64 - 1 { + panic!("LayerNorm: normalization is only supported on the last axis right now") + } + + ( + LayerNormConfig::new(num_features).with_epsilon(epsilon as f64), + stash_type == 1, + ) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::{ + ArgType, Argument, AttributeValue, Data, ElementType, NodeType, TensorData, TensorType, + }; + use std::collections::HashMap; + + fn create_test_node(epsilon: f32, axis: i64, stash_type: i64, num_features: usize) -> Node { + let weight_tensor = TensorData { + data: Data::Float32s(vec![1.0; num_features]), // Not important for the test + shape: vec![num_features], + }; + + let bias_tensor = TensorData { + data: Data::Float32s(vec![0.0; num_features]), // Not important for the test + shape: vec![num_features], + }; + + let inputs = vec![ + Argument { + name: "X".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 3, + static_shape: None, + }), + value: None, + passed: true, + }, + Argument { + name: "scale".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 1, + static_shape: None, + }), + value: Some(weight_tensor), + passed: true, + }, + Argument { + name: "bias".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 1, + static_shape: None, + }), + value: Some(bias_tensor), + passed: true, + }, + ]; + + let mut attrs = HashMap::new(); + attrs.insert("epsilon".to_string(), AttributeValue::Float32(epsilon)); + attrs.insert("axis".to_string(), AttributeValue::Int64(axis)); + attrs.insert("stash_type".to_string(), AttributeValue::Int64(stash_type)); + + Node { + node_type: NodeType::LayerNormalization, + name: "test_layernorm".to_string(), + inputs, + outputs: vec![Argument { + name: "output".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 3, + static_shape: None, + }), + value: None, + passed: true, + }], + attrs, + } + } + + #[test] + fn test_layer_norm_config_basic() { + let node = create_test_node(1e-5, -1, 1, 64); + let (config, stash_type_flag) = layer_norm_config(&node); + + assert_eq!(config.d_model, 64); + assert!(f64::abs(config.epsilon - 1e-5) < 1e-6); + assert!(stash_type_flag); + } + + #[test] + fn test_layer_norm_config_no_stash_type() { + let node = create_test_node(1e-5, -1, 0, 32); + let (config, stash_type_flag) = layer_norm_config(&node); + + assert_eq!(config.d_model, 32); + assert!(!stash_type_flag); + } + + #[test] + #[should_panic] + fn test_layer_norm_config_invalid_axis() { + // For a 1D weight tensor with shape [num_features], + // both axis=0 (the first and only dim) and axis=-1 (the last dim) are valid + // So we need to use a 2D weight tensor to test the invalid axis case + + // Create a custom node with a 2D weight tensor + let mut node = create_test_node(1e-5, 0, 1, 64); + + // Modify the weight tensor to be 2D + if let Some(ref mut tensor) = node.inputs[1].value { + tensor.shape = vec![64, 64]; // Make it 2D + } + + // Now axis=0 should trigger a panic since it's not the last dimension + let _ = layer_norm_config(&node); + } +} diff --git a/crates/onnx-ir/src/node/linear.rs b/crates/onnx-ir/src/node/linear.rs new file mode 100644 index 0000000000..59cb043ff8 --- /dev/null +++ b/crates/onnx-ir/src/node/linear.rs @@ -0,0 +1,140 @@ +use crate::ir::Node; +use burn::nn::LinearConfig; + +/// Create a LinearConfig from the attributes of the node +pub fn linear_config(node: &Node) -> LinearConfig { + if node.inputs.len() < 2 { + panic!("Linear: missing weight tensor"); + } + + let weight_shape = node.inputs[1] + .value + .as_ref() + .expect("Linear: weight tensor must be present") + .shape + .clone(); + + // check if the weight tensor has at least 2 dimensions + if weight_shape.len() < 2 { + panic!( + "Linear: weight tensor must have at least 2 dimensions (got {:?})", + weight_shape.len() + ); + } + + let (in_size, out_size) = (weight_shape[0], weight_shape[1]); + + // check if the bias is present + let bias = node.inputs.len() == 3 && node.inputs[2].value.is_some(); + + LinearConfig::new(in_size, out_size).with_bias(bias) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::{ArgType, Argument, Data, ElementType, NodeType, TensorData, TensorType}; + use std::collections::HashMap; + + fn create_test_node(has_bias: bool, weight_dims: Vec) -> Node { + let weight_tensor = TensorData { + data: Data::Float32s(vec![0.0; weight_dims.iter().product()]), // Not important for the test + shape: weight_dims.clone(), + }; + + let mut inputs = vec![ + Argument { + name: "input".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 2, + static_shape: None, + }), + value: None, + passed: true, + }, + Argument { + name: "weight".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: weight_dims.len(), + static_shape: None, + }), + value: Some(weight_tensor), + passed: true, + }, + ]; + + if has_bias { + let bias_tensor = TensorData { + data: Data::Float32s(vec![0.0; weight_dims[1]]), // bias size equals output size + shape: vec![weight_dims[1]], + }; + + inputs.push(Argument { + name: "bias".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 1, + static_shape: None, + }), + value: Some(bias_tensor), + passed: true, + }); + } + + let attrs = HashMap::new(); + + Node { + node_type: NodeType::Gemm, + name: "test_linear".to_string(), + inputs, + outputs: vec![Argument { + name: "output".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 2, + static_shape: None, + }), + value: None, + passed: true, + }], + attrs, + } + } + + #[test] + fn test_linear_config_basic() { + let node = create_test_node(false, vec![10, 5]); + let config = linear_config(&node); + + assert_eq!(config.d_input, 10); + assert_eq!(config.d_output, 5); + assert!(!config.bias); + } + + #[test] + fn test_linear_config_with_bias() { + let node = create_test_node(true, vec![10, 5]); + let config = linear_config(&node); + + assert_eq!(config.d_input, 10); + assert_eq!(config.d_output, 5); + assert!(config.bias); + } + + #[test] + #[should_panic(expected = "Linear: weight tensor must have at least 2 dimensions")] + fn test_linear_config_invalid_weight_dims() { + let node = create_test_node(false, vec![10]); + let _ = linear_config(&node); + } + + #[test] + #[should_panic(expected = "Linear: missing weight tensor")] + fn test_linear_config_missing_weight() { + let mut node = create_test_node(false, vec![10, 5]); + node.inputs.remove(1); + let _ = linear_config(&node); + } +} diff --git a/crates/onnx-ir/src/node/log_softmax.rs b/crates/onnx-ir/src/node/log_softmax.rs new file mode 100644 index 0000000000..61f0619df5 --- /dev/null +++ b/crates/onnx-ir/src/node/log_softmax.rs @@ -0,0 +1,106 @@ +use crate::ir::{ArgType, Node}; + +/// Create log_softmax config from the attributes of the node +pub fn log_softmax_config(node: &Node) -> usize { + // the axis is the last dimension (Default: 1 per ONNX spec) + let mut axis: i64 = -1; + + // check if the node has only one input + if node.inputs.len() != 1 { + panic!( + "LogSoftmax: multiple inputs are not supported (got {:?})", + node.inputs.len() + ); + } + + // extract the shape of the input tensor + let tensor = match node.inputs.first().unwrap().clone().ty { + ArgType::Tensor(tensor) => tensor, + _ => panic!("Only tensor input is valid"), + }; + + // extract the attributes + for (key, value) in node.attrs.iter() { + if key.as_str() == "axis" { + axis = value.clone().into_i64() + } + } + + // if axis is negative, it is counted from the end + if axis < 0 { + axis += tensor.rank as i64; + } + + axis as usize +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::{Argument, AttributeValue, ElementType, NodeType, TensorType}; + use std::collections::HashMap; + + fn create_test_node(axis: i64, input_rank: usize) -> Node { + let inputs = vec![Argument { + name: "data".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: input_rank, + static_shape: None, + }), + value: None, + passed: true, + }]; + + let mut attrs = HashMap::new(); + attrs.insert("axis".to_string(), AttributeValue::Int64(axis)); + + Node { + node_type: NodeType::LogSoftmax, + name: "test_log_softmax".to_string(), + inputs, + outputs: vec![Argument { + name: "output".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: input_rank, + static_shape: None, + }), + value: None, + passed: true, + }], + attrs, + } + } + + #[test] + fn test_log_softmax_config_basic() { + let node = create_test_node(-1, 3); + let config = log_softmax_config(&node); + assert_eq!(config, 2); // -1 + 3 = 2 (last dimension) + } + + #[test] + fn test_log_softmax_config_explicit_axis() { + let node = create_test_node(1, 3); + let config = log_softmax_config(&node); + assert_eq!(config, 1); + } + + #[test] + #[should_panic(expected = "LogSoftmax: multiple inputs are not supported")] + fn test_log_softmax_config_multiple_inputs() { + let mut node = create_test_node(1, 3); + node.inputs.push(Argument { + name: "extra".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 1, + static_shape: None, + }), + value: None, + passed: true, + }); + let _ = log_softmax_config(&node); + } +} diff --git a/crates/onnx-ir/src/node/mod.rs b/crates/onnx-ir/src/node/mod.rs index 065a29fcb6..e3a4ac6df7 100644 --- a/crates/onnx-ir/src/node/mod.rs +++ b/crates/onnx-ir/src/node/mod.rs @@ -1,11 +1,21 @@ +pub mod argmax; pub mod avg_pool1d; pub mod avg_pool2d; +pub mod batch_norm; +pub mod concat; pub mod conv1d; pub mod conv2d; pub mod conv3d; pub mod conv_transpose1d; pub mod conv_transpose2d; pub mod conv_transpose3d; +pub mod dropout; +pub mod flatten; +pub mod gather; +pub mod layer_norm; +pub mod linear; +pub mod log_softmax; pub mod max_pool1d; pub mod max_pool2d; pub mod slice; +pub mod softmax; diff --git a/crates/onnx-ir/src/node/softmax.rs b/crates/onnx-ir/src/node/softmax.rs new file mode 100644 index 0000000000..14f3a1c37a --- /dev/null +++ b/crates/onnx-ir/src/node/softmax.rs @@ -0,0 +1,106 @@ +use crate::ir::{ArgType, Node}; + +/// Create softmax config from the attributes of the node +pub fn softmax_config(node: &Node) -> usize { + // the axis is the last dimension (Default: 1 per ONNX spec) + let mut axis: i64 = -1; + + // check if the node has only one input + if node.inputs.len() != 1 { + panic!( + "Softmax: multiple inputs are not supported (got {:?})", + node.inputs.len() + ); + } + + // extract the shape of the input tensor + let tensor = match node.inputs.first().unwrap().clone().ty { + ArgType::Tensor(tensor) => tensor, + _ => panic!("Only tensor input is valid"), + }; + + // extract the attributes + for (key, value) in node.attrs.iter() { + if key.as_str() == "axis" { + axis = value.clone().into_i64() + } + } + + // if axis is negative, it is counted from the end + if axis < 0 { + axis += tensor.rank as i64; + } + + axis as usize +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::{Argument, AttributeValue, ElementType, NodeType, TensorType}; + use std::collections::HashMap; + + fn create_test_node(axis: i64, input_rank: usize) -> Node { + let inputs = vec![Argument { + name: "data".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: input_rank, + static_shape: None, + }), + value: None, + passed: true, + }]; + + let mut attrs = HashMap::new(); + attrs.insert("axis".to_string(), AttributeValue::Int64(axis)); + + Node { + node_type: NodeType::Softmax, + name: "test_softmax".to_string(), + inputs, + outputs: vec![Argument { + name: "output".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: input_rank, + static_shape: None, + }), + value: None, + passed: true, + }], + attrs, + } + } + + #[test] + fn test_softmax_config_basic() { + let node = create_test_node(-1, 3); + let config = softmax_config(&node); + assert_eq!(config, 2); // -1 + 3 = 2 (last dimension) + } + + #[test] + fn test_softmax_config_explicit_axis() { + let node = create_test_node(1, 3); + let config = softmax_config(&node); + assert_eq!(config, 1); + } + + #[test] + #[should_panic(expected = "Softmax: multiple inputs are not supported")] + fn test_softmax_config_multiple_inputs() { + let mut node = create_test_node(1, 3); + node.inputs.push(Argument { + name: "extra".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 1, + static_shape: None, + }), + value: None, + passed: true, + }); + let _ = softmax_config(&node); + } +} diff --git a/crates/onnx-ir/src/op_configuration.rs b/crates/onnx-ir/src/op_configuration.rs index cff13a86ce..d07568cccc 100644 --- a/crates/onnx-ir/src/op_configuration.rs +++ b/crates/onnx-ir/src/op_configuration.rs @@ -1,8 +1,6 @@ // TODO Move op_configuration.rs from burn-import to onnx-ir #3091 // See https://github.com/tracel-ai/burn/issues/3091 -use burn::nn::{BatchNormConfig, DropoutConfig, LayerNormConfig, LinearConfig}; - use crate::ir::{ArgType, Data, Node, TensorData}; // pub fn expand_config(node: &Node) -> ExpandShape { @@ -36,347 +34,6 @@ use crate::ir::{ArgType, Data, Node, TensorData}; // } // } -/// Create a FlattenConfig from the attributes of the node -pub fn flatten_config(curr: &Node) -> usize { - // the begin dimension is the first dimension (Default: 1 per ONNX spec) - let mut axis: i64 = 1; - - // check if the node has only one input - if curr.inputs.len() != 1 { - panic!( - "Flatten: multiple inputs are not supported (got {:?})", - curr.inputs.len() - ); - } - - // extract the shape of the input tensor - let tensor = match curr.inputs.first().unwrap().clone().ty { - ArgType::Tensor(tensor) => tensor, - _ => panic!("Only tensor input is valid"), - }; - - // check if the input tensor has at least 2 dimensions - if tensor.rank < 2 { - panic!( - "Flatten: input tensor must have at least 2 dimensions (got {:?})", - tensor.rank - ); - } - - // extract the attributes - for (key, value) in curr.attrs.iter() { - if key.as_str() == "axis" { - axis = value.clone().into_i64() - } - } - - // if beg_dim is negative, it is counted from the end - if axis < 0 { - axis += tensor.rank as i64; - } - - axis as usize -} - -/// Create a GatherConfig from the attributes of the node -pub fn gather_config(curr: &Node) -> usize { - // Default: 0 per ONNX spec - let mut dim: i64 = 0; - - // check if the node has only one input - if curr.inputs.len() != 2 { - panic!("Gather: index tensor must be present"); - } - - // extract the shape of the input tensor - let input_dim = match curr.inputs.first().unwrap().clone().ty { - ArgType::Tensor(tensor) => tensor.rank as i64, - ArgType::Shape(_shape) => 1, //Shape is always 1-D - other => panic!("Only tensor or shape input is valid, got {:?}", other), - }; - - // extract the attributes - for (key, value) in curr.attrs.iter() { - if key.as_str() == "axis" { - dim = value.clone().into_i64() - } - } - - // if dim is negative, it is counted from the end - if dim < 0 { - dim += input_dim; - } - - dim as usize -} - -/// Create a LinearConfig from the attributes of the node -pub fn linear_config(node: &Node) -> LinearConfig { - if node.inputs.len() < 2 { - panic!("Linear: missing weight tensor"); - } - - let weight_shape = node.inputs[1] - .value - .as_ref() - .expect("Linear: weight tensor must be present") - .shape - .clone(); - - // check if the weight tensor has at least 2 dimensions - if weight_shape.len() < 2 { - panic!( - "Linear: weight tensor must have at least 2 dimensions (got {:?})", - weight_shape.len() - ); - } - - let (in_size, out_size) = (weight_shape[0], weight_shape[1]); - - // check if the bias is present - let bias = node.inputs.len() == 3 && node.inputs[2].value.is_some(); - - LinearConfig::new(in_size, out_size).with_bias(bias) -} - -/// Create a DropoutConfig from an attribute and state of the node -pub fn dropout_config(node: &Node) -> DropoutConfig { - // Opset 7 and older store probability as an attribute - if node.attrs.contains_key("ratio") { - let prob = node.attrs.get("ratio").unwrap().clone().into_f32(); - return DropoutConfig::new(prob as f64); - } - - if node.inputs.len() < 2 { - panic!("Dropout configuration must have at least 2 inputs"); - } - - let ratio = node.inputs[1] - .value - .clone() - .expect("Dropout ratio must be passed in the second input") - .data - .into_scalar(); - - let prob = match ratio { - Data::Float16(ratio) => f64::from(f32::from(ratio)), - Data::Float32(ratio) => ratio as f64, - Data::Float64(ratio) => ratio, - _ => panic!("Dropout ratio must be a float"), - }; - - DropoutConfig::new(prob) -} - -/// Create log_softmax config from the attributes of the node -pub fn log_softmax_config(node: &Node) -> usize { - // the axis is the last dimension (Default: 1 per ONNX spec) - let mut axis: i64 = -1; - - // check if the node has only one input - if node.inputs.len() != 1 { - panic!( - "LogSoftmax: multiple inputs are not supported (got {:?})", - node.inputs.len() - ); - } - - // extract the shape of the input tensor - let tensor = match node.inputs.first().unwrap().clone().ty { - ArgType::Tensor(tensor) => tensor, - _ => panic!("Only tensor input is valid"), - }; - - // extract the attributes - for (key, value) in node.attrs.iter() { - if key.as_str() == "axis" { - axis = value.clone().into_i64() - } - } - - // if axis is negative, it is counted from the end - if axis < 0 { - axis += tensor.rank as i64; - } - - axis as usize -} - -/// Create softmax config from the attributes of the node -pub fn softmax_config(node: &Node) -> usize { - // the axis is the last dimension (Default: 1 per ONNX spec) - let mut axis: i64 = -1; - - // check if the node has only one input - if node.inputs.len() != 1 { - panic!( - "Softmax: multiple inputs are not supported (got {:?})", - node.inputs.len() - ); - } - - // extract the shape of the input tensor - let tensor = match node.inputs.first().unwrap().clone().ty { - ArgType::Tensor(tensor) => tensor, - _ => panic!("Only tensor input is valid"), - }; - - // extract the attributes - for (key, value) in node.attrs.iter() { - if key.as_str() == "axis" { - axis = value.clone().into_i64() - } - } - - // if axis is negative, it is counted from the end - if axis < 0 { - axis += tensor.rank as i64; - } - - axis as usize -} - -/// Create argmax config from the attributes of the node -pub fn argmax_config(node: &Node) -> usize { - let mut axis: i64 = 0; - - // check if the node has only one input - if node.inputs.len() != 1 { - panic!( - "Argmax: multiple inputs are not supported (got {:?})", - node.inputs.len() - ); - } - - // extract the shape of the input tensor - let tensor = match node.inputs.first().unwrap().clone().ty { - ArgType::Tensor(tensor) => tensor, - _ => panic!("Only tensor input is valid"), - }; - - // extract the attributes - for (key, value) in node.attrs.iter() { - match key.as_str() { - "axis" => axis = value.clone().into_i64(), - "select_last_index" => { - // not all params are supported in burn - if value.clone().into_i64() != 0 { - log::warn!( - "only select_last_index=0 is supported for argmax in burn. Ignoring supplied value (got {:?})", - value - ); - } - } - "keepdims" => { - // not all params are supported in burn - if value.clone().into_i64() != 1 { - panic!( - "Only keepdims=1 is supported for argmax in burn (got {:?})", - value - ); - } - } - _ => {} - } - } - - // if axis is negative, it is counted from the end - if axis < 0 { - axis += tensor.rank as i64; - } - - axis as usize -} - -/// Create concat config from the attributes of the node -pub fn concat_config(node: &Node) -> usize { - // the axis is the last dimension (Default: 1 per ONNX spec) - let mut axis: i64 = 1; - - // extract the shape of the input tensor - let tensor = match node.inputs.first().unwrap().clone().ty { - ArgType::Tensor(tensor) => tensor, - _ => panic!("Only tensor input is valid"), - }; - - // extract the attributes - for (key, value) in node.attrs.iter() { - if key.as_str() == "axis" { - axis = value.clone().into_i64() - } - } - - // if axis is negative, it is counted from the end - if axis < 0 { - axis += tensor.rank as i64; - } - - axis as usize -} - -/// Create a BatchNormConfig from the attributes of the node -pub fn batch_norm_config(node: &Node) -> BatchNormConfig { - let weight_shape = node.inputs[1] - .value - .as_ref() - .expect("BatchNorm: weight tensor must be present") - .shape - .clone(); - - let num_features = weight_shape[0]; - - let mut epsilon = 0f32; - let mut momentum = 0f32; - - for (key, value) in node.attrs.iter() { - match key.as_str() { - "momentum" => momentum = value.clone().into_f32(), - "epsilon" => epsilon = value.clone().into_f32(), - _ => {} - } - } - - BatchNormConfig::new(num_features) - .with_epsilon(epsilon as f64) - .with_momentum(momentum as f64) -} - -/// Create a LayerNormConfig from the attributes of the node -pub fn layer_norm_config(node: &Node) -> (LayerNormConfig, bool) { - let weight_shape = node.inputs[1] - .value - .as_ref() - .expect("LayerNorm: weight tensor must be present") - .shape - .clone(); - - let num_features = weight_shape[0]; - - // When `stash_type` is `1` (default), perform operations in 32-bit float and - // cast the results back to original dtype - let mut stash_type = 1; - let mut axis = -1; - let mut epsilon = 1e-5; - - for (key, value) in node.attrs.iter() { - match key.as_str() { - "axis" => axis = value.clone().into_i64(), - "epsilon" => epsilon = value.clone().into_f32(), - "stash_type" => stash_type = value.clone().into_i64(), - _ => {} - } - } - - if axis != -1 && axis != weight_shape.len() as i64 - 1 { - panic!("LayerNorm: normalization is only supported on the last axis right now") - } - - ( - LayerNormConfig::new(num_features).with_epsilon(epsilon as f64), - stash_type == 1, - ) -} - // /// Create a TileConfig from the attributes of the node // pub fn tile_config(node: &Node) -> TileConfig { // let repeat = node From b19ce42ae9a1714537ed959791618a526ec30f6a Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Wed, 30 Apr 2025 14:32:37 -0500 Subject: [PATCH 10/37] Moved op configs to individual node modules Still remaining: expand_config tile_config top_k_config trilu_config pad_config unsqueeze_config split_config --- crates/onnx-ir/src/node/clip.rs | 216 +++++ crates/onnx-ir/src/node/gemm.rs | 134 +++ crates/onnx-ir/src/node/hard_sigmoid.rs | 75 ++ crates/onnx-ir/src/node/leaky_relu.rs | 69 ++ crates/onnx-ir/src/node/mod.rs | 11 + crates/onnx-ir/src/node/one_hot.rs | 140 ++++ crates/onnx-ir/src/node/reduce.rs | 381 +++++++++ crates/onnx-ir/src/node/reshape.rs | 122 +++ crates/onnx-ir/src/node/resize.rs | 276 +++++++ crates/onnx-ir/src/node/shape.rs | 141 ++++ crates/onnx-ir/src/node/squeeze.rs | 79 ++ crates/onnx-ir/src/node/transpose.rs | 98 +++ crates/onnx-ir/src/op_configuration.rs | 1006 +---------------------- 13 files changed, 1759 insertions(+), 989 deletions(-) create mode 100644 crates/onnx-ir/src/node/clip.rs create mode 100644 crates/onnx-ir/src/node/gemm.rs create mode 100644 crates/onnx-ir/src/node/hard_sigmoid.rs create mode 100644 crates/onnx-ir/src/node/leaky_relu.rs create mode 100644 crates/onnx-ir/src/node/one_hot.rs create mode 100644 crates/onnx-ir/src/node/reduce.rs create mode 100644 crates/onnx-ir/src/node/reshape.rs create mode 100644 crates/onnx-ir/src/node/resize.rs create mode 100644 crates/onnx-ir/src/node/shape.rs create mode 100644 crates/onnx-ir/src/node/squeeze.rs create mode 100644 crates/onnx-ir/src/node/transpose.rs diff --git a/crates/onnx-ir/src/node/clip.rs b/crates/onnx-ir/src/node/clip.rs new file mode 100644 index 0000000000..c6f010578c --- /dev/null +++ b/crates/onnx-ir/src/node/clip.rs @@ -0,0 +1,216 @@ +use crate::ir::{Data, Node}; + +pub fn clip_config(node: &Node) -> (Option, Option) { + let mut min_result: Option = None; + let mut max_result: Option = None; + + // For Clip Opset 6+ , the min and max values are attributes + for (key, value) in node.attrs.iter() { + match key.as_str() { + "min" => { + let min = value.clone().into_f32() as f64; + min_result = Some(min); + } + "max" => { + let max = value.clone().into_f32(); + max_result = Some(max as f64); + } + _ => {} + } + } + + // For Clip Opset 11+ , the min and max values are inputs + // Get the min and max values from the input values + if min_result.is_none() && max_result.is_none() { + let min = node.inputs.get(1).and_then(|arg| arg.value.clone()); + let max = node.inputs.get(2).and_then(|arg| arg.value.clone()); + + if min_result.is_none() && min.is_some() { + let min = min.unwrap().data.into_scalar(); + min_result = match min { + Data::Float16(min) => Some(f32::from(min) as f64), + Data::Float32(min) => Some(min as f64), + Data::Float64(min) => Some(min), + _ => panic!("Clip: only float min is supported"), + }; + } + + if max_result.is_none() && max.is_some() { + let max = max.unwrap().data.into_scalar(); + max_result = match max { + Data::Float16(max) => Some(f32::from(max) as f64), + Data::Float32(max) => Some(max as f64), + Data::Float64(max) => Some(max), + _ => panic!("Clip: only float max is supported"), + }; + } + } + + if min_result.is_none() && max_result.is_none() { + panic!("Clip: min and max values must be either attributes or inputs"); + } + + (min_result, max_result) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::{ + ArgType, Argument, AttributeValue, Data, ElementType, NodeType, TensorData, TensorType, + }; + use std::collections::HashMap; + + fn create_test_node_with_attributes(min: Option, max: Option) -> Node { + let inputs = vec![Argument { + name: "X".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 4, + static_shape: None, + }), + value: None, + passed: true, + }]; + + let mut attrs = HashMap::new(); + if let Some(min_val) = min { + attrs.insert("min".to_string(), AttributeValue::Float32(min_val)); + } + if let Some(max_val) = max { + attrs.insert("max".to_string(), AttributeValue::Float32(max_val)); + } + + Node { + node_type: NodeType::Clip, + name: "test_clip".to_string(), + inputs, + outputs: vec![Argument { + name: "Y".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 4, + static_shape: None, + }), + value: None, + passed: true, + }], + attrs, + } + } + + fn create_test_node_with_inputs(min: Option, max: Option) -> Node { + let mut inputs = vec![Argument { + name: "X".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 4, + static_shape: None, + }), + value: None, + passed: true, + }]; + + // Add min input + inputs.push(Argument { + name: "min".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 0, + static_shape: None, + }), + value: min.map(|val| TensorData { + data: Data::Float32(val), + shape: vec![], + }), + passed: true, + }); + + // Add max input + inputs.push(Argument { + name: "max".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 0, + static_shape: None, + }), + value: max.map(|val| TensorData { + data: Data::Float32(val), + shape: vec![], + }), + passed: true, + }); + + Node { + node_type: NodeType::Clip, + name: "test_clip".to_string(), + inputs, + outputs: vec![Argument { + name: "Y".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 4, + static_shape: None, + }), + value: None, + passed: true, + }], + attrs: HashMap::new(), + } + } + + #[test] + fn test_clip_config_with_attributes() { + let node = create_test_node_with_attributes(Some(-1.0), Some(1.0)); + let (min, max) = clip_config(&node); + assert_eq!(min, Some(-1.0)); + assert_eq!(max, Some(1.0)); + } + + #[test] + fn test_clip_config_with_attributes_min_only() { + let node = create_test_node_with_attributes(Some(-1.0), None); + let (min, max) = clip_config(&node); + assert_eq!(min, Some(-1.0)); + assert_eq!(max, None); + } + + #[test] + fn test_clip_config_with_attributes_max_only() { + let node = create_test_node_with_attributes(None, Some(1.0)); + let (min, max) = clip_config(&node); + assert_eq!(min, None); + assert_eq!(max, Some(1.0)); + } + + #[test] + fn test_clip_config_with_inputs() { + let node = create_test_node_with_inputs(Some(-1.0), Some(1.0)); + let (min, max) = clip_config(&node); + assert_eq!(min, Some(-1.0)); + assert_eq!(max, Some(1.0)); + } + + #[test] + fn test_clip_config_with_inputs_min_only() { + let node = create_test_node_with_inputs(Some(-1.0), None); + let (min, max) = clip_config(&node); + assert_eq!(min, Some(-1.0)); + assert_eq!(max, None); + } + + #[test] + fn test_clip_config_with_inputs_max_only() { + let node = create_test_node_with_inputs(None, Some(1.0)); + let (min, max) = clip_config(&node); + assert_eq!(min, None); + assert_eq!(max, Some(1.0)); + } + + #[test] + #[should_panic(expected = "Clip: min and max values must be either attributes or inputs")] + fn test_clip_config_no_min_max() { + let node = create_test_node_with_attributes(None, None); + let _ = clip_config(&node); + } +} diff --git a/crates/onnx-ir/src/node/gemm.rs b/crates/onnx-ir/src/node/gemm.rs new file mode 100644 index 0000000000..ef1c5dc90b --- /dev/null +++ b/crates/onnx-ir/src/node/gemm.rs @@ -0,0 +1,134 @@ +use crate::ir::Node; + +pub fn gemm_config(curr: &Node) -> (f32, f32, i64, i64) { + let alpha = curr + .attrs + .get("alpha") + .map(|val| val.clone().into_f32()) + .unwrap_or(1.0); + let beta = curr + .attrs + .get("beta") + .map(|val| val.clone().into_f32()) + .unwrap_or(1.0); + let trans_a = curr + .attrs + .get("transA") + .map(|val| val.clone().into_i64()) + .unwrap_or(0); + let trans_b = curr + .attrs + .get("transB") + .map(|val| val.clone().into_i64()) + .unwrap_or(0); + + (alpha, beta, trans_a, trans_b) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::{ArgType, Argument, AttributeValue, ElementType, NodeType, TensorType}; + use std::collections::HashMap; + + fn create_test_node( + alpha: Option, + beta: Option, + trans_a: Option, + trans_b: Option, + ) -> Node { + let inputs = vec![ + Argument { + name: "A".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 2, + static_shape: None, + }), + value: None, + passed: true, + }, + Argument { + name: "B".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 2, + static_shape: None, + }), + value: None, + passed: true, + }, + Argument { + name: "C".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 2, + static_shape: None, + }), + value: None, + passed: true, + }, + ]; + + let mut attrs = HashMap::new(); + if let Some(alpha_val) = alpha { + attrs.insert("alpha".to_string(), AttributeValue::Float32(alpha_val)); + } + if let Some(beta_val) = beta { + attrs.insert("beta".to_string(), AttributeValue::Float32(beta_val)); + } + if let Some(trans_a_val) = trans_a { + attrs.insert("transA".to_string(), AttributeValue::Int64(trans_a_val)); + } + if let Some(trans_b_val) = trans_b { + attrs.insert("transB".to_string(), AttributeValue::Int64(trans_b_val)); + } + + Node { + node_type: NodeType::Gemm, + name: "test_gemm".to_string(), + inputs, + outputs: vec![Argument { + name: "Y".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 2, + static_shape: None, + }), + value: None, + passed: true, + }], + attrs, + } + } + + #[test] + fn test_gemm_config_defaults() { + let node = create_test_node(None, None, None, None); + let (alpha, beta, trans_a, trans_b) = gemm_config(&node); + assert_eq!(alpha, 1.0); + assert_eq!(beta, 1.0); + assert_eq!(trans_a, 0); + assert_eq!(trans_b, 0); + } + + #[test] + fn test_gemm_config_with_attrs() { + let node = create_test_node(Some(2.0), Some(3.0), Some(1), Some(1)); + let (alpha, beta, trans_a, trans_b) = gemm_config(&node); + assert_eq!(alpha, 2.0); + assert_eq!(beta, 3.0); + assert_eq!(trans_a, 1); + assert_eq!(trans_b, 1); + } + + #[test] + fn test_gemm_config_partial_attrs() { + let node = create_test_node(Some(0.5), None, Some(1), None); + let (alpha, beta, trans_a, trans_b) = gemm_config(&node); + assert_eq!(alpha, 0.5); + assert_eq!(beta, 1.0); // default + assert_eq!(trans_a, 1); + assert_eq!(trans_b, 0); // default + } +} diff --git a/crates/onnx-ir/src/node/hard_sigmoid.rs b/crates/onnx-ir/src/node/hard_sigmoid.rs new file mode 100644 index 0000000000..636dbc48a6 --- /dev/null +++ b/crates/onnx-ir/src/node/hard_sigmoid.rs @@ -0,0 +1,75 @@ +use crate::ir::Node; + +/// Create a HardSigmoidConfig from the alpha and beta attributes of the node +pub fn hard_sigmoid_config(node: &Node) -> (f64, f64) { + let mut alpha = 0.2; + let mut beta = 0.5; + + for (key, value) in node.attrs.iter() { + match key.as_str() { + "alpha" => alpha = value.clone().into_f32() as f64, + "beta" => beta = value.clone().into_f32() as f64, + _ => {} + } + } + + (alpha, beta) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::{ArgType, Argument, AttributeValue, ElementType, NodeType, TensorType}; + use std::collections::HashMap; + + fn create_test_node(alpha: f32, beta: f32) -> Node { + let inputs = vec![Argument { + name: "X".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 4, + static_shape: None, + }), + value: None, + passed: true, + }]; + + let mut attrs = HashMap::new(); + attrs.insert("alpha".to_string(), AttributeValue::Float32(alpha)); + attrs.insert("beta".to_string(), AttributeValue::Float32(beta)); + + Node { + node_type: NodeType::HardSigmoid, + name: "test_hard_sigmoid".to_string(), + inputs, + outputs: vec![Argument { + name: "Y".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 4, + static_shape: None, + }), + value: None, + passed: true, + }], + attrs, + } + } + + #[test] + fn test_hard_sigmoid_config_with_attrs() { + let node = create_test_node(0.3, 0.6); + let (alpha, beta) = hard_sigmoid_config(&node); + assert!((alpha - 0.3).abs() < 1e-6); + assert!((beta - 0.6).abs() < 1e-6); + } + + #[test] + fn test_hard_sigmoid_config_default() { + let mut node = create_test_node(0.3, 0.6); + node.attrs.clear(); // Remove all attributes + let (alpha, beta) = hard_sigmoid_config(&node); + assert_eq!(alpha, 0.2); // Check default values + assert_eq!(beta, 0.5); + } +} diff --git a/crates/onnx-ir/src/node/leaky_relu.rs b/crates/onnx-ir/src/node/leaky_relu.rs new file mode 100644 index 0000000000..6dcad47fc9 --- /dev/null +++ b/crates/onnx-ir/src/node/leaky_relu.rs @@ -0,0 +1,69 @@ +use crate::ir::Node; + +/// Create a LeakyReluConfig from the alpha attribute of the node +pub fn leaky_relu_config(node: &Node) -> f64 { + let mut alpha = 0.01; + + for (key, value) in node.attrs.iter() { + if key.as_str() == "alpha" { + alpha = value.clone().into_f32() as f64 + } + } + + alpha +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::{ArgType, Argument, AttributeValue, ElementType, NodeType, TensorType}; + use std::collections::HashMap; + + fn create_test_node(alpha: f32) -> Node { + let inputs = vec![Argument { + name: "X".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 4, + static_shape: None, + }), + value: None, + passed: true, + }]; + + let mut attrs = HashMap::new(); + attrs.insert("alpha".to_string(), AttributeValue::Float32(alpha)); + + Node { + node_type: NodeType::LeakyRelu, + name: "test_leaky_relu".to_string(), + inputs, + outputs: vec![Argument { + name: "Y".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 4, + static_shape: None, + }), + value: None, + passed: true, + }], + attrs, + } + } + + #[test] + fn test_leaky_relu_config_with_alpha() { + let node = create_test_node(0.2); + let alpha = leaky_relu_config(&node); + assert!((alpha - 0.2).abs() < 1e-6); + } + + #[test] + fn test_leaky_relu_config_default() { + let mut node = create_test_node(0.2); + node.attrs.clear(); // Remove all attributes + let alpha = leaky_relu_config(&node); + assert_eq!(alpha, 0.01); // Check default value + } +} diff --git a/crates/onnx-ir/src/node/mod.rs b/crates/onnx-ir/src/node/mod.rs index e3a4ac6df7..f212998958 100644 --- a/crates/onnx-ir/src/node/mod.rs +++ b/crates/onnx-ir/src/node/mod.rs @@ -2,6 +2,7 @@ pub mod argmax; pub mod avg_pool1d; pub mod avg_pool2d; pub mod batch_norm; +pub mod clip; pub mod concat; pub mod conv1d; pub mod conv2d; @@ -12,10 +13,20 @@ pub mod conv_transpose3d; pub mod dropout; pub mod flatten; pub mod gather; +pub mod gemm; +pub mod hard_sigmoid; pub mod layer_norm; +pub mod leaky_relu; pub mod linear; pub mod log_softmax; pub mod max_pool1d; pub mod max_pool2d; +pub mod one_hot; +pub mod reduce; +pub mod reshape; +pub mod resize; +pub mod shape; pub mod slice; pub mod softmax; +pub mod squeeze; +pub mod transpose; diff --git a/crates/onnx-ir/src/node/one_hot.rs b/crates/onnx-ir/src/node/one_hot.rs new file mode 100644 index 0000000000..496a357e57 --- /dev/null +++ b/crates/onnx-ir/src/node/one_hot.rs @@ -0,0 +1,140 @@ +use crate::ir::Node; + +pub fn one_hot_config(curr: &Node) -> (usize, [f32; 2], i64) { + let depth = curr.inputs[1] + .value + .clone() + .expect("OneHot: Only constant depth is currently supported") + .data + .into_i64(); + + let values = curr.inputs[2] + .value + .clone() + .expect("OneHot: Only constant on/off values is currently supported") + .data + .into_f32s(); + + let axis = curr + .attrs + .get("axis") + .map(|val| val.clone().into_i64()) + .unwrap_or(-1); + + (depth as usize, values.try_into().unwrap(), axis) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::{ + ArgType, Argument, AttributeValue, Data, ElementType, NodeType, TensorData, TensorType, + }; + use std::collections::HashMap; + + fn create_test_node(depth: i64, values: Vec, axis: Option) -> Node { + let inputs = vec![ + Argument { + name: "indices".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Int64, + rank: 2, + static_shape: None, + }), + value: None, + passed: true, + }, + Argument { + name: "depth".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Int64, + rank: 0, + static_shape: None, + }), + value: Some(TensorData { + data: Data::Int64(depth), + shape: vec![], + }), + passed: true, + }, + Argument { + name: "values".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 1, + static_shape: None, + }), + value: Some(TensorData { + data: Data::Float32s(values), + shape: vec![2], // always [off_value, on_value] + }), + passed: true, + }, + ]; + + let mut attrs = HashMap::new(); + if let Some(axis_val) = axis { + attrs.insert("axis".to_string(), AttributeValue::Int64(axis_val)); + } + + Node { + node_type: NodeType::OneHot, + name: "test_one_hot".to_string(), + inputs, + outputs: vec![Argument { + name: "output".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 3, // rank increases by 1 + static_shape: None, + }), + value: None, + passed: true, + }], + attrs, + } + } + + #[test] + fn test_one_hot_config_basic() { + let node = create_test_node(5, vec![0.0, 1.0], None); + let (depth, values, axis) = one_hot_config(&node); + assert_eq!(depth, 5); + assert_eq!(values, [0.0, 1.0]); + assert_eq!(axis, -1); // default axis + } + + #[test] + fn test_one_hot_config_with_axis() { + let node = create_test_node(5, vec![0.0, 1.0], Some(1)); + let (depth, values, axis) = one_hot_config(&node); + assert_eq!(depth, 5); + assert_eq!(values, [0.0, 1.0]); + assert_eq!(axis, 1); + } + + #[test] + fn test_one_hot_config_custom_values() { + let node = create_test_node(10, vec![-1.0, 2.0], None); + let (depth, values, axis) = one_hot_config(&node); + assert_eq!(depth, 10); + assert_eq!(values, [-1.0, 2.0]); // custom off/on values + assert_eq!(axis, -1); + } + + #[test] + #[should_panic(expected = "Only constant depth is currently supported")] + fn test_one_hot_config_no_depth_value() { + let mut node = create_test_node(5, vec![0.0, 1.0], None); + node.inputs[1].value = None; // Remove depth value + let _ = one_hot_config(&node); + } + + #[test] + #[should_panic(expected = "Only constant on/off values is currently supported")] + fn test_one_hot_config_no_values() { + let mut node = create_test_node(5, vec![0.0, 1.0], None); + node.inputs[2].value = None; // Remove values + let _ = one_hot_config(&node); + } +} diff --git a/crates/onnx-ir/src/node/reduce.rs b/crates/onnx-ir/src/node/reduce.rs new file mode 100644 index 0000000000..b0148581ea --- /dev/null +++ b/crates/onnx-ir/src/node/reduce.rs @@ -0,0 +1,381 @@ +use crate::ir::{ArgType, Node}; + +/// Create a ReduceMaxConfig from the attributes of the node +pub fn reduce_max_config(node: &Node) -> Option { + let mut axes = Vec::new(); + let mut keepdims = 1; + + let tensor = match node.inputs.first().unwrap().clone().ty { + ArgType::Tensor(tensor) => tensor, + _ => panic!("Only tensor input is valid"), + }; + + // Extract the attributes + for (key, value) in node.attrs.iter() { + match key.as_str() { + "axes" => axes = value.clone().into_i64s(), + "keepdims" => keepdims = value.clone().into_i64(), + _ => {} + } + } + + if axes.len() > 1 { + panic!("ReduceMax: reducing on multiple dimensions is not supported") + } + + if axes.is_empty() && keepdims == 1 { + panic!("ReduceMax: axes must be provided with keepdims") + } + + if !axes.is_empty() && keepdims == 0 { + // Not supported in Burn + panic!("ReduceMax: the reduce operation must preserve the reduced dimension") + } + + if axes.is_empty() { + None + } else { + let mut dim = axes[0]; + + if dim < 0 { + // Accepted range is [-r, r-1] where r = rank(data) but Burn only supports positive dim + dim += tensor.rank as i64; + } + Some(dim as usize) + } +} + +/// Create a ReduceMinConfig from the attributes of the node +pub fn reduce_min_config(node: &Node) -> Option { + let mut axes = Vec::new(); + let mut keepdims = 1; + + let tensor = match node.inputs.first().unwrap().clone().ty { + ArgType::Tensor(tensor) => tensor, + _ => panic!("Only tensor input is valid"), + }; + + // Extract the attributes + for (key, value) in node.attrs.iter() { + match key.as_str() { + "axes" => axes = value.clone().into_i64s(), + "keepdims" => keepdims = value.clone().into_i64(), + _ => {} + } + } + + if axes.len() > 1 { + panic!("ReduceMin: reducing on multiple dimensions is not supported") + } + + if axes.is_empty() && keepdims == 1 { + panic!("ReduceMin: axes must be provided with keepdims") + } + + if !axes.is_empty() && keepdims == 0 { + panic!("ReduceMin: the reduce operation must preserve the reduced dimension") + } + + if axes.is_empty() { + None + } else { + let mut dim = axes[0]; + + if dim < 0 { + dim += tensor.rank as i64; + } + Some(dim as usize) + } +} + +/// Create a ReduceMeanConfig from the attributes of the node +pub fn reduce_mean_config(node: &Node) -> Option { + let mut axes = Vec::new(); + let mut keepdims = 1; + + let tensor = match node.inputs.first().unwrap().clone().ty { + ArgType::Tensor(tensor) => tensor, + _ => panic!("Only tensor input is valid"), + }; + + // Extract the attributes + for (key, value) in node.attrs.iter() { + match key.as_str() { + "axes" => axes = value.clone().into_i64s(), + "keepdims" => keepdims = value.clone().into_i64(), + _ => {} + } + } + + if axes.len() > 1 { + panic!("ReduceMean: reducing on multiple dimensions is not supported") + } + + if axes.is_empty() && keepdims == 1 { + panic!("ReduceMean: axes must be provided with keepdims") + } + + if !axes.is_empty() && keepdims == 0 { + // Not supported in Burn + panic!("ReduceMean: the reduce operation must preserve the reduced dimension") + } + + if axes.is_empty() { + None + } else { + let mut dim = axes[0]; + + if dim < 0 { + // Accepted range is [-r, r-1] where r = rank(data) but Burn only supports positive dim + dim += tensor.rank as i64; + } + Some(dim as usize) + } +} + +/// Create a ReduceProdConfig from the attributes of the node +pub fn reduce_prod_config(node: &Node) -> Option { + let mut axes = Vec::new(); + let mut keepdims = 1; + + let tensor = match node.inputs.first().unwrap().clone().ty { + ArgType::Tensor(tensor) => tensor, + _ => panic!("Only tensor input is valid"), + }; + + // Extract the attributes + for (key, value) in node.attrs.iter() { + match key.as_str() { + "axes" => axes = value.clone().into_i64s(), + "keepdims" => keepdims = value.clone().into_i64(), + // TODO: handle noop_with_empty_axes (opset 18) + _ => {} + } + } + + if axes.len() > 1 { + panic!("ReduceProd: reducing on multiple dimensions is not supported") + } + + if axes.is_empty() && keepdims == 1 { + panic!("ReduceProd: axes must be provided with keepdims") + } + + if !axes.is_empty() && keepdims == 0 { + // Not supported in Burn + panic!("ReduceProd: the reduce operation must preserve the reduced dimension") + } + + if axes.is_empty() { + None + } else { + let mut dim = axes[0]; + + if dim < 0 { + // Accepted range is [-r, r-1] where r = rank(data) but Burn only supports positive dim + dim += tensor.rank as i64; + } + Some(dim as usize) + } +} + +/// Create a ReduceSumConfig from the attributes of the node +pub fn reduce_sum_config(node: &Node) -> Option { + let mut axes = Vec::new(); + let mut keepdims = 1; + + let tensor = match node.inputs.first().unwrap().clone().ty { + ArgType::Tensor(tensor) => tensor, + _ => panic!("Only tensor input is valid"), + }; + + // Extract the attributes + for (key, value) in node.attrs.iter() { + match key.as_str() { + "keepdims" => keepdims = value.clone().into_i64(), + "axes" => axes = value.clone().into_i64s(), + // TODO: handle noop_with_empty_axes + _ => {} + } + } + + // TODO: Handle case where axes are passed in. Will require its own ReduceSumNode instead of a UnaryNode. + if let Some(value) = node + .inputs + .get(1) + .and_then(|argument| argument.value.as_ref()) + { + axes = value.clone().data.into_i64s(); + } + + if axes.len() > 1 { + panic!("ReduceSum: reducing on multiple dimensions is not supported") + } + + if axes.is_empty() && keepdims == 1 { + panic!("ReduceSum: axes must be provided with keepdims") + } + + if !axes.is_empty() && keepdims == 0 { + // Not supported in Burn + panic!("ReduceSum: the reduce operation must preserve the reduced dimension") + } + + if axes.is_empty() { + None + } else { + let mut dim = axes[0]; + + if dim < 0 { + // Accepted range is [-r, r-1] where r = rank(data) but Burn only supports positive dim + dim += tensor.rank as i64; + } + Some(dim as usize) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::{ + Argument, AttributeValue, Data, ElementType, NodeType, TensorData, TensorType, + }; + use std::collections::HashMap; + + fn create_test_node( + node_type: NodeType, + axes: Option>, + keepdims: Option, + with_axes_input: bool, + ) -> Node { + let mut inputs = vec![Argument { + name: "data".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 3, + static_shape: None, + }), + value: None, + passed: true, + }]; + + // Add axes input if requested + if with_axes_input && axes.is_some() { + let axes_clone = axes.clone().unwrap(); + inputs.push(Argument { + name: "axes".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Int64, + rank: 1, + static_shape: None, + }), + value: Some(TensorData { + data: Data::Int64s(axes_clone.clone()), + shape: vec![axes_clone.len()], + }), + passed: true, + }); + } + + let mut attrs = HashMap::new(); + if !with_axes_input && axes.is_some() { + attrs.insert( + "axes".to_string(), + AttributeValue::Int64s(axes.clone().unwrap()), + ); + } + if let Some(kd) = keepdims { + attrs.insert("keepdims".to_string(), AttributeValue::Int64(kd)); + } + + let node_type_clone = node_type.clone(); + Node { + node_type, + name: format!("test_{:?}", node_type_clone).to_lowercase(), + inputs, + outputs: vec![Argument { + name: "reduced".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 3, + static_shape: None, + }), + value: None, + passed: true, + }], + attrs, + } + } + + #[test] + fn test_reduce_max_config_basic() { + let node = create_test_node(NodeType::ReduceMax, Some(vec![1]), Some(1), false); + let dim = reduce_max_config(&node); + assert_eq!(dim, Some(1)); + } + + #[test] + fn test_reduce_max_config_negative_axis() { + let node = create_test_node(NodeType::ReduceMax, Some(vec![-2]), Some(1), false); + let dim = reduce_max_config(&node); + assert_eq!(dim, Some(1)); // -2 + 3 = 1 + } + + #[test] + #[should_panic(expected = "ReduceMax: axes must be provided with keepdims")] + fn test_reduce_max_config_no_axes() { + let node = create_test_node(NodeType::ReduceMax, None, Some(1), false); + let _ = reduce_max_config(&node); + } + + #[test] + #[should_panic(expected = "ReduceMax: reducing on multiple dimensions is not supported")] + fn test_reduce_max_config_multiple_axes() { + let node = create_test_node(NodeType::ReduceMax, Some(vec![0, 1]), Some(1), false); + let _ = reduce_max_config(&node); + } + + #[test] + #[should_panic( + expected = "ReduceMax: the reduce operation must preserve the reduced dimension" + )] + fn test_reduce_max_config_no_keepdims() { + let node = create_test_node(NodeType::ReduceMax, Some(vec![1]), Some(0), false); + let _ = reduce_max_config(&node); + } + + #[test] + fn test_reduce_min_config_basic() { + let node = create_test_node(NodeType::ReduceMin, Some(vec![1]), Some(1), false); + let dim = reduce_min_config(&node); + assert_eq!(dim, Some(1)); + } + + #[test] + fn test_reduce_mean_config_basic() { + let node = create_test_node(NodeType::ReduceMean, Some(vec![1]), Some(1), false); + let dim = reduce_mean_config(&node); + assert_eq!(dim, Some(1)); + } + + #[test] + fn test_reduce_prod_config_basic() { + let node = create_test_node(NodeType::ReduceProd, Some(vec![1]), Some(1), false); + let dim = reduce_prod_config(&node); + assert_eq!(dim, Some(1)); + } + + #[test] + fn test_reduce_sum_config_basic() { + let node = create_test_node(NodeType::ReduceSum, Some(vec![1]), Some(1), false); + let dim = reduce_sum_config(&node); + assert_eq!(dim, Some(1)); + } + + #[test] + fn test_reduce_sum_config_with_input_axes() { + let node = create_test_node(NodeType::ReduceSum, Some(vec![1]), Some(1), true); + let dim = reduce_sum_config(&node); + assert_eq!(dim, Some(1)); + } +} diff --git a/crates/onnx-ir/src/node/reshape.rs b/crates/onnx-ir/src/node/reshape.rs new file mode 100644 index 0000000000..55aab093d6 --- /dev/null +++ b/crates/onnx-ir/src/node/reshape.rs @@ -0,0 +1,122 @@ +use crate::ir::{Node, TensorData}; + +pub fn reshape_config(node: &Node) -> Vec { + let mut allowzero = 0; + + for (key, value) in node.attrs.iter() { + if key.as_str() == "allowzero" { + allowzero = value.clone().into_i64() + } + } + + // Burn does not support zero size shape (0 means false in ONNX) + // (see https://onnx.ai/onnx/operators/onnx__Reshape.html#attributes) + if allowzero != 0 { + panic!("Zero shape size is not supported"); + } + + // TODO: check "shape" attribute + if node.inputs.len() != 2 || node.inputs[1].value.is_none() { + panic!("Reshape: shape tensor must be present for {:?}", node); + } + + match &node.inputs[1].value { + Some(TensorData { data, shape, .. }) => { + assert_eq!(shape.len(), 1, "Reshape: shape tensor must be 1D"); + data.clone().into_i64s() + } + _ => panic!("Only tensor input is valid for shape"), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::{ + ArgType, Argument, AttributeValue, Data, ElementType, NodeType, TensorData, TensorType, + }; + use std::collections::HashMap; + + fn create_test_node(allowzero: i64, shape_vec: Vec) -> Node { + let inputs = vec![ + Argument { + name: "data".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 4, + static_shape: None, + }), + value: None, + passed: true, + }, + Argument { + name: "shape".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Int64, + rank: 1, + static_shape: None, + }), + value: Some(TensorData { + data: Data::Int64s(shape_vec), + shape: vec![2], + }), + passed: true, + }, + ]; + + let mut attrs = HashMap::new(); + if allowzero != 0 { + attrs.insert("allowzero".to_string(), AttributeValue::Int64(allowzero)); + } + + Node { + node_type: NodeType::Reshape, + name: "test_reshape".to_string(), + inputs, + outputs: vec![Argument { + name: "reshaped".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 2, + static_shape: None, + }), + value: None, + passed: true, + }], + attrs, + } + } + + #[test] + fn test_reshape_config_basic() { + let node = create_test_node(0, vec![2, 3]); + let shape = reshape_config(&node); + assert_eq!(shape, vec![2, 3]); + } + + #[test] + #[should_panic(expected = "Zero shape size is not supported")] + fn test_reshape_config_allowzero_not_supported() { + let node = create_test_node(1, vec![2, 3]); + let _ = reshape_config(&node); + } + + #[test] + #[should_panic(expected = "shape tensor must be present")] + fn test_reshape_config_no_shape_input() { + let mut node = create_test_node(0, vec![2, 3]); + node.inputs.pop(); // Remove the shape input + let _ = reshape_config(&node); + } + + #[test] + #[should_panic(expected = "shape tensor must be 1D")] + fn test_reshape_config_invalid_shape_dim() { + let mut node = create_test_node(0, vec![2, 3]); + // Modify the shape tensor's shape to be 2D + if let Some(tensor_data) = &mut node.inputs[1].value { + tensor_data.shape = vec![2, 1]; + } + let _ = reshape_config(&node); + } +} diff --git a/crates/onnx-ir/src/node/resize.rs b/crates/onnx-ir/src/node/resize.rs new file mode 100644 index 0000000000..51180dddaa --- /dev/null +++ b/crates/onnx-ir/src/node/resize.rs @@ -0,0 +1,276 @@ +use crate::ir::{ArgType, Node, TensorData}; + +pub fn resize_config(node: &Node) -> (String, Vec, Vec) { + let mut mode: String = "".to_string(); + + let mut scales: Vec; + let mut sizes: Vec; + + let input = if let ArgType::Tensor(tensor) = &node + .inputs + .first() + .expect("Resize: Input tensor must be present") + .ty + { + tensor + } else { + panic!("Resize: input must be a tensor") + }; + + // Note: we are ignoring some attributes because results are approximately the same + // and we are not supporting all the attributes of the Resize operator. + // However, some attributes are important to be checked and we are checking + // against the default values of the attributes. + // TODO revisit this when we have more Resize operators in the model + for (key, value) in node.attrs.iter() { + match key.as_str() { + "antialias" => assert_eq!( + value.clone().into_i32(), + 0, + "Resize: antialias other than 0 is not supported" + ), + "axes" => panic!("Resize: custom axes attribute is not supported"), + "coordinate_transformation_mode" => { + log::warn!("Resize: coordinate_transformation_mode is ignored") + } + + "cubic_coeff_a" => log::warn!("Resize: cubic_coeff_a is ignored"), + "exclude_outside" => assert_eq!( + value.clone().into_i32(), + 0, + "Resize: exclude_outside other than 0 is not supported" + ), + "extrapolation_value" => assert_eq!( + value.clone().into_f32(), + 0.0, + "Resize: extrapolation_value other than 0.0 is not supported" + ), + "keep_aspect_ratio_policy" => { + assert_eq!( + value.clone().into_string().to_lowercase(), + "stretch", + "Resize: keep_aspect_ratio_policy other than 'stretch' is not supported" + ) + } + "mode" => mode = value.clone().into_string().to_lowercase(), + "nearest_mode" => log::warn!("Resize: nearest_mode is ignored"), + + _ => {} + } + } + + let roi: Vec = node + .inputs + .get(1) + .map(|input| { + if let Some(TensorData { data, .. }) = &input.value { + data.clone().into_f32s() + } else { + vec![] + } + }) + .unwrap_or_default(); + + scales = node + .inputs + .get(2) + .map(|input| { + if let Some(TensorData { data, .. }) = &input.value { + data.clone().into_f32s() + } else { + vec![] + } + }) + .unwrap_or_default(); + + sizes = node + .inputs + .get(3) + .map(|input| { + if let Some(TensorData { data, .. }) = &input.value { + data.clone() + .into_i64s() + .iter() + .map(|&x| x as usize) + .collect() + } else { + vec![] + } + }) + .unwrap_or_default(); + + if mode.is_empty() { + panic!("Resize: mode attribute is required") + } + + if !roi.is_empty() { + panic!("Resize: roi input is not supported") + } + + if scales.is_empty() && sizes.is_empty() { + panic!("Resize: either scales or sizes input is required") + } + + if !scales.is_empty() { + assert!(scales.len() == input.rank); + // ignore the fist two items from scales + // because they are the batch and channel dimensions + scales = scales.iter().skip(2).cloned().collect(); + } + + if !sizes.is_empty() { + assert!(sizes.len() == input.rank); + // ignore the fist two items from sizes + // because they are the batch and channel dimensions + sizes = sizes.iter().skip(2).cloned().collect(); + } + + (mode, scales, sizes) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::{ + Argument, AttributeValue, Data, ElementType, NodeType, TensorData, TensorType, + }; + use std::collections::HashMap; + + fn create_test_node( + mode: &str, + scales: Option>, + sizes: Option>, + roi: Option>, + ) -> Node { + let mut inputs = vec![Argument { + name: "X".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 4, // N,C,H,W format + static_shape: None, + }), + value: None, + passed: true, + }]; + + // Add ROI input if provided + inputs.push(Argument { + name: "roi".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 1, + static_shape: None, + }), + value: roi.map(|data| TensorData { + data: Data::Float32s(data), + shape: vec![8], // For 4D input (start x, start y, end x, end y) + }), + passed: true, + }); + + // Add scales input if provided + inputs.push(Argument { + name: "scales".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 1, + static_shape: None, + }), + value: scales.map(|data| TensorData { + data: Data::Float32s(data), + shape: vec![4], // N,C,H,W scales + }), + passed: true, + }); + + // Add sizes input if provided + inputs.push(Argument { + name: "sizes".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Int64, + rank: 1, + static_shape: None, + }), + value: sizes.map(|data| TensorData { + data: Data::Int64s(data), + shape: vec![4], // N,C,H,W sizes + }), + passed: true, + }); + + let mut attrs = HashMap::new(); + attrs.insert("mode".to_string(), AttributeValue::String(mode.to_string())); + + Node { + node_type: NodeType::Resize, + name: "test_resize".to_string(), + inputs, + outputs: vec![Argument { + name: "Y".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 4, + static_shape: None, + }), + value: None, + passed: true, + }], + attrs, + } + } + + #[test] + fn test_resize_config_with_scales() { + let node = create_test_node( + "nearest", + Some(vec![1.0, 1.0, 2.0, 2.0]), // Keep N,C same, double H,W + None, + None, + ); + let (mode, scales, sizes) = resize_config(&node); + assert_eq!(mode, "nearest"); + assert_eq!(scales, vec![2.0, 2.0]); // Only the spatial scales (H,W) + assert!(sizes.is_empty()); + } + + #[test] + fn test_resize_config_with_sizes() { + let node = create_test_node( + "linear", + None, + Some(vec![1, 3, 224, 224]), // Fixed output size + None, + ); + let (mode, scales, sizes) = resize_config(&node); + assert_eq!(mode, "linear"); + assert!(scales.is_empty()); + assert_eq!(sizes, vec![224, 224]); // Only the spatial sizes (H,W) + } + + #[test] + #[should_panic(expected = "Resize: roi input is not supported")] + fn test_resize_config_with_roi() { + let node = create_test_node( + "nearest", + Some(vec![1.0, 1.0, 2.0, 2.0]), + None, + Some(vec![0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0]), // ROI values + ); + let _ = resize_config(&node); + } + + #[test] + #[should_panic(expected = "Resize: either scales or sizes input is required")] + fn test_resize_config_no_scales_or_sizes() { + let node = create_test_node("nearest", None, None, None); + let _ = resize_config(&node); + } + + #[test] + #[should_panic(expected = "Resize: mode attribute is required")] + fn test_resize_config_no_mode() { + let mut node = create_test_node("nearest", Some(vec![1.0, 1.0, 2.0, 2.0]), None, None); + node.attrs.clear(); // Remove all attributes including mode + let _ = resize_config(&node); + } +} diff --git a/crates/onnx-ir/src/node/shape.rs b/crates/onnx-ir/src/node/shape.rs new file mode 100644 index 0000000000..3490ae1c63 --- /dev/null +++ b/crates/onnx-ir/src/node/shape.rs @@ -0,0 +1,141 @@ +use crate::ir::{ArgType, Node}; + +pub fn shape_config(curr: &Node) -> (usize, usize) { + if curr.inputs.len() != 1 { + panic!( + "Shape: multiple inputs are not supported (got {:?})", + curr.inputs.len() + ); + } + + // Extract the shape of the input tensor + let tensor = match curr.inputs.first().unwrap().clone().ty { + ArgType::Tensor(tensor) => tensor, + _ => panic!("Only tensor input is valid"), + }; + + // Default: all axes up to the last one (included) + let mut start_dim: i64 = 0; + let mut end_dim: i64 = tensor.rank as i64; + + // Extract the attributes + for (key, value) in curr.attrs.iter() { + match key.as_str() { + "start" => start_dim = value.clone().into_i64(), + "end" => end_dim = value.clone().into_i64(), + _ => {} + } + } + + // If dim is negative, it is counted from the end + if start_dim < 0 { + start_dim += tensor.rank as i64; + } + if end_dim < 0 { + end_dim += tensor.rank as i64; + } + + (start_dim as usize, end_dim as usize) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::{Argument, AttributeValue, ElementType, NodeType, TensorType}; + use std::collections::HashMap; + + fn create_test_node(start: Option, end: Option, rank: usize) -> Node { + let inputs = vec![Argument { + name: "data".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank, + static_shape: None, + }), + value: None, + passed: true, + }]; + + let mut attrs = HashMap::new(); + if let Some(start_val) = start { + attrs.insert("start".to_string(), AttributeValue::Int64(start_val)); + } + if let Some(end_val) = end { + attrs.insert("end".to_string(), AttributeValue::Int64(end_val)); + } + + Node { + node_type: NodeType::Shape, + name: "test_shape".to_string(), + inputs, + outputs: vec![Argument { + name: "shape".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Int64, + rank: 1, + static_shape: None, + }), + value: None, + passed: true, + }], + attrs, + } + } + + #[test] + fn test_shape_config_defaults() { + let node = create_test_node(None, None, 4); + let (start, end) = shape_config(&node); + assert_eq!(start, 0); + assert_eq!(end, 4); + } + + #[test] + fn test_shape_config_with_start() { + let node = create_test_node(Some(1), None, 4); + let (start, end) = shape_config(&node); + assert_eq!(start, 1); + assert_eq!(end, 4); + } + + #[test] + fn test_shape_config_with_end() { + let node = create_test_node(None, Some(3), 4); + let (start, end) = shape_config(&node); + assert_eq!(start, 0); + assert_eq!(end, 3); + } + + #[test] + fn test_shape_config_with_start_and_end() { + let node = create_test_node(Some(1), Some(3), 4); + let (start, end) = shape_config(&node); + assert_eq!(start, 1); + assert_eq!(end, 3); + } + + #[test] + fn test_shape_config_negative_dims() { + let node = create_test_node(Some(-2), Some(-1), 4); + let (start, end) = shape_config(&node); + assert_eq!(start, 2); // -2 + 4 = 2 + assert_eq!(end, 3); // -1 + 4 = 3 + } + + #[test] + #[should_panic(expected = "Shape: multiple inputs are not supported")] + fn test_shape_config_multiple_inputs() { + let mut node = create_test_node(None, None, 4); + node.inputs.push(Argument { + name: "extra".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 4, + static_shape: None, + }), + value: None, + passed: true, + }); + let _ = shape_config(&node); + } +} diff --git a/crates/onnx-ir/src/node/squeeze.rs b/crates/onnx-ir/src/node/squeeze.rs new file mode 100644 index 0000000000..8165eb2cf3 --- /dev/null +++ b/crates/onnx-ir/src/node/squeeze.rs @@ -0,0 +1,79 @@ +use crate::ir::{ArgType, Node}; + +pub fn squeeze_config(curr: &Node) -> Vec { + let axes = curr + .attrs + .iter() + .filter_map(|(key, value)| { + if key == "axes" { + Some(value.clone().into_i64s()) + } else { + None + } + }) + .next() + .unwrap_or_else(Vec::new); + + match curr.inputs.first().unwrap().clone().ty { + ArgType::Tensor(tensor) => tensor, + _ => panic!("Only tensor input is valid"), + }; + + axes +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::{Argument, AttributeValue, ElementType, NodeType, TensorType}; + use std::collections::HashMap; + + fn create_test_node(axes: Option>, rank: usize) -> Node { + let inputs = vec![Argument { + name: "data".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank, + static_shape: None, + }), + value: None, + passed: true, + }]; + + let mut attrs = HashMap::new(); + if let Some(ref axes_val) = axes { + attrs.insert("axes".to_string(), AttributeValue::Int64s(axes_val.clone())); + } + + Node { + node_type: NodeType::Squeeze, + name: "test_squeeze".to_string(), + inputs, + outputs: vec![Argument { + name: "squeezed".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: rank - (axes.as_ref().map_or(0, |a| a.len())), + static_shape: None, + }), + value: None, + passed: true, + }], + attrs, + } + } + + #[test] + fn test_squeeze_config_with_axes() { + let node = create_test_node(Some(vec![0, 2]), 4); + let axes = squeeze_config(&node); + assert_eq!(axes, vec![0, 2]); + } + + #[test] + fn test_squeeze_config_no_axes() { + let node = create_test_node(None, 4); + let axes = squeeze_config(&node); + assert!(axes.is_empty()); + } +} diff --git a/crates/onnx-ir/src/node/transpose.rs b/crates/onnx-ir/src/node/transpose.rs new file mode 100644 index 0000000000..4b32f4316b --- /dev/null +++ b/crates/onnx-ir/src/node/transpose.rs @@ -0,0 +1,98 @@ +use crate::ir::{ArgType, Node}; + +pub fn transpose_config(curr: &Node) -> Vec { + if curr.inputs.len() != 1 { + panic!( + "Transpose: multiple inputs are not supported (got {:?})", + curr.inputs.len() + ); + } + + // Extract the shape of the input tensor + let tensor = match curr.inputs.first().unwrap().clone().ty { + ArgType::Tensor(tensor) => tensor, + _ => panic!("Only tensor input is valid"), + }; + + // Default: reverse the dimensions + let mut perm = (0..tensor.rank as i64).rev().collect::>(); + + if let Some(axes) = curr.attrs.get("perm") { + perm = axes.clone().into_i64s(); + } + + perm +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::{Argument, AttributeValue, ElementType, NodeType, TensorType}; + use std::collections::HashMap; + + fn create_test_node(perm: Option>, rank: usize) -> Node { + let inputs = vec![Argument { + name: "data".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank, + static_shape: None, + }), + value: None, + passed: true, + }]; + + let mut attrs = HashMap::new(); + if let Some(perm_val) = perm { + attrs.insert("perm".to_string(), AttributeValue::Int64s(perm_val)); + } + + Node { + node_type: NodeType::Transpose, + name: "test_transpose".to_string(), + inputs, + outputs: vec![Argument { + name: "transposed".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank, + static_shape: None, + }), + value: None, + passed: true, + }], + attrs, + } + } + + #[test] + fn test_transpose_config_default() { + let node = create_test_node(None, 3); + let perm = transpose_config(&node); + assert_eq!(perm, vec![2, 1, 0]); // Default is to reverse the dimensions + } + + #[test] + fn test_transpose_config_with_perm() { + let node = create_test_node(Some(vec![0, 2, 1]), 3); + let perm = transpose_config(&node); + assert_eq!(perm, vec![0, 2, 1]); + } + + #[test] + #[should_panic(expected = "Transpose: multiple inputs are not supported")] + fn test_transpose_config_multiple_inputs() { + let mut node = create_test_node(None, 3); + node.inputs.push(Argument { + name: "extra".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 3, + static_shape: None, + }), + value: None, + passed: true, + }); + let _ = transpose_config(&node); + } +} diff --git a/crates/onnx-ir/src/op_configuration.rs b/crates/onnx-ir/src/op_configuration.rs index d07568cccc..012f325a1d 100644 --- a/crates/onnx-ir/src/op_configuration.rs +++ b/crates/onnx-ir/src/op_configuration.rs @@ -1,989 +1,17 @@ -// TODO Move op_configuration.rs from burn-import to onnx-ir #3091 -// See https://github.com/tracel-ai/burn/issues/3091 - -use crate::ir::{ArgType, Data, Node, TensorData}; - -// pub fn expand_config(node: &Node) -> ExpandShape { -// match &node.inputs[1].ty { -// ArgType::Tensor(tensor) => { -// assert_eq!(tensor.rank, 1, "Expand: shape tensor must be 1D"); -// assert!( -// matches!(tensor.elem_type, ElementType::Int64), -// "Expand: shape tensor must have element type int64" -// ); -// } -// ArgType::Shape(_) => { -// // Shapes are always 1-D int64 data, so nothing to assert here -// } -// _ => panic!("Only tensor input is valid for shape"), -// } - -// match &node.inputs[1].value { -// Some(TensorData { -// data: Data::Int64s(shape), -// .. -// }) => ExpandShape::Static(shape.clone()), -// None => { -// // we were unable to statically determine the input value, so we'll need to fetch it at runtime -// ExpandShape::Runtime(crate::burn::Type::from(&node.inputs[1])) -// } -// _ => panic!( -// "Shape data type must be int64, is {:?}", -// &node.inputs[1].value -// ), -// } -// } - -// /// Create a TileConfig from the attributes of the node -// pub fn tile_config(node: &Node) -> TileConfig { -// let repeat = node -// .inputs -// .get(1) -// .map(|input| { -// if let Some(TensorData { data, .. }) = &input.value { -// data.clone() -// .into_i64s() -// .iter() -// .map(|&x| x as usize) -// .collect() -// } else { -// vec![] -// } -// }) -// .unwrap_or_default(); -// TileConfig::new(repeat) -// } - -// /// Create a TopKConfig from the attributes of the node. -// pub fn top_k_config(node: &Node) -> TopKConfig { -// // extract the shape of the input data tensor -// let data_tensor = match node.inputs.first().unwrap().clone().ty { -// ArgType::Tensor(tensor) => tensor, -// _ => panic!("Only tensor input is valid"), -// }; - -// let k = match node.inputs.get(1) { -// Some(k_tensor) => k_tensor -// .clone() -// .value -// .expect("TopK: only constant 'k' tensor is currently supported") -// .data -// .into_i64s()[0], -// _ => node -// .attrs -// .get("k") -// .expect("TopK: number of top elements 'k' is missing") -// .clone() -// .into_i64(), -// }; - -// let mut axis = match node.attrs.get("axis") { -// Some(axis) => axis.clone().into_i64(), -// None => -1, -// }; - -// // if axis is negative, it is counted from the end -// if axis < 0 { -// axis += data_tensor.rank as i64; -// } - -// if let Some(largest) = node.attrs.get("largest") { -// if largest.clone().into_i64() != 1 { -// unimplemented!("TopK: only largest elements is supported") -// } -// }; - -// if let Some(sorted) = node.attrs.get("sorted") { -// if sorted.clone().into_i64() != 1 { -// unimplemented!("TopK: only sorted elements is supported") -// } -// }; - -// TopKConfig::new(axis as usize, k as usize) -// } - -// /// Create a TriluConfig from the attributes of the node -// pub fn trilu_config(node: &Node) -> TriluConfig { -// let mut upper = true; -// let mut diagonal = 0; -// for (key, value) in node.attrs.iter() { -// match key.as_str() { -// "upper" => upper = value.clone().into_i64() != 0, -// _ => {} -// } -// } -// // The second input of the Trilu node is the diagonal value, coming from a constant node -// if let Some(diagonal_arg) = node.inputs.get(1) { -// if let Some(TensorData { -// data: Data::Int64(diagonal_val), -// .. -// }) = &diagonal_arg.value -// { -// diagonal = *diagonal_val; -// } -// } -// TriluConfig::new(upper, diagonal) -// } - -// /// Create a PadConfig from the attributes of the node -// pub fn pad_config(node: &Node) -> PadConfig { -// fn get_pads_input(node: &Node) -> Vec { -// match &node.inputs[1].value { -// Some(TensorData { data, .. }) => data.clone().into_i64s(), -// _ => Vec::new(), -// } -// } -// fn get_pads(node: &Node) -> Vec { -// if node.inputs.is_empty() { -// panic!("Pad: must provide data as input") -// } -// if node.inputs.len() >= 4 { -// panic!("Pad: axes input is not supported") -// } - -// let input_dim = match &node.inputs.first().unwrap().ty { -// ArgType::Tensor(tensor) => tensor.rank, -// _ => panic!("Pad: Only tensor input is valid"), -// }; - -// //TODO : handle more possible attributes -// let mut pads: Vec = get_pads_input(node) -// .into_iter() -// .map(|x| x as usize) -// .collect(); - -// for (key, value) in node.attrs.iter() { -// match key.as_str() { -// "pads" => { -// pads = value -// .clone() -// .into_i64s() -// .iter() -// .map(|&x| { -// if x < 0 { -// panic!("Pad: Negative pad is not supported"); -// } -// x as usize -// }) -// .collect() -// } -// "mode" => { -// let mode = value.clone().into_string(); -// if mode != "constant" { -// panic!("only constant mode is supported, given mode is {}", mode); -// } -// } - -// _ => {} -// } -// } - -// if pads.is_empty() { -// panic!("Pad: pads should be given as attribute or as input"); -// } - -// if pads.len() != input_dim * 2 { -// panic!("Pad: pads should be a 1D tensor of shape [2 * num_axes]"); -// } -// // TODO: Burn's pad should support 1D tensor -// if input_dim < 2 { -// panic!("Pad: input tensor should be rank 2 or higher"); -// } - -// let left_index = input_dim - 1; -// let top_index = input_dim - 2; -// let right_index = pads.len() - 1; -// let bottom_index = pads.len() - 2; -// let index_list = [left_index, top_index, right_index, bottom_index]; - -// for (index, &item) in pads.iter().enumerate() { -// if !index_list.contains(&index) && item != 0 { -// panic!( -// "Pad: padding will only be applied to the last two dimensions but found non zero padding for other dimensions" -// ); -// } -// } - -// let left = pads[left_index]; -// let top = pads[top_index]; -// let right = pads[right_index]; -// let bottom = pads[bottom_index]; -// vec![left, right, top, bottom] -// } -// fn get_constant_value(node: &Node) -> f32 { -// // TODO: support int, boolean -// let mut constant_value = node.inputs -// .get(2) -// .and_then(|input| match &input.value.as_ref().expect("Value input must be present").data { -// Data::Float16s(constant_value) => { -// constant_value.first().map(|&f| f32::from(f)) -// } -// Data::Float32s(constant_value) => { -// constant_value.first().copied() -// } -// Data::Float64s(constant_value) => { -// constant_value.first().map(|&f| f as f32) -// } -// Data::Float16(constant_value) => Some(f32::from(*constant_value)), -// Data::Float32(constant_value) => Some(*constant_value), -// Data::Float64(constant_value) => Some(*constant_value as f32), -// _ => panic!("Pad: only float values are currently supported for constant value, submit an issue on github"), -// }) -// .unwrap_or(0.0); - -// if node.attrs.contains_key("value") { -// constant_value = node.attrs.get("value").map(|value| match value { -// AttributeValue::Float32(value) => *value, -// _ => panic!("Pad: only float32 values are currently supported for constant value as attribute, submit an issue on github"), -// }).expect("constant_value should have had a value now"); -// } -// constant_value -// } - -// let pads = get_pads(node); -// let constant_value = get_constant_value(node); - -// PadConfig::new(pads, constant_value) -// } - -// Create a LeakyReluConfig from the alpha attribute of the node -pub fn leaky_relu_config(node: &Node) -> f64 { - let mut alpha = 0.01; - - for (key, value) in node.attrs.iter() { - if key.as_str() == "alpha" { - alpha = value.clone().into_f32() as f64 - } - } - - alpha -} - -// Create a HardSigmoidConfig from the alpha and beta attributes of the node -pub fn hard_sigmoid_config(node: &Node) -> (f64, f64) { - let mut alpha = 0.2; - let mut beta = 0.5; - - for (key, value) in node.attrs.iter() { - match key.as_str() { - "alpha" => alpha = value.clone().into_f32() as f64, - "beta" => beta = value.clone().into_f32() as f64, - _ => {} - } - } - - (alpha, beta) -} - -pub fn reshape_config(node: &Node) -> Vec { - let mut allowzero = 0; - - for (key, value) in node.attrs.iter() { - if key.as_str() == "allowzero" { - allowzero = value.clone().into_i64() - } - } - - // Burn does not support zero size shape (0 means false in ONNX) - // (see https://onnx.ai/onnx/operators/onnx__Reshape.html#attributes) - if allowzero != 0 { - panic!("Zero shape size is not supported"); - } - - // TODO: check "shape" attribute - if node.inputs.len() != 2 || node.inputs[1].value.is_none() { - panic!("Reshape: shape tensor must be present for {:?}", node); - } - - match &node.inputs[1].value { - Some(TensorData { data, shape, .. }) => { - assert_eq!(shape.len(), 1, "Reshape: shape tensor must be 1D"); - data.clone().into_i64s() - } - _ => panic!("Only tensor input is valid for shape"), - } -} - -pub fn resize_config(node: &Node) -> (String, Vec, Vec) { - let mut mode: String = "".to_string(); - - let mut scales: Vec; - let mut sizes: Vec; - - let input = if let ArgType::Tensor(tensor) = &node - .inputs - .first() - .expect("Resize: Input tensor must be present") - .ty - { - tensor - } else { - panic!("Resize: input must be a tensor") - }; - - // Note: we are ignoring some attributes because results are approximately the same - // and we are not supporting all the attributes of the Resize operator. - // However, some attributes are important to be checked and we are checking - // against the default values of the attributes. - // TODO revisit this when we have more Resize operators in the model - for (key, value) in node.attrs.iter() { - match key.as_str() { - "antialias" => assert_eq!( - value.clone().into_i32(), - 0, - "Resize: antialias other than 0 is not supported" - ), - "axes" => panic!("Resize: custom axes attribute is not supported"), - "coordinate_transformation_mode" => { - log::warn!("Resize: coordinate_transformation_mode is ignored") - } - - "cubic_coeff_a" => log::warn!("Resize: cubic_coeff_a is ignored"), - "exclude_outside" => assert_eq!( - value.clone().into_i32(), - 0, - "Resize: exclude_outside other than 0 is not supported" - ), - "extrapolation_value" => assert_eq!( - value.clone().into_f32(), - 0.0, - "Resize: extrapolation_value other than 0.0 is not supported" - ), - "keep_aspect_ratio_policy" => { - assert_eq!( - value.clone().into_string().to_lowercase(), - "stretch", - "Resize: keep_aspect_ratio_policy other than 'stretch' is not supported" - ) - } - "mode" => mode = value.clone().into_string().to_lowercase(), - "nearest_mode" => log::warn!("Resize: nearest_mode is ignored"), - - _ => {} - } - } - - let roi: Vec = node - .inputs - .get(1) - .map(|input| { - if let Some(TensorData { data, .. }) = &input.value { - data.clone().into_f32s() - } else { - vec![] - } - }) - .unwrap_or_default(); - - scales = node - .inputs - .get(2) - .map(|input| { - if let Some(TensorData { data, .. }) = &input.value { - data.clone().into_f32s() - } else { - vec![] - } - }) - .unwrap_or_default(); - - sizes = node - .inputs - .get(3) - .map(|input| { - if let Some(TensorData { data, .. }) = &input.value { - data.clone() - .into_i64s() - .iter() - .map(|&x| x as usize) - .collect() - } else { - vec![] - } - }) - .unwrap_or_default(); - - if mode.is_empty() { - panic!("Resize: mode attribute is required") - } - - if !roi.is_empty() { - panic!("Resize: roi input is not supported") - } - - if scales.is_empty() && sizes.is_empty() { - panic!("Resize: either scales or sizes input is required") - } - - if !scales.is_empty() { - assert!(scales.len() == input.rank); - // ignore the fist two items from scales - // because they are the batch and channel dimensions - scales = scales.iter().skip(2).cloned().collect(); - } - - if !sizes.is_empty() { - assert!(sizes.len() == input.rank); - // ignore the fist two items from sizes - // because they are the batch and channel dimensions - sizes = sizes.iter().skip(2).cloned().collect(); - } - - (mode, scales, sizes) -} - -// //Note this function should only execute if the second input is a constant -// //if it wasn't and the output shape was known, unsqueeze has been remapped to reshape -// pub fn unsqueeze_config(node: &Node) -> UnsqueezeAxes { -// // Check if axes attribute exists -// for (key, value) in node.attrs.iter() { -// match key.as_str() { -// "axes" => return UnsqueezeAxes::Static(value.clone().into_i64s()), -// _ => {} -// } -// } - -// assert!( -// !node.inputs.is_empty(), -// "Unsqueeze: axes tensor must be present" -// ); - -// let input_value = &node.inputs[1]; - -// match &node.inputs[1].ty { -// ArgType::Tensor(tensor) => { -// assert_eq!(tensor.rank, 1, "Unsqueeze: axes tensor must be 1D"); -// if let Some(TensorData { -// data: Data::Int64s(shape), -// .. -// }) = input_value.value.as_ref() -// { -// UnsqueezeAxes::Static(shape.clone()) -// } else { -// UnsqueezeAxes::Runtime(crate::burn::Type::from(&node.inputs[1])) -// } -// } -// _ => panic!("Arg for unsqueeze must be tensor or scalar"), -// } -// } - -pub fn clip_config(node: &Node) -> (Option, Option) { - let mut min_result: Option = None; - let mut max_result: Option = None; - - // For Clip Opset 6+ , the min and max values are attributes - for (key, value) in node.attrs.iter() { - match key.as_str() { - "min" => { - let min = value.clone().into_f32() as f64; - min_result = Some(min); - } - "max" => { - let max = value.clone().into_f32(); - max_result = Some(max as f64); - } - _ => {} - } - } - - // For Clip Opset 11+ , the min and max values are inputs - // Get the min and max values from the input values - if min_result.is_none() && max_result.is_none() { - let min = &node.inputs[1].value; - let max = &node.inputs[2].value; - - if min_result.is_none() && min.is_some() { - let min = min.clone().unwrap().data.into_scalar(); - min_result = match min { - Data::Float16(min) => Some(f32::from(min) as f64), - Data::Float32(min) => Some(min as f64), - Data::Float64(min) => Some(min), - _ => panic!("Clip: only float min is supported"), - }; - } - - if max_result.is_none() && max.is_some() { - let max = max.clone().unwrap().data.into_scalar(); - max_result = match max { - Data::Float16(max) => Some(f32::from(max) as f64), - Data::Float32(max) => Some(max as f64), - Data::Float64(max) => Some(max), - _ => panic!("Clip: only float max is supported"), - }; - } - } - - if min_result.is_none() && max_result.is_none() { - panic!("Clip: min and max values must be either attributes or inputs"); - } - - (min_result, max_result) -} - -pub fn reduce_max_config(node: &Node) -> Option { - let mut axes = Vec::new(); - let mut keepdims = 1; - - let tensor = match node.inputs.first().unwrap().clone().ty { - ArgType::Tensor(tensor) => tensor, - _ => panic!("Only tensor input is valid"), - }; - - // Extract the attributes - for (key, value) in node.attrs.iter() { - match key.as_str() { - "axes" => axes = value.clone().into_i64s(), - "keepdims" => keepdims = value.clone().into_i64(), - _ => {} - } - } - - if axes.len() > 1 { - panic!("ReduceMax: reducing on multiple dimensions is not supported") - } - - if axes.is_empty() && keepdims == 1 { - panic!("ReduceMax: axes must be provided with keepdims") - } - - if !axes.is_empty() && keepdims == 0 { - // Not supported in Burn - panic!("ReduceMax: the reduce operation must preserve the reduced dimension") - } - - if axes.is_empty() { - None - } else { - let mut dim = axes[0]; - - if dim < 0 { - // Accepted range is [-r, r-1] where r = rank(data) but Burn only supports positive dim - dim += tensor.rank as i64; - } - Some(dim as usize) - } -} - -pub fn reduce_min_config(node: &Node) -> Option { - let mut axes = Vec::new(); - let mut keepdims = 1; - - let tensor = match node.inputs.first().unwrap().clone().ty { - ArgType::Tensor(tensor) => tensor, - _ => panic!("Only tensor input is valid"), - }; - - // Extract the attributes - for (key, value) in node.attrs.iter() { - match key.as_str() { - "axes" => axes = value.clone().into_i64s(), - "keepdims" => keepdims = value.clone().into_i64(), - _ => {} - } - } - - if axes.len() > 1 { - panic!("ReduceMin: reducing on multiple dimensions is not supported") - } - - if axes.is_empty() && keepdims == 1 { - panic!("ReduceMin: axes must be provided with keepdims") - } - - if !axes.is_empty() && keepdims == 0 { - panic!("ReduceMin: the reduce operation must preserve the reduced dimension") - } - - if axes.is_empty() { - None - } else { - let mut dim = axes[0]; - - if dim < 0 { - dim += tensor.rank as i64; - } - Some(dim as usize) - } -} - -pub fn reduce_mean_config(node: &Node) -> Option { - let mut axes = Vec::new(); - let mut keepdims = 1; - - let tensor = match node.inputs.first().unwrap().clone().ty { - ArgType::Tensor(tensor) => tensor, - _ => panic!("Only tensor input is valid"), - }; - - // Extract the attributes - for (key, value) in node.attrs.iter() { - match key.as_str() { - "axes" => axes = value.clone().into_i64s(), - "keepdims" => keepdims = value.clone().into_i64(), - _ => {} - } - } - - if axes.len() > 1 { - panic!("ReduceMean: reducing on multiple dimensions is not supported") - } - - if axes.is_empty() && keepdims == 1 { - panic!("ReduceMean: axes must be provided with keepdims") - } - - if !axes.is_empty() && keepdims == 0 { - // Not supported in Burn - panic!("ReduceMean: the reduce operation must preserve the reduced dimension") - } - - if axes.is_empty() { - None - } else { - let mut dim = axes[0]; - - if dim < 0 { - // Accepted range is [-r, r-1] where r = rank(data) but Burn only supports positive dim - dim += tensor.rank as i64; - } - Some(dim as usize) - } -} - -pub fn reduce_prod_config(node: &Node) -> Option { - let mut axes = Vec::new(); - let mut keepdims = 1; - - let tensor = match node.inputs.first().unwrap().clone().ty { - ArgType::Tensor(tensor) => tensor, - _ => panic!("Only tensor input is valid"), - }; - - // Extract the attributes - for (key, value) in node.attrs.iter() { - match key.as_str() { - "axes" => axes = value.clone().into_i64s(), - "keepdims" => keepdims = value.clone().into_i64(), - // TODO: handle noop_with_empty_axes (opset 18) - _ => {} - } - } - - if axes.len() > 1 { - panic!("ReduceProd: reducing on multiple dimensions is not supported") - } - - if axes.is_empty() && keepdims == 1 { - panic!("ReduceProd: axes must be provided with keepdims") - } - - if !axes.is_empty() && keepdims == 0 { - // Not supported in Burn - panic!("ReduceProd: the reduce operation must preserve the reduced dimension") - } - - if axes.is_empty() { - None - } else { - let mut dim = axes[0]; - - if dim < 0 { - // Accepted range is [-r, r-1] where r = rank(data) but Burn only supports positive dim - dim += tensor.rank as i64; - } - Some(dim as usize) - } -} - -pub fn reduce_sum_config(node: &Node) -> Option { - let mut axes = Vec::new(); - let mut keepdims = 1; - - let tensor = match node.inputs.first().unwrap().clone().ty { - ArgType::Tensor(tensor) => tensor, - _ => panic!("Only tensor input is valid"), - }; - - // Extract the attributes - for (key, value) in node.attrs.iter() { - match key.as_str() { - "keepdims" => keepdims = value.clone().into_i64(), - "axes" => axes = value.clone().into_i64s(), - // TODO: handle noop_with_empty_axes - _ => {} - } - } - - // TODO: Handle case where axes are passed in. Will require its own ReduceSumNode instead of a UnaryNode. - if let Some(value) = node - .inputs - .get(1) - .and_then(|argument| argument.value.as_ref()) - { - axes = value.clone().data.into_i64s(); - } - - if axes.len() > 1 { - panic!("ReduceMean: reducing on multiple dimensions is not supported") - } - - if axes.is_empty() && keepdims == 1 { - panic!("ReduceMean: axes must be provided with keepdims") - } - - if !axes.is_empty() && keepdims == 0 { - // Not supported in Burn - panic!("ReduceMean: the reduce operation must preserve the reduced dimension") - } - - if axes.is_empty() { - None - } else { - let mut dim = axes[0]; - - if dim < 0 { - // Accepted range is [-r, r-1] where r = rank(data) but Burn only supports positive dim - dim += tensor.rank as i64; - } - Some(dim as usize) - } -} - -pub fn shape_config(curr: &Node) -> (usize, usize) { - if curr.inputs.len() != 1 { - panic!( - "Shape: multiple inputs are not supported (got {:?})", - curr.inputs.len() - ); - } - - // Extract the shape of the input tensor - let tensor = match curr.inputs.first().unwrap().clone().ty { - ArgType::Tensor(tensor) => tensor, - _ => panic!("Only tensor input is valid"), - }; - - // Default: all axes up to the last one (included) - let mut start_dim: i64 = 0; - let mut end_dim: i64 = tensor.rank as i64; - - // Extract the attributes - for (key, value) in curr.attrs.iter() { - match key.as_str() { - "start" => start_dim = value.clone().into_i64(), - "end" => end_dim = value.clone().into_i64(), - _ => {} - } - } - - // If dim is negative, it is counted from the end - if start_dim < 0 { - start_dim += tensor.rank as i64; - } - if end_dim < 0 { - end_dim += tensor.rank as i64; - } - - (start_dim as usize, end_dim as usize) -} - -pub fn transpose_config(curr: &Node) -> Vec { - if curr.inputs.len() != 1 { - panic!( - "Transpose: multiple inputs are not supported (got {:?})", - curr.inputs.len() - ); - } - - // Extract the shape of the input tensor - let tensor = match curr.inputs.first().unwrap().clone().ty { - ArgType::Tensor(tensor) => tensor, - _ => panic!("Only tensor input is valid"), - }; - - // Default: reverse the dimensions - let mut perm = (0..tensor.rank as i64).rev().collect::>(); - - if let Some(axes) = curr.attrs.get("perm") { - perm = axes.clone().into_i64s(); - } - - perm -} - -pub fn squeeze_config(curr: &Node) -> Vec { - let axes = curr - .attrs - .iter() - .filter_map(|(key, value)| { - if key == "axes" { - Some(value.clone().into_i64s()) - } else { - None - } - }) - .next() - .unwrap_or_else(Vec::new); - - match curr.inputs.first().unwrap().clone().ty { - ArgType::Tensor(tensor) => tensor, - _ => panic!("Only tensor input is valid"), - }; - - axes -} -// pub fn split_config(node: &Node) -> SplitConfig { -// // Initialize the axis to split along (default is 0 as per ONNX specification) -// let mut axis: i64 = 0; -// // Holds the uniform split size if calculated or provided -// let mut split_size: Option = None; -// // Holds the custom split sizes if provided as input -// let mut split_sizes: Option> = None; - -// // Extract the input tensor type to determine rank and shape -// let tensor = match node.inputs.first().unwrap().ty { -// ArgType::Tensor(ref tensor) => tensor, -// _ => panic!("Split: Input must be a valid tensor"), -// }; - -// // Optionally store the number of outputs if provided as an attribute -// let mut num_outputs: Option = None; - -// // Iterate through node attributes to extract relevant values -// for (key, value) in node.attrs.iter() { -// match key.as_str() { -// "axis" => axis = value.clone().into_i64(), -// "num_outputs" => num_outputs = Some(value.clone().into_i64() as usize), -// _ => {} -// } -// } - -// // Handle the case when num_outputs is provided to calculate uniform split size -// if let Some(num_outputs) = num_outputs { -// if num_outputs == 0 { -// panic!("Split: 'num_outputs' must be a positive value greater than zero"); -// } - -// let dim_size = tensor -// .static_shape -// .as_ref() -// .expect("Split: Static shape must be known to calculate split size")[axis as usize]; - -// // Calculate the split size considering any remainder for non-evenly divisible dimensions -// let calculated_split_size = -// dim_size / (num_outputs - (dim_size % num_outputs != 0) as usize); - -// if calculated_split_size == 0 { -// panic!( -// "Split: Calculated split size is zero. Please ensure 'num_outputs' is valid for the dimension size" -// ); -// } - -// // Assign the calculated split size -// split_size = Some(calculated_split_size); -// } - -// // Adjust axis if negative to count from the end as per ONNX spec -// if axis < 0 { -// axis += tensor.rank as i64; -// } - -// // Check for custom split sizes provided as a second input -// if node.inputs.len() > 1 && node.inputs[1].value.is_some() { -// let sizes = node.inputs[1] -// .value -// .as_ref() -// .unwrap() -// .data -// .clone() -// .into_usizes(); - -// if !sizes.is_empty() { -// split_sizes = Some(sizes); -// } -// } - -// // Ensure that only one of 'split_sizes' or 'num_outputs' is specified -// if split_sizes.is_some() && split_size.is_some() { -// panic!( -// "Split: Cannot specify both 'split' input and 'num_outputs' attribute simultaneously" -// ); -// } - -// // Infer split_size if neither custom split_sizes nor split_size is provided -// if split_sizes.is_none() && split_size.is_none() { -// let num_outputs = node.outputs.len(); -// let dim_size = tensor -// .static_shape -// .as_ref() -// .expect("Split: Static shape must be known to infer split size")[axis as usize]; - -// // Calculate inferred split size based on number of outputs -// let calculated_split_size = -// dim_size / (num_outputs - (dim_size % num_outputs != 0) as usize); - -// if calculated_split_size == 0 { -// panic!( -// "Split: Inferred split size is zero. Please ensure the number of outputs is valid for the dimension size" -// ); -// } - -// split_size = Some(calculated_split_size); -// } - -// // Return the configuration for splitting operation -// SplitConfig { -// axis: axis as usize, -// split_size, -// split_sizes, -// } -// } - -pub fn one_hot_config(curr: &Node) -> (usize, [f32; 2], i64) { - let depth = curr.inputs[1] - .value - .clone() - .expect("OneHot: Only constant depth is currently supported") - .data - .into_i64(); - - let values = curr.inputs[2] - .value - .clone() - .expect("OneHot: Only constant on/off values is currently supported") - .data - .into_f32s(); - - let axis = curr - .attrs - .get("axis") - .map(|val| val.clone().into_i64()) - .unwrap_or(-1); - - (depth as usize, values.try_into().unwrap(), axis) -} - -pub fn gemm_config(curr: &Node) -> (f32, f32, i64, i64) { - let alpha = curr - .attrs - .get("alpha") - .map(|val| val.clone().into_f32()) - .unwrap_or(1.0); - let beta = curr - .attrs - .get("beta") - .map(|val| val.clone().into_f32()) - .unwrap_or(1.0); - let trans_a = curr - .attrs - .get("transA") - .map(|val| val.clone().into_i64()) - .unwrap_or(0); - let trans_b = curr - .attrs - .get("transB") - .map(|val| val.clone().into_i64()) - .unwrap_or(0); - - (alpha, beta, trans_a, trans_b) -} +// Reexport functions from node modules for compatibility +// These should be deprecated and eventually removed in favor of using +// direct imports from the node modules + +pub use crate::node::clip::clip_config; +pub use crate::node::gemm::gemm_config; +pub use crate::node::hard_sigmoid::hard_sigmoid_config; +pub use crate::node::leaky_relu::leaky_relu_config; +pub use crate::node::one_hot::one_hot_config; +pub use crate::node::reduce::{ + reduce_max_config, reduce_mean_config, reduce_min_config, reduce_prod_config, reduce_sum_config, +}; +pub use crate::node::reshape::reshape_config; +pub use crate::node::resize::resize_config; +pub use crate::node::shape::shape_config; +pub use crate::node::squeeze::squeeze_config; +pub use crate::node::transpose::transpose_config; From d38a7e993e84cd9a623036d9286c9f4e85a9f867 Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Wed, 30 Apr 2025 14:47:40 -0500 Subject: [PATCH 11/37] Break down reduce module into individual modules --- crates/onnx-ir/src/node/mod.rs | 6 +- crates/onnx-ir/src/node/reduce.rs | 381 ------------------------- crates/onnx-ir/src/node/reduce_max.rs | 128 +++++++++ crates/onnx-ir/src/node/reduce_mean.rs | 128 +++++++++ crates/onnx-ir/src/node/reduce_min.rs | 126 ++++++++ crates/onnx-ir/src/node/reduce_prod.rs | 129 +++++++++ crates/onnx-ir/src/node/reduce_sum.rs | 172 +++++++++++ crates/onnx-ir/src/op_configuration.rs | 8 +- 8 files changed, 693 insertions(+), 385 deletions(-) delete mode 100644 crates/onnx-ir/src/node/reduce.rs create mode 100644 crates/onnx-ir/src/node/reduce_max.rs create mode 100644 crates/onnx-ir/src/node/reduce_mean.rs create mode 100644 crates/onnx-ir/src/node/reduce_min.rs create mode 100644 crates/onnx-ir/src/node/reduce_prod.rs create mode 100644 crates/onnx-ir/src/node/reduce_sum.rs diff --git a/crates/onnx-ir/src/node/mod.rs b/crates/onnx-ir/src/node/mod.rs index f212998958..4e5ca6b836 100644 --- a/crates/onnx-ir/src/node/mod.rs +++ b/crates/onnx-ir/src/node/mod.rs @@ -22,7 +22,11 @@ pub mod log_softmax; pub mod max_pool1d; pub mod max_pool2d; pub mod one_hot; -pub mod reduce; +pub mod reduce_max; +pub mod reduce_mean; +pub mod reduce_min; +pub mod reduce_prod; +pub mod reduce_sum; pub mod reshape; pub mod resize; pub mod shape; diff --git a/crates/onnx-ir/src/node/reduce.rs b/crates/onnx-ir/src/node/reduce.rs deleted file mode 100644 index b0148581ea..0000000000 --- a/crates/onnx-ir/src/node/reduce.rs +++ /dev/null @@ -1,381 +0,0 @@ -use crate::ir::{ArgType, Node}; - -/// Create a ReduceMaxConfig from the attributes of the node -pub fn reduce_max_config(node: &Node) -> Option { - let mut axes = Vec::new(); - let mut keepdims = 1; - - let tensor = match node.inputs.first().unwrap().clone().ty { - ArgType::Tensor(tensor) => tensor, - _ => panic!("Only tensor input is valid"), - }; - - // Extract the attributes - for (key, value) in node.attrs.iter() { - match key.as_str() { - "axes" => axes = value.clone().into_i64s(), - "keepdims" => keepdims = value.clone().into_i64(), - _ => {} - } - } - - if axes.len() > 1 { - panic!("ReduceMax: reducing on multiple dimensions is not supported") - } - - if axes.is_empty() && keepdims == 1 { - panic!("ReduceMax: axes must be provided with keepdims") - } - - if !axes.is_empty() && keepdims == 0 { - // Not supported in Burn - panic!("ReduceMax: the reduce operation must preserve the reduced dimension") - } - - if axes.is_empty() { - None - } else { - let mut dim = axes[0]; - - if dim < 0 { - // Accepted range is [-r, r-1] where r = rank(data) but Burn only supports positive dim - dim += tensor.rank as i64; - } - Some(dim as usize) - } -} - -/// Create a ReduceMinConfig from the attributes of the node -pub fn reduce_min_config(node: &Node) -> Option { - let mut axes = Vec::new(); - let mut keepdims = 1; - - let tensor = match node.inputs.first().unwrap().clone().ty { - ArgType::Tensor(tensor) => tensor, - _ => panic!("Only tensor input is valid"), - }; - - // Extract the attributes - for (key, value) in node.attrs.iter() { - match key.as_str() { - "axes" => axes = value.clone().into_i64s(), - "keepdims" => keepdims = value.clone().into_i64(), - _ => {} - } - } - - if axes.len() > 1 { - panic!("ReduceMin: reducing on multiple dimensions is not supported") - } - - if axes.is_empty() && keepdims == 1 { - panic!("ReduceMin: axes must be provided with keepdims") - } - - if !axes.is_empty() && keepdims == 0 { - panic!("ReduceMin: the reduce operation must preserve the reduced dimension") - } - - if axes.is_empty() { - None - } else { - let mut dim = axes[0]; - - if dim < 0 { - dim += tensor.rank as i64; - } - Some(dim as usize) - } -} - -/// Create a ReduceMeanConfig from the attributes of the node -pub fn reduce_mean_config(node: &Node) -> Option { - let mut axes = Vec::new(); - let mut keepdims = 1; - - let tensor = match node.inputs.first().unwrap().clone().ty { - ArgType::Tensor(tensor) => tensor, - _ => panic!("Only tensor input is valid"), - }; - - // Extract the attributes - for (key, value) in node.attrs.iter() { - match key.as_str() { - "axes" => axes = value.clone().into_i64s(), - "keepdims" => keepdims = value.clone().into_i64(), - _ => {} - } - } - - if axes.len() > 1 { - panic!("ReduceMean: reducing on multiple dimensions is not supported") - } - - if axes.is_empty() && keepdims == 1 { - panic!("ReduceMean: axes must be provided with keepdims") - } - - if !axes.is_empty() && keepdims == 0 { - // Not supported in Burn - panic!("ReduceMean: the reduce operation must preserve the reduced dimension") - } - - if axes.is_empty() { - None - } else { - let mut dim = axes[0]; - - if dim < 0 { - // Accepted range is [-r, r-1] where r = rank(data) but Burn only supports positive dim - dim += tensor.rank as i64; - } - Some(dim as usize) - } -} - -/// Create a ReduceProdConfig from the attributes of the node -pub fn reduce_prod_config(node: &Node) -> Option { - let mut axes = Vec::new(); - let mut keepdims = 1; - - let tensor = match node.inputs.first().unwrap().clone().ty { - ArgType::Tensor(tensor) => tensor, - _ => panic!("Only tensor input is valid"), - }; - - // Extract the attributes - for (key, value) in node.attrs.iter() { - match key.as_str() { - "axes" => axes = value.clone().into_i64s(), - "keepdims" => keepdims = value.clone().into_i64(), - // TODO: handle noop_with_empty_axes (opset 18) - _ => {} - } - } - - if axes.len() > 1 { - panic!("ReduceProd: reducing on multiple dimensions is not supported") - } - - if axes.is_empty() && keepdims == 1 { - panic!("ReduceProd: axes must be provided with keepdims") - } - - if !axes.is_empty() && keepdims == 0 { - // Not supported in Burn - panic!("ReduceProd: the reduce operation must preserve the reduced dimension") - } - - if axes.is_empty() { - None - } else { - let mut dim = axes[0]; - - if dim < 0 { - // Accepted range is [-r, r-1] where r = rank(data) but Burn only supports positive dim - dim += tensor.rank as i64; - } - Some(dim as usize) - } -} - -/// Create a ReduceSumConfig from the attributes of the node -pub fn reduce_sum_config(node: &Node) -> Option { - let mut axes = Vec::new(); - let mut keepdims = 1; - - let tensor = match node.inputs.first().unwrap().clone().ty { - ArgType::Tensor(tensor) => tensor, - _ => panic!("Only tensor input is valid"), - }; - - // Extract the attributes - for (key, value) in node.attrs.iter() { - match key.as_str() { - "keepdims" => keepdims = value.clone().into_i64(), - "axes" => axes = value.clone().into_i64s(), - // TODO: handle noop_with_empty_axes - _ => {} - } - } - - // TODO: Handle case where axes are passed in. Will require its own ReduceSumNode instead of a UnaryNode. - if let Some(value) = node - .inputs - .get(1) - .and_then(|argument| argument.value.as_ref()) - { - axes = value.clone().data.into_i64s(); - } - - if axes.len() > 1 { - panic!("ReduceSum: reducing on multiple dimensions is not supported") - } - - if axes.is_empty() && keepdims == 1 { - panic!("ReduceSum: axes must be provided with keepdims") - } - - if !axes.is_empty() && keepdims == 0 { - // Not supported in Burn - panic!("ReduceSum: the reduce operation must preserve the reduced dimension") - } - - if axes.is_empty() { - None - } else { - let mut dim = axes[0]; - - if dim < 0 { - // Accepted range is [-r, r-1] where r = rank(data) but Burn only supports positive dim - dim += tensor.rank as i64; - } - Some(dim as usize) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::ir::{ - Argument, AttributeValue, Data, ElementType, NodeType, TensorData, TensorType, - }; - use std::collections::HashMap; - - fn create_test_node( - node_type: NodeType, - axes: Option>, - keepdims: Option, - with_axes_input: bool, - ) -> Node { - let mut inputs = vec![Argument { - name: "data".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: 3, - static_shape: None, - }), - value: None, - passed: true, - }]; - - // Add axes input if requested - if with_axes_input && axes.is_some() { - let axes_clone = axes.clone().unwrap(); - inputs.push(Argument { - name: "axes".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Int64, - rank: 1, - static_shape: None, - }), - value: Some(TensorData { - data: Data::Int64s(axes_clone.clone()), - shape: vec![axes_clone.len()], - }), - passed: true, - }); - } - - let mut attrs = HashMap::new(); - if !with_axes_input && axes.is_some() { - attrs.insert( - "axes".to_string(), - AttributeValue::Int64s(axes.clone().unwrap()), - ); - } - if let Some(kd) = keepdims { - attrs.insert("keepdims".to_string(), AttributeValue::Int64(kd)); - } - - let node_type_clone = node_type.clone(); - Node { - node_type, - name: format!("test_{:?}", node_type_clone).to_lowercase(), - inputs, - outputs: vec![Argument { - name: "reduced".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: 3, - static_shape: None, - }), - value: None, - passed: true, - }], - attrs, - } - } - - #[test] - fn test_reduce_max_config_basic() { - let node = create_test_node(NodeType::ReduceMax, Some(vec![1]), Some(1), false); - let dim = reduce_max_config(&node); - assert_eq!(dim, Some(1)); - } - - #[test] - fn test_reduce_max_config_negative_axis() { - let node = create_test_node(NodeType::ReduceMax, Some(vec![-2]), Some(1), false); - let dim = reduce_max_config(&node); - assert_eq!(dim, Some(1)); // -2 + 3 = 1 - } - - #[test] - #[should_panic(expected = "ReduceMax: axes must be provided with keepdims")] - fn test_reduce_max_config_no_axes() { - let node = create_test_node(NodeType::ReduceMax, None, Some(1), false); - let _ = reduce_max_config(&node); - } - - #[test] - #[should_panic(expected = "ReduceMax: reducing on multiple dimensions is not supported")] - fn test_reduce_max_config_multiple_axes() { - let node = create_test_node(NodeType::ReduceMax, Some(vec![0, 1]), Some(1), false); - let _ = reduce_max_config(&node); - } - - #[test] - #[should_panic( - expected = "ReduceMax: the reduce operation must preserve the reduced dimension" - )] - fn test_reduce_max_config_no_keepdims() { - let node = create_test_node(NodeType::ReduceMax, Some(vec![1]), Some(0), false); - let _ = reduce_max_config(&node); - } - - #[test] - fn test_reduce_min_config_basic() { - let node = create_test_node(NodeType::ReduceMin, Some(vec![1]), Some(1), false); - let dim = reduce_min_config(&node); - assert_eq!(dim, Some(1)); - } - - #[test] - fn test_reduce_mean_config_basic() { - let node = create_test_node(NodeType::ReduceMean, Some(vec![1]), Some(1), false); - let dim = reduce_mean_config(&node); - assert_eq!(dim, Some(1)); - } - - #[test] - fn test_reduce_prod_config_basic() { - let node = create_test_node(NodeType::ReduceProd, Some(vec![1]), Some(1), false); - let dim = reduce_prod_config(&node); - assert_eq!(dim, Some(1)); - } - - #[test] - fn test_reduce_sum_config_basic() { - let node = create_test_node(NodeType::ReduceSum, Some(vec![1]), Some(1), false); - let dim = reduce_sum_config(&node); - assert_eq!(dim, Some(1)); - } - - #[test] - fn test_reduce_sum_config_with_input_axes() { - let node = create_test_node(NodeType::ReduceSum, Some(vec![1]), Some(1), true); - let dim = reduce_sum_config(&node); - assert_eq!(dim, Some(1)); - } -} diff --git a/crates/onnx-ir/src/node/reduce_max.rs b/crates/onnx-ir/src/node/reduce_max.rs new file mode 100644 index 0000000000..d3a60e4e4c --- /dev/null +++ b/crates/onnx-ir/src/node/reduce_max.rs @@ -0,0 +1,128 @@ +use crate::ir::{ArgType, Node}; + +/// Create a ReduceMaxConfig from the attributes of the node +pub fn reduce_max_config(node: &Node) -> Option { + let mut axes = Vec::new(); + let mut keepdims = 1; + + let tensor = match node.inputs.first().unwrap().clone().ty { + ArgType::Tensor(tensor) => tensor, + _ => panic!("Only tensor input is valid"), + }; + + // Extract the attributes + for (key, value) in node.attrs.iter() { + match key.as_str() { + "axes" => axes = value.clone().into_i64s(), + "keepdims" => keepdims = value.clone().into_i64(), + _ => {} + } + } + + if axes.len() > 1 { + panic!("ReduceMax: reducing on multiple dimensions is not supported") + } + + if axes.is_empty() && keepdims == 1 { + panic!("ReduceMax: axes must be provided with keepdims") + } + + if !axes.is_empty() && keepdims == 0 { + // Not supported in Burn + panic!("ReduceMax: the reduce operation must preserve the reduced dimension") + } + + if axes.is_empty() { + None + } else { + let mut dim = axes[0]; + + if dim < 0 { + // Accepted range is [-r, r-1] where r = rank(data) but Burn only supports positive dim + dim += tensor.rank as i64; + } + Some(dim as usize) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::{Argument, AttributeValue, ElementType, NodeType, TensorType}; + use std::collections::HashMap; + + fn create_test_node(axes: Option>, keepdims: Option) -> Node { + let inputs = vec![Argument { + name: "data".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 3, + static_shape: None, + }), + value: None, + passed: true, + }]; + + let mut attrs = HashMap::new(); + if let Some(axes_val) = axes { + attrs.insert("axes".to_string(), AttributeValue::Int64s(axes_val.clone())); + } + if let Some(kd) = keepdims { + attrs.insert("keepdims".to_string(), AttributeValue::Int64(kd)); + } + + Node { + node_type: NodeType::ReduceMax, + name: "test_reduce_max".to_string(), + inputs, + outputs: vec![Argument { + name: "reduced".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 3, + static_shape: None, + }), + value: None, + passed: true, + }], + attrs, + } + } + + #[test] + fn test_reduce_max_config_basic() { + let node = create_test_node(Some(vec![1]), Some(1)); + let dim = reduce_max_config(&node); + assert_eq!(dim, Some(1)); + } + + #[test] + fn test_reduce_max_config_negative_axis() { + let node = create_test_node(Some(vec![-2]), Some(1)); + let dim = reduce_max_config(&node); + assert_eq!(dim, Some(1)); // -2 + 3 = 1 + } + + #[test] + #[should_panic(expected = "ReduceMax: axes must be provided with keepdims")] + fn test_reduce_max_config_no_axes() { + let node = create_test_node(None, Some(1)); + let _ = reduce_max_config(&node); + } + + #[test] + #[should_panic(expected = "ReduceMax: reducing on multiple dimensions is not supported")] + fn test_reduce_max_config_multiple_axes() { + let node = create_test_node(Some(vec![0, 1]), Some(1)); + let _ = reduce_max_config(&node); + } + + #[test] + #[should_panic( + expected = "ReduceMax: the reduce operation must preserve the reduced dimension" + )] + fn test_reduce_max_config_no_keepdims() { + let node = create_test_node(Some(vec![1]), Some(0)); + let _ = reduce_max_config(&node); + } +} diff --git a/crates/onnx-ir/src/node/reduce_mean.rs b/crates/onnx-ir/src/node/reduce_mean.rs new file mode 100644 index 0000000000..bf3fc2e1b0 --- /dev/null +++ b/crates/onnx-ir/src/node/reduce_mean.rs @@ -0,0 +1,128 @@ +use crate::ir::{ArgType, Node}; + +/// Create a ReduceMeanConfig from the attributes of the node +pub fn reduce_mean_config(node: &Node) -> Option { + let mut axes = Vec::new(); + let mut keepdims = 1; + + let tensor = match node.inputs.first().unwrap().clone().ty { + ArgType::Tensor(tensor) => tensor, + _ => panic!("Only tensor input is valid"), + }; + + // Extract the attributes + for (key, value) in node.attrs.iter() { + match key.as_str() { + "axes" => axes = value.clone().into_i64s(), + "keepdims" => keepdims = value.clone().into_i64(), + _ => {} + } + } + + if axes.len() > 1 { + panic!("ReduceMean: reducing on multiple dimensions is not supported") + } + + if axes.is_empty() && keepdims == 1 { + panic!("ReduceMean: axes must be provided with keepdims") + } + + if !axes.is_empty() && keepdims == 0 { + // Not supported in Burn + panic!("ReduceMean: the reduce operation must preserve the reduced dimension") + } + + if axes.is_empty() { + None + } else { + let mut dim = axes[0]; + + if dim < 0 { + // Accepted range is [-r, r-1] where r = rank(data) but Burn only supports positive dim + dim += tensor.rank as i64; + } + Some(dim as usize) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::{Argument, AttributeValue, ElementType, NodeType, TensorType}; + use std::collections::HashMap; + + fn create_test_node(axes: Option>, keepdims: Option) -> Node { + let inputs = vec![Argument { + name: "data".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 3, + static_shape: None, + }), + value: None, + passed: true, + }]; + + let mut attrs = HashMap::new(); + if let Some(axes_val) = axes { + attrs.insert("axes".to_string(), AttributeValue::Int64s(axes_val.clone())); + } + if let Some(kd) = keepdims { + attrs.insert("keepdims".to_string(), AttributeValue::Int64(kd)); + } + + Node { + node_type: NodeType::ReduceMean, + name: "test_reduce_mean".to_string(), + inputs, + outputs: vec![Argument { + name: "reduced".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 3, + static_shape: None, + }), + value: None, + passed: true, + }], + attrs, + } + } + + #[test] + fn test_reduce_mean_config_basic() { + let node = create_test_node(Some(vec![1]), Some(1)); + let dim = reduce_mean_config(&node); + assert_eq!(dim, Some(1)); + } + + #[test] + fn test_reduce_mean_config_negative_axis() { + let node = create_test_node(Some(vec![-2]), Some(1)); + let dim = reduce_mean_config(&node); + assert_eq!(dim, Some(1)); // -2 + 3 = 1 + } + + #[test] + #[should_panic(expected = "ReduceMean: axes must be provided with keepdims")] + fn test_reduce_mean_config_no_axes() { + let node = create_test_node(None, Some(1)); + let _ = reduce_mean_config(&node); + } + + #[test] + #[should_panic(expected = "ReduceMean: reducing on multiple dimensions is not supported")] + fn test_reduce_mean_config_multiple_axes() { + let node = create_test_node(Some(vec![0, 1]), Some(1)); + let _ = reduce_mean_config(&node); + } + + #[test] + #[should_panic( + expected = "ReduceMean: the reduce operation must preserve the reduced dimension" + )] + fn test_reduce_mean_config_no_keepdims() { + let node = create_test_node(Some(vec![1]), Some(0)); + let _ = reduce_mean_config(&node); + } +} diff --git a/crates/onnx-ir/src/node/reduce_min.rs b/crates/onnx-ir/src/node/reduce_min.rs new file mode 100644 index 0000000000..5bfffb4c41 --- /dev/null +++ b/crates/onnx-ir/src/node/reduce_min.rs @@ -0,0 +1,126 @@ +use crate::ir::{ArgType, Node}; + +/// Create a ReduceMinConfig from the attributes of the node +pub fn reduce_min_config(node: &Node) -> Option { + let mut axes = Vec::new(); + let mut keepdims = 1; + + let tensor = match node.inputs.first().unwrap().clone().ty { + ArgType::Tensor(tensor) => tensor, + _ => panic!("Only tensor input is valid"), + }; + + // Extract the attributes + for (key, value) in node.attrs.iter() { + match key.as_str() { + "axes" => axes = value.clone().into_i64s(), + "keepdims" => keepdims = value.clone().into_i64(), + _ => {} + } + } + + if axes.len() > 1 { + panic!("ReduceMin: reducing on multiple dimensions is not supported") + } + + if axes.is_empty() && keepdims == 1 { + panic!("ReduceMin: axes must be provided with keepdims") + } + + if !axes.is_empty() && keepdims == 0 { + panic!("ReduceMin: the reduce operation must preserve the reduced dimension") + } + + if axes.is_empty() { + None + } else { + let mut dim = axes[0]; + + if dim < 0 { + dim += tensor.rank as i64; + } + Some(dim as usize) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::{Argument, AttributeValue, ElementType, NodeType, TensorType}; + use std::collections::HashMap; + + fn create_test_node(axes: Option>, keepdims: Option) -> Node { + let inputs = vec![Argument { + name: "data".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 3, + static_shape: None, + }), + value: None, + passed: true, + }]; + + let mut attrs = HashMap::new(); + if let Some(axes_val) = axes { + attrs.insert("axes".to_string(), AttributeValue::Int64s(axes_val.clone())); + } + if let Some(kd) = keepdims { + attrs.insert("keepdims".to_string(), AttributeValue::Int64(kd)); + } + + Node { + node_type: NodeType::ReduceMin, + name: "test_reduce_min".to_string(), + inputs, + outputs: vec![Argument { + name: "reduced".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 3, + static_shape: None, + }), + value: None, + passed: true, + }], + attrs, + } + } + + #[test] + fn test_reduce_min_config_basic() { + let node = create_test_node(Some(vec![1]), Some(1)); + let dim = reduce_min_config(&node); + assert_eq!(dim, Some(1)); + } + + #[test] + fn test_reduce_min_config_negative_axis() { + let node = create_test_node(Some(vec![-2]), Some(1)); + let dim = reduce_min_config(&node); + assert_eq!(dim, Some(1)); // -2 + 3 = 1 + } + + #[test] + #[should_panic(expected = "ReduceMin: axes must be provided with keepdims")] + fn test_reduce_min_config_no_axes() { + let node = create_test_node(None, Some(1)); + let _ = reduce_min_config(&node); + } + + #[test] + #[should_panic(expected = "ReduceMin: reducing on multiple dimensions is not supported")] + fn test_reduce_min_config_multiple_axes() { + let node = create_test_node(Some(vec![0, 1]), Some(1)); + let _ = reduce_min_config(&node); + } + + #[test] + #[should_panic( + expected = "ReduceMin: the reduce operation must preserve the reduced dimension" + )] + fn test_reduce_min_config_no_keepdims() { + let node = create_test_node(Some(vec![1]), Some(0)); + let _ = reduce_min_config(&node); + } +} diff --git a/crates/onnx-ir/src/node/reduce_prod.rs b/crates/onnx-ir/src/node/reduce_prod.rs new file mode 100644 index 0000000000..3d4f8ce7f6 --- /dev/null +++ b/crates/onnx-ir/src/node/reduce_prod.rs @@ -0,0 +1,129 @@ +use crate::ir::{ArgType, Node}; + +/// Create a ReduceProdConfig from the attributes of the node +pub fn reduce_prod_config(node: &Node) -> Option { + let mut axes = Vec::new(); + let mut keepdims = 1; + + let tensor = match node.inputs.first().unwrap().clone().ty { + ArgType::Tensor(tensor) => tensor, + _ => panic!("Only tensor input is valid"), + }; + + // Extract the attributes + for (key, value) in node.attrs.iter() { + match key.as_str() { + "axes" => axes = value.clone().into_i64s(), + "keepdims" => keepdims = value.clone().into_i64(), + // TODO: handle noop_with_empty_axes (opset 18) + _ => {} + } + } + + if axes.len() > 1 { + panic!("ReduceProd: reducing on multiple dimensions is not supported") + } + + if axes.is_empty() && keepdims == 1 { + panic!("ReduceProd: axes must be provided with keepdims") + } + + if !axes.is_empty() && keepdims == 0 { + // Not supported in Burn + panic!("ReduceProd: the reduce operation must preserve the reduced dimension") + } + + if axes.is_empty() { + None + } else { + let mut dim = axes[0]; + + if dim < 0 { + // Accepted range is [-r, r-1] where r = rank(data) but Burn only supports positive dim + dim += tensor.rank as i64; + } + Some(dim as usize) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::{Argument, AttributeValue, ElementType, NodeType, TensorType}; + use std::collections::HashMap; + + fn create_test_node(axes: Option>, keepdims: Option) -> Node { + let inputs = vec![Argument { + name: "data".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 3, + static_shape: None, + }), + value: None, + passed: true, + }]; + + let mut attrs = HashMap::new(); + if let Some(axes_val) = axes { + attrs.insert("axes".to_string(), AttributeValue::Int64s(axes_val.clone())); + } + if let Some(kd) = keepdims { + attrs.insert("keepdims".to_string(), AttributeValue::Int64(kd)); + } + + Node { + node_type: NodeType::ReduceProd, + name: "test_reduce_prod".to_string(), + inputs, + outputs: vec![Argument { + name: "reduced".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 3, + static_shape: None, + }), + value: None, + passed: true, + }], + attrs, + } + } + + #[test] + fn test_reduce_prod_config_basic() { + let node = create_test_node(Some(vec![1]), Some(1)); + let dim = reduce_prod_config(&node); + assert_eq!(dim, Some(1)); + } + + #[test] + fn test_reduce_prod_config_negative_axis() { + let node = create_test_node(Some(vec![-2]), Some(1)); + let dim = reduce_prod_config(&node); + assert_eq!(dim, Some(1)); // -2 + 3 = 1 + } + + #[test] + #[should_panic(expected = "ReduceProd: axes must be provided with keepdims")] + fn test_reduce_prod_config_no_axes() { + let node = create_test_node(None, Some(1)); + let _ = reduce_prod_config(&node); + } + + #[test] + #[should_panic(expected = "ReduceProd: reducing on multiple dimensions is not supported")] + fn test_reduce_prod_config_multiple_axes() { + let node = create_test_node(Some(vec![0, 1]), Some(1)); + let _ = reduce_prod_config(&node); + } + + #[test] + #[should_panic( + expected = "ReduceProd: the reduce operation must preserve the reduced dimension" + )] + fn test_reduce_prod_config_no_keepdims() { + let node = create_test_node(Some(vec![1]), Some(0)); + let _ = reduce_prod_config(&node); + } +} diff --git a/crates/onnx-ir/src/node/reduce_sum.rs b/crates/onnx-ir/src/node/reduce_sum.rs new file mode 100644 index 0000000000..30fa6e17fa --- /dev/null +++ b/crates/onnx-ir/src/node/reduce_sum.rs @@ -0,0 +1,172 @@ +use crate::ir::{ArgType, Node}; + +/// Create a ReduceSumConfig from the attributes of the node +pub fn reduce_sum_config(node: &Node) -> Option { + let mut axes = Vec::new(); + let mut keepdims = 1; + + let tensor = match node.inputs.first().unwrap().clone().ty { + ArgType::Tensor(tensor) => tensor, + _ => panic!("Only tensor input is valid"), + }; + + // Extract the attributes + for (key, value) in node.attrs.iter() { + match key.as_str() { + "keepdims" => keepdims = value.clone().into_i64(), + "axes" => axes = value.clone().into_i64s(), + // TODO: handle noop_with_empty_axes + _ => {} + } + } + + // Process axes from additional input (if available) + if let Some(value) = node + .inputs + .get(1) + .and_then(|argument| argument.value.as_ref()) + { + axes = value.clone().data.into_i64s(); + } + + if axes.len() > 1 { + panic!("ReduceSum: reducing on multiple dimensions is not supported") + } + + if axes.is_empty() && keepdims == 1 { + panic!("ReduceSum: axes must be provided with keepdims") + } + + if !axes.is_empty() && keepdims == 0 { + // Not supported in Burn + panic!("ReduceSum: the reduce operation must preserve the reduced dimension") + } + + if axes.is_empty() { + None + } else { + let mut dim = axes[0]; + + if dim < 0 { + // Accepted range is [-r, r-1] where r = rank(data) but Burn only supports positive dim + dim += tensor.rank as i64; + } + Some(dim as usize) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::{ + Argument, AttributeValue, Data, ElementType, NodeType, TensorData, TensorType, + }; + use std::collections::HashMap; + + fn create_test_node( + axes: Option>, + keepdims: Option, + with_axes_input: bool, + ) -> Node { + let mut inputs = vec![Argument { + name: "data".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 3, + static_shape: None, + }), + value: None, + passed: true, + }]; + + // Add axes input if requested + if with_axes_input && axes.is_some() { + let axes_clone = axes.clone().unwrap(); + inputs.push(Argument { + name: "axes".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Int64, + rank: 1, + static_shape: None, + }), + value: Some(TensorData { + data: Data::Int64s(axes_clone.clone()), + shape: vec![axes_clone.len()], + }), + passed: true, + }); + } + + let mut attrs = HashMap::new(); + if !with_axes_input && axes.is_some() { + attrs.insert( + "axes".to_string(), + AttributeValue::Int64s(axes.clone().unwrap()), + ); + } + if let Some(kd) = keepdims { + attrs.insert("keepdims".to_string(), AttributeValue::Int64(kd)); + } + + Node { + node_type: NodeType::ReduceSum, + name: "test_reduce_sum".to_string(), + inputs, + outputs: vec![Argument { + name: "reduced".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 3, + static_shape: None, + }), + value: None, + passed: true, + }], + attrs, + } + } + + #[test] + fn test_reduce_sum_config_basic() { + let node = create_test_node(Some(vec![1]), Some(1), false); + let dim = reduce_sum_config(&node); + assert_eq!(dim, Some(1)); + } + + #[test] + fn test_reduce_sum_config_with_input_axes() { + let node = create_test_node(Some(vec![1]), Some(1), true); + let dim = reduce_sum_config(&node); + assert_eq!(dim, Some(1)); + } + + #[test] + fn test_reduce_sum_config_negative_axis() { + let node = create_test_node(Some(vec![-2]), Some(1), false); + let dim = reduce_sum_config(&node); + assert_eq!(dim, Some(1)); // -2 + 3 = 1 + } + + #[test] + #[should_panic(expected = "ReduceSum: axes must be provided with keepdims")] + fn test_reduce_sum_config_no_axes() { + let node = create_test_node(None, Some(1), false); + let _ = reduce_sum_config(&node); + } + + #[test] + #[should_panic(expected = "ReduceSum: reducing on multiple dimensions is not supported")] + fn test_reduce_sum_config_multiple_axes() { + let node = create_test_node(Some(vec![0, 1]), Some(1), false); + let _ = reduce_sum_config(&node); + } + + #[test] + #[should_panic( + expected = "ReduceSum: the reduce operation must preserve the reduced dimension" + )] + fn test_reduce_sum_config_no_keepdims() { + let node = create_test_node(Some(vec![1]), Some(0), false); + let _ = reduce_sum_config(&node); + } +} diff --git a/crates/onnx-ir/src/op_configuration.rs b/crates/onnx-ir/src/op_configuration.rs index 012f325a1d..d453d84c9a 100644 --- a/crates/onnx-ir/src/op_configuration.rs +++ b/crates/onnx-ir/src/op_configuration.rs @@ -7,9 +7,11 @@ pub use crate::node::gemm::gemm_config; pub use crate::node::hard_sigmoid::hard_sigmoid_config; pub use crate::node::leaky_relu::leaky_relu_config; pub use crate::node::one_hot::one_hot_config; -pub use crate::node::reduce::{ - reduce_max_config, reduce_mean_config, reduce_min_config, reduce_prod_config, reduce_sum_config, -}; +pub use crate::node::reduce_max::reduce_max_config; +pub use crate::node::reduce_mean::reduce_mean_config; +pub use crate::node::reduce_min::reduce_min_config; +pub use crate::node::reduce_prod::reduce_prod_config; +pub use crate::node::reduce_sum::reduce_sum_config; pub use crate::node::reshape::reshape_config; pub use crate::node::resize::resize_config; pub use crate::node::shape::shape_config; From b21369bc13d3948d13383b9e01aaceddd87de641 Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Wed, 30 Apr 2025 15:43:46 -0500 Subject: [PATCH 12/37] Move rank inference functions --- crates/onnx-ir/src/node/one_hot.rs | 22 ++- crates/onnx-ir/src/node/reduce_max.rs | 34 +++- crates/onnx-ir/src/node/reduce_mean.rs | 33 +++- crates/onnx-ir/src/node/reduce_min.rs | 34 +++- crates/onnx-ir/src/node/reduce_prod.rs | 34 +++- crates/onnx-ir/src/node/reduce_sum.rs | 38 +++- crates/onnx-ir/src/node/shape.rs | 17 ++ crates/onnx-ir/src/node/squeeze.rs | 38 +++- crates/onnx-ir/src/rank_inference.rs | 245 +------------------------ 9 files changed, 250 insertions(+), 245 deletions(-) diff --git a/crates/onnx-ir/src/node/one_hot.rs b/crates/onnx-ir/src/node/one_hot.rs index 496a357e57..2b01bee530 100644 --- a/crates/onnx-ir/src/node/one_hot.rs +++ b/crates/onnx-ir/src/node/one_hot.rs @@ -1,4 +1,4 @@ -use crate::ir::Node; +use crate::ir::{ArgType, Node, TensorType}; pub fn one_hot_config(curr: &Node) -> (usize, [f32; 2], i64) { let depth = curr.inputs[1] @@ -24,6 +24,26 @@ pub fn one_hot_config(curr: &Node) -> (usize, [f32; 2], i64) { (depth as usize, values.try_into().unwrap(), axis) } +/// Update output rank for OneHot (input rank + 1). +pub fn one_hot_output_shape(node: &mut Node) { + log::debug!("OneHot rank inference for node {}", node.name); + + let input_rank = match &node.inputs[0].ty { + ArgType::Tensor(tensor) => tensor.rank, + _ => panic!("OneHot: invalid input type"), + }; + log::debug!("OneHot input rank for {}: {}", node.name, input_rank); + + let output_rank = input_rank + 1; + log::debug!("OneHot output rank for {}: {}", node.name, output_rank); + + node.outputs[0].ty = ArgType::Tensor(TensorType { + elem_type: node.outputs[0].ty.elem_type().clone(), + rank: output_rank, + static_shape: None, + }); +} + #[cfg(test)] mod tests { use super::*; diff --git a/crates/onnx-ir/src/node/reduce_max.rs b/crates/onnx-ir/src/node/reduce_max.rs index d3a60e4e4c..883914f1b3 100644 --- a/crates/onnx-ir/src/node/reduce_max.rs +++ b/crates/onnx-ir/src/node/reduce_max.rs @@ -1,4 +1,4 @@ -use crate::ir::{ArgType, Node}; +use crate::ir::{ArgType, AttributeValue, Node, TensorType}; /// Create a ReduceMaxConfig from the attributes of the node pub fn reduce_max_config(node: &Node) -> Option { @@ -45,6 +45,38 @@ pub fn reduce_max_config(node: &Node) -> Option { } } +/// Update output rank for ReduceMax based on axes. +pub fn reduce_max_update_outputs(node: &mut Node) { + log::debug!("ReduceMax rank inference for node {}", node.name); + + if node.inputs.len() != 1 { + panic!("ReduceMax: multiple inputs are not supported"); + } + let tensor = match &node.inputs[0].ty { + ArgType::Tensor(tensor) => tensor, + _ => panic!("Only tensor input is valid"), + }; + log::debug!("ReduceMax input rank for {}: {}", node.name, tensor.rank); + + let dim_only = match node.attrs.get("axes") { + Some(value) => match &value { + AttributeValue::Int64(_) => true, + AttributeValue::Int64s(ints) => ints.len() == 1, + _ => false, + }, + None => false, + }; + + let output_rank = if dim_only { tensor.rank } else { 1 }; + log::debug!("ReduceMax output rank for {}: {}", node.name, output_rank); + + node.outputs[0].ty = ArgType::Tensor(TensorType { + elem_type: tensor.elem_type.clone(), + rank: output_rank, + static_shape: None, + }); +} + #[cfg(test)] mod tests { use super::*; diff --git a/crates/onnx-ir/src/node/reduce_mean.rs b/crates/onnx-ir/src/node/reduce_mean.rs index bf3fc2e1b0..d3c0775226 100644 --- a/crates/onnx-ir/src/node/reduce_mean.rs +++ b/crates/onnx-ir/src/node/reduce_mean.rs @@ -1,4 +1,4 @@ -use crate::ir::{ArgType, Node}; +use crate::ir::{ArgType, AttributeValue, Node, TensorType}; /// Create a ReduceMeanConfig from the attributes of the node pub fn reduce_mean_config(node: &Node) -> Option { @@ -45,6 +45,37 @@ pub fn reduce_mean_config(node: &Node) -> Option { } } +/// Update output rank for ReduceMean based on axes. +pub fn reduce_mean_update_outputs(node: &mut Node) { + log::debug!("ReduceMean rank inference for node {}", node.name); + + if node.inputs.len() != 1 { + panic!("ReduceMean: multiple inputs are not supported"); + } + let tensor = match &node.inputs[0].ty { + ArgType::Tensor(tensor) => tensor, + _ => panic!("Only tensor input is valid"), + }; + + let dim_only = match node.attrs.get("axes") { + Some(value) => match &value { + AttributeValue::Int64(_) => true, + AttributeValue::Int64s(ints) => ints.len() == 1, + _ => false, + }, + None => false, + }; + + let output_rank = if dim_only { tensor.rank } else { 1 }; + log::debug!("ReduceMean output rank for {}: {}", node.name, output_rank); + + node.outputs[0].ty = ArgType::Tensor(TensorType { + elem_type: tensor.elem_type.clone(), + rank: output_rank, + static_shape: None, + }); +} + #[cfg(test)] mod tests { use super::*; diff --git a/crates/onnx-ir/src/node/reduce_min.rs b/crates/onnx-ir/src/node/reduce_min.rs index 5bfffb4c41..ccf28acb66 100644 --- a/crates/onnx-ir/src/node/reduce_min.rs +++ b/crates/onnx-ir/src/node/reduce_min.rs @@ -1,4 +1,4 @@ -use crate::ir::{ArgType, Node}; +use crate::ir::{ArgType, AttributeValue, Node, TensorType}; /// Create a ReduceMinConfig from the attributes of the node pub fn reduce_min_config(node: &Node) -> Option { @@ -43,6 +43,38 @@ pub fn reduce_min_config(node: &Node) -> Option { } } +/// Update output rank for ReduceMin based on axes. +pub fn reduce_min_update_outputs(node: &mut Node) { + log::debug!("ReduceMin rank inference for node {}", node.name); + + if node.inputs.len() != 1 { + panic!("ReduceMin: multiple inputs are not supported"); + } + let tensor = match &node.inputs[0].ty { + ArgType::Tensor(tensor) => tensor, + _ => panic!("Only tensor input is valid"), + }; + log::debug!("ReduceMin input rank for {}: {}", node.name, tensor.rank); + + let dim_only = match node.attrs.get("axes") { + Some(value) => match &value { + AttributeValue::Int64(_) => true, + AttributeValue::Int64s(ints) => ints.len() == 1, + _ => false, + }, + None => false, + }; + + let output_rank = if dim_only { tensor.rank } else { 1 }; + log::debug!("ReduceMin output rank for {}: {}", node.name, output_rank); + + node.outputs[0].ty = ArgType::Tensor(TensorType { + elem_type: tensor.elem_type.clone(), + rank: output_rank, + static_shape: None, + }); +} + #[cfg(test)] mod tests { use super::*; diff --git a/crates/onnx-ir/src/node/reduce_prod.rs b/crates/onnx-ir/src/node/reduce_prod.rs index 3d4f8ce7f6..2f4df96b84 100644 --- a/crates/onnx-ir/src/node/reduce_prod.rs +++ b/crates/onnx-ir/src/node/reduce_prod.rs @@ -1,4 +1,4 @@ -use crate::ir::{ArgType, Node}; +use crate::ir::{ArgType, AttributeValue, Node, TensorType}; /// Create a ReduceProdConfig from the attributes of the node pub fn reduce_prod_config(node: &Node) -> Option { @@ -46,6 +46,38 @@ pub fn reduce_prod_config(node: &Node) -> Option { } } +/// Update output rank for ReduceProd based on axes. +pub fn reduce_prod_update_outputs(node: &mut Node) { + log::debug!("ReduceProd rank inference for node {}", node.name); + + if node.inputs.len() != 1 { + panic!("ReduceProd: multiple inputs are not supported"); + } + let tensor = match &node.inputs[0].ty { + ArgType::Tensor(tensor) => tensor, + _ => panic!("Only tensor input is valid"), + }; + log::debug!("ReduceProd input rank for {}: {}", node.name, tensor.rank); + + let dim_only = match node.attrs.get("axes") { + Some(value) => match &value { + AttributeValue::Int64(_) => true, + AttributeValue::Int64s(ints) => ints.len() == 1, + _ => false, + }, + None => false, + }; + + let output_rank = if dim_only { tensor.rank } else { 1 }; + log::debug!("ReduceProd output rank for {}: {}", node.name, output_rank); + + node.outputs[0].ty = ArgType::Tensor(TensorType { + elem_type: tensor.elem_type.clone(), + rank: output_rank, + static_shape: None, + }); +} + #[cfg(test)] mod tests { use super::*; diff --git a/crates/onnx-ir/src/node/reduce_sum.rs b/crates/onnx-ir/src/node/reduce_sum.rs index 30fa6e17fa..6182d22c09 100644 --- a/crates/onnx-ir/src/node/reduce_sum.rs +++ b/crates/onnx-ir/src/node/reduce_sum.rs @@ -1,4 +1,4 @@ -use crate::ir::{ArgType, Node}; +use crate::ir::{ArgType, AttributeValue, Data, Node, TensorType}; /// Create a ReduceSumConfig from the attributes of the node pub fn reduce_sum_config(node: &Node) -> Option { @@ -55,6 +55,42 @@ pub fn reduce_sum_config(node: &Node) -> Option { } } +/// Update output rank for ReduceSum based on axes. +pub fn reduce_sum_update_outputs(node: &mut Node) { + log::debug!("ReduceSum rank inference for node {}", node.name); + + let tensor = match &node.inputs[0].ty { + ArgType::Tensor(tensor) => tensor, + _ => panic!("Only tensor input is valid"), + }; + log::debug!("ReduceSum input rank for {}: {}", node.name, tensor.rank); + + let dim_only = match node.attrs.get("axes") { + Some(value) => match &value { + AttributeValue::Int64(_) => true, + AttributeValue::Int64s(ints) => ints.len() == 1, + _ => false, + }, + None => false, + } || match node.inputs.get(1).and_then(|arg| arg.value.as_ref()) { + Some(value) => match &value.data { + Data::Int64(_) => true, + Data::Int64s(ints) => ints.len() == 1, + _ => false, + }, + None => false, + }; + + let output_rank = if dim_only { tensor.rank } else { 1 }; + log::debug!("ReduceSum output rank for {}: {}", node.name, output_rank); + + node.outputs[0].ty = ArgType::Tensor(TensorType { + elem_type: tensor.elem_type.clone(), + rank: output_rank, + static_shape: None, + }); +} + #[cfg(test)] mod tests { use super::*; diff --git a/crates/onnx-ir/src/node/shape.rs b/crates/onnx-ir/src/node/shape.rs index 3490ae1c63..694fc05436 100644 --- a/crates/onnx-ir/src/node/shape.rs +++ b/crates/onnx-ir/src/node/shape.rs @@ -38,6 +38,23 @@ pub fn shape_config(curr: &Node) -> (usize, usize) { (start_dim as usize, end_dim as usize) } +/// Update output type for Shape operation (rank 1). +pub fn shape_update_outputs(node: &mut Node) { + if node.inputs.len() != 1 { + panic!("Shape: multiple inputs are not supported: {:?}", node); + } + let (start, end) = shape_config(node); + let dim = end - start; + log::debug!( + "Shape operation for node {}: start={}, end={}, dim={}", + node.name, + start, + end, + dim + ); + node.outputs[0].ty = ArgType::Shape(dim); +} + #[cfg(test)] mod tests { use super::*; diff --git a/crates/onnx-ir/src/node/squeeze.rs b/crates/onnx-ir/src/node/squeeze.rs index 8165eb2cf3..9c528cd29e 100644 --- a/crates/onnx-ir/src/node/squeeze.rs +++ b/crates/onnx-ir/src/node/squeeze.rs @@ -1,4 +1,4 @@ -use crate::ir::{ArgType, Node}; +use crate::ir::{ArgType, Data, Node, TensorType}; pub fn squeeze_config(curr: &Node) -> Vec { let axes = curr @@ -22,6 +22,42 @@ pub fn squeeze_config(curr: &Node) -> Vec { axes } +/// Update output rank for Squeeze based on axes. +pub fn squeeze_update_output(node: &mut Node) { + log::debug!("Squeeze rank inference for node {}", node.name); + + let axes = if node.inputs.len() == 2 { + match &node.inputs[1].value { + Some(value) => match &value.data { + Data::Int64s(axes) => Some(axes.clone()), + _ => panic!("Squeeze: invalid input types"), + }, + None => None, + } + } else { + node.attrs.get("axes").cloned().map(|v| v.into_i64s()) + }; + + let axes = axes.unwrap_or_else(|| panic!("Squeeze must specify an axis")); + log::debug!("Squeeze axes for {}: {:?}", node.name, axes); + + let input_rank = match &node.inputs[0].ty { + ArgType::Tensor(tensor) => tensor.rank, + ty => panic!("Squeeze: invalid input type: {:?}", ty), + }; + + log::debug!("Squeeze input rank for {}: {}", node.name, input_rank); + + let output_rank = input_rank - axes.len(); + log::debug!("Squeeze output rank for {}: {}", node.name, output_rank); + + node.outputs[0].ty = ArgType::Tensor(TensorType { + elem_type: node.inputs[0].ty.elem_type().clone(), + rank: output_rank, + static_shape: None, + }); +} + #[cfg(test)] mod tests { use super::*; diff --git a/crates/onnx-ir/src/rank_inference.rs b/crates/onnx-ir/src/rank_inference.rs index 7e4f8792bb..9bab387e70 100644 --- a/crates/onnx-ir/src/rank_inference.rs +++ b/crates/onnx-ir/src/rank_inference.rs @@ -5,9 +5,14 @@ use protobuf::Enum; use crate::{ ir::{ArgType, AttributeValue, Data, ElementType, Node, NodeType, TensorType}, - node::slice::slice_update_output_rank, + node::{ + one_hot::one_hot_output_shape, reduce_max::reduce_max_update_outputs, + reduce_mean::reduce_mean_update_outputs, reduce_min::reduce_min_update_outputs, + reduce_prod::reduce_prod_update_outputs, reduce_sum::reduce_sum_update_outputs, + shape::shape_update_outputs, slice::slice_update_output_rank, + squeeze::squeeze_update_output, + }, protos::tensor_proto::DataType, - util::shape_config, }; /// Infer the rank of each output tensor and update them based solely on rank inference. @@ -439,37 +444,6 @@ fn reshape_update_outputs(node: &mut Node) { }); } -/// Update output rank for ReduceMean based on axes. -fn reduce_mean_update_outputs(node: &mut Node) { - log::debug!("ReduceMean rank inference for node {}", node.name); - - if node.inputs.len() != 1 { - panic!("ReduceMean: multiple inputs are not supported"); - } - let tensor = match &node.inputs[0].ty { - ArgType::Tensor(tensor) => tensor, - _ => panic!("Only tensor input is valid"), - }; - - let dim_only = match node.attrs.get("axes") { - Some(value) => match &value { - AttributeValue::Int64(_) => true, - AttributeValue::Int64s(ints) => ints.len() == 1, - _ => false, - }, - None => false, - }; - - let output_rank = if dim_only { tensor.rank } else { 1 }; - log::debug!("ReduceMean output rank for {}: {}", node.name, output_rank); - - node.outputs[0].ty = ArgType::Tensor(TensorType { - elem_type: tensor.elem_type.clone(), - rank: output_rank, - static_shape: None, - }); -} - /// Update output rank for ArgMax (same as input rank). fn argmax_update_outputs(node: &mut Node) { log::debug!("ArgMax rank inference for node {}", node.name); @@ -494,42 +468,6 @@ fn argmax_update_outputs(node: &mut Node) { log::debug!("ArgMax output rank for {}: {}", node.name, tensor.rank); } -/// Update output rank for Squeeze based on axes. -fn squeeze_update_output(node: &mut Node) { - log::debug!("Squeeze rank inference for node {}", node.name); - - let axes = if node.inputs.len() == 2 { - match &node.inputs[1].value { - Some(value) => match &value.data { - Data::Int64s(axes) => Some(axes.clone()), - _ => panic!("Squeeze: invalid input types"), - }, - None => None, - } - } else { - node.attrs.get("axes").cloned().map(|v| v.into_i64s()) - }; - - let axes = axes.unwrap_or_else(|| panic!("Squeeze must specify an axis")); - log::debug!("Squeeze axes for {}: {:?}", node.name, axes); - - let input_rank = match &node.inputs[0].ty { - ArgType::Tensor(tensor) => tensor.rank, - ty => panic!("Squeeze: invalid input type: {:?}", ty), - }; - - log::debug!("Squeeze input rank for {}: {}", node.name, input_rank); - - let output_rank = input_rank - axes.len(); - log::debug!("Squeeze output rank for {}: {}", node.name, output_rank); - - node.outputs[0].ty = ArgType::Tensor(TensorType { - elem_type: node.inputs[0].ty.elem_type().clone(), - rank: output_rank, - static_shape: None, - }); -} - /// Update output rank for broadcasting operations (e.g., Add, Sub) to max input rank. fn same_as_input_broadcast(node: &mut Node) { log::debug!("Broadcasting operation for node {}", node.name); @@ -779,23 +717,6 @@ fn expand_update_outputs(node: &mut Node) { } } -/// Update output type for Shape operation (rank 1). -fn shape_update_outputs(node: &mut Node) { - if node.inputs.len() != 1 { - panic!("Shape: multiple inputs are not supported: {:?}", node); - } - let (start, end) = shape_config(node); - let dim = end - start; - log::debug!( - "Shape operation for node {}: start={}, end={}, dim={}", - node.name, - start, - end, - dim - ); - node.outputs[0].ty = ArgType::Shape(dim); -} - /// Update output type for Flatten operation (rank 2). fn flatten_update_outputs(node: &mut Node) { if node.inputs.len() != 1 { @@ -872,138 +793,6 @@ fn range_update_outputs(node: &mut Node) { log::debug!("Range output rank for {}: 1", node.name); } -/// Update output rank for ReduceMax based on axes. -fn reduce_max_update_outputs(node: &mut Node) { - log::debug!("ReduceMax rank inference for node {}", node.name); - - if node.inputs.len() != 1 { - panic!("ReduceMax: multiple inputs are not supported"); - } - let tensor = match &node.inputs[0].ty { - ArgType::Tensor(tensor) => tensor, - _ => panic!("Only tensor input is valid"), - }; - log::debug!("ReduceMax input rank for {}: {}", node.name, tensor.rank); - - let dim_only = match node.attrs.get("axes") { - Some(value) => match &value { - AttributeValue::Int64(_) => true, - AttributeValue::Int64s(ints) => ints.len() == 1, - _ => false, - }, - None => false, - }; - - let output_rank = if dim_only { tensor.rank } else { 1 }; - log::debug!("ReduceMax output rank for {}: {}", node.name, output_rank); - - node.outputs[0].ty = ArgType::Tensor(TensorType { - elem_type: tensor.elem_type.clone(), - rank: output_rank, - static_shape: None, - }); -} - -/// Update output rank for ReduceMin based on axes. -fn reduce_min_update_outputs(node: &mut Node) { - log::debug!("ReduceMin rank inference for node {}", node.name); - - if node.inputs.len() != 1 { - panic!("ReduceMin: multiple inputs are not supported"); - } - let tensor = match &node.inputs[0].ty { - ArgType::Tensor(tensor) => tensor, - _ => panic!("Only tensor input is valid"), - }; - log::debug!("ReduceMin input rank for {}: {}", node.name, tensor.rank); - - let dim_only = match node.attrs.get("axes") { - Some(value) => match &value { - AttributeValue::Int64(_) => true, - AttributeValue::Int64s(ints) => ints.len() == 1, - _ => false, - }, - None => false, - }; - - let output_rank = if dim_only { tensor.rank } else { 1 }; - log::debug!("ReduceMin output rank for {}: {}", node.name, output_rank); - - node.outputs[0].ty = ArgType::Tensor(TensorType { - elem_type: tensor.elem_type.clone(), - rank: output_rank, - static_shape: None, - }); -} - -/// Update output rank for ReduceProd based on axes. -fn reduce_prod_update_outputs(node: &mut Node) { - log::debug!("ReduceProd rank inference for node {}", node.name); - - if node.inputs.len() != 1 { - panic!("ReduceProd: multiple inputs are not supported"); - } - let tensor = match &node.inputs[0].ty { - ArgType::Tensor(tensor) => tensor, - _ => panic!("Only tensor input is valid"), - }; - log::debug!("ReduceProd input rank for {}: {}", node.name, tensor.rank); - - let dim_only = match node.attrs.get("axes") { - Some(value) => match &value { - AttributeValue::Int64(_) => true, - AttributeValue::Int64s(ints) => ints.len() == 1, - _ => false, - }, - None => false, - }; - - let output_rank = if dim_only { tensor.rank } else { 1 }; - log::debug!("ReduceProd output rank for {}: {}", node.name, output_rank); - - node.outputs[0].ty = ArgType::Tensor(TensorType { - elem_type: tensor.elem_type.clone(), - rank: output_rank, - static_shape: None, - }); -} - -/// Update output rank for ReduceSum based on axes. -fn reduce_sum_update_outputs(node: &mut Node) { - log::debug!("ReduceSum rank inference for node {}", node.name); - - let tensor = match &node.inputs[0].ty { - ArgType::Tensor(tensor) => tensor, - _ => panic!("Only tensor input is valid"), - }; - log::debug!("ReduceSum input rank for {}: {}", node.name, tensor.rank); - - let dim_only = match node.attrs.get("axes") { - Some(value) => match &value { - AttributeValue::Int64(_) => true, - AttributeValue::Int64s(ints) => ints.len() == 1, - _ => false, - }, - None => false, - } || match node.inputs.get(1).and_then(|arg| arg.value.as_ref()) { - Some(value) => match &value.data { - Data::Int64(_) => true, - Data::Int64s(ints) => ints.len() == 1, - _ => false, - }, - None => false, - }; - - let output_rank = if dim_only { tensor.rank } else { 1 }; - log::debug!("ReduceSum output rank for {}: {}", node.name, output_rank); - - node.outputs[0].ty = ArgType::Tensor(TensorType { - elem_type: tensor.elem_type.clone(), - rank: output_rank, - static_shape: None, - }); -} - /// Update output rank for Where to max input rank. fn where_update_outputs(node: &mut Node) { log::debug!("Where rank inference for node {}", node.name); @@ -1149,26 +938,6 @@ fn split_update_outputs(node: &mut Node) { } } -/// Update output rank for OneHot (input rank + 1). -fn one_hot_output_shape(node: &mut Node) { - log::debug!("OneHot rank inference for node {}", node.name); - - let input_rank = match &node.inputs[0].ty { - ArgType::Tensor(tensor) => tensor.rank, - _ => panic!("OneHot: invalid input type"), - }; - log::debug!("OneHot input rank for {}: {}", node.name, input_rank); - - let output_rank = input_rank + 1; - log::debug!("OneHot output rank for {}: {}", node.name, output_rank); - - node.outputs[0].ty = ArgType::Tensor(TensorType { - elem_type: node.outputs[0].ty.elem_type().clone(), - rank: output_rank, - static_shape: None, - }); -} - fn gemm_output_shape(node: &mut Node) { log::debug!("Gemm rank inference for node {}", node.name); From 03ea3004ecbadc7f4bf54d65c8af40a6dabb3052 Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Wed, 30 Apr 2025 17:25:32 -0500 Subject: [PATCH 13/37] Move rank updates to node module --- crates/onnx-ir/src/node/argmax.rs | 26 +- crates/onnx-ir/src/node/cast.rs | 137 +++ crates/onnx-ir/src/node/comparison.rs | 121 +++ crates/onnx-ir/src/node/concat.rs | 26 +- crates/onnx-ir/src/node/constant.rs | 143 +++ crates/onnx-ir/src/node/constant_of_shape.rs | 162 ++++ crates/onnx-ir/src/node/expand.rs | 154 ++++ crates/onnx-ir/src/node/flatten.rs | 23 +- crates/onnx-ir/src/node/gather.rs | 75 +- crates/onnx-ir/src/node/gemm.rs | 36 +- crates/onnx-ir/src/node/linear.rs | 21 +- crates/onnx-ir/src/node/matmul.rs | 139 +++ crates/onnx-ir/src/node/mod.rs | 13 + crates/onnx-ir/src/node/random.rs | 113 +++ crates/onnx-ir/src/node/random_like.rs | 122 +++ crates/onnx-ir/src/node/range.rs | 93 ++ crates/onnx-ir/src/node/reshape.rs | 52 +- crates/onnx-ir/src/node/split.rs | 110 +++ crates/onnx-ir/src/node/topk.rs | 124 +++ crates/onnx-ir/src/node/unsqueeze.rs | 202 +++++ crates/onnx-ir/src/node/where_op.rs | 161 ++++ crates/onnx-ir/src/rank_inference.rs | 883 +------------------ crates/onnx-ir/src/util.rs | 190 +++- 23 files changed, 2250 insertions(+), 876 deletions(-) create mode 100644 crates/onnx-ir/src/node/cast.rs create mode 100644 crates/onnx-ir/src/node/comparison.rs create mode 100644 crates/onnx-ir/src/node/constant.rs create mode 100644 crates/onnx-ir/src/node/constant_of_shape.rs create mode 100644 crates/onnx-ir/src/node/expand.rs create mode 100644 crates/onnx-ir/src/node/matmul.rs create mode 100644 crates/onnx-ir/src/node/random.rs create mode 100644 crates/onnx-ir/src/node/random_like.rs create mode 100644 crates/onnx-ir/src/node/range.rs create mode 100644 crates/onnx-ir/src/node/split.rs create mode 100644 crates/onnx-ir/src/node/topk.rs create mode 100644 crates/onnx-ir/src/node/unsqueeze.rs create mode 100644 crates/onnx-ir/src/node/where_op.rs diff --git a/crates/onnx-ir/src/node/argmax.rs b/crates/onnx-ir/src/node/argmax.rs index ae4826233f..8657859681 100644 --- a/crates/onnx-ir/src/node/argmax.rs +++ b/crates/onnx-ir/src/node/argmax.rs @@ -1,4 +1,4 @@ -use crate::ir::{ArgType, Node}; +use crate::ir::{ArgType, ElementType, Node, TensorType}; /// Create argmax config from the attributes of the node pub fn argmax_config(node: &Node) -> usize { @@ -52,6 +52,30 @@ pub fn argmax_config(node: &Node) -> usize { axis as usize } +/// Update output rank for ArgMax (same as input rank). +pub fn argmax_update_outputs(node: &mut Node) { + log::debug!("ArgMax rank inference for node {}", node.name); + + if node.inputs.len() != 1 { + panic!("ArgMax: multiple inputs are not supported"); + } + let tensor = match &node.inputs[0].ty { + ArgType::Tensor(tensor) => tensor, + _ => panic!("Only tensor input is valid"), + }; + + log::debug!("ArgMax input rank for {}: {}", node.name, tensor.rank); + + // Note: argmax in burn does not support keepdims=false + node.outputs[0].ty = ArgType::Tensor(TensorType { + elem_type: ElementType::Int64, + rank: tensor.rank, + static_shape: None, + }); + + log::debug!("ArgMax output rank for {}: {}", node.name, tensor.rank); +} + #[cfg(test)] mod tests { use super::*; diff --git a/crates/onnx-ir/src/node/cast.rs b/crates/onnx-ir/src/node/cast.rs new file mode 100644 index 0000000000..afb805c74a --- /dev/null +++ b/crates/onnx-ir/src/node/cast.rs @@ -0,0 +1,137 @@ +use crate::ir::{ArgType, AttributeValue, ElementType, Node, TensorType}; +use crate::protos::tensor_proto::DataType; +use protobuf::Enum; + +/// Update output type for Cast operations, preserving rank. +pub fn cast_update_outputs(node: &mut Node) { + if node.inputs.len() != 1 { + panic!("Cast: multiple inputs are not supported"); + } + let input = &mut node.inputs[0]; + let output = &mut node.outputs[0]; + + let elem_type = match node.attrs.get("to") { + Some(value) => match &value { + AttributeValue::Int64(type_id) => match DataType::from_i32(*type_id as i32).unwrap() { + DataType::FLOAT => ElementType::Float32, + DataType::INT32 => ElementType::Int32, + DataType::INT64 => ElementType::Int64, + DataType::DOUBLE => ElementType::Float64, + DataType::BOOL => ElementType::Bool, + _ => panic!("Cast: unsupported type"), + }, + _ => panic!("'to' attribute must be an Int64"), + }, + None => panic!("Cast node must have a 'to' attribute"), + }; + + match input.ty.clone() { + ArgType::Tensor(tensor) => { + if tensor.rank == 0 { + // treat 0-dim tensor as scalar + output.ty = ArgType::Scalar(elem_type); + input.ty = ArgType::Scalar(tensor.elem_type); + } else { + // Cast input and output are the same shape, but possibly different types + output.ty = ArgType::Tensor(TensorType { + elem_type, + rank: tensor.rank, + static_shape: None, + }); + } + } + ArgType::Scalar(_) => output.ty = ArgType::Scalar(elem_type), + _ => panic!("Cast: only scalar and tensor inputs are valid"), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::{Argument, NodeType}; + use std::collections::HashMap; + + fn create_test_node(input_rank: usize, to_type: i64) -> Node { + let mut attrs = HashMap::new(); + attrs.insert("to".to_string(), AttributeValue::Int64(to_type)); + + let inputs = vec![Argument { + name: "X".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: input_rank, + static_shape: None, + }), + value: None, + passed: true, + }]; + + Node { + node_type: NodeType::Cast, + name: "test_cast".to_string(), + inputs, + outputs: vec![Argument { + name: "Y".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, // This will be overwritten + rank: 0, + static_shape: None, + }), + value: None, + passed: true, + }], + attrs, + } + } + + #[test] + fn test_cast_float_to_int64() { + let mut node = create_test_node(2, DataType::INT64.value() as i64); + cast_update_outputs(&mut node); + + match &node.outputs[0].ty { + ArgType::Tensor(tensor) => { + assert_eq!(tensor.elem_type, ElementType::Int64); + assert_eq!(tensor.rank, 2); + } + _ => panic!("Expected tensor output"), + } + } + + #[test] + fn test_cast_scalar_handling() { + let mut node = create_test_node(0, DataType::BOOL.value() as i64); + cast_update_outputs(&mut node); + + match &node.outputs[0].ty { + ArgType::Scalar(elem_type) => { + assert_eq!(*elem_type, ElementType::Bool); + } + _ => panic!("Expected scalar output for 0-rank tensor"), + } + + match &node.inputs[0].ty { + ArgType::Scalar(elem_type) => { + assert_eq!(*elem_type, ElementType::Float32); + } + _ => panic!("Input should have been converted to scalar"), + } + } + + #[test] + #[should_panic(expected = "Cast: multiple inputs are not supported")] + fn test_cast_multiple_inputs() { + let mut node = create_test_node(2, DataType::INT64.value() as i64); + node.inputs.push(Argument { + name: "extra".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 1, + static_shape: None, + }), + value: None, + passed: true, + }); + cast_update_outputs(&mut node); + } +} diff --git a/crates/onnx-ir/src/node/comparison.rs b/crates/onnx-ir/src/node/comparison.rs new file mode 100644 index 0000000000..37bbb91bfc --- /dev/null +++ b/crates/onnx-ir/src/node/comparison.rs @@ -0,0 +1,121 @@ +use crate::ir::{ArgType, ElementType, Node, TensorType}; + +/// Update output type for comparison operations (e.g., Equal, Greater) to max input rank. +pub fn elementwise_comparison_outputs(node: &mut Node) { + log::debug!("Elementwise comparison for node {}", node.name); + + let max_rank = node.inputs.iter().fold(0, |acc, input| match &input.ty { + ArgType::Tensor(tensor) => acc.max(tensor.rank), + ArgType::Scalar(_) => acc, + _ => panic!("Invalid input type for comparison op"), + }); + + log::debug!("Max rank for comparison node {}: {}", node.name, max_rank); + + if max_rank == 0 { + node.outputs[0].ty = ArgType::Scalar(ElementType::Bool); + log::debug!("Scalar boolean result for node {}", node.name); + } else { + node.outputs[0].ty = ArgType::Tensor(TensorType { + elem_type: ElementType::Bool, + rank: max_rank, + static_shape: None, + }); + log::debug!( + "Tensor boolean result for node {} with rank {}", + node.name, + max_rank + ); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::{Argument, NodeType}; + use std::collections::HashMap; + + fn create_test_node(input1_rank: usize, input2_rank: usize) -> Node { + let inputs = vec![ + Argument { + name: "A".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: input1_rank, + static_shape: None, + }), + value: None, + passed: true, + }, + Argument { + name: "B".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: input2_rank, + static_shape: None, + }), + value: None, + passed: true, + }, + ]; + + let outputs = vec![Argument { + name: "result".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Bool, + rank: 0, // Will be updated + static_shape: None, + }), + value: None, + passed: true, + }]; + + Node { + node_type: NodeType::Equal, + name: "test_comparison".to_string(), + inputs, + outputs, + attrs: HashMap::new(), + } + } + + #[test] + fn test_comparison_rank_broadcasting() { + let mut node = create_test_node(2, 3); + elementwise_comparison_outputs(&mut node); + + match &node.outputs[0].ty { + ArgType::Tensor(tensor) => { + assert_eq!(tensor.elem_type, ElementType::Bool); + assert_eq!(tensor.rank, 3); // max(2, 3) = 3 + } + _ => panic!("Expected tensor output"), + } + } + + #[test] + fn test_comparison_scalar_result() { + let mut node = create_test_node(0, 0); + + // Convert inputs to scalars + node.inputs[0].ty = ArgType::Scalar(ElementType::Float32); + node.inputs[1].ty = ArgType::Scalar(ElementType::Float32); + + elementwise_comparison_outputs(&mut node); + + match &node.outputs[0].ty { + ArgType::Scalar(elem_type) => { + assert_eq!(*elem_type, ElementType::Bool); + } + _ => panic!("Expected scalar output"), + } + } + + #[test] + #[should_panic(expected = "Invalid input type for comparison op")] + fn test_comparison_invalid_input() { + let mut node = create_test_node(2, 2); + node.inputs[0].ty = ArgType::Shape(2); + elementwise_comparison_outputs(&mut node); + } +} diff --git a/crates/onnx-ir/src/node/concat.rs b/crates/onnx-ir/src/node/concat.rs index 147b552869..b2b4c4b73e 100644 --- a/crates/onnx-ir/src/node/concat.rs +++ b/crates/onnx-ir/src/node/concat.rs @@ -1,4 +1,28 @@ -use crate::ir::{ArgType, Node}; +use crate::ir::{ArgType, Node, TensorType}; + +/// Update output rank for Concat (same as first tensor input). +pub fn concat_update_outputs(node: &mut Node) { + log::debug!("Concat rank inference for node {}", node.name); + + let tensor = node + .inputs + .iter() + .find_map(|input| match &input.ty { + ArgType::Tensor(tensor) => Some(tensor.clone()), + _ => None, + }) + .unwrap(); + + log::debug!("Concat using input rank for {}: {}", node.name, tensor.rank); + + node.outputs[0].ty = ArgType::Tensor(TensorType { + elem_type: tensor.elem_type, + rank: tensor.rank, + static_shape: None, + }); + + log::debug!("Concat output rank for {}: {}", node.name, tensor.rank); +} /// Create concat config from the attributes of the node pub fn concat_config(node: &Node) -> usize { diff --git a/crates/onnx-ir/src/node/constant.rs b/crates/onnx-ir/src/node/constant.rs new file mode 100644 index 0000000000..3bd48c0c18 --- /dev/null +++ b/crates/onnx-ir/src/node/constant.rs @@ -0,0 +1,143 @@ +use crate::ir::{ArgType, AttributeValue, ElementType, Node, TensorType}; + +/// Update output type for constant nodes based on attribute values, focusing on rank only. +pub fn constant_update_outputs(node: &mut Node) { + log::debug!("Constant rank inference for node {}", node.name); + + let keys = [ + "value", + "value_float", + "value_floats", + "value_int", + "value_ints", + "value_string", + "value_strings", + "sparse_value", + ]; + + let matched_value = keys.iter().find_map(|&key| node.attrs.get(key).cloned()); + log::debug!("Constant found attribute: {}", matched_value.is_some()); + + node.outputs[0].ty = match matched_value { + Some(value) => match &value { + AttributeValue::Tensor(tensor) if tensor.shape.is_empty() => { + log::debug!("Constant as scalar for {}", node.name); + ArgType::Scalar(tensor.elem_type()) + } + AttributeValue::Tensor(tensor) => { + log::debug!( + "Constant tensor with rank {} for {}", + tensor.shape.len(), + node.name + ); + ArgType::Tensor(TensorType { + elem_type: tensor.elem_type(), + rank: tensor.shape.len(), + static_shape: None, + }) + } + AttributeValue::Float32(_) => { + log::debug!("Constant Float32 scalar for {}", node.name); + ArgType::Scalar(ElementType::Float32) + } + AttributeValue::Float32s(_) => { + log::debug!("Constant Float32s tensor with rank 1 for {}", node.name); + ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 1, + static_shape: None, + }) + } + AttributeValue::Int64(_) => { + log::debug!("Constant Int64 scalar for {}", node.name); + ArgType::Scalar(ElementType::Int64) + } + AttributeValue::Int64s(_) => { + log::debug!("Constant Int64s tensor with rank 1 for {}", node.name); + ArgType::Tensor(TensorType { + elem_type: ElementType::Int64, + rank: 1, + static_shape: None, + }) + } + ty => panic!("Constant value of {:?} is not supported", ty), + }, + None => panic!("Constant node must have a value attribute"), + }; +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::{Argument, NodeType, TensorData}; + use std::collections::HashMap; + + fn create_test_node() -> Node { + let inputs = vec![]; + + let attrs = HashMap::new(); + // Empty attrs initially + + Node { + node_type: NodeType::Constant, + name: "test_constant".to_string(), + inputs, + outputs: vec![Argument { + name: "output".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, // This will be overwritten + rank: 0, + static_shape: None, + }), + value: None, + passed: true, + }], + attrs, + } + } + + #[test] + fn test_constant_scalar_float() { + let mut node = create_test_node(); + node.attrs + .insert("value_float".to_string(), AttributeValue::Float32(6.14)); + + constant_update_outputs(&mut node); + + match &node.outputs[0].ty { + ArgType::Scalar(elem_type) => { + assert_eq!(*elem_type, ElementType::Float32); + } + _ => panic!("Expected scalar output"), + } + } + + #[test] + fn test_constant_tensor() { + let mut node = create_test_node(); + node.attrs.insert( + "value".to_string(), + AttributeValue::Tensor(TensorData { + shape: vec![2, 3], + data: crate::ir::Data::Float32s(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]), + }), + ); + + constant_update_outputs(&mut node); + + match &node.outputs[0].ty { + ArgType::Tensor(tensor) => { + assert_eq!(tensor.elem_type, ElementType::Float32); + assert_eq!(tensor.rank, 2); + } + _ => panic!("Expected tensor output"), + } + } + + #[test] + #[should_panic(expected = "Constant node must have a value attribute")] + fn test_constant_missing_value() { + let mut node = create_test_node(); + constant_update_outputs(&mut node); + } +} diff --git a/crates/onnx-ir/src/node/constant_of_shape.rs b/crates/onnx-ir/src/node/constant_of_shape.rs new file mode 100644 index 0000000000..a97c807397 --- /dev/null +++ b/crates/onnx-ir/src/node/constant_of_shape.rs @@ -0,0 +1,162 @@ +use crate::ir::{ArgType, ElementType, Node, TensorType}; + +/// Updates the output rank for a ConstantOfShape node based on the rank of its input. +pub fn constant_of_shape_update_output(node: &mut Node) { + log::debug!("ConstantOfShape rank inference for node {}", node.name); + + let value_type = node + .attrs + .get("value") + .map(|v| v.clone().into_tensor().elem_type()) + .unwrap_or(ElementType::Float32); // If not given, defaults to 0 as float32 + log::debug!( + "ConstantOfShape value type for {}: {:?}", + node.name, + value_type + ); + + let rank = match &node.inputs[0].ty { + ArgType::Shape(rank) => { + log::debug!( + "ConstantOfShape input is Shape with rank {} for {}", + rank, + node.name + ); + *rank + } + ArgType::Tensor(tensor_type) => { + log::debug!("ConstantOfShape input is Tensor for {}", node.name); + let r = tensor_type + .static_shape + .as_ref() + .and_then(|shape| shape.first()) + .copied() + .expect( + "ConstantOfShape node must have a Tensor with a non-empty static shape value", + ); + log::debug!( + "ConstantOfShape derived rank from tensor: {} for {}", + r, + node.name + ); + r + } + _ => panic!("ConstantOfShape node requires a Tensor or Shape type as input"), + }; + + // Update the input type to be a shape + node.inputs[0].ty = ArgType::Shape(rank); + log::debug!( + "ConstantOfShape updated input to Shape({}) for {}", + rank, + node.name + ); + + node.outputs[0].ty = ArgType::Tensor(TensorType { + elem_type: value_type, + rank, + static_shape: None, + }); + log::debug!("ConstantOfShape output rank for {}: {}", node.name, rank); +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::{Argument, AttributeValue, Data, NodeType, TensorData}; + use std::collections::HashMap; + + fn create_test_node(input_ty: ArgType) -> Node { + let inputs = vec![Argument { + name: "shape".to_string(), + ty: input_ty, + value: None, + passed: true, + }]; + + let outputs = vec![Argument { + name: "output".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, // Will be updated + rank: 0, // Will be updated + static_shape: None, + }), + value: None, + passed: true, + }]; + + let attrs = HashMap::new(); + // Default value attribute not set initially + + Node { + node_type: NodeType::ConstantOfShape, + name: "test_constantofshape".to_string(), + inputs, + outputs, + attrs, + } + } + + #[test] + fn test_shape_input() { + let mut node = create_test_node(ArgType::Shape(3)); + + constant_of_shape_update_output(&mut node); + + match &node.outputs[0].ty { + ArgType::Tensor(tensor) => { + assert_eq!(tensor.elem_type, ElementType::Float32); + assert_eq!(tensor.rank, 3); + } + _ => panic!("Expected tensor output"), + } + } + + #[test] + fn test_tensor_input_with_static_shape() { + let mut node = create_test_node(ArgType::Tensor(TensorType { + elem_type: ElementType::Int64, + rank: 1, + static_shape: Some(vec![4]), + })); + + constant_of_shape_update_output(&mut node); + + match &node.outputs[0].ty { + ArgType::Tensor(tensor) => { + assert_eq!(tensor.elem_type, ElementType::Float32); + assert_eq!(tensor.rank, 4); + } + _ => panic!("Expected tensor output"), + } + } + + #[test] + fn test_custom_value_type() { + let mut node = create_test_node(ArgType::Shape(2)); + node.attrs.insert( + "value".to_string(), + AttributeValue::Tensor(TensorData { + shape: vec![], + data: Data::Int64s(vec![7]), // Int64 value + }), + ); + + constant_of_shape_update_output(&mut node); + + match &node.outputs[0].ty { + ArgType::Tensor(tensor) => { + assert_eq!(tensor.elem_type, ElementType::Int64); + assert_eq!(tensor.rank, 2); + } + _ => panic!("Expected tensor output"), + } + } + + #[test] + #[should_panic(expected = "ConstantOfShape node requires a Tensor or Shape type as input")] + fn test_invalid_input_type() { + let mut node = create_test_node(ArgType::Scalar(ElementType::Float32)); + constant_of_shape_update_output(&mut node); + } +} diff --git a/crates/onnx-ir/src/node/expand.rs b/crates/onnx-ir/src/node/expand.rs new file mode 100644 index 0000000000..7e60f2a22b --- /dev/null +++ b/crates/onnx-ir/src/node/expand.rs @@ -0,0 +1,154 @@ +use crate::ir::{ArgType, Data, Node, TensorType}; + +/// Updates the output rank and shape for the Expand operation based on the provided shape input. +/// If the shape is a constant, the rank and static shape of the output are set accordingly. +/// If the shape is dynamic, the rank is inferred from the static shape of the shape input. +pub fn expand_update_outputs(node: &mut Node) { + let shape = if node.inputs.len() == 2 { + match &node.inputs[1].value { + Some(value) => match &value.data { + Data::Int64s(shape) => Some(shape.clone()), + _ => panic!("Expand operation encountered invalid input types"), + }, + None => None, + } + } else { + panic!("Expand operation requires exactly two inputs"); + }; + + let output = match &node.outputs[0].ty { + ArgType::Tensor(tensor) => tensor.clone(), + _ => panic!("Expand operation encountered invalid output types"), + }; + + if let Some(shape) = shape { + node.outputs[0].ty = ArgType::Tensor(TensorType { + rank: shape.len(), + static_shape: Some(shape.into_iter().map(|dim| dim as usize).collect()), + ..output + }); + } else { + // When the shape cannot be determined statically (i.e., the second argument 'shape' is passed dynamically), + // infer the rank from the static shape of the input tensor. + let output_rank = match &node.inputs[1].ty { + ArgType::Tensor(tensor) => tensor + .static_shape + .as_ref() + .expect("Shape input must have a static shape defined") + .first() + .copied() + .expect("Static shape must contain at least one element"), + ArgType::Shape(rank) => *rank, + _ => panic!("Shape input must be of tensor or shape type",), + }; + + node.outputs[0].ty = ArgType::Tensor(TensorType { + rank: output_rank, + static_shape: None, // The exact shape cannot be determined statically + ..output + }); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::{Argument, ElementType, NodeType, TensorData}; + use std::collections::HashMap; + + fn create_test_node(input_rank: usize, shape_value: Option>) -> Node { + let inputs = vec![ + Argument { + name: "input".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: input_rank, + static_shape: None, + }), + value: None, + passed: true, + }, + Argument { + name: "shape".to_string(), + ty: if shape_value.is_some() { + ArgType::Tensor(TensorType { + elem_type: ElementType::Int64, + rank: 1, + static_shape: Some(vec![shape_value.as_ref().unwrap().len()]), + }) + } else { + ArgType::Tensor(TensorType { + elem_type: ElementType::Int64, + rank: 1, + static_shape: Some(vec![3]), // Example: a shape with 3 dimensions + }) + }, + value: shape_value.map(|shape| TensorData { + shape: vec![shape.len()], + data: Data::Int64s(shape), + }), + passed: true, + }, + ]; + + let outputs = vec![Argument { + name: "output".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 0, // Will be updated + static_shape: None, + }), + value: None, + passed: true, + }]; + + Node { + node_type: NodeType::Expand, + name: "test_expand".to_string(), + inputs, + outputs, + attrs: HashMap::new(), + } + } + + #[test] + fn test_expand_with_constant_shape() { + let mut node = create_test_node(2, Some(vec![2, 3, 4])); + + expand_update_outputs(&mut node); + + match &node.outputs[0].ty { + ArgType::Tensor(tensor) => { + assert_eq!(tensor.elem_type, ElementType::Float32); + assert_eq!(tensor.rank, 3); + assert_eq!(tensor.static_shape, Some(vec![2, 3, 4])); + } + _ => panic!("Expected tensor output"), + } + } + + #[test] + fn test_expand_with_dynamic_shape() { + let mut node = create_test_node(2, None); + + expand_update_outputs(&mut node); + + match &node.outputs[0].ty { + ArgType::Tensor(tensor) => { + assert_eq!(tensor.elem_type, ElementType::Float32); + assert_eq!(tensor.rank, 3); + assert_eq!(tensor.static_shape, None); + } + _ => panic!("Expected tensor output"), + } + } + + #[test] + #[should_panic(expected = "Expand operation requires exactly two inputs")] + fn test_expand_with_incorrect_inputs() { + let mut node = create_test_node(2, Some(vec![2, 3, 4])); + node.inputs.pop(); // Remove one input + + expand_update_outputs(&mut node); + } +} diff --git a/crates/onnx-ir/src/node/flatten.rs b/crates/onnx-ir/src/node/flatten.rs index 373c2a560f..99073c9b98 100644 --- a/crates/onnx-ir/src/node/flatten.rs +++ b/crates/onnx-ir/src/node/flatten.rs @@ -1,4 +1,25 @@ -use crate::ir::{ArgType, Node}; +use crate::ir::{ArgType, Node, TensorType}; + +/// Update output type for Flatten operation (rank 2). +pub fn flatten_update_outputs(node: &mut Node) { + if node.inputs.len() != 1 { + panic!("Flatten: multiple inputs are not supported"); + } + let tensor = node + .inputs + .iter() + .find_map(|input| match &input.ty { + ArgType::Tensor(tensor) => Some(tensor), + _ => None, + }) + .unwrap(); + + // Flatten to a 2D tensor + node.outputs[0].ty = ArgType::Tensor(TensorType { + rank: 2, + ..tensor.clone() + }); +} /// Create a FlattenConfig from the attributes of the node pub fn flatten_config(curr: &Node) -> usize { diff --git a/crates/onnx-ir/src/node/gather.rs b/crates/onnx-ir/src/node/gather.rs index f4a63756cf..cf80da01d6 100644 --- a/crates/onnx-ir/src/node/gather.rs +++ b/crates/onnx-ir/src/node/gather.rs @@ -1,4 +1,77 @@ -use crate::ir::{ArgType, Node}; +use crate::ir::{ArgType, ElementType, Node, TensorType}; + +/// Update output rank for Gather based on input and indices ranks. +pub fn gather_update_outputs(node: &mut Node) { + log::debug!("Gather rank inference for node {}", node.name); + + if node.inputs.len() != 2 { + panic!("Gather requires two inputs: data and indices"); + } + + let indices_rank = match &node.inputs[1].ty { + ArgType::Tensor(tensor) => tensor.rank, + ArgType::Scalar(_) => 0, + _ => panic!("Only tensor indices is valid, got {:?}", node.inputs[1].ty), + }; + log::debug!("Gather indices rank for {}: {}", node.name, indices_rank); + + match &node.inputs[0].ty { + ArgType::Tensor(input_tensor) => { + log::debug!( + "Gather input tensor rank for {}: {}", + node.name, + input_tensor.rank + ); + // Output of rank q+(r-1), where q is rank of indices tensor and r is rank of input + let output_rank = indices_rank + input_tensor.rank - 1; + log::debug!("Gather output rank for {}: {}", node.name, output_rank); + + if output_rank == 0 { + node.outputs[0].ty = ArgType::Scalar(input_tensor.elem_type.clone()); + log::debug!("Gather result for {} is scalar", node.name); + } else { + node.outputs[0].ty = ArgType::Tensor(TensorType { + elem_type: input_tensor.elem_type.clone(), + rank: output_rank, + static_shape: None, + }); + log::debug!( + "Gather result for {} is tensor with rank {}", + node.name, + output_rank + ); + } + } + ArgType::Shape(_) => { + log::debug!("Gather input is shape for {}", node.name); + let shape_rank = 1; + // Output of rank q+(r-1), where q is rank of indices tensor and r is rank of input + let output_rank = indices_rank + shape_rank - 1; + log::debug!( + "Gather output rank for {} with shape input: {}", + node.name, + output_rank + ); + + if output_rank == 0 { + node.outputs[0].ty = ArgType::Scalar(ElementType::Int64); + log::debug!("Gather result for {} is scalar (from shape)", node.name); + } else { + node.outputs[0].ty = ArgType::Tensor(TensorType { + elem_type: ElementType::Int64, + rank: output_rank, + static_shape: None, + }); + log::debug!( + "Gather result for {} is tensor with rank {} (from shape)", + node.name, + output_rank + ); + } + } + ty => panic!("Only tensor/shape input is valid but received: {:?}", ty), + } +} /// Create a GatherConfig from the attributes of the node pub fn gather_config(curr: &Node) -> usize { diff --git a/crates/onnx-ir/src/node/gemm.rs b/crates/onnx-ir/src/node/gemm.rs index ef1c5dc90b..4424b66d16 100644 --- a/crates/onnx-ir/src/node/gemm.rs +++ b/crates/onnx-ir/src/node/gemm.rs @@ -1,4 +1,38 @@ -use crate::ir::Node; +use crate::ir::{ArgType, Node, TensorType}; +use core::cmp::max; + +/// Update output shape for Gemm operation based on input ranks. +pub fn gemm_output_shape(node: &mut Node) { + log::debug!("Gemm rank inference for node {}", node.name); + + let a_rank = match &node.inputs[0].ty { + ArgType::Tensor(tensor) => tensor.rank, + _ => panic!("Input A should be a tensor!"), + }; + let b_rank = match &node.inputs[1].ty { + ArgType::Tensor(tensor) => tensor.rank, + _ => panic!("Input B should be a tensor!"), + }; + + log::debug!( + "Gemm input ranks for {}: a_rank={}, b_rank={}", + node.name, + a_rank, + b_rank + ); + + let output_rank = max(a_rank, b_rank); + log::debug!("Gemm output rank for {}: {}", node.name, output_rank); + + node.outputs[0].ty = ArgType::Tensor(TensorType { + rank: output_rank, + static_shape: None, + elem_type: match &node.inputs[0].ty { + ArgType::Tensor(t) => t.elem_type.clone(), + _ => panic!("Unexpected type for input A"), + }, + }); +} pub fn gemm_config(curr: &Node) -> (f32, f32, i64, i64) { let alpha = curr diff --git a/crates/onnx-ir/src/node/linear.rs b/crates/onnx-ir/src/node/linear.rs index 59cb043ff8..1aceeb3f91 100644 --- a/crates/onnx-ir/src/node/linear.rs +++ b/crates/onnx-ir/src/node/linear.rs @@ -1,6 +1,25 @@ -use crate::ir::Node; +use crate::ir::{ArgType, Node, TensorType}; use burn::nn::LinearConfig; +/// Update output rank for Linear operations (same as input rank). +pub fn linear_update_outputs(node: &mut Node) { + log::debug!("Linear rank inference for node {}", node.name); + + if let ArgType::Tensor(tensor) = &node.inputs[0].ty { + log::debug!("Linear input rank for {}: {}", node.name, tensor.rank); + + node.outputs[0].ty = ArgType::Tensor(TensorType { + elem_type: tensor.elem_type.clone(), + rank: tensor.rank, + static_shape: None, + }); + + log::debug!("Linear output rank for {}: {}", node.name, tensor.rank); + } else { + panic!("Only tensor input is valid"); + } +} + /// Create a LinearConfig from the attributes of the node pub fn linear_config(node: &Node) -> LinearConfig { if node.inputs.len() < 2 { diff --git a/crates/onnx-ir/src/node/matmul.rs b/crates/onnx-ir/src/node/matmul.rs new file mode 100644 index 0000000000..3cfc7de62f --- /dev/null +++ b/crates/onnx-ir/src/node/matmul.rs @@ -0,0 +1,139 @@ +use crate::ir::{ArgType, Node, TensorType}; +use core::cmp::max; + +/// Update output rank for MatMul based on input ranks. +pub fn matmul_update_outputs(node: &mut Node) { + log::debug!("MatMul rank inference for node {}", node.name); + + match (&node.inputs[0].ty, &node.inputs[1].ty) { + (ArgType::Tensor(a), ArgType::Tensor(b)) => { + log::debug!( + "MatMul input ranks for {}: a.rank={}, b.rank={}", + node.name, + a.rank, + b.rank + ); + + let mut out_rank = max(a.rank, b.rank); + if (a.rank >= 2 && b.rank == 1) || (a.rank == 1 && b.rank >= 2) { + out_rank -= 1; + log::debug!( + "MatMul special case for node {}: reducing output rank", + node.name + ); + } + + node.outputs[0].ty = ArgType::Tensor(TensorType { + elem_type: a.elem_type.clone(), + rank: out_rank, + static_shape: None, + }); + + log::debug!("MatMul output rank for {}: {}", node.name, out_rank); + } + _ => panic!("Only tensor inputs are valid"), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::{Argument, ElementType, NodeType}; + use std::collections::HashMap; + + fn create_test_node(a_rank: usize, b_rank: usize) -> Node { + let inputs = vec![ + Argument { + name: "A".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: a_rank, + static_shape: None, + }), + value: None, + passed: true, + }, + Argument { + name: "B".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: b_rank, + static_shape: None, + }), + value: None, + passed: true, + }, + ]; + + let outputs = vec![Argument { + name: "C".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 0, // Will be updated + static_shape: None, + }), + value: None, + passed: true, + }]; + + Node { + node_type: NodeType::MatMul, + name: "test_matmul".to_string(), + inputs, + outputs, + attrs: HashMap::new(), + } + } + + #[test] + fn test_matmul_standard_case() { + let mut node = create_test_node(2, 2); + matmul_update_outputs(&mut node); + + match &node.outputs[0].ty { + ArgType::Tensor(tensor) => { + assert_eq!(tensor.elem_type, ElementType::Float32); + assert_eq!(tensor.rank, 2); + } + _ => panic!("Expected tensor output"), + } + } + + #[test] + fn test_matmul_broadcasting() { + let mut node = create_test_node(3, 2); + matmul_update_outputs(&mut node); + + match &node.outputs[0].ty { + ArgType::Tensor(tensor) => { + assert_eq!(tensor.elem_type, ElementType::Float32); + assert_eq!(tensor.rank, 3); + } + _ => panic!("Expected tensor output"), + } + } + + #[test] + fn test_matmul_vector_matrix() { + // When multiplying a vector (rank 1) by a matrix (rank 2) + // the result should have rank 1 (vector) + let mut node = create_test_node(1, 2); + matmul_update_outputs(&mut node); + + match &node.outputs[0].ty { + ArgType::Tensor(tensor) => { + assert_eq!(tensor.elem_type, ElementType::Float32); + assert_eq!(tensor.rank, 1); + } + _ => panic!("Expected tensor output"), + } + } + + #[test] + #[should_panic(expected = "Only tensor inputs are valid")] + fn test_matmul_invalid_input() { + let mut node = create_test_node(2, 2); + node.inputs[0].ty = ArgType::Scalar(ElementType::Float32); + matmul_update_outputs(&mut node); + } +} diff --git a/crates/onnx-ir/src/node/mod.rs b/crates/onnx-ir/src/node/mod.rs index 4e5ca6b836..a0c1eac985 100644 --- a/crates/onnx-ir/src/node/mod.rs +++ b/crates/onnx-ir/src/node/mod.rs @@ -2,8 +2,12 @@ pub mod argmax; pub mod avg_pool1d; pub mod avg_pool2d; pub mod batch_norm; +pub mod cast; pub mod clip; +pub mod comparison; pub mod concat; +pub mod constant; +pub mod constant_of_shape; pub mod conv1d; pub mod conv2d; pub mod conv3d; @@ -11,6 +15,7 @@ pub mod conv_transpose1d; pub mod conv_transpose2d; pub mod conv_transpose3d; pub mod dropout; +pub mod expand; pub mod flatten; pub mod gather; pub mod gemm; @@ -19,9 +24,13 @@ pub mod layer_norm; pub mod leaky_relu; pub mod linear; pub mod log_softmax; +pub mod matmul; pub mod max_pool1d; pub mod max_pool2d; pub mod one_hot; +pub mod random; +pub mod random_like; +pub mod range; pub mod reduce_max; pub mod reduce_mean; pub mod reduce_min; @@ -32,5 +41,9 @@ pub mod resize; pub mod shape; pub mod slice; pub mod softmax; +pub mod split; pub mod squeeze; +pub mod topk; pub mod transpose; +pub mod unsqueeze; +pub mod where_op; diff --git a/crates/onnx-ir/src/node/random.rs b/crates/onnx-ir/src/node/random.rs new file mode 100644 index 0000000000..32ada33208 --- /dev/null +++ b/crates/onnx-ir/src/node/random.rs @@ -0,0 +1,113 @@ +use crate::ir::{ArgType, ElementType, Node, TensorType}; +use crate::protos::tensor_proto::DataType; +use protobuf::Enum; + +/// Update output rank for Random operations with explicit shape attribute. +pub fn random_update_output(node: &mut Node) { + log::debug!("Random rank inference for node {}", node.name); + + let dtype = node + .attrs + .get("dtype") + .map(|val| DataType::from_i32(val.clone().into_i32()).unwrap()) + .unwrap_or(DataType::FLOAT); + log::debug!("Random dtype for {}: {:?}", node.name, dtype); + + let shape = node + .attrs + .get("shape") + .expect("required shape attribute missing") + .clone() + .into_i64s(); + log::debug!("Random shape for {}: {:?}", node.name, shape); + + let elem_type = match dtype { + DataType::FLOAT => ElementType::Float32, + DataType::DOUBLE => ElementType::Float64, + _ => panic!("tensor with type {dtype:?} not supported for random output"), + }; + + let rank = shape.len(); + log::debug!("Random output rank for {}: {}", node.name, rank); + + node.outputs[0].ty = ArgType::Tensor(TensorType { + elem_type, + rank, + static_shape: None, + }); +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::{Argument, AttributeValue, NodeType}; + use std::collections::HashMap; + + fn create_test_node(dtype: i32, shape: Vec) -> Node { + let mut attrs = HashMap::new(); + attrs.insert("dtype".to_string(), AttributeValue::Int64(dtype as i64)); + attrs.insert("shape".to_string(), AttributeValue::Int64s(shape.clone())); + + let outputs = vec![Argument { + name: "output".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, // Will be updated + rank: 0, // Will be updated + static_shape: None, + }), + value: None, + passed: true, + }]; + + Node { + node_type: NodeType::RandomNormal, + name: "test_random".to_string(), + inputs: vec![], + outputs, + attrs, + } + } + + #[test] + fn test_random_normal_float() { + let mut node = create_test_node(DataType::FLOAT.value(), vec![2, 3, 4]); + random_update_output(&mut node); + + match &node.outputs[0].ty { + ArgType::Tensor(tensor) => { + assert_eq!(tensor.elem_type, ElementType::Float32); + assert_eq!(tensor.rank, 3); + } + _ => panic!("Expected tensor output"), + } + } + + #[test] + fn test_random_normal_double() { + let mut node = create_test_node(DataType::DOUBLE.value(), vec![5]); + random_update_output(&mut node); + + match &node.outputs[0].ty { + ArgType::Tensor(tensor) => { + assert_eq!(tensor.elem_type, ElementType::Float64); + assert_eq!(tensor.rank, 1); + } + _ => panic!("Expected tensor output"), + } + } + + #[test] + #[should_panic(expected = "required shape attribute missing")] + fn test_random_normal_missing_shape() { + let mut node = create_test_node(DataType::FLOAT.value(), vec![2, 3]); + node.attrs.remove("shape"); + random_update_output(&mut node); + } + + #[test] + #[should_panic(expected = "tensor with type INT32 not supported for random output")] + fn test_random_normal_unsupported_type() { + let mut node = create_test_node(DataType::INT32.value(), vec![2, 3]); + random_update_output(&mut node); + } +} diff --git a/crates/onnx-ir/src/node/random_like.rs b/crates/onnx-ir/src/node/random_like.rs new file mode 100644 index 0000000000..c9ccb01acc --- /dev/null +++ b/crates/onnx-ir/src/node/random_like.rs @@ -0,0 +1,122 @@ +use crate::ir::{ArgType, ElementType, Node, TensorType}; +use crate::protos::tensor_proto::DataType; +use protobuf::Enum; + +/// Update output rank for RandomLike operations based on input rank. +pub fn random_like_update_output(node: &mut Node) { + log::debug!("RandomLike rank inference for node {}", node.name); + + let dtype = node + .attrs + .get("dtype") + .map(|val| DataType::from_i32(val.clone().into_i32()).unwrap()) + .unwrap_or(DataType::FLOAT); + log::debug!("RandomLike dtype for {}: {:?}", node.name, dtype); + + let elem_type = match dtype { + DataType::FLOAT => ElementType::Float32, + DataType::FLOAT16 => ElementType::Float16, + DataType::DOUBLE => ElementType::Float64, + _ => panic!("Tensor with type {dtype:?} not supported for random output"), + }; + + if let ArgType::Tensor(tensor) = &node.inputs[0].ty { + log::debug!("RandomLike input rank for {}: {}", node.name, tensor.rank); + + node.outputs[0].ty = ArgType::Tensor(TensorType { + elem_type, + rank: tensor.rank, + static_shape: tensor.static_shape.clone(), + }); + + log::debug!("RandomLike output rank for {}: {}", node.name, tensor.rank); + } else { + panic!("Only tensor input is valid"); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::{Argument, AttributeValue, NodeType}; + use std::collections::HashMap; + + fn create_test_node(dtype: i32, input_rank: usize, static_shape: Option>) -> Node { + let mut attrs = HashMap::new(); + attrs.insert("dtype".to_string(), AttributeValue::Int64(dtype as i64)); + + let inputs = vec![Argument { + name: "input".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: input_rank, + static_shape, + }), + value: None, + passed: true, + }]; + + let outputs = vec![Argument { + name: "output".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, // Will be updated + rank: 0, // Will be updated + static_shape: None, + }), + value: None, + passed: true, + }]; + + Node { + node_type: NodeType::RandomNormalLike, + name: "test_random_like".to_string(), + inputs, + outputs, + attrs, + } + } + + #[test] + fn test_random_like_float() { + let mut node = create_test_node(DataType::FLOAT.value(), 3, None); + random_like_update_output(&mut node); + + match &node.outputs[0].ty { + ArgType::Tensor(tensor) => { + assert_eq!(tensor.elem_type, ElementType::Float32); + assert_eq!(tensor.rank, 3); + } + _ => panic!("Expected tensor output"), + } + } + + #[test] + fn test_random_like_double() { + let mut node = create_test_node(DataType::DOUBLE.value(), 2, Some(vec![5, 10])); + random_like_update_output(&mut node); + + match &node.outputs[0].ty { + ArgType::Tensor(tensor) => { + assert_eq!(tensor.elem_type, ElementType::Float64); + assert_eq!(tensor.rank, 2); + assert_eq!(tensor.static_shape, Some(vec![5, 10])); + } + _ => panic!("Expected tensor output"), + } + } + + #[test] + #[should_panic(expected = "Only tensor input is valid")] + fn test_random_like_invalid_input() { + let mut node = create_test_node(DataType::FLOAT.value(), 2, None); + node.inputs[0].ty = ArgType::Scalar(ElementType::Float32); + random_like_update_output(&mut node); + } + + #[test] + #[should_panic(expected = "Tensor with type INT32 not supported for random output")] + fn test_random_like_unsupported_type() { + let mut node = create_test_node(DataType::INT32.value(), 2, None); + random_like_update_output(&mut node); + } +} diff --git a/crates/onnx-ir/src/node/range.rs b/crates/onnx-ir/src/node/range.rs new file mode 100644 index 0000000000..8fe97e43f5 --- /dev/null +++ b/crates/onnx-ir/src/node/range.rs @@ -0,0 +1,93 @@ +use crate::ir::{ArgType, ElementType, Node, TensorType}; + +/// Update output rank for Range (always rank 1). +pub fn range_update_outputs(node: &mut Node) { + log::debug!("Range rank inference for node {}", node.name); + + if node.inputs.len() != 3 { + panic!("Range: expected 3 inputs, found {}", node.inputs.len()); + } + log::debug!( + "Range operation always produces rank 1 tensor for {}", + node.name + ); + + node.outputs[0].ty = ArgType::Tensor(TensorType { + elem_type: ElementType::Int64, + rank: 1, + static_shape: None, + }); + + log::debug!("Range output rank for {}: 1", node.name); +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::{Argument, NodeType}; + use std::collections::HashMap; + + fn create_test_node() -> Node { + let inputs = vec![ + Argument { + name: "start".to_string(), + ty: ArgType::Scalar(ElementType::Int64), + value: None, + passed: true, + }, + Argument { + name: "limit".to_string(), + ty: ArgType::Scalar(ElementType::Int64), + value: None, + passed: true, + }, + Argument { + name: "delta".to_string(), + ty: ArgType::Scalar(ElementType::Int64), + value: None, + passed: true, + }, + ]; + + let outputs = vec![Argument { + name: "output".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Int64, + rank: 0, // Will be updated + static_shape: None, + }), + value: None, + passed: true, + }]; + + Node { + node_type: NodeType::Range, + name: "test_range".to_string(), + inputs, + outputs, + attrs: HashMap::new(), + } + } + + #[test] + fn test_range_output() { + let mut node = create_test_node(); + range_update_outputs(&mut node); + + match &node.outputs[0].ty { + ArgType::Tensor(tensor) => { + assert_eq!(tensor.elem_type, ElementType::Int64); + assert_eq!(tensor.rank, 1); + } + _ => panic!("Expected tensor output"), + } + } + + #[test] + #[should_panic(expected = "Range: expected 3 inputs, found 2")] + fn test_range_missing_inputs() { + let mut node = create_test_node(); + node.inputs.pop(); + range_update_outputs(&mut node); + } +} diff --git a/crates/onnx-ir/src/node/reshape.rs b/crates/onnx-ir/src/node/reshape.rs index 55aab093d6..142ee40dc7 100644 --- a/crates/onnx-ir/src/node/reshape.rs +++ b/crates/onnx-ir/src/node/reshape.rs @@ -1,4 +1,54 @@ -use crate::ir::{Node, TensorData}; +use crate::ir::{ArgType, Data, Node, TensorData, TensorType}; + +/// Update output rank for Reshape based on shape input if constant, otherwise use input rank. +pub fn reshape_update_outputs(node: &mut Node) { + log::debug!("Reshape rank inference for node {}", node.name); + + let shape = if node.inputs.len() == 2 { + log::debug!("Reshape node {} has shape as second input", node.name); + match &node.inputs[1].value { + Some(value) => match &value.data { + Data::Int64s(shape) => { + log::debug!("Reshape node {} has constant shape: {:?}", node.name, shape); + Some(shape.clone()) + } + _ => panic!("Reshape: invalid input types"), + }, + None => { + log::debug!( + "Reshape node {} has dynamic shape as second input", + node.name + ); + None + } + } + } else { + log::debug!("Reshape node {} using shape from attributes", node.name); + node.attrs.get("shape").cloned().map(|v| { + let shape = v.into_i64s(); + log::debug!("Reshape node {} shape attribute: {:?}", node.name, shape); + shape + }) + }; + + let output = match &node.outputs[0].ty { + ArgType::Tensor(tensor) => tensor.clone(), + _ => panic!("Reshape: invalid output types"), + }; + + let rank = match &shape { + Some(s) => s.len(), + None => output.rank, + }; + + log::debug!("Reshape output rank for node {}: {}", node.name, rank); + + node.outputs[0].ty = ArgType::Tensor(TensorType { + rank, + static_shape: None, + ..output + }); +} pub fn reshape_config(node: &Node) -> Vec { let mut allowzero = 0; diff --git a/crates/onnx-ir/src/node/split.rs b/crates/onnx-ir/src/node/split.rs new file mode 100644 index 0000000000..82514f5057 --- /dev/null +++ b/crates/onnx-ir/src/node/split.rs @@ -0,0 +1,110 @@ +use crate::ir::{ArgType, Node, TensorType}; + +/// Update output rank for Split (same as input). +pub fn split_update_outputs(node: &mut Node) { + log::debug!("Split rank inference for node {}", node.name); + + let tensor = match &node.inputs[0].ty { + ArgType::Tensor(tensor) => tensor, + _ => panic!("Split: Input must be a tensor"), + }; + log::debug!("Split input rank for {}: {}", node.name, tensor.rank); + log::debug!( + "Split will generate {} outputs for {}", + node.outputs.len(), + node.name + ); + + for (i, output_arg) in node.outputs.iter_mut().enumerate() { + output_arg.ty = ArgType::Tensor(TensorType { + elem_type: tensor.elem_type.clone(), + rank: tensor.rank, + static_shape: None, + }); + log::debug!("Split output {} rank for {}: {}", i, node.name, tensor.rank); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::{Argument, ElementType, NodeType}; + use std::collections::HashMap; + + fn create_test_node(input_rank: usize, num_outputs: usize) -> Node { + let inputs = vec![Argument { + name: "input".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: input_rank, + static_shape: None, + }), + value: None, + passed: true, + }]; + + let mut outputs = Vec::new(); + for i in 0..num_outputs { + outputs.push(Argument { + name: format!("output_{}", i), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 0, // Will be updated + static_shape: None, + }), + value: None, + passed: true, + }); + } + + let attrs = HashMap::new(); + + Node { + node_type: NodeType::Split, + name: "test_split".to_string(), + inputs, + outputs, + attrs, + } + } + + #[test] + fn test_split_single_output() { + let mut node = create_test_node(3, 1); + split_update_outputs(&mut node); + + assert_eq!(node.outputs.len(), 1); + match &node.outputs[0].ty { + ArgType::Tensor(tensor) => { + assert_eq!(tensor.elem_type, ElementType::Float32); + assert_eq!(tensor.rank, 3); + } + _ => panic!("Expected tensor output"), + } + } + + #[test] + fn test_split_multiple_outputs() { + let mut node = create_test_node(4, 3); + split_update_outputs(&mut node); + + assert_eq!(node.outputs.len(), 3); + for output in &node.outputs { + match &output.ty { + ArgType::Tensor(tensor) => { + assert_eq!(tensor.elem_type, ElementType::Float32); + assert_eq!(tensor.rank, 4); + } + _ => panic!("Expected tensor output"), + } + } + } + + #[test] + #[should_panic(expected = "Split: Input must be a tensor")] + fn test_split_invalid_input() { + let mut node = create_test_node(3, 2); + node.inputs[0].ty = ArgType::Scalar(ElementType::Float32); + split_update_outputs(&mut node); + } +} diff --git a/crates/onnx-ir/src/node/topk.rs b/crates/onnx-ir/src/node/topk.rs new file mode 100644 index 0000000000..e1a0ba429d --- /dev/null +++ b/crates/onnx-ir/src/node/topk.rs @@ -0,0 +1,124 @@ +use crate::ir::{ArgType, ElementType, Node, TensorType}; + +/// Update output rank for TopK (same as input rank). +pub fn top_k_update_output(node: &mut Node) { + log::debug!("TopK rank inference for node {}", node.name); + + let rank = match &node.inputs[0].ty { + ArgType::Tensor(tensor) => tensor.rank, + _ => panic!("TopK: invalid input type"), + }; + log::debug!("TopK input rank for {}: {}", node.name, rank); + + node.outputs[0].ty = ArgType::Tensor(TensorType { + elem_type: node.inputs[0].ty.elem_type().clone(), + rank, + static_shape: None, + }); + node.outputs[1].ty = ArgType::Tensor(TensorType { + elem_type: ElementType::Int64, + rank, + static_shape: None, + }); + + log::debug!( + "TopK output rank for {}: {} (both outputs)", + node.name, + rank + ); +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::{Argument, NodeType}; + use std::collections::HashMap; + + fn create_test_node(input_rank: usize) -> Node { + let inputs = vec![ + Argument { + name: "X".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: input_rank, + static_shape: None, + }), + value: None, + passed: true, + }, + Argument { + name: "K".to_string(), + ty: ArgType::Scalar(ElementType::Int64), + value: None, + passed: true, + }, + ]; + + let outputs = vec![ + Argument { + name: "Values".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 0, // Will be updated + static_shape: None, + }), + value: None, + passed: true, + }, + Argument { + name: "Indices".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Int64, + rank: 0, // Will be updated + static_shape: None, + }), + value: None, + passed: true, + }, + ]; + + let attrs = HashMap::new(); + + Node { + node_type: NodeType::TopK, + name: "test_topk".to_string(), + inputs, + outputs, + attrs, + } + } + + #[test] + fn test_topk_basic() { + let mut node = create_test_node(3); + top_k_update_output(&mut node); + + assert_eq!(node.outputs.len(), 2); + + // Check first output (values) + match &node.outputs[0].ty { + ArgType::Tensor(tensor) => { + assert_eq!(tensor.elem_type, ElementType::Float32); + assert_eq!(tensor.rank, 3); + } + _ => panic!("Expected tensor output for values"), + } + + // Check second output (indices) + match &node.outputs[1].ty { + ArgType::Tensor(tensor) => { + assert_eq!(tensor.elem_type, ElementType::Int64); + assert_eq!(tensor.rank, 3); + } + _ => panic!("Expected tensor output for indices"), + } + } + + #[test] + #[should_panic(expected = "TopK: invalid input type")] + fn test_topk_invalid_input() { + let mut node = create_test_node(3); + node.inputs[0].ty = ArgType::Scalar(ElementType::Float32); + top_k_update_output(&mut node); + } +} diff --git a/crates/onnx-ir/src/node/unsqueeze.rs b/crates/onnx-ir/src/node/unsqueeze.rs new file mode 100644 index 0000000000..8ea66adad6 --- /dev/null +++ b/crates/onnx-ir/src/node/unsqueeze.rs @@ -0,0 +1,202 @@ +use crate::ir::{ArgType, Data, Node, TensorType}; + +/// Update output rank for Unsqueeze based on axes. +/// Update the output tensor dimension based on the "axes" attribute or the second input +pub fn unsqueeze_update_output(node: &mut Node) { + log::debug!("Unsqueeze rank inference for node {}", node.name); + + let axes = if node.inputs.len() == 2 { + match &node.inputs[1].value { + Some(value) => match &value.data { + Data::Int64s(a) => Some(a.clone()), + _ => panic!("Unsqueeze: invalid input types"), + }, + None => None, + } + } else { + let axes = node.attrs.get("axes").cloned().map(|v| { + let axes = v.into_i64s(); + log::debug!( + "Unsqueeze axes from attribute for {}: {:?}", + node.name, + axes + ); + axes + }); + axes + }; + + let input_rank = match &node.inputs[0].ty { + ArgType::Tensor(tensor) => tensor.rank, + ArgType::Scalar(_) => { + 0 // treat scalar as 0-dim tensor + } + _ => panic!("Unsqueeze: invalid input type"), + }; + + let output_elem = match &node.outputs[0].ty { + ArgType::Tensor(_) => node.inputs[0].ty.elem_type().clone(), + ArgType::Scalar(elem_type) => elem_type.clone(), + _ => panic!("Unsqueeze: invalid output type"), + }; + + let output_rank = if let Some(axes) = axes { + input_rank + axes.len() + } else if let ArgType::Tensor(tensor) = &node.inputs[1].ty { + if let Some(static_shape) = &tensor.static_shape { + input_rank + *static_shape.first().expect("Empty shape") + } else { + panic!("Unsqueeze: should have static shape") + } + } else { + panic!("Unsqueeze: missing axes information") + }; + + node.outputs[0].ty = ArgType::Tensor(TensorType { + rank: output_rank, + static_shape: None, // shape is tracked and calculated at runtime + elem_type: output_elem, + }); + + log::debug!("Unsqueeze output rank for {}: {}", node.name, output_rank); +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::{Argument, AttributeValue, ElementType, NodeType, TensorData}; + use std::collections::HashMap; + + fn create_test_node_with_attr(input_rank: usize, axes: Vec) -> Node { + let inputs = vec![Argument { + name: "X".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: input_rank, + static_shape: None, + }), + value: None, + passed: true, + }]; + + let outputs = vec![Argument { + name: "Y".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 0, // Will be updated + static_shape: None, + }), + value: None, + passed: true, + }]; + + let mut attrs = HashMap::new(); + attrs.insert("axes".to_string(), AttributeValue::Int64s(axes.clone())); + + Node { + node_type: NodeType::Unsqueeze, + name: "test_unsqueeze".to_string(), + inputs, + outputs, + attrs, + } + } + + fn create_test_node_with_input(input_rank: usize, axes: Vec) -> Node { + let inputs = vec![ + Argument { + name: "X".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: input_rank, + static_shape: None, + }), + value: None, + passed: true, + }, + Argument { + name: "axes".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Int64, + rank: 1, + static_shape: Some(vec![axes.len()]), + }), + value: Some(TensorData { + data: Data::Int64s(axes), + shape: vec![1], + }), + passed: true, + }, + ]; + + let outputs = vec![Argument { + name: "Y".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 0, // Will be updated + static_shape: None, + }), + value: None, + passed: true, + }]; + + Node { + node_type: NodeType::Unsqueeze, + name: "test_unsqueeze".to_string(), + inputs, + outputs, + attrs: HashMap::new(), + } + } + + #[test] + fn test_unsqueeze_with_attr() { + let mut node = create_test_node_with_attr(2, vec![0, 3]); + unsqueeze_update_output(&mut node); + + match &node.outputs[0].ty { + ArgType::Tensor(tensor) => { + assert_eq!(tensor.elem_type, ElementType::Float32); + assert_eq!(tensor.rank, 4); // 2 + 2 = 4 + } + _ => panic!("Expected tensor output"), + } + } + + #[test] + fn test_unsqueeze_with_input() { + let mut node = create_test_node_with_input(3, vec![1, 2, 4]); + unsqueeze_update_output(&mut node); + + match &node.outputs[0].ty { + ArgType::Tensor(tensor) => { + assert_eq!(tensor.elem_type, ElementType::Float32); + assert_eq!(tensor.rank, 6); // 3 + 3 = 6 + } + _ => panic!("Expected tensor output"), + } + } + + #[test] + fn test_unsqueeze_scalar() { + let mut node = create_test_node_with_attr(0, vec![0]); + node.inputs[0].ty = ArgType::Scalar(ElementType::Float32); + unsqueeze_update_output(&mut node); + + match &node.outputs[0].ty { + ArgType::Tensor(tensor) => { + assert_eq!(tensor.elem_type, ElementType::Float32); + assert_eq!(tensor.rank, 1); // 0 + 1 = 1 + } + _ => panic!("Expected tensor output"), + } + } + + #[test] + #[should_panic(expected = "Unsqueeze: invalid input type")] + fn test_unsqueeze_invalid_input() { + let mut node = create_test_node_with_attr(2, vec![0]); + node.inputs[0].ty = ArgType::Shape(1); + unsqueeze_update_output(&mut node); + } +} diff --git a/crates/onnx-ir/src/node/where_op.rs b/crates/onnx-ir/src/node/where_op.rs new file mode 100644 index 0000000000..4a9b67767a --- /dev/null +++ b/crates/onnx-ir/src/node/where_op.rs @@ -0,0 +1,161 @@ +use crate::ir::{ArgType, ElementType, Node, TensorType}; +use core::cmp::max; + +/// Update output rank for Where to max input rank. +pub fn where_update_outputs(node: &mut Node) { + log::debug!("Where rank inference for node {}", node.name); + + let condition = &node.inputs[0].ty; + let x = &node.inputs[1].ty; + let y = &node.inputs[2].ty; + let elem_type = x.elem_type().clone(); + assert_eq!( + *condition.elem_type(), + ElementType::Bool, + "Where condition must be boolean!" + ); + assert_eq!( + elem_type, + *y.elem_type(), + "Where x and y have different element types!" + ); + + log::debug!( + "Where input ranks for {}: condition={}, x={}, y={}", + node.name, + condition.rank(), + x.rank(), + y.rank() + ); + + let output_rank = max(condition.rank(), max(x.rank(), y.rank())); + log::debug!("Where output rank for {}: {}", node.name, output_rank); + + if output_rank == 0 { + node.outputs[0].ty = ArgType::Scalar(elem_type); + log::debug!("Where result for {} is scalar", node.name); + } else { + node.outputs[0].ty = ArgType::Tensor(TensorType { + elem_type, + rank: output_rank, + static_shape: None, + }); + log::debug!( + "Where result for {} is tensor with rank {}", + node.name, + output_rank + ); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::{Argument, NodeType}; + use std::collections::HashMap; + + fn create_test_node(condition_rank: usize, x_rank: usize, y_rank: usize) -> Node { + let inputs = vec![ + Argument { + name: "condition".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Bool, + rank: condition_rank, + static_shape: None, + }), + value: None, + passed: true, + }, + Argument { + name: "X".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: x_rank, + static_shape: None, + }), + value: None, + passed: true, + }, + Argument { + name: "Y".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: y_rank, + static_shape: None, + }), + value: None, + passed: true, + }, + ]; + + let outputs = vec![Argument { + name: "output".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 0, // Will be updated + static_shape: None, + }), + value: None, + passed: true, + }]; + + Node { + node_type: NodeType::Where, + name: "test_where".to_string(), + inputs, + outputs, + attrs: HashMap::new(), + } + } + + #[test] + fn test_where_basic() { + let mut node = create_test_node(2, 3, 2); + where_update_outputs(&mut node); + + match &node.outputs[0].ty { + ArgType::Tensor(tensor) => { + assert_eq!(tensor.elem_type, ElementType::Float32); + assert_eq!(tensor.rank, 3); // max(2, max(3, 2)) = 3 + } + _ => panic!("Expected tensor output"), + } + } + + #[test] + fn test_where_scalar_result() { + let mut node = create_test_node(0, 0, 0); + where_update_outputs(&mut node); + + match &node.outputs[0].ty { + ArgType::Scalar(elem_type) => { + assert_eq!(*elem_type, ElementType::Float32); + } + _ => panic!("Expected scalar output"), + } + } + + #[test] + #[should_panic(expected = "Where condition must be boolean!")] + fn test_where_invalid_condition() { + let mut node = create_test_node(2, 2, 2); + node.inputs[0].ty = ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, // Not boolean + rank: 2, + static_shape: None, + }); + where_update_outputs(&mut node); + } + + #[test] + #[should_panic(expected = "Where x and y have different element types!")] + fn test_where_mismatched_types() { + let mut node = create_test_node(2, 2, 2); + node.inputs[2].ty = ArgType::Tensor(TensorType { + elem_type: ElementType::Int64, // Different from X + rank: 2, + static_shape: None, + }); + where_update_outputs(&mut node); + } +} diff --git a/crates/onnx-ir/src/rank_inference.rs b/crates/onnx-ir/src/rank_inference.rs index 9bab387e70..1e8adeedd8 100644 --- a/crates/onnx-ir/src/rank_inference.rs +++ b/crates/onnx-ir/src/rank_inference.rs @@ -1,18 +1,21 @@ -use core::cmp::max; -use core::panic; - -use protobuf::Enum; - use crate::{ - ir::{ArgType, AttributeValue, Data, ElementType, Node, NodeType, TensorType}, + ir::{Node, NodeType}, node::{ - one_hot::one_hot_output_shape, reduce_max::reduce_max_update_outputs, - reduce_mean::reduce_mean_update_outputs, reduce_min::reduce_min_update_outputs, - reduce_prod::reduce_prod_update_outputs, reduce_sum::reduce_sum_update_outputs, - shape::shape_update_outputs, slice::slice_update_output_rank, - squeeze::squeeze_update_output, + argmax::argmax_update_outputs, cast::cast_update_outputs, + comparison::elementwise_comparison_outputs, concat::concat_update_outputs, + constant::constant_update_outputs, constant_of_shape::constant_of_shape_update_output, + expand::expand_update_outputs, flatten::flatten_update_outputs, + gather::gather_update_outputs, gemm::gemm_output_shape, linear::linear_update_outputs, + matmul::matmul_update_outputs, one_hot::one_hot_output_shape, random::random_update_output, + random_like::random_like_update_output, range::range_update_outputs, + reduce_max::reduce_max_update_outputs, reduce_mean::reduce_mean_update_outputs, + reduce_min::reduce_min_update_outputs, reduce_prod::reduce_prod_update_outputs, + reduce_sum::reduce_sum_update_outputs, reshape::reshape_update_outputs, + shape::shape_update_outputs, slice::slice_update_output_rank, split::split_update_outputs, + squeeze::squeeze_update_output, topk::top_k_update_output, + unsqueeze::unsqueeze_update_output, where_op::where_update_outputs, }, - protos::tensor_proto::DataType, + util::{same_as_input, same_as_input_broadcast, temporary_pass_through_stub}, }; /// Infer the rank of each output tensor and update them based solely on rank inference. @@ -113,859 +116,3 @@ pub fn rank_inference(node: &mut Node) { node.outputs ); } - -/// Update output type for constant nodes based on attribute values, focusing on rank only. -fn constant_update_outputs(node: &mut Node) { - log::debug!("Constant rank inference for node {}", node.name); - - let keys = [ - "value", - "value_float", - "value_floats", - "value_int", - "value_ints", - "value_string", - "value_strings", - "sparse_value", - ]; - - let matched_value = keys.iter().find_map(|&key| node.attrs.get(key).cloned()); - log::debug!("Constant found attribute: {}", matched_value.is_some()); - - node.outputs[0].ty = match matched_value { - Some(value) => match &value { - AttributeValue::Tensor(tensor) if tensor.shape.is_empty() => { - log::debug!("Constant as scalar for {}", node.name); - ArgType::Scalar(tensor.elem_type()) - } - AttributeValue::Tensor(tensor) => { - log::debug!( - "Constant tensor with rank {} for {}", - tensor.shape.len(), - node.name - ); - ArgType::Tensor(TensorType { - elem_type: tensor.elem_type(), - rank: tensor.shape.len(), - static_shape: None, - }) - } - AttributeValue::Float32(_) => { - log::debug!("Constant Float32 scalar for {}", node.name); - ArgType::Scalar(ElementType::Float32) - } - AttributeValue::Float32s(_) => { - log::debug!("Constant Float32s tensor with rank 1 for {}", node.name); - ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: 1, - static_shape: None, - }) - } - AttributeValue::Int64(_) => { - log::debug!("Constant Int64 scalar for {}", node.name); - ArgType::Scalar(ElementType::Int64) - } - AttributeValue::Int64s(_) => { - log::debug!("Constant Int64s tensor with rank 1 for {}", node.name); - ArgType::Tensor(TensorType { - elem_type: ElementType::Int64, - rank: 1, - static_shape: None, - }) - } - ty => panic!("Constant value of {:?} is not supported", ty), - }, - None => panic!("Constant node must have a value attribute"), - }; -} - -/// Updates the output rank for a ConstantOfShape node based on the rank of its input. -fn constant_of_shape_update_output(node: &mut Node) { - log::debug!("ConstantOfShape rank inference for node {}", node.name); - - let value_type = node - .attrs - .get("value") - .map(|v| v.clone().into_tensor().elem_type()) - .unwrap_or(ElementType::Float32); // If not given, defaults to 0 as float32 - log::debug!( - "ConstantOfShape value type for {}: {:?}", - node.name, - value_type - ); - - let rank = match &node.inputs[0].ty { - ArgType::Shape(rank) => { - log::debug!( - "ConstantOfShape input is Shape with rank {} for {}", - rank, - node.name - ); - *rank - } - ArgType::Tensor(tensor_type) => { - log::debug!("ConstantOfShape input is Tensor for {}", node.name); - let r = tensor_type - .static_shape - .as_ref() - .and_then(|shape| shape.first()) - .copied() - .expect( - "ConstantOfShape node must have a Tensor with a non-empty static shape value", - ); - log::debug!( - "ConstantOfShape derived rank from tensor: {} for {}", - r, - node.name - ); - r - } - _ => panic!("ConstantOfShape node requires a Tensor or Shape type as input"), - }; - - // Update the input type to be a shape - node.inputs[0].ty = ArgType::Shape(rank); - log::debug!( - "ConstantOfShape updated input to Shape({}) for {}", - rank, - node.name - ); - - node.outputs[0].ty = ArgType::Tensor(TensorType { - elem_type: value_type, - rank, - static_shape: None, - }); - log::debug!("ConstantOfShape output rank for {}: {}", node.name, rank); -} - -/// Update output rank for Random operations with explicit shape attribute. -fn random_update_output(node: &mut Node) { - log::debug!("Random rank inference for node {}", node.name); - - let dtype = node - .attrs - .get("dtype") - .map(|val| DataType::from_i32(val.clone().into_i32()).unwrap()) - .unwrap_or(DataType::FLOAT); - log::debug!("Random dtype for {}: {:?}", node.name, dtype); - - let shape = node - .attrs - .get("shape") - .expect("required shape attribute missing") - .clone() - .into_i64s(); - log::debug!("Random shape for {}: {:?}", node.name, shape); - - let elem_type = match dtype { - DataType::FLOAT => ElementType::Float32, - DataType::DOUBLE => ElementType::Float64, - _ => panic!("tensor with type {dtype:?} not supported for random output"), - }; - - let rank = shape.len(); - log::debug!("Random output rank for {}: {}", node.name, rank); - - node.outputs[0].ty = ArgType::Tensor(TensorType { - elem_type, - rank, - static_shape: None, - }); -} - -/// Update output rank for RandomLike operations based on input rank. -fn random_like_update_output(node: &mut Node) { - log::debug!("RandomLike rank inference for node {}", node.name); - - let dtype = node - .attrs - .get("dtype") - .map(|val| DataType::from_i32(val.clone().into_i32()).unwrap()) - .unwrap_or(DataType::FLOAT); - log::debug!("RandomLike dtype for {}: {:?}", node.name, dtype); - - let elem_type = match dtype { - DataType::FLOAT => ElementType::Float32, - DataType::FLOAT16 => ElementType::Float16, - DataType::DOUBLE => ElementType::Float64, - _ => panic!("Tensor with type {dtype:?} not supported for random output"), - }; - - if let ArgType::Tensor(tensor) = &node.inputs[0].ty { - log::debug!("RandomLike input rank for {}: {}", node.name, tensor.rank); - - node.outputs[0].ty = ArgType::Tensor(TensorType { - elem_type, - rank: tensor.rank, - static_shape: tensor.static_shape.clone(), - }); - - log::debug!("RandomLike output rank for {}: {}", node.name, tensor.rank); - } else { - panic!("Only tensor input is valid"); - } -} - -/// Update output rank for Linear operations (same as input rank). -fn linear_update_outputs(node: &mut Node) { - log::debug!("Linear rank inference for node {}", node.name); - - if let ArgType::Tensor(tensor) = &node.inputs[0].ty { - log::debug!("Linear input rank for {}: {}", node.name, tensor.rank); - - node.outputs[0].ty = ArgType::Tensor(TensorType { - elem_type: tensor.elem_type.clone(), - rank: tensor.rank, - static_shape: None, - }); - - log::debug!("Linear output rank for {}: {}", node.name, tensor.rank); - } else { - panic!("Only tensor input is valid"); - } -} - -/// Update output type for Cast operations, preserving rank. -fn cast_update_outputs(node: &mut Node) { - if node.inputs.len() != 1 { - panic!("Cast: multiple inputs are not supported"); - } - let input = &mut node.inputs[0]; - let output = &mut node.outputs[0]; - - let elem_type = match node.attrs.get("to") { - Some(value) => match &value { - AttributeValue::Int64(type_id) => match DataType::from_i32(*type_id as i32).unwrap() { - DataType::FLOAT => ElementType::Float32, - DataType::INT32 => ElementType::Int32, - DataType::INT64 => ElementType::Int64, - DataType::DOUBLE => ElementType::Float64, - DataType::BOOL => ElementType::Bool, - _ => panic!("Cast: unsupported type"), - }, - _ => panic!("'to' attribute must be an Int64"), - }, - None => panic!("Cast node must have a 'to' attribute"), - }; - - match input.ty.clone() { - ArgType::Tensor(tensor) => { - if tensor.rank == 0 { - // treat 0-dim tensor as scalar - output.ty = ArgType::Scalar(elem_type); - input.ty = ArgType::Scalar(tensor.elem_type); - } else { - // Cast input and output are the same shape, but possibly different types - output.ty = ArgType::Tensor(TensorType { - elem_type, - rank: tensor.rank, - static_shape: None, - }); - } - } - ArgType::Scalar(_) => output.ty = ArgType::Scalar(elem_type), - _ => panic!("Cast: only scalar and tensor inputs are valid"), - } -} - -/// Update output rank for Concat (same as first tensor input). -fn concat_update_outputs(node: &mut Node) { - log::debug!("Concat rank inference for node {}", node.name); - - let tensor = node - .inputs - .iter() - .find_map(|input| match &input.ty { - ArgType::Tensor(tensor) => Some(tensor.clone()), - _ => None, - }) - .unwrap(); - - log::debug!("Concat using input rank for {}: {}", node.name, tensor.rank); - - node.outputs[0].ty = ArgType::Tensor(TensorType { - elem_type: tensor.elem_type, - rank: tensor.rank, - static_shape: None, - }); - - log::debug!("Concat output rank for {}: {}", node.name, tensor.rank); -} - -/// Update output rank for Reshape based on shape input if constant, otherwise use input rank. -fn reshape_update_outputs(node: &mut Node) { - log::debug!("Reshape rank inference for node {}", node.name); - - let shape = if node.inputs.len() == 2 { - log::debug!("Reshape node {} has shape as second input", node.name); - match &node.inputs[1].value { - Some(value) => match &value.data { - Data::Int64s(shape) => { - log::debug!("Reshape node {} has constant shape: {:?}", node.name, shape); - Some(shape.clone()) - } - _ => panic!("Reshape: invalid input types"), - }, - None => { - log::debug!( - "Reshape node {} has dynamic shape as second input", - node.name - ); - None - } - } - } else { - log::debug!("Reshape node {} using shape from attributes", node.name); - node.attrs.get("shape").cloned().map(|v| { - let shape = v.into_i64s(); - log::debug!("Reshape node {} shape attribute: {:?}", node.name, shape); - shape - }) - }; - - let output = match &node.outputs[0].ty { - ArgType::Tensor(tensor) => tensor.clone(), - _ => panic!("Reshape: invalid output types"), - }; - - let rank = match &shape { - Some(s) => s.len(), - None => output.rank, - }; - - log::debug!("Reshape output rank for node {}: {}", node.name, rank); - - node.outputs[0].ty = ArgType::Tensor(TensorType { - rank, - static_shape: None, - ..output - }); -} - -/// Update output rank for ArgMax (same as input rank). -fn argmax_update_outputs(node: &mut Node) { - log::debug!("ArgMax rank inference for node {}", node.name); - - if node.inputs.len() != 1 { - panic!("ArgMax: multiple inputs are not supported"); - } - let tensor = match &node.inputs[0].ty { - ArgType::Tensor(tensor) => tensor, - _ => panic!("Only tensor input is valid"), - }; - - log::debug!("ArgMax input rank for {}: {}", node.name, tensor.rank); - - // Note: argmax in burn does not support keepdims=false - node.outputs[0].ty = ArgType::Tensor(TensorType { - elem_type: ElementType::Int64, - rank: tensor.rank, - static_shape: None, - }); - - log::debug!("ArgMax output rank for {}: {}", node.name, tensor.rank); -} - -/// Update output rank for broadcasting operations (e.g., Add, Sub) to max input rank. -fn same_as_input_broadcast(node: &mut Node) { - log::debug!("Broadcasting operation for node {}", node.name); - - let max_rank = node.inputs.iter().fold(0, |acc, input| match &input.ty { - ArgType::Tensor(tensor) => acc.max(tensor.rank), - ArgType::Scalar(_) => acc, - _ => panic!("Unsupported input type for broadcasting operation"), - }); - - log::debug!("Max rank for broadcasting node {}: {}", node.name, max_rank); - - if max_rank == 0 { - node.outputs[0].ty = ArgType::Scalar(node.inputs[0].ty.elem_type().clone()); - log::debug!("Scalar result for node {}", node.name); - } else { - let elem_type = node - .inputs - .iter() - .find_map(|input| match &input.ty { - ArgType::Tensor(tensor) => Some(tensor.elem_type.clone()), - _ => None, - }) - .unwrap_or_else(|| node.inputs[0].ty.elem_type().clone()); - - node.outputs[0].ty = ArgType::Tensor(TensorType { - elem_type, - rank: max_rank, - static_shape: None, - // Removed call to set_broadcasting_output_shape - }); - log::debug!( - "Tensor result for node {} with rank {}", - node.name, - max_rank - ); - } -} - -/// Update output rank for Unsqueeze based on axes. -/// Update the output tensor dimension based on the "axes" attribute or the second input -fn unsqueeze_update_output(node: &mut Node) { - log::debug!("Unsqueeze rank inference for node {}", node.name); - - let axes = if node.inputs.len() == 2 { - match &node.inputs[1].value { - Some(value) => match &value.data { - Data::Int64s(a) => Some(a.clone()), - _ => panic!("Unsqueeze: invalid input types"), - }, - None => None, - } - } else { - let axes = node.attrs.get("axes").cloned().map(|v| { - let axes = v.into_i64s(); - log::debug!( - "Unsqueeze axes from attribute for {}: {:?}", - node.name, - axes - ); - axes - }); - axes - }; - - let input_rank = match &node.inputs[0].ty { - ArgType::Tensor(tensor) => tensor.rank, - ArgType::Scalar(_) => { - 0 // treat scalar as 0-dim tensor - } - _ => panic!("Unsqueeze: invalid input type"), - }; - - let output_elem = match &node.outputs[0].ty { - ArgType::Tensor(_) => node.inputs[0].ty.elem_type().clone(), - ArgType::Scalar(elem_type) => elem_type.clone(), - _ => panic!("Unsqueeze: invalid output type"), - }; - - let output_rank = if let Some(axes) = axes { - input_rank + axes.len() - } else if let ArgType::Tensor(tensor) = &node.inputs[1].ty { - if let Some(static_shape) = &tensor.static_shape { - input_rank + *static_shape.first().expect("Empty shape") - } else { - panic!("Unsqueeze: should have static shape") - } - } else { - panic!("Unsqueeze: missing axes information") - }; - - node.outputs[0].ty = ArgType::Tensor(TensorType { - rank: output_rank, - static_shape: None, // shape is tracked and calculated at runtime - elem_type: output_elem, - }); - - log::debug!("Unsqueeze output rank for {}: {}", node.name, output_rank); -} - -/// Preserve input rank for operations like Relu, Sigmoid, etc. -fn same_as_input(node: &mut Node) { - log::debug!("Copying input type to output for node {}", node.name); - - if let ArgType::Tensor(tensor) = &node.inputs[0].ty { - log::debug!("Input rank for {}: {}", node.name, tensor.rank); - } else if let ArgType::Scalar(_) = &node.inputs[0].ty { - log::debug!("Input is scalar for {}", node.name); - } - - node.outputs[0].ty = node.inputs[0].ty.clone(); - log::debug!("Output type is same as input for {}", node.name); -} - -/// Update output rank for TopK (same as input rank). -fn top_k_update_output(node: &mut Node) { - log::debug!("TopK rank inference for node {}", node.name); - - let rank = match &node.inputs[0].ty { - ArgType::Tensor(tensor) => tensor.rank, - _ => panic!("TopK: invalid input type"), - }; - log::debug!("TopK input rank for {}: {}", node.name, rank); - - node.outputs[0].ty = ArgType::Tensor(TensorType { - elem_type: node.inputs[0].ty.elem_type().clone(), - rank, - static_shape: None, - }); - node.outputs[1].ty = ArgType::Tensor(TensorType { - elem_type: ElementType::Int64, - rank, - static_shape: None, - }); - - log::debug!( - "TopK output rank for {}: {} (both outputs)", - node.name, - rank - ); -} - -/// Temporary stub preserves input type for unhandled operations. -fn temporary_pass_through_stub(node: &mut Node) { - log::warn!( - "Must implement rank inference for node type {:?} (name: {})", - node.node_type, - node.name - ); - - if let Some(input_rank) = node.inputs.first().map(|input| match &input.ty { - ArgType::Tensor(tensor) => tensor.rank, - ArgType::Scalar(_) => 0, - _ => 0, - }) { - log::debug!( - "Passing through input rank {} for unhandled node {}", - input_rank, - node.name - ); - } - - node.outputs[0].ty = node.inputs[0].ty.clone(); - log::debug!( - "Using pass-through inference for unhandled node type {:?} ({})", - node.node_type, - node.name - ); -} - -/// Update output type for comparison operations (e.g., Equal, Greater) to max input rank. -fn elementwise_comparison_outputs(node: &mut Node) { - log::debug!("Elementwise comparison for node {}", node.name); - - let max_rank = node.inputs.iter().fold(0, |acc, input| match &input.ty { - ArgType::Tensor(tensor) => acc.max(tensor.rank), - ArgType::Scalar(_) => acc, - _ => panic!("Invalid input type for comparison op"), - }); - - log::debug!("Max rank for comparison node {}: {}", node.name, max_rank); - - if max_rank == 0 { - node.outputs[0].ty = ArgType::Scalar(ElementType::Bool); - log::debug!("Scalar boolean result for node {}", node.name); - } else { - node.outputs[0].ty = ArgType::Tensor(TensorType { - elem_type: ElementType::Bool, - rank: max_rank, - static_shape: None, - }); - log::debug!( - "Tensor boolean result for node {} with rank {}", - node.name, - max_rank - ); - } -} - -/// Updates the output rank and shape for the Expand operation based on the provided shape input. -/// If the shape is a constant, the rank and static shape of the output are set accordingly. -/// If the shape is dynamic, the rank is inferred from the static shape of the shape input. -fn expand_update_outputs(node: &mut Node) { - let shape = if node.inputs.len() == 2 { - match &node.inputs[1].value { - Some(value) => match &value.data { - Data::Int64s(shape) => Some(shape.clone()), - _ => panic!("Expand operation encountered invalid input types"), - }, - None => None, - } - } else { - panic!("Expand operation requires exactly two inputs"); - }; - - let output = match &node.outputs[0].ty { - ArgType::Tensor(tensor) => tensor.clone(), - _ => panic!("Expand operation encountered invalid output types"), - }; - - if let Some(shape) = shape { - node.outputs[0].ty = ArgType::Tensor(TensorType { - rank: shape.len(), - static_shape: Some(shape.into_iter().map(|dim| dim as usize).collect()), - ..output - }); - } else { - // When the shape cannot be determined statically (i.e., the second argument 'shape' is passed dynamically), - // infer the rank from the static shape of the input tensor. - let output_rank = match &node.inputs[1].ty { - ArgType::Tensor(tensor) => tensor - .static_shape - .as_ref() - .expect("Shape input must have a static shape defined") - .first() - .copied() - .expect("Static shape must contain at least one element"), - ArgType::Shape(rank) => *rank, - _ => panic!("Shape input must be of tensor or shape type",), - }; - - node.outputs[0].ty = ArgType::Tensor(TensorType { - rank: output_rank, - static_shape: None, // The exact shape cannot be determined statically - ..output - }); - } -} - -/// Update output type for Flatten operation (rank 2). -fn flatten_update_outputs(node: &mut Node) { - if node.inputs.len() != 1 { - panic!("Flatten: multiple inputs are not supported"); - } - let tensor = node - .inputs - .iter() - .find_map(|input| match &input.ty { - ArgType::Tensor(tensor) => Some(tensor), - _ => None, - }) - .unwrap(); - - // Flatten to a 2D tensor - node.outputs[0].ty = ArgType::Tensor(TensorType { - rank: 2, - ..tensor.clone() - }); -} - -/// Update output rank for MatMul based on input ranks. -fn matmul_update_outputs(node: &mut Node) { - log::debug!("MatMul rank inference for node {}", node.name); - - match (&node.inputs[0].ty, &node.inputs[1].ty) { - (ArgType::Tensor(a), ArgType::Tensor(b)) => { - log::debug!( - "MatMul input ranks for {}: a.rank={}, b.rank={}", - node.name, - a.rank, - b.rank - ); - - let mut out_rank = max(a.rank, b.rank); - if (a.rank >= 2 && b.rank == 1) || (a.rank == 1 && b.rank >= 2) { - out_rank -= 1; - log::debug!( - "MatMul special case for node {}: reducing output rank", - node.name - ); - } - - node.outputs[0].ty = ArgType::Tensor(TensorType { - elem_type: a.elem_type.clone(), - rank: out_rank, - static_shape: None, - }); - - log::debug!("MatMul output rank for {}: {}", node.name, out_rank); - } - _ => panic!("Only tensor inputs are valid"), - } -} - -/// Update output rank for Range (always rank 1). -fn range_update_outputs(node: &mut Node) { - log::debug!("Range rank inference for node {}", node.name); - - if node.inputs.len() != 3 { - panic!("Range: expected 3 inputs, found {}", node.inputs.len()); - } - log::debug!( - "Range operation always produces rank 1 tensor for {}", - node.name - ); - - node.outputs[0].ty = ArgType::Tensor(TensorType { - elem_type: ElementType::Int64, - rank: 1, - static_shape: None, - }); - - log::debug!("Range output rank for {}: 1", node.name); -} - -/// Update output rank for Where to max input rank. -fn where_update_outputs(node: &mut Node) { - log::debug!("Where rank inference for node {}", node.name); - - let condition = &node.inputs[0].ty; - let x = &node.inputs[1].ty; - let y = &node.inputs[2].ty; - let elem_type = x.elem_type().clone(); - assert_eq!( - *condition.elem_type(), - ElementType::Bool, - "Where condition must be boolean!" - ); - assert_eq!( - elem_type, - *y.elem_type(), - "Where x and y have different element types!" - ); - - log::debug!( - "Where input ranks for {}: condition={}, x={}, y={}", - node.name, - condition.rank(), - x.rank(), - y.rank() - ); - - let output_rank = max(condition.rank(), max(x.rank(), y.rank())); - log::debug!("Where output rank for {}: {}", node.name, output_rank); - - if output_rank == 0 { - node.outputs[0].ty = ArgType::Scalar(elem_type); - log::debug!("Where result for {} is scalar", node.name); - } else { - node.outputs[0].ty = ArgType::Tensor(TensorType { - elem_type, - rank: output_rank, - static_shape: None, - }); - log::debug!( - "Where result for {} is tensor with rank {}", - node.name, - output_rank - ); - } -} - -/// Update output rank for Gather based on input and indices ranks. -fn gather_update_outputs(node: &mut Node) { - log::debug!("Gather rank inference for node {}", node.name); - - if node.inputs.len() != 2 { - panic!("Gather requires two inputs: data and indices"); - } - - let indices_rank = match &node.inputs[1].ty { - ArgType::Tensor(tensor) => tensor.rank, - ArgType::Scalar(_) => 0, - _ => panic!("Only tensor indices is valid, got {:?}", node.inputs[1].ty), - }; - log::debug!("Gather indices rank for {}: {}", node.name, indices_rank); - - match &node.inputs[0].ty { - ArgType::Tensor(input_tensor) => { - log::debug!( - "Gather input tensor rank for {}: {}", - node.name, - input_tensor.rank - ); - // Output of rank q+(r-1), where q is rank of indices tensor and r is rank of input - let output_rank = indices_rank + input_tensor.rank - 1; - log::debug!("Gather output rank for {}: {}", node.name, output_rank); - - if output_rank == 0 { - node.outputs[0].ty = ArgType::Scalar(input_tensor.elem_type.clone()); - log::debug!("Gather result for {} is scalar", node.name); - } else { - node.outputs[0].ty = ArgType::Tensor(TensorType { - elem_type: input_tensor.elem_type.clone(), - rank: output_rank, - static_shape: None, - }); - log::debug!( - "Gather result for {} is tensor with rank {}", - node.name, - output_rank - ); - } - } - ArgType::Shape(_) => { - log::debug!("Gather input is shape for {}", node.name); - let shape_rank = 1; - // Output of rank q+(r-1), where q is rank of indices tensor and r is rank of input - let output_rank = indices_rank + shape_rank - 1; - log::debug!( - "Gather output rank for {} with shape input: {}", - node.name, - output_rank - ); - - if output_rank == 0 { - node.outputs[0].ty = ArgType::Scalar(ElementType::Int64); - log::debug!("Gather result for {} is scalar (from shape)", node.name); - } else { - node.outputs[0].ty = ArgType::Tensor(TensorType { - elem_type: ElementType::Int64, - rank: output_rank, - static_shape: None, - }); - log::debug!( - "Gather result for {} is tensor with rank {} (from shape)", - node.name, - output_rank - ); - } - } - ty => panic!("Only tensor/shape input is valid but received: {:?}", ty), - } -} - -/// Update output rank for Split (same as input). -fn split_update_outputs(node: &mut Node) { - log::debug!("Split rank inference for node {}", node.name); - - let tensor = match &node.inputs[0].ty { - ArgType::Tensor(tensor) => tensor, - _ => panic!("Split: Input must be a tensor"), - }; - log::debug!("Split input rank for {}: {}", node.name, tensor.rank); - log::debug!( - "Split will generate {} outputs for {}", - node.outputs.len(), - node.name - ); - - for (i, output_arg) in node.outputs.iter_mut().enumerate() { - output_arg.ty = ArgType::Tensor(TensorType { - elem_type: tensor.elem_type.clone(), - rank: tensor.rank, - static_shape: None, - }); - log::debug!("Split output {} rank for {}: {}", i, node.name, tensor.rank); - } -} - -fn gemm_output_shape(node: &mut Node) { - log::debug!("Gemm rank inference for node {}", node.name); - - let a_rank = match &node.inputs[0].ty { - ArgType::Tensor(tensor) => tensor.rank, - _ => panic!("Input A should be a tensor!"), - }; - let b_rank = match &node.inputs[1].ty { - ArgType::Tensor(tensor) => tensor.rank, - _ => panic!("Input B should be a tensor!"), - }; - - log::debug!( - "Gemm input ranks for {}: a_rank={}, b_rank={}", - node.name, - a_rank, - b_rank - ); - - let output_rank = max(a_rank, b_rank); - log::debug!("Gemm output rank for {}: {}", node.name, output_rank); - - node.outputs[0].ty = ArgType::Tensor(TensorType { - rank: output_rank, - static_shape: None, - elem_type: match &node.inputs[0].ty { - ArgType::Tensor(t) => t.elem_type.clone(), - _ => panic!("Unexpected type for input A"), - }, - }); -} diff --git a/crates/onnx-ir/src/util.rs b/crates/onnx-ir/src/util.rs index 6843b3c9be..fe63f651a1 100644 --- a/crates/onnx-ir/src/util.rs +++ b/crates/onnx-ir/src/util.rs @@ -1,4 +1,5 @@ -use crate::ir::{ArgType, Node}; +use crate::ir::{ArgType, Node, TensorType}; + use crate::protos::OperatorSetIdProto; pub fn shape_config(curr: &Node) -> (usize, usize) { @@ -81,3 +82,190 @@ pub fn verify_opsets(opsets: &[OperatorSetIdProto], min_version: i64) -> bool { } true } + +/// Preserve input rank for operations like Relu, Sigmoid, etc. +pub fn same_as_input(node: &mut Node) { + log::debug!("Copying input type to output for node {}", node.name); + + if let ArgType::Tensor(tensor) = &node.inputs[0].ty { + log::debug!("Input rank for {}: {}", node.name, tensor.rank); + } else if let ArgType::Scalar(_) = &node.inputs[0].ty { + log::debug!("Input is scalar for {}", node.name); + } + + node.outputs[0].ty = node.inputs[0].ty.clone(); + log::debug!("Output type is same as input for {}", node.name); +} + +/// Update output rank for broadcasting operations (e.g., Add, Sub) to max input rank. +pub fn same_as_input_broadcast(node: &mut Node) { + log::debug!("Broadcasting operation for node {}", node.name); + + let max_rank = node.inputs.iter().fold(0, |acc, input| match &input.ty { + ArgType::Tensor(tensor) => acc.max(tensor.rank), + ArgType::Scalar(_) => acc, + _ => panic!("Unsupported input type for broadcasting operation"), + }); + + log::debug!("Max rank for broadcasting node {}: {}", node.name, max_rank); + + if max_rank == 0 { + node.outputs[0].ty = ArgType::Scalar(node.inputs[0].ty.elem_type().clone()); + log::debug!("Scalar result for node {}", node.name); + } else { + let elem_type = node + .inputs + .iter() + .find_map(|input| match &input.ty { + ArgType::Tensor(tensor) => Some(tensor.elem_type.clone()), + _ => None, + }) + .unwrap_or_else(|| node.inputs[0].ty.elem_type().clone()); + + node.outputs[0].ty = ArgType::Tensor(TensorType { + elem_type, + rank: max_rank, + static_shape: None, + }); + log::debug!( + "Tensor result for node {} with rank {}", + node.name, + max_rank + ); + } +} + +/// Temporary stub preserves input type for unhandled operations. +pub fn temporary_pass_through_stub(node: &mut Node) { + log::warn!( + "Must implement rank inference for node type {:?} (name: {})", + node.node_type, + node.name + ); + + if let Some(input_rank) = node.inputs.first().map(|input| match &input.ty { + ArgType::Tensor(tensor) => tensor.rank, + ArgType::Scalar(_) => 0, + _ => 0, + }) { + log::debug!( + "Passing through input rank {} for unhandled node {}", + input_rank, + node.name + ); + } + + node.outputs[0].ty = node.inputs[0].ty.clone(); + log::debug!( + "Using pass-through inference for unhandled node type {:?} ({})", + node.node_type, + node.name + ); +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::{Argument, ElementType, NodeType}; + use std::collections::HashMap; + + fn create_test_node(op_type: NodeType, input_ranks: Vec) -> Node { + let mut inputs = Vec::new(); + + for (i, rank) in input_ranks.iter().enumerate() { + inputs.push(Argument { + name: format!("input_{}", i), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: *rank, + static_shape: None, + }), + value: None, + passed: true, + }); + } + + let outputs = vec![Argument { + name: "output".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 0, // Will be updated + static_shape: None, + }), + value: None, + passed: true, + }]; + + Node { + node_type: op_type.clone(), + name: format!("test_{:?}", op_type).to_lowercase(), + inputs, + outputs, + attrs: HashMap::new(), + } + } + + #[test] + fn test_same_as_input() { + let mut node = create_test_node(NodeType::Relu, vec![3]); + same_as_input(&mut node); + + match &node.outputs[0].ty { + ArgType::Tensor(tensor) => { + assert_eq!(tensor.elem_type, ElementType::Float32); + assert_eq!(tensor.rank, 3); + } + _ => panic!("Expected tensor output"), + } + } + + #[test] + fn test_same_as_input_broadcast_max_rank() { + let mut node = create_test_node(NodeType::Add, vec![2, 4, 3]); + same_as_input_broadcast(&mut node); + + match &node.outputs[0].ty { + ArgType::Tensor(tensor) => { + assert_eq!(tensor.elem_type, ElementType::Float32); + assert_eq!(tensor.rank, 4); // max(2, 4, 3) = 4 + } + _ => panic!("Expected tensor output"), + } + } + + #[test] + fn test_same_as_input_broadcast_with_scalar() { + let mut node = create_test_node(NodeType::Add, vec![3]); + // Add a scalar input + node.inputs.push(Argument { + name: "scalar_input".to_string(), + ty: ArgType::Scalar(ElementType::Float32), + value: None, + passed: true, + }); + + same_as_input_broadcast(&mut node); + + match &node.outputs[0].ty { + ArgType::Tensor(tensor) => { + assert_eq!(tensor.elem_type, ElementType::Float32); + assert_eq!(tensor.rank, 3); // Scalar doesn't affect rank + } + _ => panic!("Expected tensor output"), + } + } + + #[test] + fn test_temporary_pass_through_stub() { + let mut node = create_test_node(NodeType::Identity, vec![5]); + temporary_pass_through_stub(&mut node); + + match &node.outputs[0].ty { + ArgType::Tensor(tensor) => { + assert_eq!(tensor.elem_type, ElementType::Float32); + assert_eq!(tensor.rank, 5); + } + _ => panic!("Expected tensor output"), + } + } +} From c60cde3061fe5d4be681dcdbd5ecac74b3dbace1 Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Wed, 30 Apr 2025 17:34:52 -0500 Subject: [PATCH 14/37] Add documentation --- crates/onnx-ir/src/node/mod.rs | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/crates/onnx-ir/src/node/mod.rs b/crates/onnx-ir/src/node/mod.rs index a0c1eac985..1e5a28649a 100644 --- a/crates/onnx-ir/src/node/mod.rs +++ b/crates/onnx-ir/src/node/mod.rs @@ -1,3 +1,12 @@ +//! Node module contains implementations of ONNX operations. +//! +//! Each submodule implements a specific ONNX operation, providing: +//! - Operation configuration and parameters +//! - Rank inference functionality +//! +//! This modular structure allows for clean separation of operation implementations +//! and facilitates easier maintenance and extension of the ONNX operation set. + pub mod argmax; pub mod avg_pool1d; pub mod avg_pool2d; From a66a43b8143da7b72e5ed1492ee6a1d9198547ab Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Wed, 30 Apr 2025 17:46:08 -0500 Subject: [PATCH 15/37] Repoint config function from onnx-ir --- .../burn-import/src/onnx/op_configuration.rs | 1343 ----------------- crates/burn-import/src/onnx/to_burn.rs | 26 +- 2 files changed, 16 insertions(+), 1353 deletions(-) diff --git a/crates/burn-import/src/onnx/op_configuration.rs b/crates/burn-import/src/onnx/op_configuration.rs index 84b5b2601b..27b15b34ac 100644 --- a/crates/burn-import/src/onnx/op_configuration.rs +++ b/crates/burn-import/src/onnx/op_configuration.rs @@ -1,328 +1,12 @@ // TODO Move op_configuration.rs from burn-import to onnx-ir #3091 // See https://github.com/tracel-ai/burn/issues/3091 -use burn::nn::{ - BatchNormConfig, DropoutConfig, LayerNormConfig, LinearConfig, PaddingConfig2d, - PaddingConfig3d, - conv::{Conv2dConfig, Conv3dConfig, ConvTranspose2dConfig, ConvTranspose3dConfig}, - pool::{AvgPool2dConfig, MaxPool2dConfig}, -}; - use crate::burn::node::{ expand::ExpandShape, pad::PadConfig, split::SplitConfig, tile::TileConfig, top_k::TopKConfig, trilu::TriluConfig, unsqueeze::UnsqueezeAxes, }; use onnx_ir::ir::{ArgType, AttributeValue, Data, ElementType, Node, TensorData}; -// Conv1dConfig implementation moved to onnx-ir::node::conv1d - -/// Create a Conv2dConfig from the attributes of the node -pub fn conv2d_config(curr: &Node) -> Conv2dConfig { - let mut kernel_shape = Vec::new(); // TODO default inferred from weight tensor per spec - let mut strides = vec![1, 1]; - let mut pads = vec![0, 0, 0, 0]; - let mut dilations = vec![1, 1]; - let mut group: usize = 1; - - let weight_shape = curr.inputs[1] - .value - .as_ref() - .expect("Conv2d: weight tensor must be present") - .shape - .clone(); - - // check if the bias is present - let bias = curr.inputs.len() == 3; - - for (key, value) in curr.attrs.iter() { - match key.as_str() { - "kernel_shape" => kernel_shape = value.clone().into_i64s(), - "strides" => strides = value.clone().into_i64s(), - "pads" => pads = value.clone().into_i64s(), - "dilations" => dilations = value.clone().into_i64s(), - "group" => group = value.clone().into_i64() as usize, - _ => {} - } - } - - // the channels are inverted in the weight tensor - let channels_in = weight_shape[1] * group; - let channels_out = weight_shape[0]; - - let padding = padding_config_2d(&pads); - - Conv2dConfig::new( - [channels_in, channels_out], - [kernel_shape[0] as usize, kernel_shape[1] as usize], - ) - .with_stride([strides[0] as usize, strides[1] as usize]) - .with_dilation([dilations[0] as usize, dilations[1] as usize]) - .with_groups(group) - .with_bias(bias) - .with_padding(padding) -} - -/// Create a Conv3dConfig from the attributes of the node -pub fn conv3d_config(curr: &Node) -> Conv3dConfig { - let mut kernel_shape = Vec::new(); // TODO default inferred from weight tensor per spec - let mut strides = vec![1, 1, 1]; - let mut pads = vec![0, 0, 0, 0, 0, 0]; - let mut dilations = vec![1, 1, 1]; - let mut group: usize = 1; - - let weight_shape = curr.inputs[1] - .value - .as_ref() - .expect("Conv3d: weight tensor must be present") - .shape - .clone(); - - // check if the bias is present - let bias = curr.inputs.len() == 3; - - for (key, value) in curr.attrs.iter() { - match key.as_str() { - "kernel_shape" => kernel_shape = value.clone().into_i64s(), - "strides" => strides = value.clone().into_i64s(), - "pads" => pads = value.clone().into_i64s(), - "dilations" => dilations = value.clone().into_i64s(), - "group" => group = value.clone().into_i64() as usize, - _ => {} - } - } - - // the channels are inverted in the weight tensor - let channels_in = weight_shape[1] * group; - let channels_out = weight_shape[0]; - - let padding = padding_config_3d(&pads); - - Conv3dConfig::new( - [channels_in, channels_out], - [ - kernel_shape[0] as usize, - kernel_shape[1] as usize, - kernel_shape[2] as usize, - ], - ) - .with_stride([ - strides[0] as usize, - strides[1] as usize, - strides[2] as usize, - ]) - .with_dilation([ - dilations[0] as usize, - dilations[1] as usize, - dilations[2] as usize, - ]) - .with_groups(group) - .with_bias(bias) - .with_padding(padding) -} - -// MaxPool1dConfig implementation moved to onnx-ir::node::max_pool1d - -/// Create a MaxPool2dConfig from the attributes of the node -pub fn max_pool2d_config(curr: &Node) -> MaxPool2dConfig { - let mut kernel_shape = Vec::new(); - let mut strides = vec![1, 1]; - let mut pads = vec![0, 0, 0, 0]; - let mut dilations = vec![1, 1]; - - for (key, value) in curr.attrs.iter() { - match key.as_str() { - "kernel_shape" => kernel_shape = value.clone().into_i64s(), - "strides" => strides = value.clone().into_i64s(), - "pads" => pads = value.clone().into_i64s(), - "dilations" => dilations = value.clone().into_i64s(), - _ => {} - } - } - - let padding = padding_config_2d(&pads); - - MaxPool2dConfig::new([kernel_shape[0] as usize, kernel_shape[1] as usize]) - .with_strides([strides[0] as usize, strides[1] as usize]) - .with_padding(padding) - .with_dilation([dilations[0] as usize, dilations[1] as usize]) -} - -// ConvTranspose1dConfig implementation moved to onnx-ir::node::conv_transpose1d - -pub fn conv_transpose2d_config(curr: &Node) -> ConvTranspose2dConfig { - let mut attrs = curr.attrs.clone(); - let kernel_shape = attrs - .remove("kernel_shape") - .map(AttributeValue::into_i64s) - .unwrap_or_default(); - let stride = attrs - .remove("strides") - .map(AttributeValue::into_i64s) - .unwrap_or_else(|| vec![1, 1]); - let pads = attrs - .remove("pads") - .map(AttributeValue::into_i64s) - .unwrap_or_else(|| vec![0, 0, 0, 0]); - let dilations = attrs - .remove("dilations") - .map(AttributeValue::into_i64s) - .unwrap_or_else(|| vec![1, 1]); - let group = attrs - .remove("group") - .map(AttributeValue::into_i64) - .unwrap_or(1) as usize; - let output_padding = attrs - .remove("output_padding") - .map(AttributeValue::into_i64s) - .unwrap_or_else(|| vec![0, 0]); - - // Trick with remove + empty check is simplest way to not forget some attribute for runtime: - if !attrs.is_empty() { - panic!("Not all attributes are used: {attrs:?}"); - } - // Check the pads are symmetric. - let [left, top, right, bottom] = [pads[0], pads[1], pads[2], pads[3]]; - if left < 0 || top < 0 || right < 0 || bottom < 0 { - panic!("Negative pad values are not supported"); - } else if (left != right) || (top != bottom) { - panic!("Asymmetric padding is not supported"); - } - - let weight_shape = curr.inputs[1] - .value - .as_ref() - .expect("ConvTranspose2d: weight tensor must be present") - .shape - .clone(); - - // check if the bias is present - let bias = curr.inputs.len() == 3; - - // the channels are inverted in the weight tensor - let channels: [usize; 2] = [weight_shape[1] * group, weight_shape[0]]; - - ConvTranspose2dConfig::new( - channels, - [kernel_shape[0] as usize, kernel_shape[1] as usize], - ) - .with_stride([stride[0] as usize, stride[1] as usize]) - .with_padding([pads[0] as usize, pads[1] as usize]) - .with_dilation([dilations[0] as usize, dilations[1] as usize]) - .with_padding_out([output_padding[0] as usize, output_padding[1] as usize]) - .with_groups(group) - .with_bias(bias) -} - -pub fn conv_transpose3d_config(curr: &Node) -> ConvTranspose3dConfig { - let mut attrs = curr.attrs.clone(); - let kernel_shape = attrs - .remove("kernel_shape") - .map(AttributeValue::into_i64s) - .unwrap_or_default(); - let stride = attrs - .remove("strides") - .map(AttributeValue::into_i64s) - .unwrap_or_else(|| vec![1, 1, 1]); - let pads = attrs - .remove("pads") - .map(AttributeValue::into_i64s) - .unwrap_or_else(|| vec![0, 0, 0, 0, 0, 0]); - let dilations = attrs - .remove("dilations") - .map(AttributeValue::into_i64s) - .unwrap_or_else(|| vec![1, 1, 1]); - let group = attrs - .remove("group") - .map(AttributeValue::into_i64) - .unwrap_or(1) as usize; - let output_padding = attrs - .remove("output_padding") - .map(AttributeValue::into_i64s) - .unwrap_or_else(|| vec![0, 0, 0]); - - // Trick with remove + empty check is simplest way to not forget some attribute for runtime: - if !attrs.is_empty() { - panic!("Not all attributes are used: {attrs:?}"); - } - // Check the pads are symmetric. - let [left, top, front, right, bottom, back] = - [pads[0], pads[1], pads[2], pads[3], pads[4], pads[5]]; - - if left < 0 || top < 0 || front < 0 || right < 0 || bottom < 0 || back < 0 { - panic!("Negative pad values are not supported"); - } else if (left != right) || (top != bottom) || (front != back) { - panic!("Asymmetric padding is not supported"); - } - - let weight_shape = curr.inputs[1] - .value - .as_ref() - .expect("ConvTranspose3d: weight tensor must be present") - .shape - .clone(); - - // check if the bias is present - let bias = curr.inputs.len() == 3; - - // the channels are inverted in the weight tensor - let channels: [usize; 2] = [weight_shape[1] * group, weight_shape[0]]; - - ConvTranspose3dConfig::new( - channels, - [ - kernel_shape[0] as usize, - kernel_shape[1] as usize, - kernel_shape[2] as usize, - ], - ) - .with_stride([stride[0] as usize, stride[1] as usize, stride[2] as usize]) - .with_padding([pads[0] as usize, pads[1] as usize, pads[2] as usize]) - .with_dilation([ - dilations[0] as usize, - dilations[1] as usize, - dilations[2] as usize, - ]) - .with_padding_out([ - output_padding[0] as usize, - output_padding[1] as usize, - output_padding[2] as usize, - ]) - .with_groups(group) - .with_bias(bias) -} - -// AvgPool1dConfig implementation moved to onnx-ir::node::avg_pool1d -/// Create a AvgPool2dConfig from the attributes of the node -pub fn avg_pool2d_config(curr: &Node) -> AvgPool2dConfig { - let mut kernel_shape = Vec::new(); - let mut strides = vec![1, 1]; - let mut pads = vec![0, 0, 0, 0]; - let mut count_include_pad: i64 = 0; - let mut ceil_mode: i64 = 0; - - for (key, value) in curr.attrs.iter() { - match key.as_str() { - "kernel_shape" => kernel_shape = value.clone().into_i64s(), - "strides" => strides = value.clone().into_i64s(), - "pads" => pads = value.clone().into_i64s(), - "count_include_pad" => count_include_pad = value.clone().into_i64(), - "ceil_mode" => ceil_mode = value.clone().into_i64(), - _ => {} - } - } - - if ceil_mode == 1 { - panic!("ceil_mode is not supported"); - } - - let padding = padding_config_2d(&pads); - - AvgPool2dConfig::new([kernel_shape[0] as usize, kernel_shape[1] as usize]) - .with_strides([strides[0] as usize, strides[1] as usize]) - .with_padding(padding) - .with_count_include_pad(count_include_pad == 1) -} - pub fn expand_config(node: &Node) -> ExpandShape { match &node.inputs[1].ty { ArgType::Tensor(tensor) => { @@ -354,352 +38,6 @@ pub fn expand_config(node: &Node) -> ExpandShape { } } -/// Create a FlattenConfig from the attributes of the node -pub fn flatten_config(curr: &Node) -> usize { - // the begin dimension is the first dimension (Default: 1 per ONNX spec) - let mut axis: i64 = 1; - - // check if the node has only one input - if curr.inputs.len() != 1 { - panic!( - "Flatten: multiple inputs are not supported (got {:?})", - curr.inputs.len() - ); - } - - // extract the shape of the input tensor - let tensor = match curr.inputs.first().unwrap().clone().ty { - ArgType::Tensor(tensor) => tensor, - _ => panic!("Only tensor input is valid"), - }; - - // check if the input tensor has at least 2 dimensions - if tensor.rank < 2 { - panic!( - "Flatten: input tensor must have at least 2 dimensions (got {:?})", - tensor.rank - ); - } - - // extract the attributes - for (key, value) in curr.attrs.iter() { - match key.as_str() { - "axis" => axis = value.clone().into_i64(), - _ => {} - } - } - - // if beg_dim is negative, it is counted from the end - if axis < 0 { - axis += tensor.rank as i64; - } - - axis as usize -} - -/// Create a GatherConfig from the attributes of the node -pub fn gather_config(curr: &Node) -> usize { - // Default: 0 per ONNX spec - let mut dim: i64 = 0; - - // check if the node has only one input - if curr.inputs.len() != 2 { - panic!("Gather: index tensor must be present"); - } - - // extract the shape of the input tensor - let input_dim = match curr.inputs.first().unwrap().clone().ty { - ArgType::Tensor(tensor) => tensor.rank as i64, - ArgType::Shape(_shape) => 1, //Shape is always 1-D - other => panic!("Only tensor or shape input is valid, got {:?}", other), - }; - - // extract the attributes - for (key, value) in curr.attrs.iter() { - match key.as_str() { - "axis" => dim = value.clone().into_i64(), - _ => {} - } - } - - // if dim is negative, it is counted from the end - if dim < 0 { - dim += input_dim; - } - - dim as usize -} - -/// Create a LinearConfig from the attributes of the node -pub fn linear_config(node: &Node) -> LinearConfig { - if node.inputs.len() < 2 { - panic!("Linear: missing weight tensor"); - } - - let weight_shape = node.inputs[1] - .value - .as_ref() - .expect("Linear: weight tensor must be present") - .shape - .clone(); - - // check if the weight tensor has at least 2 dimensions - if weight_shape.len() < 2 { - panic!( - "Linear: weight tensor must have at least 2 dimensions (got {:?})", - weight_shape.len() - ); - } - - let (in_size, out_size) = (weight_shape[0], weight_shape[1]); - - // check if the bias is present - let bias = node.inputs.len() == 3 && node.inputs[2].value.is_some(); - - LinearConfig::new(in_size, out_size).with_bias(bias) -} - -/// Create a DropoutConfig from an attribute and state of the node -pub fn dropout_config(node: &Node) -> DropoutConfig { - // Opset 7 and older store probability as an attribute - if node.attrs.contains_key("ratio") { - let prob = node.attrs.get("ratio").unwrap().clone().into_f32(); - return DropoutConfig::new(prob as f64); - } - - if node.inputs.len() < 2 { - panic!("Dropout configuration must have at least 2 inputs"); - } - - let ratio = node.inputs[1] - .value - .clone() - .expect("Dropout ratio must be passed in the second input") - .data - .into_scalar(); - - let prob = match ratio { - Data::Float16(ratio) => f64::from(f32::from(ratio)), - Data::Float32(ratio) => ratio as f64, - Data::Float64(ratio) => ratio, - _ => panic!("Dropout ratio must be a float"), - }; - - DropoutConfig::new(prob) -} - -/// Create log_softmax config from the attributes of the node -pub fn log_softmax_config(node: &Node) -> usize { - // the axis is the last dimension (Default: 1 per ONNX spec) - let mut axis: i64 = -1; - - // check if the node has only one input - if node.inputs.len() != 1 { - panic!( - "LogSoftmax: multiple inputs are not supported (got {:?})", - node.inputs.len() - ); - } - - // extract the shape of the input tensor - let tensor = match node.inputs.first().unwrap().clone().ty { - ArgType::Tensor(tensor) => tensor, - _ => panic!("Only tensor input is valid"), - }; - - // extract the attributes - for (key, value) in node.attrs.iter() { - match key.as_str() { - "axis" => axis = value.clone().into_i64(), - _ => {} - } - } - - // if axis is negative, it is counted from the end - if axis < 0 { - axis += tensor.rank as i64; - } - - axis as usize -} - -/// Create softmax config from the attributes of the node -pub fn softmax_config(node: &Node) -> usize { - // the axis is the last dimension (Default: 1 per ONNX spec) - let mut axis: i64 = -1; - - // check if the node has only one input - if node.inputs.len() != 1 { - panic!( - "Softmax: multiple inputs are not supported (got {:?})", - node.inputs.len() - ); - } - - // extract the shape of the input tensor - let tensor = match node.inputs.first().unwrap().clone().ty { - ArgType::Tensor(tensor) => tensor, - _ => panic!("Only tensor input is valid"), - }; - - // extract the attributes - for (key, value) in node.attrs.iter() { - match key.as_str() { - "axis" => axis = value.clone().into_i64(), - _ => {} - } - } - - // if axis is negative, it is counted from the end - if axis < 0 { - axis += tensor.rank as i64; - } - - axis as usize -} - -/// Create argmax config from the attributes of the node -pub fn argmax_config(node: &Node) -> usize { - let mut axis: i64 = 0; - - // check if the node has only one input - if node.inputs.len() != 1 { - panic!( - "Argmax: multiple inputs are not supported (got {:?})", - node.inputs.len() - ); - } - - // extract the shape of the input tensor - let tensor = match node.inputs.first().unwrap().clone().ty { - ArgType::Tensor(tensor) => tensor, - _ => panic!("Only tensor input is valid"), - }; - - // extract the attributes - for (key, value) in node.attrs.iter() { - match key.as_str() { - "axis" => axis = value.clone().into_i64(), - "select_last_index" => { - // not all params are supported in burn - if value.clone().into_i64() != 0 { - log::warn!( - "only select_last_index=0 is supported for argmax in burn. Ignoring supplied value (got {:?})", - value - ); - } - } - "keepdims" => { - // not all params are supported in burn - if value.clone().into_i64() != 1 { - panic!( - "Only keepdims=1 is supported for argmax in burn (got {:?})", - value - ); - } - } - _ => {} - } - } - - // if axis is negative, it is counted from the end - if axis < 0 { - axis += tensor.rank as i64; - } - - axis as usize -} - -/// Create concat config from the attributes of the node -pub fn concat_config(node: &Node) -> usize { - // the axis is the last dimension (Default: 1 per ONNX spec) - let mut axis: i64 = 1; - - // extract the shape of the input tensor - let tensor = match node.inputs.first().unwrap().clone().ty { - ArgType::Tensor(tensor) => tensor, - _ => panic!("Only tensor input is valid"), - }; - - // extract the attributes - for (key, value) in node.attrs.iter() { - match key.as_str() { - "axis" => axis = value.clone().into_i64(), - _ => {} - } - } - - // if axis is negative, it is counted from the end - if axis < 0 { - axis += tensor.rank as i64; - } - - axis as usize -} - -/// Create a BatchNormConfig from the attributes of the node -pub fn batch_norm_config(node: &Node) -> BatchNormConfig { - let weight_shape = node.inputs[1] - .value - .as_ref() - .expect("BatchNorm: weight tensor must be present") - .shape - .clone(); - - let num_features = weight_shape[0]; - - let mut epsilon = 0f32; - let mut momentum = 0f32; - - for (key, value) in node.attrs.iter() { - match key.as_str() { - "momentum" => momentum = value.clone().into_f32(), - "epsilon" => epsilon = value.clone().into_f32(), - _ => {} - } - } - - BatchNormConfig::new(num_features) - .with_epsilon(epsilon as f64) - .with_momentum(momentum as f64) -} - -/// Create a LayerNormConfig from the attributes of the node -pub fn layer_norm_config(node: &Node) -> (LayerNormConfig, bool) { - let weight_shape = node.inputs[1] - .value - .as_ref() - .expect("LayerNorm: weight tensor must be present") - .shape - .clone(); - - let num_features = weight_shape[0]; - - // When `stash_type` is `1` (default), perform operations in 32-bit float and - // cast the results back to original dtype - let mut stash_type = 1; - let mut axis = -1; - let mut epsilon = 1e-5; - - for (key, value) in node.attrs.iter() { - match key.as_str() { - "axis" => axis = value.clone().into_i64(), - "epsilon" => epsilon = value.clone().into_f32(), - "stash_type" => stash_type = value.clone().into_i64(), - _ => {} - } - } - - if axis != -1 && axis != weight_shape.len() as i64 - 1 { - panic!("LayerNorm: normalization is only supported on the last axis right now") - } - - ( - LayerNormConfig::new(num_features).with_epsilon(epsilon as f64), - stash_type == 1, - ) -} - /// Create a TileConfig from the attributes of the node pub fn tile_config(node: &Node) -> TileConfig { let repeat = node @@ -912,273 +250,6 @@ pub fn pad_config(node: &Node) -> PadConfig { PadConfig::new(pads, constant_value) } -// padding_config_1d moved to onnx-ir::node::conv1d - -/// Calculate the padding configuration for a 2D operations such as Convolution and Pooling. -/// -/// # Arguments -/// -/// * `pads` - The padding values -/// -/// # Panics -/// -/// * If the padding is negative -/// * If the padding is not symmetric -/// -/// # Returns -/// -/// * The padding configuration -/// -/// # Remarks -/// -/// This function is used when the padding is specified as a list of integers, -/// and not used when the padding is specified as a string, e.g. "SAME_UPPER". -fn padding_config_2d(pads: &[i64]) -> PaddingConfig2d { - let [left, top, right, bottom] = [pads[0], pads[1], pads[2], pads[3]]; - - if left < 0 || top < 0 || right < 0 || bottom < 0 { - panic!("Negative pad values are not supported"); - } else if (left != right) || (top != bottom) { - panic!("Asymmetric padding is not supported"); - } else if left == 0 && top == 0 && right == 0 && bottom == 0 { - // i.e [0, 0, 0, 0] - PaddingConfig2d::Valid - } else if left == right && top == bottom { - // i.e [2, 3, 2, 3] - PaddingConfig2d::Explicit(left as usize, top as usize) - } else { - // Unaccounted for padding configuration - panic!("Padding configuration ({:?}) not supported", pads); - } -} - -/// Calculate the padding configuration for a 3D operations such as Convolution and Pooling. -/// -/// # Arguments -/// -/// * `pads` - The padding values -/// -/// # Panics -/// -/// * If the padding is negative -/// * If the padding is not symmetric -/// -/// # Returns -/// -/// * The padding configuration -/// -/// # Remarks -/// -/// This function is used when the padding is specified as a list of integers, -/// and not used when the padding is specified as a string, e.g. "SAME_UPPER". -fn padding_config_3d(pads: &[i64]) -> PaddingConfig3d { - let [left, top, front, right, bottom, back] = - [pads[0], pads[1], pads[2], pads[3], pads[4], pads[5]]; - - if left < 0 || top < 0 || front < 0 || right < 0 || bottom < 0 || back < 0 { - panic!("Negative pad values are not supported"); - } else if (left != right) || (top != bottom) || (front != back) { - panic!("Asymmetric padding is not supported"); - } else if left == 0 && top == 0 && front == 0 && right == 0 && bottom == 0 && back == 0 { - // i.e [0, 0, 0, 0] - PaddingConfig3d::Valid - } else if left == right && top == bottom && front == back { - // i.e [2, 3, 2, 3] - PaddingConfig3d::Explicit(left as usize, top as usize, front as usize) - } else { - // Unaccounted for padding configuration - panic!("Padding configuration ({:?}) not supported", pads); - } -} - -// Create a LeakyReluConfig from the alpha attribute of the node -pub fn leaky_relu_config(node: &Node) -> f64 { - let mut alpha = 0.01; - - for (key, value) in node.attrs.iter() { - match key.as_str() { - "alpha" => alpha = value.clone().into_f32() as f64, - _ => {} - } - } - - alpha -} - -// Create a HardSigmoidConfig from the alpha and beta attributes of the node -pub fn hard_sigmoid_config(node: &Node) -> (f64, f64) { - let mut alpha = 0.2; - let mut beta = 0.5; - - for (key, value) in node.attrs.iter() { - match key.as_str() { - "alpha" => alpha = value.clone().into_f32() as f64, - "beta" => beta = value.clone().into_f32() as f64, - _ => {} - } - } - - (alpha, beta) -} - -pub fn reshape_config(node: &Node) -> Vec { - let mut allowzero = 0; - - for (key, value) in node.attrs.iter() { - match key.as_str() { - "allowzero" => allowzero = value.clone().into_i64(), - _ => {} - } - } - - // Burn does not support zero size shape (0 means false in ONNX) - // (see https://onnx.ai/onnx/operators/onnx__Reshape.html#attributes) - if allowzero != 0 { - panic!("Zero shape size is not supported"); - } - - // TODO: check "shape" attribute - if node.inputs.len() != 2 || node.inputs[1].value.is_none() { - panic!("Reshape: shape tensor must be present for {:?}", node); - } - - match &node.inputs[1].value { - Some(TensorData { data, shape, .. }) => { - assert_eq!(shape.len(), 1, "Reshape: shape tensor must be 1D"); - data.clone().into_i64s() - } - _ => panic!("Only tensor input is valid for shape"), - } -} - -pub fn resize_config(node: &Node) -> (String, Vec, Vec) { - let mut mode: String = "".to_string(); - - let mut scales: Vec; - let mut sizes: Vec; - - let input = if let ArgType::Tensor(tensor) = &node - .inputs - .first() - .expect("Resize: Input tensor must be present") - .ty - { - tensor - } else { - panic!("Resize: input must be a tensor") - }; - - // Note: we are ignoring some attributes because results are approximately the same - // and we are not supporting all the attributes of the Resize operator. - // However, some attributes are important to be checked and we are checking - // against the default values of the attributes. - // TODO revisit this when we have more Resize operators in the model - for (key, value) in node.attrs.iter() { - match key.as_str() { - "antialias" => assert_eq!( - value.clone().into_i32(), - 0, - "Resize: antialias other than 0 is not supported" - ), - "axes" => panic!("Resize: custom axes attribute is not supported"), - "coordinate_transformation_mode" => { - log::warn!("Resize: coordinate_transformation_mode is ignored") - } - - "cubic_coeff_a" => log::warn!("Resize: cubic_coeff_a is ignored"), - "exclude_outside" => assert_eq!( - value.clone().into_i32(), - 0, - "Resize: exclude_outside other than 0 is not supported" - ), - "extrapolation_value" => assert_eq!( - value.clone().into_f32(), - 0.0, - "Resize: extrapolation_value other than 0.0 is not supported" - ), - "keep_aspect_ratio_policy" => { - assert_eq!( - value.clone().into_string().to_lowercase(), - "stretch", - "Resize: keep_aspect_ratio_policy other than 'stretch' is not supported" - ) - } - "mode" => mode = value.clone().into_string().to_lowercase(), - "nearest_mode" => log::warn!("Resize: nearest_mode is ignored"), - - _ => {} - } - } - - let roi: Vec = node - .inputs - .get(1) - .map(|input| { - if let Some(TensorData { data, .. }) = &input.value { - data.clone().into_f32s() - } else { - vec![] - } - }) - .unwrap_or_default(); - - scales = node - .inputs - .get(2) - .map(|input| { - if let Some(TensorData { data, .. }) = &input.value { - data.clone().into_f32s() - } else { - vec![] - } - }) - .unwrap_or_default(); - - sizes = node - .inputs - .get(3) - .map(|input| { - if let Some(TensorData { data, .. }) = &input.value { - data.clone() - .into_i64s() - .iter() - .map(|&x| x as usize) - .collect() - } else { - vec![] - } - }) - .unwrap_or_default(); - - if mode.is_empty() { - panic!("Resize: mode attribute is required") - } - - if !roi.is_empty() { - panic!("Resize: roi input is not supported") - } - - if scales.is_empty() && sizes.is_empty() { - panic!("Resize: either scales or sizes input is required") - } - - if !scales.is_empty() { - assert!(scales.len() == input.rank); - // ignore the fist two items from scales - // because they are the batch and channel dimensions - scales = scales.iter().skip(2).cloned().collect(); - } - - if !sizes.is_empty() { - assert!(sizes.len() == input.rank); - // ignore the fist two items from sizes - // because they are the batch and channel dimensions - sizes = sizes.iter().skip(2).cloned().collect(); - } - - (mode, scales, sizes) -} - //Note this function should only execute if the second input is a constant //if it wasn't and the output shape was known, unsqueeze has been remapped to reshape pub fn unsqueeze_config(node: &Node) -> UnsqueezeAxes { @@ -1214,371 +285,6 @@ pub fn unsqueeze_config(node: &Node) -> UnsqueezeAxes { } } -pub fn clip_config(node: &Node) -> (Option, Option) { - let mut min_result: Option = None; - let mut max_result: Option = None; - - // For Clip Opset 6+ , the min and max values are attributes - for (key, value) in node.attrs.iter() { - match key.as_str() { - "min" => { - let min = value.clone().into_f32() as f64; - min_result = Some(min); - } - "max" => { - let max = value.clone().into_f32(); - max_result = Some(max as f64); - } - _ => {} - } - } - - // For Clip Opset 11+ , the min and max values are inputs - // Get the min and max values from the input values - if min_result.is_none() && max_result.is_none() { - let min = &node.inputs[1].value; - let max = &node.inputs[2].value; - - if min_result.is_none() && min.is_some() { - let min = min.clone().unwrap().data.into_scalar(); - min_result = match min { - Data::Float16(min) => Some(f32::from(min) as f64), - Data::Float32(min) => Some(min as f64), - Data::Float64(min) => Some(min), - _ => panic!("Clip: only float min is supported"), - }; - } - - if max_result.is_none() && max.is_some() { - let max = max.clone().unwrap().data.into_scalar(); - max_result = match max { - Data::Float16(max) => Some(f32::from(max) as f64), - Data::Float32(max) => Some(max as f64), - Data::Float64(max) => Some(max), - _ => panic!("Clip: only float max is supported"), - }; - } - } - - if min_result.is_none() && max_result.is_none() { - panic!("Clip: min and max values must be either attributes or inputs"); - } - - (min_result, max_result) -} - -pub fn reduce_max_config(node: &Node) -> Option { - let mut axes = Vec::new(); - let mut keepdims = 1; - - let tensor = match node.inputs.first().unwrap().clone().ty { - ArgType::Tensor(tensor) => tensor, - _ => panic!("Only tensor input is valid"), - }; - - // Extract the attributes - for (key, value) in node.attrs.iter() { - match key.as_str() { - "axes" => axes = value.clone().into_i64s(), - "keepdims" => keepdims = value.clone().into_i64(), - _ => {} - } - } - - if axes.len() > 1 { - panic!("ReduceMax: reducing on multiple dimensions is not supported") - } - - if axes.is_empty() && keepdims == 1 { - panic!("ReduceMax: axes must be provided with keepdims") - } - - if !axes.is_empty() && keepdims == 0 { - // Not supported in Burn - panic!("ReduceMax: the reduce operation must preserve the reduced dimension") - } - - if axes.is_empty() { - None - } else { - let mut dim = axes[0]; - - if dim < 0 { - // Accepted range is [-r, r-1] where r = rank(data) but Burn only supports positive dim - dim += tensor.rank as i64; - } - Some(dim as usize) - } -} - -pub fn reduce_min_config(node: &Node) -> Option { - let mut axes = Vec::new(); - let mut keepdims = 1; - - let tensor = match node.inputs.first().unwrap().clone().ty { - ArgType::Tensor(tensor) => tensor, - _ => panic!("Only tensor input is valid"), - }; - - // Extract the attributes - for (key, value) in node.attrs.iter() { - match key.as_str() { - "axes" => axes = value.clone().into_i64s(), - "keepdims" => keepdims = value.clone().into_i64(), - _ => {} - } - } - - if axes.len() > 1 { - panic!("ReduceMin: reducing on multiple dimensions is not supported") - } - - if axes.is_empty() && keepdims == 1 { - panic!("ReduceMin: axes must be provided with keepdims") - } - - if !axes.is_empty() && keepdims == 0 { - panic!("ReduceMin: the reduce operation must preserve the reduced dimension") - } - - if axes.is_empty() { - None - } else { - let mut dim = axes[0]; - - if dim < 0 { - dim += tensor.rank as i64; - } - Some(dim as usize) - } -} - -pub fn reduce_mean_config(node: &Node) -> Option { - let mut axes = Vec::new(); - let mut keepdims = 1; - - let tensor = match node.inputs.first().unwrap().clone().ty { - ArgType::Tensor(tensor) => tensor, - _ => panic!("Only tensor input is valid"), - }; - - // Extract the attributes - for (key, value) in node.attrs.iter() { - match key.as_str() { - "axes" => axes = value.clone().into_i64s(), - "keepdims" => keepdims = value.clone().into_i64(), - _ => {} - } - } - - if axes.len() > 1 { - panic!("ReduceMean: reducing on multiple dimensions is not supported") - } - - if axes.is_empty() && keepdims == 1 { - panic!("ReduceMean: axes must be provided with keepdims") - } - - if !axes.is_empty() && keepdims == 0 { - // Not supported in Burn - panic!("ReduceMean: the reduce operation must preserve the reduced dimension") - } - - if axes.is_empty() { - None - } else { - let mut dim = axes[0]; - - if dim < 0 { - // Accepted range is [-r, r-1] where r = rank(data) but Burn only supports positive dim - dim += tensor.rank as i64; - } - Some(dim as usize) - } -} - -pub fn reduce_prod_config(node: &Node) -> Option { - let mut axes = Vec::new(); - let mut keepdims = 1; - - let tensor = match node.inputs.first().unwrap().clone().ty { - ArgType::Tensor(tensor) => tensor, - _ => panic!("Only tensor input is valid"), - }; - - // Extract the attributes - for (key, value) in node.attrs.iter() { - match key.as_str() { - "axes" => axes = value.clone().into_i64s(), - "keepdims" => keepdims = value.clone().into_i64(), - // TODO: handle noop_with_empty_axes (opset 18) - _ => {} - } - } - - if axes.len() > 1 { - panic!("ReduceProd: reducing on multiple dimensions is not supported") - } - - if axes.is_empty() && keepdims == 1 { - panic!("ReduceProd: axes must be provided with keepdims") - } - - if !axes.is_empty() && keepdims == 0 { - // Not supported in Burn - panic!("ReduceProd: the reduce operation must preserve the reduced dimension") - } - - if axes.is_empty() { - None - } else { - let mut dim = axes[0]; - - if dim < 0 { - // Accepted range is [-r, r-1] where r = rank(data) but Burn only supports positive dim - dim += tensor.rank as i64; - } - Some(dim as usize) - } -} - -pub fn reduce_sum_config(node: &Node) -> Option { - let mut axes = Vec::new(); - let mut keepdims = 1; - - let tensor = match node.inputs.first().unwrap().clone().ty { - ArgType::Tensor(tensor) => tensor, - _ => panic!("Only tensor input is valid"), - }; - - // Extract the attributes - for (key, value) in node.attrs.iter() { - match key.as_str() { - "keepdims" => keepdims = value.clone().into_i64(), - "axes" => axes = value.clone().into_i64s(), - // TODO: handle noop_with_empty_axes - _ => {} - } - } - - // TODO: Handle case where axes are passed in. Will require its own ReduceSumNode instead of a UnaryNode. - if let Some(value) = node - .inputs - .get(1) - .and_then(|argument| argument.value.as_ref()) - { - axes = value.clone().data.into_i64s(); - } - - if axes.len() > 1 { - panic!("ReduceMean: reducing on multiple dimensions is not supported") - } - - if axes.is_empty() && keepdims == 1 { - panic!("ReduceMean: axes must be provided with keepdims") - } - - if !axes.is_empty() && keepdims == 0 { - // Not supported in Burn - panic!("ReduceMean: the reduce operation must preserve the reduced dimension") - } - - if axes.is_empty() { - None - } else { - let mut dim = axes[0]; - - if dim < 0 { - // Accepted range is [-r, r-1] where r = rank(data) but Burn only supports positive dim - dim += tensor.rank as i64; - } - Some(dim as usize) - } -} - -pub fn shape_config(curr: &Node) -> (usize, usize) { - if curr.inputs.len() != 1 { - panic!( - "Shape: multiple inputs are not supported (got {:?})", - curr.inputs.len() - ); - } - - // Extract the shape of the input tensor - let tensor = match curr.inputs.first().unwrap().clone().ty { - ArgType::Tensor(tensor) => tensor, - _ => panic!("Only tensor input is valid"), - }; - - // Default: all axes up to the last one (included) - let mut start_dim: i64 = 0; - let mut end_dim: i64 = tensor.rank as i64; - - // Extract the attributes - for (key, value) in curr.attrs.iter() { - match key.as_str() { - "start" => start_dim = value.clone().into_i64(), - "end" => end_dim = value.clone().into_i64(), - _ => {} - } - } - - // If dim is negative, it is counted from the end - if start_dim < 0 { - start_dim += tensor.rank as i64; - } - if end_dim < 0 { - end_dim += tensor.rank as i64; - } - - (start_dim as usize, end_dim as usize) -} - -pub fn transpose_config(curr: &Node) -> Vec { - if curr.inputs.len() != 1 { - panic!( - "Transpose: multiple inputs are not supported (got {:?})", - curr.inputs.len() - ); - } - - // Extract the shape of the input tensor - let tensor = match curr.inputs.first().unwrap().clone().ty { - ArgType::Tensor(tensor) => tensor, - _ => panic!("Only tensor input is valid"), - }; - - // Default: reverse the dimensions - let mut perm = (0..tensor.rank as i64).rev().collect::>(); - - if let Some(axes) = curr.attrs.get("perm") { - perm = axes.clone().into_i64s(); - } - - perm -} - -pub fn squeeze_config(curr: &Node) -> Vec { - let axes = curr - .attrs - .iter() - .filter_map(|(key, value)| { - if key == "axes" { - Some(value.clone().into_i64s()) - } else { - None - } - }) - .next() - .unwrap_or_else(Vec::new); - - match curr.inputs.first().unwrap().clone().ty { - ArgType::Tensor(tensor) => tensor, - _ => panic!("Only tensor input is valid"), - }; - - axes -} pub fn split_config(node: &Node) -> SplitConfig { // Initialize the axis to split along (default is 0 as per ONNX specification) let mut axis: i64 = 0; @@ -1685,52 +391,3 @@ pub fn split_config(node: &Node) -> SplitConfig { split_sizes, } } - -pub fn one_hot_config(curr: &Node) -> (usize, [f32; 2], i64) { - let depth = curr.inputs[1] - .value - .clone() - .expect("OneHot: Only constant depth is currently supported") - .data - .into_i64(); - - let values = curr.inputs[2] - .value - .clone() - .expect("OneHot: Only constant on/off values is currently supported") - .data - .into_f32s(); - - let axis = curr - .attrs - .get("axis") - .map(|val| val.clone().into_i64()) - .unwrap_or(-1); - - (depth as usize, values.try_into().unwrap(), axis) -} - -pub fn gemm_config(curr: &Node) -> (f32, f32, i64, i64) { - let alpha = curr - .attrs - .get("alpha") - .map(|val| val.clone().into_f32()) - .unwrap_or(1.0); - let beta = curr - .attrs - .get("beta") - .map(|val| val.clone().into_f32()) - .unwrap_or(1.0); - let trans_a = curr - .attrs - .get("transA") - .map(|val| val.clone().into_i64()) - .unwrap_or(0); - let trans_b = curr - .attrs - .get("transB") - .map(|val| val.clone().into_i64()) - .unwrap_or(0); - - (alpha, beta, trans_a, trans_b) -} diff --git a/crates/burn-import/src/onnx/to_burn.rs b/crates/burn-import/src/onnx/to_burn.rs index da924042ac..9cc3faa330 100644 --- a/crates/burn-import/src/onnx/to_burn.rs +++ b/crates/burn-import/src/onnx/to_burn.rs @@ -70,14 +70,8 @@ use crate::{ }; use super::op_configuration::{ - argmax_config, avg_pool2d_config, batch_norm_config, clip_config, concat_config, - conv_transpose2d_config, conv_transpose3d_config, conv2d_config, conv3d_config, dropout_config, - expand_config, flatten_config, gather_config, gemm_config, hard_sigmoid_config, - layer_norm_config, leaky_relu_config, linear_config, log_softmax_config, max_pool2d_config, - one_hot_config, pad_config, reduce_max_config, reduce_mean_config, reduce_min_config, - reduce_prod_config, reduce_sum_config, reshape_config, resize_config, shape_config, - softmax_config, split_config, squeeze_config, tile_config, top_k_config, transpose_config, - trilu_config, unsqueeze_config, + expand_config, pad_config, split_config, tile_config, top_k_config, trilu_config, + unsqueeze_config, }; use onnx_ir::{ convert_constant_value, @@ -86,10 +80,22 @@ use onnx_ir::{ TensorType as OnnxTensorType, }, node::{ - avg_pool1d::avg_pool1d_config, conv_transpose1d::conv_transpose1d_config, - conv1d::conv1d_config, max_pool1d::max_pool1d_config, slice::slice_config, + argmax::argmax_config, avg_pool1d::avg_pool1d_config, avg_pool2d::avg_pool2d_config, + batch_norm::batch_norm_config, concat::concat_config, + conv_transpose1d::conv_transpose1d_config, conv_transpose2d::conv_transpose2d_config, + conv_transpose3d::conv_transpose3d_config, conv1d::conv1d_config, conv2d::conv2d_config, + conv3d::conv3d_config, dropout::dropout_config, flatten::flatten_config, + gather::gather_config, layer_norm::layer_norm_config, linear::linear_config, + log_softmax::log_softmax_config, max_pool1d::max_pool1d_config, + max_pool2d::max_pool2d_config, slice::slice_config, softmax::softmax_config, + }, + op_configuration::{ + clip_config, gemm_config, hard_sigmoid_config, leaky_relu_config, one_hot_config, + reduce_max_config, reduce_mean_config, reduce_min_config, reduce_prod_config, + reduce_sum_config, reshape_config, resize_config, squeeze_config, transpose_config, }, parse_onnx, + util::shape_config, }; pub use crate::burn::graph::RecordType; From 4834dfffcc870f1c2cf6d25a1f47c820e1cfce14 Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Thu, 1 May 2025 11:33:37 -0500 Subject: [PATCH 16/37] Remove burn types from onnx-ir --- Cargo.lock | 1 - crates/burn-core/src/nn/padding.rs | 21 +- crates/burn-import/src/burn/codegen.rs | 8 +- .../burn-import/src/burn/node/avg_pool1d.rs | 6 +- .../burn-import/src/burn/node/avg_pool2d.rs | 6 +- crates/burn-import/src/burn/node/base.rs | 6 +- .../burn-import/src/burn/node/batch_norm.rs | 3 +- crates/burn-import/src/burn/node/conv1d.rs | 9 +- crates/burn-import/src/burn/node/conv2d.rs | 6 +- crates/burn-import/src/burn/node/conv3d.rs | 6 +- .../src/burn/node/conv_transpose_2d.rs | 5 +- .../src/burn/node/conv_transpose_3d.rs | 5 +- crates/burn-import/src/burn/node/dropout.rs | 5 +- .../burn-import/src/burn/node/layer_norm.rs | 3 +- crates/burn-import/src/burn/node/linear.rs | 3 +- .../burn-import/src/burn/node/max_pool1d.rs | 9 +- .../burn-import/src/burn/node/max_pool2d.rs | 6 +- crates/burn-import/src/onnx/to_burn.rs | 54 +--- crates/onnx-ir/Cargo.toml | 8 - crates/onnx-ir/src/node/avg_pool1d.rs | 37 ++- crates/onnx-ir/src/node/avg_pool2d.rs | 84 ++++--- crates/onnx-ir/src/node/batch_norm.rs | 35 ++- crates/onnx-ir/src/node/conv1d.rs | 103 ++++---- crates/onnx-ir/src/node/conv2d.rs | 105 +++++--- crates/onnx-ir/src/node/conv3d.rs | 106 ++++---- crates/onnx-ir/src/node/conv_transpose1d.rs | 3 - crates/onnx-ir/src/node/conv_transpose2d.rs | 74 +++++- crates/onnx-ir/src/node/conv_transpose3d.rs | 74 +++++- crates/onnx-ir/src/node/dropout.rs | 15 +- crates/onnx-ir/src/node/layer_norm.rs | 26 +- crates/onnx-ir/src/node/linear.rs | 29 ++- crates/onnx-ir/src/node/max_pool1d.rs | 42 +++- crates/onnx-ir/src/node/max_pool2d.rs | 84 ++++--- crates/onnx-ir/src/node/mod.rs | 1 + crates/onnx-ir/src/node/padding.rs | 231 ++++++++++++++++++ 35 files changed, 872 insertions(+), 347 deletions(-) create mode 100644 crates/onnx-ir/src/node/padding.rs diff --git a/Cargo.lock b/Cargo.lock index 1a890ff4fa..cb3d74d4e0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4474,7 +4474,6 @@ dependencies = [ name = "onnx-ir" version = "0.18.0" dependencies = [ - "burn", "bytemuck", "half", "log", diff --git a/crates/burn-core/src/nn/padding.rs b/crates/burn-core/src/nn/padding.rs index 0e847c8532..5c61322073 100644 --- a/crates/burn-core/src/nn/padding.rs +++ b/crates/burn-core/src/nn/padding.rs @@ -7,12 +7,11 @@ use crate::config::Config; /// Padding configuration for 1D operators. #[derive(Config, Debug, PartialEq)] pub enum PaddingConfig1d { - /// Dynamically calculate the amount of padding necessary to ensure that the output size will be - /// the same as the input. + /// Dynamically calculates padding to ensure output size matches input size. Same, - /// Same as no padding. + /// No padding applied. Valid, - /// Applies the specified amount of padding to all inputs. + /// Applies a specific amount of padding to all inputs. Explicit(usize), } @@ -35,12 +34,11 @@ impl PaddingConfig1d { /// Padding configuration for 2D operators. #[derive(Config, Debug, PartialEq)] pub enum PaddingConfig2d { - /// Dynamically calculate the amount of padding necessary to ensure that the output size will be - /// the same as the input. + /// Dynamically calculates padding to preserve input dimensions in output. Same, - /// Same as no padding. + /// No padding applied. Valid, - /// Applies the specified amount of padding to all inputs. + /// Applies specified padding values to height and width dimensions. Explicit(usize, usize), } @@ -70,12 +68,11 @@ impl PaddingConfig2d { /// Padding configuration for 3D operators. #[derive(Config, Debug, PartialEq)] pub enum PaddingConfig3d { - /// Dynamically calculate the amount of padding necessary to ensure that the output size will be - /// the same as the input. + /// Dynamically calculates padding to preserve input dimensions in output. Same, - /// Same as no padding. + /// No padding applied. Valid, - /// Applies the specified amount of padding to all inputs. + /// Applies specified padding values to depth, height, and width dimensions. Explicit(usize, usize, usize), } diff --git a/crates/burn-import/src/burn/codegen.rs b/crates/burn-import/src/burn/codegen.rs index 7f511dafd4..dff8ff4875 100644 --- a/crates/burn-import/src/burn/codegen.rs +++ b/crates/burn-import/src/burn/codegen.rs @@ -1,10 +1,7 @@ +use onnx_ir::node::padding::{PaddingConfig1d, PaddingConfig2d, PaddingConfig3d}; use proc_macro2::TokenStream; use quote::quote; -use burn::nn::PaddingConfig1d; -use burn::nn::PaddingConfig2d; -use burn::nn::PaddingConfig3d; - fn convert_primitive(primitive: T) -> TokenStream { let value = format!("{:?}", primitive); @@ -76,7 +73,6 @@ impl ToTokens for f32 { impl ToTokens for PaddingConfig1d { fn to_tokens(&self) -> TokenStream { match self { - Self::Same => quote! { PaddingConfig1d::Same }, Self::Valid => quote! { PaddingConfig1d::Valid }, Self::Explicit(padding) => { let padding = padding.to_tokens(); @@ -90,7 +86,6 @@ impl ToTokens for PaddingConfig1d { impl ToTokens for PaddingConfig2d { fn to_tokens(&self) -> TokenStream { match self { - Self::Same => quote! { PaddingConfig2d::Same }, Self::Valid => quote! { PaddingConfig2d::Valid }, Self::Explicit(padding1, padding2) => { let padding1 = padding1.to_tokens(); @@ -105,7 +100,6 @@ impl ToTokens for PaddingConfig2d { impl ToTokens for PaddingConfig3d { fn to_tokens(&self) -> TokenStream { match self { - Self::Same => quote! { PaddingConfig3d::Same }, Self::Valid => quote! { PaddingConfig3d::Valid }, Self::Explicit(padding1, padding2, padding3) => { let padding1 = padding1.to_tokens(); diff --git a/crates/burn-import/src/burn/node/avg_pool1d.rs b/crates/burn-import/src/burn/node/avg_pool1d.rs index 2a8310bb9c..680609290f 100644 --- a/crates/burn-import/src/burn/node/avg_pool1d.rs +++ b/crates/burn-import/src/burn/node/avg_pool1d.rs @@ -1,7 +1,8 @@ +use onnx_ir::node::avg_pool1d::AvgPool1dConfig; use proc_macro2::TokenStream; use quote::quote; -use burn::{nn::pool::AvgPool1dConfig, record::PrecisionSettings}; +use burn::record::PrecisionSettings; use super::{Node, NodeCodegen}; use crate::burn::{BurnImports, OtherType, Scope, TensorType, ToTokens, Type}; @@ -93,7 +94,8 @@ impl NodeCodegen for AvgPool1dNode { mod tests { use super::*; use crate::burn::{TensorType, graph::BurnGraph, node::test::assert_tokens}; - use burn::{nn::PaddingConfig1d, record::FullPrecisionSettings}; + use burn::record::FullPrecisionSettings; + use onnx_ir::node::padding::PaddingConfig1d; #[test] fn test_codegen() { diff --git a/crates/burn-import/src/burn/node/avg_pool2d.rs b/crates/burn-import/src/burn/node/avg_pool2d.rs index 2e84a5170b..f929693e51 100644 --- a/crates/burn-import/src/burn/node/avg_pool2d.rs +++ b/crates/burn-import/src/burn/node/avg_pool2d.rs @@ -1,7 +1,8 @@ +use onnx_ir::node::avg_pool2d::AvgPool2dConfig; use proc_macro2::TokenStream; use quote::quote; -use burn::{nn::pool::AvgPool2dConfig, record::PrecisionSettings}; +use burn::record::PrecisionSettings; use super::{Node, NodeCodegen}; use crate::burn::{BurnImports, OtherType, Scope, TensorType, ToTokens, Type}; @@ -97,7 +98,8 @@ mod tests { graph::BurnGraph, node::{avg_pool2d::AvgPool2dNode, test::assert_tokens}, }; - use burn::{nn::PaddingConfig2d, nn::pool::AvgPool2dConfig, record::FullPrecisionSettings}; + use burn::record::FullPrecisionSettings; + use onnx_ir::node::padding::PaddingConfig2d; #[test] fn test_codegen() { diff --git a/crates/burn-import/src/burn/node/base.rs b/crates/burn-import/src/burn/node/base.rs index d77df6172f..e63ff723fc 100644 --- a/crates/burn-import/src/burn/node/base.rs +++ b/crates/burn-import/src/burn/node/base.rs @@ -308,10 +308,8 @@ pub(crate) mod tests { graph::BurnGraph, node::{NodeCodegen, conv2d::Conv2dNode, matmul::MatmulNode, test::assert_tokens}, }; - use burn::{ - nn::PaddingConfig2d, nn::conv::Conv2dConfig, record::FullPrecisionSettings, - tensor::TensorData, - }; + use burn::{record::FullPrecisionSettings, tensor::TensorData}; + use onnx_ir::node::{conv2d::Conv2dConfig, padding::PaddingConfig2d}; use proc_macro2::TokenStream; use quote::quote; diff --git a/crates/burn-import/src/burn/node/batch_norm.rs b/crates/burn-import/src/burn/node/batch_norm.rs index 41f7194f4f..6733dd91f9 100644 --- a/crates/burn-import/src/burn/node/batch_norm.rs +++ b/crates/burn-import/src/burn/node/batch_norm.rs @@ -2,10 +2,11 @@ use super::{Node, NodeCodegen, SerializationBackend}; use crate::burn::{BurnImports, OtherType, Scope, TensorType, ToTokens, Type}; use burn::{ module::{ConstantRecord, Param, ParamId}, - nn::{BatchNormConfig, BatchNormRecord}, + nn::BatchNormRecord, record::{PrecisionSettings, Record}, tensor::{Tensor, TensorData}, }; +use onnx_ir::node::batch_norm::BatchNormConfig; use proc_macro2::TokenStream; use quote::quote; use serde::Serialize; diff --git a/crates/burn-import/src/burn/node/conv1d.rs b/crates/burn-import/src/burn/node/conv1d.rs index 65a8343697..1eeb61d4f5 100644 --- a/crates/burn-import/src/burn/node/conv1d.rs +++ b/crates/burn-import/src/burn/node/conv1d.rs @@ -2,10 +2,11 @@ use super::{Node, NodeCodegen, SerializationBackend}; use crate::burn::{BurnImports, OtherType, Scope, TensorType, ToTokens, Type}; use burn::{ module::{ConstantRecord, Param, ParamId}, - nn::conv::{Conv1dConfig, Conv1dRecord}, + nn::conv::Conv1dRecord, record::{PrecisionSettings, Record}, tensor::{Tensor, TensorData}, }; +use onnx_ir::node::conv1d::Conv1dConfig; use proc_macro2::TokenStream; use quote::quote; use serde::Serialize; @@ -135,10 +136,8 @@ mod tests { graph::BurnGraph, node::{conv1d::Conv1dNode, test::assert_tokens}, }; - use burn::{ - nn::{PaddingConfig1d, conv::Conv1dConfig}, - record::FullPrecisionSettings, - }; + use burn::record::FullPrecisionSettings; + use onnx_ir::node::padding::PaddingConfig1d; #[test] fn test_codegen() { diff --git a/crates/burn-import/src/burn/node/conv2d.rs b/crates/burn-import/src/burn/node/conv2d.rs index 7de8c97cb6..8051e0bd73 100644 --- a/crates/burn-import/src/burn/node/conv2d.rs +++ b/crates/burn-import/src/burn/node/conv2d.rs @@ -2,10 +2,11 @@ use super::{Node, NodeCodegen, SerializationBackend}; use crate::burn::{BurnImports, OtherType, Scope, TensorType, ToTokens, Type}; use burn::{ module::{ConstantRecord, Param, ParamId}, - nn::conv::{Conv2dConfig, Conv2dRecord}, + nn::conv::Conv2dRecord, record::{PrecisionSettings, Record}, tensor::{Tensor, TensorData}, }; +use onnx_ir::node::conv2d::Conv2dConfig; use proc_macro2::TokenStream; use quote::quote; use serde::Serialize; @@ -134,7 +135,8 @@ mod tests { graph::BurnGraph, node::{conv2d::Conv2dNode, test::assert_tokens}, }; - use burn::{nn::PaddingConfig2d, nn::conv::Conv2dConfig, record::FullPrecisionSettings}; + use burn::record::FullPrecisionSettings; + use onnx_ir::node::padding::PaddingConfig2d; #[test] fn test_codegen() { diff --git a/crates/burn-import/src/burn/node/conv3d.rs b/crates/burn-import/src/burn/node/conv3d.rs index 9cad0f6e95..b0855434a3 100644 --- a/crates/burn-import/src/burn/node/conv3d.rs +++ b/crates/burn-import/src/burn/node/conv3d.rs @@ -2,10 +2,11 @@ use super::{Node, NodeCodegen, SerializationBackend}; use crate::burn::{BurnImports, OtherType, Scope, TensorType, ToTokens, Type}; use burn::{ module::{ConstantRecord, Param, ParamId}, - nn::conv::{Conv3dConfig, Conv3dRecord}, + nn::conv::Conv3dRecord, record::{PrecisionSettings, Record}, tensor::{Tensor, TensorData}, }; +use onnx_ir::node::conv3d::Conv3dConfig; use proc_macro2::TokenStream; use quote::quote; use serde::Serialize; @@ -134,7 +135,8 @@ mod tests { graph::BurnGraph, node::{conv3d::Conv3dNode, test::assert_tokens}, }; - use burn::{nn::PaddingConfig3d, nn::conv::Conv3dConfig, record::FullPrecisionSettings}; + use burn::record::FullPrecisionSettings; + use onnx_ir::node::padding::PaddingConfig3d; #[test] fn test_codegen() { diff --git a/crates/burn-import/src/burn/node/conv_transpose_2d.rs b/crates/burn-import/src/burn/node/conv_transpose_2d.rs index 2fa77ca83c..851e884c11 100644 --- a/crates/burn-import/src/burn/node/conv_transpose_2d.rs +++ b/crates/burn-import/src/burn/node/conv_transpose_2d.rs @@ -2,10 +2,11 @@ use super::{Node, NodeCodegen, SerializationBackend}; use crate::burn::{BurnImports, OtherType, Scope, TensorType, ToTokens, Type}; use burn::{ module::{ConstantRecord, Param, ParamId}, - nn::conv::{ConvTranspose2dConfig, ConvTranspose2dRecord}, + nn::conv::ConvTranspose2dRecord, record::{PrecisionSettings, Record}, tensor::{Tensor, TensorData}, }; +use onnx_ir::node::conv_transpose2d::ConvTranspose2dConfig; use proc_macro2::TokenStream; use quote::quote; use serde::Serialize; @@ -137,7 +138,7 @@ mod tests { graph::BurnGraph, node::{conv_transpose_2d::ConvTranspose2dNode, test::assert_tokens}, }; - use burn::{nn::conv::ConvTranspose2dConfig, record::FullPrecisionSettings}; + use burn::record::FullPrecisionSettings; #[test] fn test_codegen() { diff --git a/crates/burn-import/src/burn/node/conv_transpose_3d.rs b/crates/burn-import/src/burn/node/conv_transpose_3d.rs index 098dbcaffe..f8a8906a07 100644 --- a/crates/burn-import/src/burn/node/conv_transpose_3d.rs +++ b/crates/burn-import/src/burn/node/conv_transpose_3d.rs @@ -2,10 +2,11 @@ use super::{Node, NodeCodegen, SerializationBackend}; use crate::burn::{BurnImports, OtherType, Scope, TensorType, ToTokens, Type}; use burn::{ module::{ConstantRecord, Param, ParamId}, - nn::conv::{ConvTranspose3dConfig, ConvTranspose3dRecord}, + nn::conv::ConvTranspose3dRecord, record::{PrecisionSettings, Record}, tensor::{Tensor, TensorData}, }; +use onnx_ir::node::conv_transpose3d::ConvTranspose3dConfig; use proc_macro2::TokenStream; use quote::quote; use serde::Serialize; @@ -137,7 +138,7 @@ mod tests { graph::BurnGraph, node::{conv_transpose_3d::ConvTranspose3dNode, test::assert_tokens}, }; - use burn::{nn::conv::ConvTranspose3dConfig, record::FullPrecisionSettings}; + use burn::record::FullPrecisionSettings; #[test] fn test_codegen() { diff --git a/crates/burn-import/src/burn/node/dropout.rs b/crates/burn-import/src/burn/node/dropout.rs index f6a8d37eaa..ee9d1ed8dd 100644 --- a/crates/burn-import/src/burn/node/dropout.rs +++ b/crates/burn-import/src/burn/node/dropout.rs @@ -1,7 +1,8 @@ +use onnx_ir::node::dropout::DropoutConfig; use proc_macro2::TokenStream; use quote::quote; -use burn::{nn::DropoutConfig, record::PrecisionSettings}; +use burn::record::PrecisionSettings; use super::{Node, NodeCodegen}; use crate::burn::{BurnImports, OtherType, Scope, TensorType, ToTokens, Type}; @@ -83,7 +84,7 @@ impl NodeCodegen for DropoutNode { mod tests { use super::*; use crate::burn::{TensorType, graph::BurnGraph, node::test::assert_tokens}; - use burn::{nn::DropoutConfig, record::FullPrecisionSettings}; + use burn::record::FullPrecisionSettings; #[test] fn test_codegen() { diff --git a/crates/burn-import/src/burn/node/layer_norm.rs b/crates/burn-import/src/burn/node/layer_norm.rs index b9fd6d5d7c..867ce1252e 100644 --- a/crates/burn-import/src/burn/node/layer_norm.rs +++ b/crates/burn-import/src/burn/node/layer_norm.rs @@ -2,10 +2,11 @@ use super::{Node, NodeCodegen, SerializationBackend}; use crate::burn::{BurnImports, OtherType, Scope, TensorType, ToTokens, Type}; use burn::{ module::{ConstantRecord, Param, ParamId}, - nn::{LayerNormConfig, LayerNormRecord}, + nn::LayerNormRecord, record::{PrecisionSettings, Record}, tensor::{Tensor, TensorData}, }; +use onnx_ir::node::layer_norm::LayerNormConfig; use proc_macro2::TokenStream; use quote::quote; use serde::Serialize; diff --git a/crates/burn-import/src/burn/node/linear.rs b/crates/burn-import/src/burn/node/linear.rs index e8828295f6..8510d66131 100644 --- a/crates/burn-import/src/burn/node/linear.rs +++ b/crates/burn-import/src/burn/node/linear.rs @@ -2,10 +2,11 @@ use super::{Node, NodeCodegen, SerializationBackend}; use crate::burn::{BurnImports, OtherType, Scope, TensorType, ToTokens, Type}; use burn::{ module::{Param, ParamId}, - nn::{LinearConfig, LinearRecord}, + nn::LinearRecord, record::{PrecisionSettings, Record}, tensor::{Tensor, TensorData}, }; +use onnx_ir::node::linear::LinearConfig; use proc_macro2::TokenStream; use quote::quote; use serde::Serialize; diff --git a/crates/burn-import/src/burn/node/max_pool1d.rs b/crates/burn-import/src/burn/node/max_pool1d.rs index a77e068779..c092fef7f9 100644 --- a/crates/burn-import/src/burn/node/max_pool1d.rs +++ b/crates/burn-import/src/burn/node/max_pool1d.rs @@ -1,7 +1,8 @@ +use onnx_ir::node::max_pool1d::MaxPool1dConfig; use proc_macro2::TokenStream; use quote::quote; -use burn::{nn::pool::MaxPool1dConfig, record::PrecisionSettings}; +use burn::record::PrecisionSettings; use super::{Node, NodeCodegen}; use crate::burn::{BurnImports, OtherType, Scope, TensorType, ToTokens, Type}; @@ -92,10 +93,8 @@ impl NodeCodegen for MaxPool1dNode { mod tests { use super::*; use crate::burn::{TensorType, graph::BurnGraph, node::test::assert_tokens}; - use burn::{ - nn::{PaddingConfig1d, pool::MaxPool1dConfig}, - record::FullPrecisionSettings, - }; + use burn::record::FullPrecisionSettings; + use onnx_ir::node::padding::PaddingConfig1d; #[test] fn test_codegen() { diff --git a/crates/burn-import/src/burn/node/max_pool2d.rs b/crates/burn-import/src/burn/node/max_pool2d.rs index 3e3ae6ec09..15224c2af8 100644 --- a/crates/burn-import/src/burn/node/max_pool2d.rs +++ b/crates/burn-import/src/burn/node/max_pool2d.rs @@ -1,7 +1,8 @@ +use onnx_ir::node::max_pool2d::MaxPool2dConfig; use proc_macro2::TokenStream; use quote::quote; -use burn::{nn::pool::MaxPool2dConfig, record::PrecisionSettings}; +use burn::record::PrecisionSettings; use super::{Node, NodeCodegen}; use crate::burn::{BurnImports, OtherType, Scope, TensorType, ToTokens, Type}; @@ -96,7 +97,8 @@ mod tests { graph::BurnGraph, node::{max_pool2d::MaxPool2dNode, test::assert_tokens}, }; - use burn::{nn::PaddingConfig2d, nn::pool::MaxPool2dConfig, record::FullPrecisionSettings}; + use burn::record::FullPrecisionSettings; + use onnx_ir::node::padding::PaddingConfig2d; #[test] fn test_codegen() { diff --git a/crates/burn-import/src/onnx/to_burn.rs b/crates/burn-import/src/onnx/to_burn.rs index 9cc3faa330..14a16d020e 100644 --- a/crates/burn-import/src/onnx/to_burn.rs +++ b/crates/burn-import/src/onnx/to_burn.rs @@ -1033,27 +1033,7 @@ impl ParsedOnnxGraph { let output = TensorType::from(node.outputs.first().unwrap()); // Get configuration from onnx-ir - let onnx_config = conv1d_config(&node); - - // Convert onnx-ir padding to burn padding - let burn_padding = match onnx_config.padding { - onnx_ir::node::conv1d::PaddingConfig1d::Valid => burn::nn::PaddingConfig1d::Valid, - onnx_ir::node::conv1d::PaddingConfig1d::Explicit(size) => { - burn::nn::PaddingConfig1d::Explicit(size) - } - }; - - // Convert to burn Conv1dConfig - let config = burn::nn::conv::Conv1dConfig::new( - onnx_config.channels_in, - onnx_config.channels_out, - onnx_config.kernel_size, - ) - .with_stride(onnx_config.stride) - .with_dilation(onnx_config.dilation) - .with_groups(onnx_config.groups) - .with_bias(onnx_config.bias) - .with_padding(burn_padding); + let config = conv1d_config(&node); let bias = node.inputs.len() == 3; let weight = extract_data_serialize::(1, &node).unwrap(); @@ -1103,21 +1083,7 @@ impl ParsedOnnxGraph { let output = TensorType::from(node.outputs.first().unwrap()); // Get configuration from onnx-ir - let onnx_config = max_pool1d_config(&node); - - // Convert onnx-ir padding to burn padding - let burn_padding = match onnx_config.padding { - onnx_ir::node::conv1d::PaddingConfig1d::Valid => burn::nn::PaddingConfig1d::Valid, - onnx_ir::node::conv1d::PaddingConfig1d::Explicit(size) => { - burn::nn::PaddingConfig1d::Explicit(size) - } - }; - - // Convert to burn MaxPool1dConfig - let config = burn::nn::pool::MaxPool1dConfig::new(onnx_config.kernel_size) - .with_stride(onnx_config.stride) - .with_padding(burn_padding) - .with_dilation(onnx_config.dilation); + let config = max_pool1d_config(&node); let name = &node.name; MaxPool1dNode::new(name, input, output, config) @@ -1223,21 +1189,7 @@ impl ParsedOnnxGraph { let output = TensorType::from(node.outputs.first().unwrap()); // Get configuration from onnx-ir - let onnx_config = avg_pool1d_config(&node); - - // Convert onnx-ir padding to burn padding - let burn_padding = match onnx_config.padding { - onnx_ir::node::conv1d::PaddingConfig1d::Valid => burn::nn::PaddingConfig1d::Valid, - onnx_ir::node::conv1d::PaddingConfig1d::Explicit(size) => { - burn::nn::PaddingConfig1d::Explicit(size) - } - }; - - // Convert to burn AvgPool1dConfig - let config = burn::nn::pool::AvgPool1dConfig::new(onnx_config.kernel_size) - .with_stride(onnx_config.stride) - .with_padding(burn_padding) - .with_count_include_pad(onnx_config.count_include_pad); + let config = avg_pool1d_config(&node); let name = &node.name; AvgPool1dNode::new(name, input, output, config) diff --git a/crates/onnx-ir/Cargo.toml b/crates/onnx-ir/Cargo.toml index 931c5a3652..bc720ebda0 100644 --- a/crates/onnx-ir/Cargo.toml +++ b/crates/onnx-ir/Cargo.toml @@ -14,14 +14,6 @@ version.workspace = true [dependencies] - -# REMOVE burn specific crates -burn = { path = "../burn", version = "0.18.0", default-features = false, features = [ - "std", -] } -# burn-import = { path = "../burn-import", version = "0.18.0" } - - bytemuck = { workspace = true } half = { workspace = true } log = { workspace = true } diff --git a/crates/onnx-ir/src/node/avg_pool1d.rs b/crates/onnx-ir/src/node/avg_pool1d.rs index 8b45b473f8..4b2b7e9045 100644 --- a/crates/onnx-ir/src/node/avg_pool1d.rs +++ b/crates/onnx-ir/src/node/avg_pool1d.rs @@ -1,7 +1,6 @@ -use crate::ir::Node; +use crate::{ir::Node, node::padding::padding_config_1d}; -// Reuse PaddingConfig1d from conv1d module -pub use super::conv1d::PaddingConfig1d; +use super::padding::PaddingConfig1d; /// Configuration for AvgPool1d operations extracted from ONNX nodes #[derive(Debug, Clone)] @@ -16,6 +15,36 @@ pub struct AvgPool1dConfig { pub count_include_pad: bool, } +impl AvgPool1dConfig { + /// Create a new AvgPool1dConfig + pub fn new(kernel_size: usize) -> Self { + Self { + kernel_size, + stride: 1, + padding: PaddingConfig1d::Valid, + count_include_pad: true, + } + } + + /// Set the stride + pub fn with_stride(mut self, stride: usize) -> Self { + self.stride = stride; + self + } + + /// Set the padding configuration + pub fn with_padding(mut self, padding: PaddingConfig1d) -> Self { + self.padding = padding; + self + } + + /// Set whether to include padding in the average calculation + pub fn with_count_include_pad(mut self, count_include_pad: bool) -> Self { + self.count_include_pad = count_include_pad; + self + } +} + /// Create an AvgPool1dConfig from the attributes of the node pub fn avg_pool1d_config(curr: &Node) -> AvgPool1dConfig { let mut kernel_shape = Vec::new(); @@ -46,7 +75,7 @@ pub fn avg_pool1d_config(curr: &Node) -> AvgPool1dConfig { panic!("ceil_mode is not supported"); } - let padding = super::conv1d::padding_config_1d(&pads); + let padding = padding_config_1d(&pads); AvgPool1dConfig { kernel_size: kernel_shape[0] as usize, diff --git a/crates/onnx-ir/src/node/avg_pool2d.rs b/crates/onnx-ir/src/node/avg_pool2d.rs index 7603c2489a..c7ec45c1e2 100644 --- a/crates/onnx-ir/src/node/avg_pool2d.rs +++ b/crates/onnx-ir/src/node/avg_pool2d.rs @@ -1,6 +1,48 @@ use crate::ir::Node; -use burn::nn::PaddingConfig2d; -use burn::nn::pool::AvgPool2dConfig; +use crate::node::padding::{PaddingConfig2d, padding_config_2d}; + +/// Configuration for AvgPool2d operations +#[derive(Debug, Clone)] +pub struct AvgPool2dConfig { + /// Kernel size [height, width] + pub kernel_size: [usize; 2], + /// Stride [height, width] + pub strides: [usize; 2], + /// Padding configuration + pub padding: PaddingConfig2d, + /// Whether to include padding in the average calculation + pub count_include_pad: bool, +} + +impl AvgPool2dConfig { + /// Create a new AvgPool2dConfig + pub fn new(kernel_size: [usize; 2]) -> Self { + Self { + kernel_size, + strides: [1, 1], + padding: PaddingConfig2d::Valid, + count_include_pad: true, + } + } + + /// Set the strides + pub fn with_strides(mut self, strides: [usize; 2]) -> Self { + self.strides = strides; + self + } + + /// Set the padding configuration + pub fn with_padding(mut self, padding: PaddingConfig2d) -> Self { + self.padding = padding; + self + } + + /// Set whether to include padding in the average calculation + pub fn with_count_include_pad(mut self, count_include_pad: bool) -> Self { + self.count_include_pad = count_include_pad; + self + } +} /// Create a AvgPool2dConfig from the attributes of the node pub fn avg_pool2d_config(curr: &Node) -> AvgPool2dConfig { @@ -33,44 +75,6 @@ pub fn avg_pool2d_config(curr: &Node) -> AvgPool2dConfig { .with_count_include_pad(count_include_pad == 1) } -/// Calculate the padding configuration for a 2D operations such as Convolution and Pooling. -/// -/// # Arguments -/// -/// * `pads` - The padding values -/// -/// # Panics -/// -/// * If the padding is negative -/// * If the padding is not symmetric -/// -/// # Returns -/// -/// * The padding configuration -/// -/// # Remarks -/// -/// This function is used when the padding is specified as a list of integers, -/// and not used when the padding is specified as a string, e.g. "SAME_UPPER". -fn padding_config_2d(pads: &[i64]) -> PaddingConfig2d { - let [left, top, right, bottom] = [pads[0], pads[1], pads[2], pads[3]]; - - if left < 0 || top < 0 || right < 0 || bottom < 0 { - panic!("Negative pad values are not supported"); - } else if (left != right) || (top != bottom) { - panic!("Asymmetric padding is not supported"); - } else if left == 0 && top == 0 && right == 0 && bottom == 0 { - // i.e [0, 0, 0, 0] - PaddingConfig2d::Valid - } else if left == right && top == bottom { - // i.e [2, 3, 2, 3] - PaddingConfig2d::Explicit(left as usize, top as usize) - } else { - // Unaccounted for padding configuration - panic!("Padding configuration ({:?}) not supported", pads); - } -} - #[cfg(test)] mod tests { use super::*; diff --git a/crates/onnx-ir/src/node/batch_norm.rs b/crates/onnx-ir/src/node/batch_norm.rs index 0e23cf021d..a6ef3b2a94 100644 --- a/crates/onnx-ir/src/node/batch_norm.rs +++ b/crates/onnx-ir/src/node/batch_norm.rs @@ -1,5 +1,38 @@ use crate::ir::Node; -use burn::nn::BatchNormConfig; + +/// Configuration for BatchNorm operations +#[derive(Debug, Clone)] +pub struct BatchNormConfig { + /// Number of features (channels) + pub num_features: usize, + /// Small constant added for numerical stability + pub epsilon: f64, + /// Momentum for running statistics + pub momentum: f64, +} + +impl BatchNormConfig { + /// Create a new BatchNormConfig + pub fn new(num_features: usize) -> Self { + Self { + num_features, + epsilon: 1e-5, + momentum: 0.1, + } + } + + /// Set the epsilon value + pub fn with_epsilon(mut self, epsilon: f64) -> Self { + self.epsilon = epsilon; + self + } + + /// Set the momentum value + pub fn with_momentum(mut self, momentum: f64) -> Self { + self.momentum = momentum; + self + } +} /// Create a BatchNormConfig from the attributes of the node pub fn batch_norm_config(node: &Node) -> BatchNormConfig { diff --git a/crates/onnx-ir/src/node/conv1d.rs b/crates/onnx-ir/src/node/conv1d.rs index 3ad98e212c..017352a5be 100644 --- a/crates/onnx-ir/src/node/conv1d.rs +++ b/crates/onnx-ir/src/node/conv1d.rs @@ -1,23 +1,6 @@ use crate::ir::Node; -use std::fmt; -/// Padding configuration for 1D operations such as convolution -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum PaddingConfig1d { - /// No padding (valid padding) - Valid, - /// Explicit padding with a specific size - Explicit(usize), -} - -impl fmt::Display for PaddingConfig1d { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - PaddingConfig1d::Valid => write!(f, "Valid"), - PaddingConfig1d::Explicit(size) => write!(f, "Explicit({})", size), - } - } -} +use super::padding::{PaddingConfig1d, padding_config_1d}; /// Configuration for Conv1d operations extracted from ONNX nodes #[derive(Debug, Clone)] @@ -40,6 +23,52 @@ pub struct Conv1dConfig { pub padding: PaddingConfig1d, } +impl Conv1dConfig { + /// Create a new Conv1dConfig + pub fn new(channels_in: usize, channels_out: usize, kernel_size: usize) -> Self { + Self { + channels_in, + channels_out, + kernel_size, + stride: 1, + padding: PaddingConfig1d::Valid, + dilation: 1, + groups: 1, + bias: true, + } + } + + /// Set the stride + pub fn with_stride(mut self, stride: usize) -> Self { + self.stride = stride; + self + } + + /// Set the padding configuration + pub fn with_padding(mut self, padding: PaddingConfig1d) -> Self { + self.padding = padding; + self + } + + /// Set the dilation + pub fn with_dilation(mut self, dilation: usize) -> Self { + self.dilation = dilation; + self + } + + /// Set the number of groups + pub fn with_groups(mut self, groups: usize) -> Self { + self.groups = groups; + self + } + + /// Set whether bias is used + pub fn with_bias(mut self, bias: bool) -> Self { + self.bias = bias; + self + } +} + /// Create a Conv1dConfig from the attributes of the node pub fn conv1d_config(curr: &Node) -> Conv1dConfig { let mut kernel_shape = Vec::new(); // TODO default inferred from weight tensor per spec @@ -87,44 +116,6 @@ pub fn conv1d_config(curr: &Node) -> Conv1dConfig { } } -/// Calculate the padding configuration for a 1D operations such as Convolution and Pooling. -/// -/// # Arguments -/// -/// * `pads` - The padding values -/// -/// # Panics -/// -/// * If the padding is negative -/// * If the padding is not symmetric -/// -/// # Returns -/// -/// * The padding configuration -/// -/// # Remarks -/// -/// This function is used when the padding is specified as a list of integers, -/// and not used when the padding is specified as a string, e.g. "SAME_UPPER". -pub fn padding_config_1d(pads: &[i64]) -> PaddingConfig1d { - let [left, right] = [pads[0], pads[1]]; - - if left < 0 || right < 0 { - panic!("Negative pad values are not supported"); - } else if left != right { - panic!("Asymmetric padding is not supported"); - } else if left == 0 && right == 0 { - // i.e. [0, 0] - PaddingConfig1d::Valid - } else if left == right { - // i.e. [2, 2] - PaddingConfig1d::Explicit(left as usize) - } else { - // Unaccounted for padding configuration - panic!("Padding configuration ({:?}) not supported", pads); - } -} - #[cfg(test)] mod tests { use super::*; diff --git a/crates/onnx-ir/src/node/conv2d.rs b/crates/onnx-ir/src/node/conv2d.rs index 283d885157..072f86a58f 100644 --- a/crates/onnx-ir/src/node/conv2d.rs +++ b/crates/onnx-ir/src/node/conv2d.rs @@ -1,6 +1,69 @@ use crate::ir::Node; -use burn::nn::PaddingConfig2d; -use burn::nn::conv::Conv2dConfig; +use crate::node::padding::{PaddingConfig2d, padding_config_2d}; + +/// Configuration for Conv2d operations +#[derive(Debug, Clone)] +pub struct Conv2dConfig { + /// Channels [in, out] + pub channels: [usize; 2], + /// Kernel size [height, width] + pub kernel_size: [usize; 2], + /// Stride [height, width] + pub stride: [usize; 2], + /// Padding configuration + pub padding: PaddingConfig2d, + /// Dilation [height, width] + pub dilation: [usize; 2], + /// Number of groups + pub groups: usize, + /// Whether bias is used + pub bias: bool, +} + +impl Conv2dConfig { + /// Create a new Conv2dConfig + pub fn new(channels: [usize; 2], kernel_size: [usize; 2]) -> Self { + Self { + channels, + kernel_size, + stride: [1, 1], + padding: PaddingConfig2d::Valid, + dilation: [1, 1], + groups: 1, + bias: true, + } + } + + /// Set the stride + pub fn with_stride(mut self, stride: [usize; 2]) -> Self { + self.stride = stride; + self + } + + /// Set the padding configuration + pub fn with_padding(mut self, padding: PaddingConfig2d) -> Self { + self.padding = padding; + self + } + + /// Set the dilation + pub fn with_dilation(mut self, dilation: [usize; 2]) -> Self { + self.dilation = dilation; + self + } + + /// Set the number of groups + pub fn with_groups(mut self, groups: usize) -> Self { + self.groups = groups; + self + } + + /// Set whether bias is used + pub fn with_bias(mut self, bias: bool) -> Self { + self.bias = bias; + self + } +} /// Create a Conv2dConfig from the attributes of the node pub fn conv2d_config(curr: &Node) -> Conv2dConfig { @@ -48,44 +111,6 @@ pub fn conv2d_config(curr: &Node) -> Conv2dConfig { .with_padding(padding) } -/// Calculate the padding configuration for a 2D operations such as Convolution and Pooling. -/// -/// # Arguments -/// -/// * `pads` - The padding values -/// -/// # Panics -/// -/// * If the padding is negative -/// * If the padding is not symmetric -/// -/// # Returns -/// -/// * The padding configuration -/// -/// # Remarks -/// -/// This function is used when the padding is specified as a list of integers, -/// and not used when the padding is specified as a string, e.g. "SAME_UPPER". -fn padding_config_2d(pads: &[i64]) -> PaddingConfig2d { - let [left, top, right, bottom] = [pads[0], pads[1], pads[2], pads[3]]; - - if left < 0 || top < 0 || right < 0 || bottom < 0 { - panic!("Negative pad values are not supported"); - } else if (left != right) || (top != bottom) { - panic!("Asymmetric padding is not supported"); - } else if left == 0 && top == 0 && right == 0 && bottom == 0 { - // i.e [0, 0, 0, 0] - PaddingConfig2d::Valid - } else if left == right && top == bottom { - // i.e [2, 3, 2, 3] - PaddingConfig2d::Explicit(left as usize, top as usize) - } else { - // Unaccounted for padding configuration - panic!("Padding configuration ({:?}) not supported", pads); - } -} - #[cfg(test)] mod tests { use super::*; diff --git a/crates/onnx-ir/src/node/conv3d.rs b/crates/onnx-ir/src/node/conv3d.rs index c3ed5b5ce3..e55093f9d1 100644 --- a/crates/onnx-ir/src/node/conv3d.rs +++ b/crates/onnx-ir/src/node/conv3d.rs @@ -1,6 +1,69 @@ use crate::ir::Node; -use burn::nn::PaddingConfig3d; -use burn::nn::conv::Conv3dConfig; +use crate::node::padding::{PaddingConfig3d, padding_config_3d}; + +/// Configuration for Conv3d operations. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Conv3dConfig { + /// Input and output channels [in, out]. + pub channels: [usize; 2], + /// Size of the kernel. + pub kernel_size: [usize; 3], + /// Stride of the convolutional kernel. + pub stride: [usize; 3], + /// Dilation of the convolutional kernel. + pub dilation: [usize; 3], + /// Groups. + pub groups: usize, + /// Use bias. + pub bias: bool, + /// Padding. + pub padding: PaddingConfig3d, +} + +impl Conv3dConfig { + /// Create a new configuration for a Conv3d. + pub fn new(channels: [usize; 2], kernel_size: [usize; 3]) -> Self { + Self { + channels, + kernel_size, + stride: [1, 1, 1], + dilation: [1, 1, 1], + groups: 1, + bias: true, + padding: PaddingConfig3d::Valid, + } + } + + /// Set the stride. + pub fn with_stride(mut self, stride: [usize; 3]) -> Self { + self.stride = stride; + self + } + + /// Set the dilation. + pub fn with_dilation(mut self, dilation: [usize; 3]) -> Self { + self.dilation = dilation; + self + } + + /// Set the groups. + pub fn with_groups(mut self, groups: usize) -> Self { + self.groups = groups; + self + } + + /// Set whether to use bias. + pub fn with_bias(mut self, bias: bool) -> Self { + self.bias = bias; + self + } + + /// Set the padding. + pub fn with_padding(mut self, padding: PaddingConfig3d) -> Self { + self.padding = padding; + self + } +} /// Create a Conv3dConfig from the attributes of the node pub fn conv3d_config(curr: &Node) -> Conv3dConfig { @@ -60,45 +123,6 @@ pub fn conv3d_config(curr: &Node) -> Conv3dConfig { .with_padding(padding) } -/// Calculate the padding configuration for a 3D operations such as Convolution and Pooling. -/// -/// # Arguments -/// -/// * `pads` - The padding values -/// -/// # Panics -/// -/// * If the padding is negative -/// * If the padding is not symmetric -/// -/// # Returns -/// -/// * The padding configuration -/// -/// # Remarks -/// -/// This function is used when the padding is specified as a list of integers, -/// and not used when the padding is specified as a string, e.g. "SAME_UPPER". -fn padding_config_3d(pads: &[i64]) -> PaddingConfig3d { - let [left, top, front, right, bottom, back] = - [pads[0], pads[1], pads[2], pads[3], pads[4], pads[5]]; - - if left < 0 || top < 0 || front < 0 || right < 0 || bottom < 0 || back < 0 { - panic!("Negative pad values are not supported"); - } else if (left != right) || (top != bottom) || (front != back) { - panic!("Asymmetric padding is not supported"); - } else if left == 0 && top == 0 && front == 0 && right == 0 && bottom == 0 && back == 0 { - // i.e [0, 0, 0, 0] - PaddingConfig3d::Valid - } else if left == right && top == bottom && front == back { - // i.e [2, 3, 2, 3] - PaddingConfig3d::Explicit(left as usize, top as usize, front as usize) - } else { - // Unaccounted for padding configuration - panic!("Padding configuration ({:?}) not supported", pads); - } -} - #[cfg(test)] mod tests { use super::*; diff --git a/crates/onnx-ir/src/node/conv_transpose1d.rs b/crates/onnx-ir/src/node/conv_transpose1d.rs index e6d046b4e3..7c402880f6 100644 --- a/crates/onnx-ir/src/node/conv_transpose1d.rs +++ b/crates/onnx-ir/src/node/conv_transpose1d.rs @@ -1,8 +1,5 @@ use crate::ir::{AttributeValue, Node}; -// Reuse PaddingConfig1d from conv1d module -pub use super::conv1d::PaddingConfig1d; - /// Configuration for ConvTranspose1d operations extracted from ONNX nodes #[derive(Debug, Clone)] pub struct ConvTranspose1dConfig { diff --git a/crates/onnx-ir/src/node/conv_transpose2d.rs b/crates/onnx-ir/src/node/conv_transpose2d.rs index a0dd9e975a..b373525218 100644 --- a/crates/onnx-ir/src/node/conv_transpose2d.rs +++ b/crates/onnx-ir/src/node/conv_transpose2d.rs @@ -1,5 +1,77 @@ use crate::ir::{AttributeValue, Node}; -use burn::nn::conv::ConvTranspose2dConfig; + +/// Configuration for ConvTranspose2d operations. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ConvTranspose2dConfig { + /// Input and output channels [in, out]. + pub channels: [usize; 2], + /// Size of the kernel. + pub kernel_size: [usize; 2], + /// Stride of the convolutional kernel. + pub stride: [usize; 2], + /// Dilation of the convolutional kernel. + pub dilation: [usize; 2], + /// Padding. + pub padding: [usize; 2], + /// Output padding. + pub padding_out: [usize; 2], + /// Groups. + pub groups: usize, + /// Use bias. + pub bias: bool, +} + +impl ConvTranspose2dConfig { + /// Create a new configuration for a ConvTranspose2d. + pub fn new(channels: [usize; 2], kernel_size: [usize; 2]) -> Self { + Self { + channels, + kernel_size, + stride: [1, 1], + dilation: [1, 1], + padding: [0, 0], + padding_out: [0, 0], + groups: 1, + bias: true, + } + } + + /// Set the stride. + pub fn with_stride(mut self, stride: [usize; 2]) -> Self { + self.stride = stride; + self + } + + /// Set the dilation. + pub fn with_dilation(mut self, dilation: [usize; 2]) -> Self { + self.dilation = dilation; + self + } + + /// Set the padding. + pub fn with_padding(mut self, padding: [usize; 2]) -> Self { + self.padding = padding; + self + } + + /// Set the output padding. + pub fn with_padding_out(mut self, padding_out: [usize; 2]) -> Self { + self.padding_out = padding_out; + self + } + + /// Set the groups. + pub fn with_groups(mut self, groups: usize) -> Self { + self.groups = groups; + self + } + + /// Set whether to use bias. + pub fn with_bias(mut self, bias: bool) -> Self { + self.bias = bias; + self + } +} /// Create a ConvTranspose2dConfig from the attributes of the node pub fn conv_transpose2d_config(curr: &Node) -> ConvTranspose2dConfig { diff --git a/crates/onnx-ir/src/node/conv_transpose3d.rs b/crates/onnx-ir/src/node/conv_transpose3d.rs index d4326f6ddb..b003c5a045 100644 --- a/crates/onnx-ir/src/node/conv_transpose3d.rs +++ b/crates/onnx-ir/src/node/conv_transpose3d.rs @@ -1,5 +1,77 @@ use crate::ir::{AttributeValue, Node}; -use burn::nn::conv::ConvTranspose3dConfig; + +/// Configuration for ConvTranspose3d operations. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ConvTranspose3dConfig { + /// Input and output channels [in, out]. + pub channels: [usize; 2], + /// Size of the kernel. + pub kernel_size: [usize; 3], + /// Stride of the convolutional kernel. + pub stride: [usize; 3], + /// Dilation of the convolutional kernel. + pub dilation: [usize; 3], + /// Padding. + pub padding: [usize; 3], + /// Output padding. + pub padding_out: [usize; 3], + /// Groups. + pub groups: usize, + /// Use bias. + pub bias: bool, +} + +impl ConvTranspose3dConfig { + /// Create a new configuration for a ConvTranspose3d. + pub fn new(channels: [usize; 2], kernel_size: [usize; 3]) -> Self { + Self { + channels, + kernel_size, + stride: [1, 1, 1], + dilation: [1, 1, 1], + padding: [0, 0, 0], + padding_out: [0, 0, 0], + groups: 1, + bias: true, + } + } + + /// Set the stride. + pub fn with_stride(mut self, stride: [usize; 3]) -> Self { + self.stride = stride; + self + } + + /// Set the dilation. + pub fn with_dilation(mut self, dilation: [usize; 3]) -> Self { + self.dilation = dilation; + self + } + + /// Set the padding. + pub fn with_padding(mut self, padding: [usize; 3]) -> Self { + self.padding = padding; + self + } + + /// Set the output padding. + pub fn with_padding_out(mut self, padding_out: [usize; 3]) -> Self { + self.padding_out = padding_out; + self + } + + /// Set the groups. + pub fn with_groups(mut self, groups: usize) -> Self { + self.groups = groups; + self + } + + /// Set whether to use bias. + pub fn with_bias(mut self, bias: bool) -> Self { + self.bias = bias; + self + } +} /// Create a ConvTranspose3dConfig from the attributes of the node pub fn conv_transpose3d_config(curr: &Node) -> ConvTranspose3dConfig { diff --git a/crates/onnx-ir/src/node/dropout.rs b/crates/onnx-ir/src/node/dropout.rs index 40fc1f03ef..ef5820c0c1 100644 --- a/crates/onnx-ir/src/node/dropout.rs +++ b/crates/onnx-ir/src/node/dropout.rs @@ -1,5 +1,18 @@ use crate::ir::{Data, Node}; -use burn::nn::DropoutConfig; + +/// Configuration for Dropout operations +#[derive(Debug, Clone)] +pub struct DropoutConfig { + /// Probability of dropping out a unit + pub prob: f64, +} + +impl DropoutConfig { + /// Create a new DropoutConfig + pub fn new(prob: f64) -> Self { + Self { prob } + } +} /// Create a DropoutConfig from an attribute and state of the node pub fn dropout_config(node: &Node) -> DropoutConfig { diff --git a/crates/onnx-ir/src/node/layer_norm.rs b/crates/onnx-ir/src/node/layer_norm.rs index 90d8538985..9911f9b6c1 100644 --- a/crates/onnx-ir/src/node/layer_norm.rs +++ b/crates/onnx-ir/src/node/layer_norm.rs @@ -1,5 +1,29 @@ use crate::ir::Node; -use burn::nn::LayerNormConfig; + +/// Configuration for LayerNorm operations +#[derive(Debug, Clone)] +pub struct LayerNormConfig { + /// Number of features/model dimension + pub d_model: usize, + /// Small constant added for numerical stability + pub epsilon: f64, +} + +impl LayerNormConfig { + /// Create a new LayerNormConfig + pub fn new(d_model: usize) -> Self { + Self { + d_model, + epsilon: 1e-5, + } + } + + /// Set the epsilon value + pub fn with_epsilon(mut self, epsilon: f64) -> Self { + self.epsilon = epsilon; + self + } +} /// Create a LayerNormConfig from the attributes of the node pub fn layer_norm_config(node: &Node) -> (LayerNormConfig, bool) { diff --git a/crates/onnx-ir/src/node/linear.rs b/crates/onnx-ir/src/node/linear.rs index 1aceeb3f91..ecff7b9a25 100644 --- a/crates/onnx-ir/src/node/linear.rs +++ b/crates/onnx-ir/src/node/linear.rs @@ -1,5 +1,32 @@ use crate::ir::{ArgType, Node, TensorType}; -use burn::nn::LinearConfig; + +/// Configuration for Linear operations +#[derive(Debug, Clone)] +pub struct LinearConfig { + /// Input dimension (features) + pub d_input: usize, + /// Output dimension (features) + pub d_output: usize, + /// Whether bias is used + pub bias: bool, +} + +impl LinearConfig { + /// Create a new LinearConfig + pub fn new(d_input: usize, d_output: usize) -> Self { + Self { + d_input, + d_output, + bias: true, + } + } + + /// Set whether bias is used + pub fn with_bias(mut self, bias: bool) -> Self { + self.bias = bias; + self + } +} /// Update output rank for Linear operations (same as input rank). pub fn linear_update_outputs(node: &mut Node) { diff --git a/crates/onnx-ir/src/node/max_pool1d.rs b/crates/onnx-ir/src/node/max_pool1d.rs index b4e8ede8a4..3523cf91f9 100644 --- a/crates/onnx-ir/src/node/max_pool1d.rs +++ b/crates/onnx-ir/src/node/max_pool1d.rs @@ -1,7 +1,6 @@ -use crate::ir::Node; +use crate::{ir::Node, node::padding::padding_config_1d}; -// Reuse PaddingConfig1d from conv1d module -pub use super::conv1d::PaddingConfig1d; +use super::padding::PaddingConfig1d; /// Configuration for MaxPool1d operations extracted from ONNX nodes #[derive(Debug, Clone)] @@ -16,6 +15,36 @@ pub struct MaxPool1dConfig { pub padding: PaddingConfig1d, } +impl MaxPool1dConfig { + /// Create a new MaxPool1dConfig + pub fn new(kernel_size: usize) -> Self { + Self { + kernel_size, + stride: 1, + padding: PaddingConfig1d::Valid, + dilation: 1, + } + } + + /// Set the stride + pub fn with_stride(mut self, stride: usize) -> Self { + self.stride = stride; + self + } + + /// Set the padding configuration + pub fn with_padding(mut self, padding: PaddingConfig1d) -> Self { + self.padding = padding; + self + } + + /// Set the dilation + pub fn with_dilation(mut self, dilation: usize) -> Self { + self.dilation = dilation; + self + } +} + /// Create a MaxPool1dConfig from the attributes of the node pub fn max_pool1d_config(curr: &Node) -> MaxPool1dConfig { let mut kernel_shape = Vec::new(); @@ -41,7 +70,7 @@ pub fn max_pool1d_config(curr: &Node) -> MaxPool1dConfig { assert_eq!(dilation.len(), 1, "MaxPool1d: dilation must have length 1"); assert_eq!(stride.len(), 1, "MaxPool1d: stride must have length 1"); - let padding = super::conv1d::padding_config_1d(&pads); + let padding = padding_config_1d(&pads); MaxPool1dConfig { kernel_size: kernel_shape[0] as usize, @@ -54,7 +83,10 @@ pub fn max_pool1d_config(curr: &Node) -> MaxPool1dConfig { #[cfg(test)] mod tests { use super::*; - use crate::ir::{ArgType, Argument, AttributeValue, ElementType, NodeType, TensorType}; + use crate::{ + ir::{ArgType, Argument, AttributeValue, ElementType, NodeType, TensorType}, + node::padding::PaddingConfig1d, + }; use std::collections::HashMap; fn create_test_node( diff --git a/crates/onnx-ir/src/node/max_pool2d.rs b/crates/onnx-ir/src/node/max_pool2d.rs index a5f4e3a9e1..62c3c71971 100644 --- a/crates/onnx-ir/src/node/max_pool2d.rs +++ b/crates/onnx-ir/src/node/max_pool2d.rs @@ -1,6 +1,48 @@ use crate::ir::Node; -use burn::nn::PaddingConfig2d; -use burn::nn::pool::MaxPool2dConfig; +use crate::node::padding::{PaddingConfig2d, padding_config_2d}; + +/// Configuration for MaxPool2d operations +#[derive(Debug, Clone)] +pub struct MaxPool2dConfig { + /// Kernel size [height, width] + pub kernel_size: [usize; 2], + /// Stride [height, width] + pub strides: [usize; 2], + /// Padding configuration + pub padding: PaddingConfig2d, + /// Dilation [height, width] + pub dilation: [usize; 2], +} + +impl MaxPool2dConfig { + /// Create a new MaxPool2dConfig + pub fn new(kernel_size: [usize; 2]) -> Self { + Self { + kernel_size, + strides: [1, 1], + padding: PaddingConfig2d::Valid, + dilation: [1, 1], + } + } + + /// Set the strides + pub fn with_strides(mut self, strides: [usize; 2]) -> Self { + self.strides = strides; + self + } + + /// Set the padding configuration + pub fn with_padding(mut self, padding: PaddingConfig2d) -> Self { + self.padding = padding; + self + } + + /// Set the dilation + pub fn with_dilation(mut self, dilation: [usize; 2]) -> Self { + self.dilation = dilation; + self + } +} /// Create a MaxPool2dConfig from the attributes of the node pub fn max_pool2d_config(curr: &Node) -> MaxPool2dConfig { @@ -27,44 +69,6 @@ pub fn max_pool2d_config(curr: &Node) -> MaxPool2dConfig { .with_dilation([dilations[0] as usize, dilations[1] as usize]) } -/// Calculate the padding configuration for a 2D operations such as Convolution and Pooling. -/// -/// # Arguments -/// -/// * `pads` - The padding values -/// -/// # Panics -/// -/// * If the padding is negative -/// * If the padding is not symmetric -/// -/// # Returns -/// -/// * The padding configuration -/// -/// # Remarks -/// -/// This function is used when the padding is specified as a list of integers, -/// and not used when the padding is specified as a string, e.g. "SAME_UPPER". -fn padding_config_2d(pads: &[i64]) -> PaddingConfig2d { - let [left, top, right, bottom] = [pads[0], pads[1], pads[2], pads[3]]; - - if left < 0 || top < 0 || right < 0 || bottom < 0 { - panic!("Negative pad values are not supported"); - } else if (left != right) || (top != bottom) { - panic!("Asymmetric padding is not supported"); - } else if left == 0 && top == 0 && right == 0 && bottom == 0 { - // i.e [0, 0, 0, 0] - PaddingConfig2d::Valid - } else if left == right && top == bottom { - // i.e [2, 3, 2, 3] - PaddingConfig2d::Explicit(left as usize, top as usize) - } else { - // Unaccounted for padding configuration - panic!("Padding configuration ({:?}) not supported", pads); - } -} - #[cfg(test)] mod tests { use super::*; diff --git a/crates/onnx-ir/src/node/mod.rs b/crates/onnx-ir/src/node/mod.rs index 1e5a28649a..4e560c7453 100644 --- a/crates/onnx-ir/src/node/mod.rs +++ b/crates/onnx-ir/src/node/mod.rs @@ -37,6 +37,7 @@ pub mod matmul; pub mod max_pool1d; pub mod max_pool2d; pub mod one_hot; +pub mod padding; pub mod random; pub mod random_like; pub mod range; diff --git a/crates/onnx-ir/src/node/padding.rs b/crates/onnx-ir/src/node/padding.rs new file mode 100644 index 0000000000..bfe3f68071 --- /dev/null +++ b/crates/onnx-ir/src/node/padding.rs @@ -0,0 +1,231 @@ +use std::fmt; + +/// Padding configuration for 1D operations such as convolution +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum PaddingConfig1d { + /// No padding (valid padding) + Valid, + /// Explicit padding with a specific size + Explicit(usize), +} + +impl fmt::Display for PaddingConfig1d { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + PaddingConfig1d::Valid => write!(f, "Valid"), + PaddingConfig1d::Explicit(size) => write!(f, "Explicit({})", size), + } + } +} + +/// Calculate the padding configuration for a 1D operations such as Convolution and Pooling. +/// +/// # Arguments +/// +/// * `pads` - The padding values +/// +/// # Panics +/// +/// * If the padding is negative +/// * If the padding is not symmetric +/// +/// # Returns +/// +/// * The padding configuration +/// +/// # Remarks +/// +/// This function is used when the padding is specified as a list of integers, +/// and not used when the padding is specified as a string, e.g. "SAME_UPPER". +pub fn padding_config_1d(pads: &[i64]) -> PaddingConfig1d { + let [left, right] = [pads[0], pads[1]]; + + if left < 0 || right < 0 { + panic!("Negative pad values are not supported"); + } else if left != right { + panic!("Asymmetric padding is not supported"); + } else if left == 0 && right == 0 { + // i.e. [0, 0] + PaddingConfig1d::Valid + } else if left == right { + // i.e. [2, 2] + PaddingConfig1d::Explicit(left as usize) + } else { + // Unaccounted for padding configuration + panic!("Padding configuration ({:?}) not supported", pads); + } +} + +/// Padding configuration for 2D operations such as convolution +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum PaddingConfig2d { + /// No padding (valid padding) + Valid, + /// Explicit padding with specific width and height + Explicit(usize, usize), +} + +impl fmt::Display for PaddingConfig2d { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + PaddingConfig2d::Valid => write!(f, "Valid"), + PaddingConfig2d::Explicit(width, height) => { + write!(f, "Explicit({}, {})", width, height) + } + } + } +} + +/// Padding configuration for 3D operations such as convolution +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum PaddingConfig3d { + /// No padding (valid padding) + Valid, + /// Explicit padding with specific width, height, and depth + Explicit(usize, usize, usize), +} + +impl fmt::Display for PaddingConfig3d { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + PaddingConfig3d::Valid => write!(f, "Valid"), + PaddingConfig3d::Explicit(width, height, depth) => { + write!(f, "Explicit({}, {}, {})", width, height, depth) + } + } + } +} + +/// Calculate the padding configuration for a 2D operations such as Convolution and Pooling. +/// +/// # Arguments +/// +/// * `pads` - The padding values [left, right, top, bottom] +/// +/// # Panics +/// +/// * If the padding is negative +/// * If the padding is not symmetric +/// +/// # Returns +/// +/// * The padding configuration +/// +/// # Remarks +/// +/// This function is used when the padding is specified as a list of integers, +/// and not used when the padding is specified as a string, e.g. "SAME_UPPER". +pub fn padding_config_2d(pads: &[i64]) -> PaddingConfig2d { + let [top, left, bottom, right] = [pads[0], pads[1], pads[2], pads[3]]; + + if left < 0 || right < 0 || top < 0 || bottom < 0 { + panic!("Negative pad values are not supported"); + } else if left != right || top != bottom { + panic!("Asymmetric padding is not supported"); + } else if left == 0 && right == 0 && top == 0 && bottom == 0 { + PaddingConfig2d::Valid + } else if left == right && top == bottom { + PaddingConfig2d::Explicit(top as usize, left as usize) + } else { + // Unaccounted for padding configuration + panic!("Padding configuration ({:?}) not supported", pads); + } +} + +/// Calculate the padding configuration for a 3D operations such as Convolution and Pooling. +/// +/// # Arguments +/// +/// * `pads` - The padding values [left, right, top, bottom, front, back] +/// +/// # Panics +/// +/// * If the padding is negative +/// * If the padding is not symmetric +/// +/// # Returns +/// +/// * The padding configuration +/// +/// # Remarks +/// +/// This function is used when the padding is specified as a list of integers, +/// and not used when the padding is specified as a string, e.g. "SAME_UPPER". +pub fn padding_config_3d(pads: &[i64]) -> PaddingConfig3d { + let [front, top, left, back, bottom, right] = + [pads[0], pads[1], pads[2], pads[3], pads[4], pads[5]]; + + if left < 0 || right < 0 || top < 0 || bottom < 0 || front < 0 || back < 0 { + panic!("Negative pad values are not supported"); + } else if left != right || top != bottom || front != back { + panic!("Asymmetric padding is not supported"); + } else if left == 0 && right == 0 && top == 0 && bottom == 0 && front == 0 && back == 0 { + PaddingConfig3d::Valid + } else if left == right && top == bottom && front == back { + PaddingConfig3d::Explicit(front as usize, top as usize, left as usize) + } else { + // Unaccounted for padding configuration + panic!("Padding configuration ({:?}) not supported", pads); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_padding_config_2d_valid() { + let pads = vec![0, 0, 0, 0]; + let config = padding_config_2d(&pads); + assert!(matches!(config, PaddingConfig2d::Valid)); + } + + #[test] + fn test_padding_config_2d_explicit() { + let pads = vec![2, 2, 3, 3]; + let config = padding_config_2d(&pads); + assert!(matches!(config, PaddingConfig2d::Explicit(2, 3))); + } + + #[test] + #[should_panic(expected = "Asymmetric padding is not supported")] + fn test_padding_config_2d_asymmetric() { + let pads = vec![1, 2, 3, 3]; + let _ = padding_config_2d(&pads); + } + + #[test] + #[should_panic(expected = "Negative pad values are not supported")] + fn test_padding_config_2d_negative() { + let pads = vec![-1, -1, -1, -1]; + let _ = padding_config_2d(&pads); + } + + #[test] + fn test_padding_config_3d_valid() { + let pads = vec![0, 0, 0, 0, 0, 0]; + let config = padding_config_3d(&pads); + assert!(matches!(config, PaddingConfig3d::Valid)); + } + + #[test] + fn test_padding_config_3d_explicit() { + let pads = vec![2, 2, 3, 3, 1, 1]; + let config = padding_config_3d(&pads); + assert!(matches!(config, PaddingConfig3d::Explicit(2, 3, 1))); + } + + #[test] + #[should_panic(expected = "Asymmetric padding is not supported")] + fn test_padding_config_3d_asymmetric() { + let pads = vec![1, 2, 3, 3, 1, 1]; + let _ = padding_config_3d(&pads); + } + + #[test] + #[should_panic(expected = "Negative pad values are not supported")] + fn test_padding_config_3d_negative() { + let pads = vec![-1, -1, -1, -1, -1, -1]; + let _ = padding_config_3d(&pads); + } +} From e1d968209cd5be3e1e79c9d6b46b085d9c115ece Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Thu, 1 May 2025 11:40:05 -0500 Subject: [PATCH 17/37] Fixed left over tests --- crates/onnx-ir/src/node/padding.rs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/crates/onnx-ir/src/node/padding.rs b/crates/onnx-ir/src/node/padding.rs index bfe3f68071..3fba910027 100644 --- a/crates/onnx-ir/src/node/padding.rs +++ b/crates/onnx-ir/src/node/padding.rs @@ -182,15 +182,15 @@ mod tests { #[test] fn test_padding_config_2d_explicit() { - let pads = vec![2, 2, 3, 3]; + let pads = vec![2, 2, 2, 2]; let config = padding_config_2d(&pads); - assert!(matches!(config, PaddingConfig2d::Explicit(2, 3))); + assert!(matches!(config, PaddingConfig2d::Explicit(2, 2))); } #[test] #[should_panic(expected = "Asymmetric padding is not supported")] fn test_padding_config_2d_asymmetric() { - let pads = vec![1, 2, 3, 3]; + let pads = vec![2, 3, 2, 2]; let _ = padding_config_2d(&pads); } @@ -210,7 +210,7 @@ mod tests { #[test] fn test_padding_config_3d_explicit() { - let pads = vec![2, 2, 3, 3, 1, 1]; + let pads = vec![2, 3, 1, 2, 3, 1]; let config = padding_config_3d(&pads); assert!(matches!(config, PaddingConfig3d::Explicit(2, 3, 1))); } @@ -218,7 +218,7 @@ mod tests { #[test] #[should_panic(expected = "Asymmetric padding is not supported")] fn test_padding_config_3d_asymmetric() { - let pads = vec![1, 2, 3, 3, 1, 1]; + let pads = vec![2, 3, 1, 3, 3, 1]; let _ = padding_config_3d(&pads); } From 29a08965800a6f595449ccce45fd94348bb5fc9e Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Thu, 1 May 2025 14:19:26 -0500 Subject: [PATCH 18/37] No default init for config structs --- .../burn-import/src/burn/node/avg_pool1d.rs | 9 +- .../burn-import/src/burn/node/avg_pool2d.rs | 9 +- crates/burn-import/src/burn/node/base.rs | 20 ++- .../burn-import/src/burn/node/batch_norm.rs | 6 +- crates/burn-import/src/burn/node/conv1d.rs | 11 +- crates/burn-import/src/burn/node/conv2d.rs | 10 +- crates/burn-import/src/burn/node/conv3d.rs | 10 +- .../src/burn/node/conv_transpose_2d.rs | 17 +- .../src/burn/node/conv_transpose_3d.rs | 17 +- crates/onnx-ir/src/node/avg_pool1d.rs | 31 +--- crates/onnx-ir/src/node/avg_pool2d.rs | 41 ++--- crates/onnx-ir/src/node/batch_norm.rs | 22 +-- crates/onnx-ir/src/node/conv1d.rs | 52 ++---- crates/onnx-ir/src/node/conv2d.rs | 60 +++---- crates/onnx-ir/src/node/conv3d.rs | 76 +++------ crates/onnx-ir/src/node/conv_transpose1d.rs | 157 +++++++++--------- crates/onnx-ir/src/node/conv_transpose2d.rs | 72 +++----- crates/onnx-ir/src/node/conv_transpose3d.rs | 88 ++++------ crates/onnx-ir/src/op_configuration.rs | 9 + 19 files changed, 324 insertions(+), 393 deletions(-) diff --git a/crates/burn-import/src/burn/node/avg_pool1d.rs b/crates/burn-import/src/burn/node/avg_pool1d.rs index 680609290f..f22dffdceb 100644 --- a/crates/burn-import/src/burn/node/avg_pool1d.rs +++ b/crates/burn-import/src/burn/node/avg_pool1d.rs @@ -105,9 +105,12 @@ mod tests { "avg_pool1d", TensorType::new_float("input", 3), TensorType::new_float("output", 3), - AvgPool1dConfig::new(3) - .with_stride(1) - .with_padding(PaddingConfig1d::Valid), + AvgPool1dConfig::new( + 3, // kernel_size + 1, // stride + PaddingConfig1d::Valid, // padding + true, // count_include_pad + ), )); graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]); diff --git a/crates/burn-import/src/burn/node/avg_pool2d.rs b/crates/burn-import/src/burn/node/avg_pool2d.rs index f929693e51..ed3b78be5b 100644 --- a/crates/burn-import/src/burn/node/avg_pool2d.rs +++ b/crates/burn-import/src/burn/node/avg_pool2d.rs @@ -109,9 +109,12 @@ mod tests { "avg_pool2d", TensorType::new_float("input", 4), TensorType::new_float("output", 4), - AvgPool2dConfig::new([3, 3]) - .with_strides([1, 1]) - .with_padding(PaddingConfig2d::Valid), + AvgPool2dConfig::new( + [3, 3], // kernel_size + [1, 1], // strides + PaddingConfig2d::Valid, // padding + true, // count_include_pad + ), )); graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]); diff --git a/crates/burn-import/src/burn/node/base.rs b/crates/burn-import/src/burn/node/base.rs index e63ff723fc..71f614f1af 100644 --- a/crates/burn-import/src/burn/node/base.rs +++ b/crates/burn-import/src/burn/node/base.rs @@ -371,7 +371,15 @@ pub(crate) mod tests { TensorType::new_float("tensor4", 4), TensorData::from([2f32]), None, - Conv2dConfig::new([3, 3], [3, 3]).with_padding(PaddingConfig2d::Valid), + Conv2dConfig::new( + [3, 3], // kernel_size + [3, 3], // stride + [1, 1], // dilation + PaddingConfig2d::Valid, // padding + [1, 1], // output_padding + 1, // groups + true // bias + ), )); graph.register_input_output( @@ -444,7 +452,15 @@ pub(crate) mod tests { TensorType::new_float("tensor4", 4), TensorData::from([2f32]), None, - Conv2dConfig::new([3, 3], [3, 3]).with_padding(PaddingConfig2d::Valid), + Conv2dConfig::new( + [3, 3], // kernel_size + [3, 3], // stride + [1, 1], // dilation + PaddingConfig2d::Valid, // padding + [1, 1], // output_padding + 1, // groups + true // bias + ), )); graph.register(MatmulNode::new( TensorType::new_float("tensor3", 4), diff --git a/crates/burn-import/src/burn/node/batch_norm.rs b/crates/burn-import/src/burn/node/batch_norm.rs index 6733dd91f9..ed9221a15f 100644 --- a/crates/burn-import/src/burn/node/batch_norm.rs +++ b/crates/burn-import/src/burn/node/batch_norm.rs @@ -172,7 +172,11 @@ mod tests { TensorData::from([2f32]), TensorData::from([2f32]), TensorData::from([2f32]), - BatchNormConfig::new(128), + BatchNormConfig::new( + 128, // num_features + 0.00001, // epsilon + 0.1 // momentum + ), )); graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]); diff --git a/crates/burn-import/src/burn/node/conv1d.rs b/crates/burn-import/src/burn/node/conv1d.rs index 1eeb61d4f5..717ddc7cfc 100644 --- a/crates/burn-import/src/burn/node/conv1d.rs +++ b/crates/burn-import/src/burn/node/conv1d.rs @@ -149,7 +149,16 @@ mod tests { TensorType::new_float("output", 4), TensorData::from([2f32]), None, - Conv1dConfig::new(3, 3, 3).with_padding(PaddingConfig1d::Valid), + Conv1dConfig::new( + 3, // channels_in + 3, // channels_out + 3, // kernel_size + 1, // stride + PaddingConfig1d::Valid, // padding + 1, // dilation + 1, // groups + true // bias + ), )); graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]); diff --git a/crates/burn-import/src/burn/node/conv2d.rs b/crates/burn-import/src/burn/node/conv2d.rs index 8051e0bd73..039932fa89 100644 --- a/crates/burn-import/src/burn/node/conv2d.rs +++ b/crates/burn-import/src/burn/node/conv2d.rs @@ -148,7 +148,15 @@ mod tests { TensorType::new_float("output", 4), TensorData::from([2f32]), None, - Conv2dConfig::new([3, 3], [3, 3]).with_padding(PaddingConfig2d::Valid), + Conv2dConfig::new( + [3, 3], // kernel_size + [3, 3], // stride + [1, 1], // dilation + PaddingConfig2d::Valid, // padding + [1, 1], // output_padding + 1, // groups + true // bias + ), )); graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]); diff --git a/crates/burn-import/src/burn/node/conv3d.rs b/crates/burn-import/src/burn/node/conv3d.rs index b0855434a3..50f6fdf384 100644 --- a/crates/burn-import/src/burn/node/conv3d.rs +++ b/crates/burn-import/src/burn/node/conv3d.rs @@ -148,7 +148,15 @@ mod tests { TensorType::new_float("output", 5), TensorData::from([2f32]), None, - Conv3dConfig::new([3, 3], [3, 3, 3]).with_padding(PaddingConfig3d::Valid), + Conv3dConfig::new( + [3, 3], // kernel_size + [3, 3, 3], // stride + [1, 1, 1], // dilation + [1, 1, 1], // output_padding + 1, // groups + true, // bias + PaddingConfig3d::Valid // padding + ), )); graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]); diff --git a/crates/burn-import/src/burn/node/conv_transpose_2d.rs b/crates/burn-import/src/burn/node/conv_transpose_2d.rs index 851e884c11..e1880c6127 100644 --- a/crates/burn-import/src/burn/node/conv_transpose_2d.rs +++ b/crates/burn-import/src/burn/node/conv_transpose_2d.rs @@ -150,7 +150,16 @@ mod tests { TensorType::new_float("output", 4), TensorData::from([2f32]), None, - ConvTranspose2dConfig::new([3, 3], [3, 3]).with_padding([0, 0]), + ConvTranspose2dConfig::new( + [3, 3], // kernel_size + [1, 1], // stride + [0, 0], // dilation + [0, 0], // padding + [0, 0], // output_padding + [0, 0], // padding_out + 1, // groups + true, // bias + ), )); graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]); @@ -173,11 +182,11 @@ mod tests { impl Model { #[allow(unused_variables)] pub fn new(device: &B::Device) -> Self { - let conv_transpose_2d = ConvTranspose2dConfig::new([3, 3], [3, 3]) - .with_stride([1, 1]) + let conv_transpose_2d = ConvTranspose2dConfig::new([3, 3], [1, 1]) + .with_stride([0, 0]) .with_padding([0, 0]) .with_padding_out([0, 0]) - .with_dilation([1, 1]) + .with_dilation([0, 0]) .with_groups(1) .with_bias(true) .init(device); diff --git a/crates/burn-import/src/burn/node/conv_transpose_3d.rs b/crates/burn-import/src/burn/node/conv_transpose_3d.rs index f8a8906a07..ae56aa689d 100644 --- a/crates/burn-import/src/burn/node/conv_transpose_3d.rs +++ b/crates/burn-import/src/burn/node/conv_transpose_3d.rs @@ -150,7 +150,16 @@ mod tests { TensorType::new_float("output", 5), TensorData::from([2f32]), None, - ConvTranspose3dConfig::new([3, 3], [3, 3, 3]).with_padding([0, 0, 0]), + ConvTranspose3dConfig::new( + [3, 3], // kernel_size + [1, 1, 1], // stride + [0, 0, 0], // dilation + [0, 0, 0], // padding + [0, 0, 0], // output_padding + [0, 0, 0], // output_padding additional + 1, // groups + true, // bias + ), )); graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]); @@ -173,11 +182,11 @@ mod tests { impl Model { #[allow(unused_variables)] pub fn new(device: &B::Device) -> Self { - let conv_transpose_3d = ConvTranspose3dConfig::new([3, 3], [3, 3, 3]) - .with_stride([1, 1, 1]) + let conv_transpose_3d = ConvTranspose3dConfig::new([3, 3], [1, 1, 1]) + .with_stride([0, 0, 0]) .with_padding([0, 0, 0]) .with_padding_out([0, 0, 0]) - .with_dilation([1, 1, 1]) + .with_dilation([0, 0, 0]) .with_groups(1) .with_bias(true) .init(device); diff --git a/crates/onnx-ir/src/node/avg_pool1d.rs b/crates/onnx-ir/src/node/avg_pool1d.rs index 4b2b7e9045..af5bc6dabe 100644 --- a/crates/onnx-ir/src/node/avg_pool1d.rs +++ b/crates/onnx-ir/src/node/avg_pool1d.rs @@ -17,32 +17,19 @@ pub struct AvgPool1dConfig { impl AvgPool1dConfig { /// Create a new AvgPool1dConfig - pub fn new(kernel_size: usize) -> Self { + pub fn new( + kernel_size: usize, + stride: usize, + padding: PaddingConfig1d, + count_include_pad: bool, + ) -> Self { Self { kernel_size, - stride: 1, - padding: PaddingConfig1d::Valid, - count_include_pad: true, + stride, + padding, + count_include_pad, } } - - /// Set the stride - pub fn with_stride(mut self, stride: usize) -> Self { - self.stride = stride; - self - } - - /// Set the padding configuration - pub fn with_padding(mut self, padding: PaddingConfig1d) -> Self { - self.padding = padding; - self - } - - /// Set whether to include padding in the average calculation - pub fn with_count_include_pad(mut self, count_include_pad: bool) -> Self { - self.count_include_pad = count_include_pad; - self - } } /// Create an AvgPool1dConfig from the attributes of the node diff --git a/crates/onnx-ir/src/node/avg_pool2d.rs b/crates/onnx-ir/src/node/avg_pool2d.rs index c7ec45c1e2..4e95cc8228 100644 --- a/crates/onnx-ir/src/node/avg_pool2d.rs +++ b/crates/onnx-ir/src/node/avg_pool2d.rs @@ -16,32 +16,19 @@ pub struct AvgPool2dConfig { impl AvgPool2dConfig { /// Create a new AvgPool2dConfig - pub fn new(kernel_size: [usize; 2]) -> Self { + pub fn new( + kernel_size: [usize; 2], + strides: [usize; 2], + padding: PaddingConfig2d, + count_include_pad: bool, + ) -> Self { Self { kernel_size, - strides: [1, 1], - padding: PaddingConfig2d::Valid, - count_include_pad: true, + strides, + padding, + count_include_pad, } } - - /// Set the strides - pub fn with_strides(mut self, strides: [usize; 2]) -> Self { - self.strides = strides; - self - } - - /// Set the padding configuration - pub fn with_padding(mut self, padding: PaddingConfig2d) -> Self { - self.padding = padding; - self - } - - /// Set whether to include padding in the average calculation - pub fn with_count_include_pad(mut self, count_include_pad: bool) -> Self { - self.count_include_pad = count_include_pad; - self - } } /// Create a AvgPool2dConfig from the attributes of the node @@ -69,10 +56,12 @@ pub fn avg_pool2d_config(curr: &Node) -> AvgPool2dConfig { let padding = padding_config_2d(&pads); - AvgPool2dConfig::new([kernel_shape[0] as usize, kernel_shape[1] as usize]) - .with_strides([strides[0] as usize, strides[1] as usize]) - .with_padding(padding) - .with_count_include_pad(count_include_pad == 1) + AvgPool2dConfig::new( + [kernel_shape[0] as usize, kernel_shape[1] as usize], + [strides[0] as usize, strides[1] as usize], + padding, + count_include_pad == 1, + ) } #[cfg(test)] diff --git a/crates/onnx-ir/src/node/batch_norm.rs b/crates/onnx-ir/src/node/batch_norm.rs index a6ef3b2a94..c3e83ff3d4 100644 --- a/crates/onnx-ir/src/node/batch_norm.rs +++ b/crates/onnx-ir/src/node/batch_norm.rs @@ -13,25 +13,13 @@ pub struct BatchNormConfig { impl BatchNormConfig { /// Create a new BatchNormConfig - pub fn new(num_features: usize) -> Self { + pub fn new(num_features: usize, epsilon: f64, momentum: f64) -> Self { Self { num_features, - epsilon: 1e-5, - momentum: 0.1, + epsilon, + momentum, } } - - /// Set the epsilon value - pub fn with_epsilon(mut self, epsilon: f64) -> Self { - self.epsilon = epsilon; - self - } - - /// Set the momentum value - pub fn with_momentum(mut self, momentum: f64) -> Self { - self.momentum = momentum; - self - } } /// Create a BatchNormConfig from the attributes of the node @@ -56,9 +44,7 @@ pub fn batch_norm_config(node: &Node) -> BatchNormConfig { } } - BatchNormConfig::new(num_features) - .with_epsilon(epsilon as f64) - .with_momentum(momentum as f64) + BatchNormConfig::new(num_features, epsilon as f64, momentum as f64) } #[cfg(test)] diff --git a/crates/onnx-ir/src/node/conv1d.rs b/crates/onnx-ir/src/node/conv1d.rs index 017352a5be..cc8420faf6 100644 --- a/crates/onnx-ir/src/node/conv1d.rs +++ b/crates/onnx-ir/src/node/conv1d.rs @@ -25,48 +25,28 @@ pub struct Conv1dConfig { impl Conv1dConfig { /// Create a new Conv1dConfig - pub fn new(channels_in: usize, channels_out: usize, kernel_size: usize) -> Self { + #[allow(clippy::too_many_arguments)] + pub fn new( + channels_in: usize, + channels_out: usize, + kernel_size: usize, + stride: usize, + padding: PaddingConfig1d, + dilation: usize, + groups: usize, + bias: bool, + ) -> Self { Self { channels_in, channels_out, kernel_size, - stride: 1, - padding: PaddingConfig1d::Valid, - dilation: 1, - groups: 1, - bias: true, + stride, + padding, + dilation, + groups, + bias, } } - - /// Set the stride - pub fn with_stride(mut self, stride: usize) -> Self { - self.stride = stride; - self - } - - /// Set the padding configuration - pub fn with_padding(mut self, padding: PaddingConfig1d) -> Self { - self.padding = padding; - self - } - - /// Set the dilation - pub fn with_dilation(mut self, dilation: usize) -> Self { - self.dilation = dilation; - self - } - - /// Set the number of groups - pub fn with_groups(mut self, groups: usize) -> Self { - self.groups = groups; - self - } - - /// Set whether bias is used - pub fn with_bias(mut self, bias: bool) -> Self { - self.bias = bias; - self - } } /// Create a Conv1dConfig from the attributes of the node diff --git a/crates/onnx-ir/src/node/conv2d.rs b/crates/onnx-ir/src/node/conv2d.rs index 072f86a58f..1899546932 100644 --- a/crates/onnx-ir/src/node/conv2d.rs +++ b/crates/onnx-ir/src/node/conv2d.rs @@ -22,47 +22,25 @@ pub struct Conv2dConfig { impl Conv2dConfig { /// Create a new Conv2dConfig - pub fn new(channels: [usize; 2], kernel_size: [usize; 2]) -> Self { + pub fn new( + channels: [usize; 2], + kernel_size: [usize; 2], + stride: [usize; 2], + padding: PaddingConfig2d, + dilation: [usize; 2], + groups: usize, + bias: bool, + ) -> Self { Self { channels, kernel_size, - stride: [1, 1], - padding: PaddingConfig2d::Valid, - dilation: [1, 1], - groups: 1, - bias: true, + stride, + padding, + dilation, + groups, + bias, } } - - /// Set the stride - pub fn with_stride(mut self, stride: [usize; 2]) -> Self { - self.stride = stride; - self - } - - /// Set the padding configuration - pub fn with_padding(mut self, padding: PaddingConfig2d) -> Self { - self.padding = padding; - self - } - - /// Set the dilation - pub fn with_dilation(mut self, dilation: [usize; 2]) -> Self { - self.dilation = dilation; - self - } - - /// Set the number of groups - pub fn with_groups(mut self, groups: usize) -> Self { - self.groups = groups; - self - } - - /// Set whether bias is used - pub fn with_bias(mut self, bias: bool) -> Self { - self.bias = bias; - self - } } /// Create a Conv2dConfig from the attributes of the node @@ -103,12 +81,12 @@ pub fn conv2d_config(curr: &Node) -> Conv2dConfig { Conv2dConfig::new( [channels_in, channels_out], [kernel_shape[0] as usize, kernel_shape[1] as usize], + [strides[0] as usize, strides[1] as usize], + padding, + [dilations[0] as usize, dilations[1] as usize], + group, + bias, ) - .with_stride([strides[0] as usize, strides[1] as usize]) - .with_dilation([dilations[0] as usize, dilations[1] as usize]) - .with_groups(group) - .with_bias(bias) - .with_padding(padding) } #[cfg(test)] diff --git a/crates/onnx-ir/src/node/conv3d.rs b/crates/onnx-ir/src/node/conv3d.rs index e55093f9d1..97f71951f2 100644 --- a/crates/onnx-ir/src/node/conv3d.rs +++ b/crates/onnx-ir/src/node/conv3d.rs @@ -22,47 +22,25 @@ pub struct Conv3dConfig { impl Conv3dConfig { /// Create a new configuration for a Conv3d. - pub fn new(channels: [usize; 2], kernel_size: [usize; 3]) -> Self { + pub fn new( + channels: [usize; 2], + kernel_size: [usize; 3], + stride: [usize; 3], + dilation: [usize; 3], + groups: usize, + bias: bool, + padding: PaddingConfig3d, + ) -> Self { Self { channels, kernel_size, - stride: [1, 1, 1], - dilation: [1, 1, 1], - groups: 1, - bias: true, - padding: PaddingConfig3d::Valid, + stride, + dilation, + groups, + bias, + padding, } } - - /// Set the stride. - pub fn with_stride(mut self, stride: [usize; 3]) -> Self { - self.stride = stride; - self - } - - /// Set the dilation. - pub fn with_dilation(mut self, dilation: [usize; 3]) -> Self { - self.dilation = dilation; - self - } - - /// Set the groups. - pub fn with_groups(mut self, groups: usize) -> Self { - self.groups = groups; - self - } - - /// Set whether to use bias. - pub fn with_bias(mut self, bias: bool) -> Self { - self.bias = bias; - self - } - - /// Set the padding. - pub fn with_padding(mut self, padding: PaddingConfig3d) -> Self { - self.padding = padding; - self - } } /// Create a Conv3dConfig from the attributes of the node @@ -107,20 +85,20 @@ pub fn conv3d_config(curr: &Node) -> Conv3dConfig { kernel_shape[1] as usize, kernel_shape[2] as usize, ], + [ + strides[0] as usize, + strides[1] as usize, + strides[2] as usize, + ], + [ + dilations[0] as usize, + dilations[1] as usize, + dilations[2] as usize, + ], + group, + bias, + padding, ) - .with_stride([ - strides[0] as usize, - strides[1] as usize, - strides[2] as usize, - ]) - .with_dilation([ - dilations[0] as usize, - dilations[1] as usize, - dilations[2] as usize, - ]) - .with_groups(group) - .with_bias(bias) - .with_padding(padding) } #[cfg(test)] diff --git a/crates/onnx-ir/src/node/conv_transpose1d.rs b/crates/onnx-ir/src/node/conv_transpose1d.rs index 7c402880f6..e82b3d8132 100644 --- a/crates/onnx-ir/src/node/conv_transpose1d.rs +++ b/crates/onnx-ir/src/node/conv_transpose1d.rs @@ -23,86 +23,93 @@ pub struct ConvTranspose1dConfig { pub padding_out: usize, } -/// Create a ConvTranspose1dConfig from the attributes of the node -pub fn conv_transpose1d_config(curr: &Node) -> ConvTranspose1dConfig { - let mut attrs = curr.attrs.clone(); - - // Extract kernel_shape, default to an empty vector if not present - let kernel_shape = attrs - .remove("kernel_shape") - .map(AttributeValue::into_i64s) - .unwrap_or_default(); - - // Extract strides, default to 1 if not present - let stride = attrs - .remove("strides") - .map(AttributeValue::into_i64s) - .unwrap_or_else(|| vec![1]); - - // Extract padding, default to 0 if not present - let pads = attrs - .remove("pads") - .map(AttributeValue::into_i64s) - .unwrap_or_else(|| vec![0, 0]); - - // Extract dilations, default to 1 if not present - let dilations = attrs - .remove("dilations") - .map(AttributeValue::into_i64s) - .unwrap_or_else(|| vec![1]); - - // Extract group attribute, default to 1 - let group = attrs - .remove("group") - .map(AttributeValue::into_i64) - .unwrap_or(1) as usize; - - // Extract output_padding, default to 0 if not present - let output_padding = attrs - .remove("output_padding") - .map(AttributeValue::into_i64s) - .unwrap_or_else(|| vec![0]); - - // Ensure no unused attributes remain - if !attrs.is_empty() { - panic!("Not all attributes are used: {attrs:?}"); - } +impl ConvTranspose1dConfig { + /// Create a new ConvTranspose1dConfig from the attributes of the node + pub fn new(curr: &Node) -> Self { + let mut attrs = curr.attrs.clone(); + + // Extract kernel_shape, default to an empty vector if not present + let kernel_shape = attrs + .remove("kernel_shape") + .map(AttributeValue::into_i64s) + .unwrap_or_default(); + + // Extract strides, default to 1 if not present + let stride = attrs + .remove("strides") + .map(AttributeValue::into_i64s) + .unwrap_or_else(|| vec![1]); + + // Extract padding, default to 0 if not present + let pads = attrs + .remove("pads") + .map(AttributeValue::into_i64s) + .unwrap_or_else(|| vec![0, 0]); + + // Extract dilations, default to 1 if not present + let dilations = attrs + .remove("dilations") + .map(AttributeValue::into_i64s) + .unwrap_or_else(|| vec![1]); + + // Extract group attribute, default to 1 + let group = attrs + .remove("group") + .map(AttributeValue::into_i64) + .unwrap_or(1) as usize; + + // Extract output_padding, default to 0 if not present + let output_padding = attrs + .remove("output_padding") + .map(AttributeValue::into_i64s) + .unwrap_or_else(|| vec![0]); + + // Ensure no unused attributes remain + if !attrs.is_empty() { + panic!("Not all attributes are used: {attrs:?}"); + } - // Check the pads are symmetric - if pads.len() != 2 || pads[0] != pads[1] { - panic!( - "Asymmetric padding is not supported for ConvTranspose1d: {:?}", - pads - ); - } + // Check the pads are symmetric + if pads.len() != 2 || pads[0] != pads[1] { + panic!( + "Asymmetric padding is not supported for ConvTranspose1d: {:?}", + pads + ); + } - let weight_shape = curr.inputs[1] - .value - .as_ref() - .expect("ConvTranspose1d: weight tensor must be present") - .shape - .clone(); - - // Check if bias is present (third input) - let bias = curr.inputs.len() == 3; - - // Extract channels from the weight tensor shape [out_channels, in_channels] - let channels_in = weight_shape[1] * group; - let channels_out = weight_shape[0]; - - ConvTranspose1dConfig { - channels_in, - channels_out, - kernel_size: kernel_shape[0] as usize, - stride: stride[0] as usize, - padding: pads[0] as usize, - dilation: dilations[0] as usize, - padding_out: output_padding[0] as usize, - groups: group, - bias, + let weight_shape = curr.inputs[1] + .value + .as_ref() + .expect("ConvTranspose1d: weight tensor must be present") + .shape + .clone(); + + // Check if bias is present (third input) + let bias = curr.inputs.len() == 3; + + // Extract channels from the weight tensor shape [out_channels, in_channels] + let channels_in = weight_shape[1] * group; + let channels_out = weight_shape[0]; + + Self { + channels_in, + channels_out, + kernel_size: kernel_shape[0] as usize, + stride: stride[0] as usize, + padding: pads[0] as usize, + dilation: dilations[0] as usize, + padding_out: output_padding[0] as usize, + groups: group, + bias, + } } } +/// Create a ConvTranspose1dConfig from the attributes of the node +pub fn conv_transpose1d_config(curr: &Node) -> ConvTranspose1dConfig { + ConvTranspose1dConfig::new(curr) +} + #[cfg(test)] mod tests { use super::*; diff --git a/crates/onnx-ir/src/node/conv_transpose2d.rs b/crates/onnx-ir/src/node/conv_transpose2d.rs index b373525218..568a459b06 100644 --- a/crates/onnx-ir/src/node/conv_transpose2d.rs +++ b/crates/onnx-ir/src/node/conv_transpose2d.rs @@ -23,54 +23,28 @@ pub struct ConvTranspose2dConfig { impl ConvTranspose2dConfig { /// Create a new configuration for a ConvTranspose2d. - pub fn new(channels: [usize; 2], kernel_size: [usize; 2]) -> Self { + #[allow(clippy::too_many_arguments)] + pub fn new( + channels: [usize; 2], + kernel_size: [usize; 2], + stride: [usize; 2], + dilation: [usize; 2], + padding: [usize; 2], + padding_out: [usize; 2], + groups: usize, + bias: bool, + ) -> Self { Self { channels, kernel_size, - stride: [1, 1], - dilation: [1, 1], - padding: [0, 0], - padding_out: [0, 0], - groups: 1, - bias: true, + stride, + dilation, + padding, + padding_out, + groups, + bias, } } - - /// Set the stride. - pub fn with_stride(mut self, stride: [usize; 2]) -> Self { - self.stride = stride; - self - } - - /// Set the dilation. - pub fn with_dilation(mut self, dilation: [usize; 2]) -> Self { - self.dilation = dilation; - self - } - - /// Set the padding. - pub fn with_padding(mut self, padding: [usize; 2]) -> Self { - self.padding = padding; - self - } - - /// Set the output padding. - pub fn with_padding_out(mut self, padding_out: [usize; 2]) -> Self { - self.padding_out = padding_out; - self - } - - /// Set the groups. - pub fn with_groups(mut self, groups: usize) -> Self { - self.groups = groups; - self - } - - /// Set whether to use bias. - pub fn with_bias(mut self, bias: bool) -> Self { - self.bias = bias; - self - } } /// Create a ConvTranspose2dConfig from the attributes of the node @@ -129,13 +103,13 @@ pub fn conv_transpose2d_config(curr: &Node) -> ConvTranspose2dConfig { ConvTranspose2dConfig::new( channels, [kernel_shape[0] as usize, kernel_shape[1] as usize], + [stride[0] as usize, stride[1] as usize], + [dilations[0] as usize, dilations[1] as usize], + [pads[0] as usize, pads[1] as usize], + [output_padding[0] as usize, output_padding[1] as usize], + group, + bias, ) - .with_stride([stride[0] as usize, stride[1] as usize]) - .with_padding([pads[0] as usize, pads[1] as usize]) - .with_dilation([dilations[0] as usize, dilations[1] as usize]) - .with_padding_out([output_padding[0] as usize, output_padding[1] as usize]) - .with_groups(group) - .with_bias(bias) } #[cfg(test)] diff --git a/crates/onnx-ir/src/node/conv_transpose3d.rs b/crates/onnx-ir/src/node/conv_transpose3d.rs index b003c5a045..0bac8cd1fa 100644 --- a/crates/onnx-ir/src/node/conv_transpose3d.rs +++ b/crates/onnx-ir/src/node/conv_transpose3d.rs @@ -23,54 +23,28 @@ pub struct ConvTranspose3dConfig { impl ConvTranspose3dConfig { /// Create a new configuration for a ConvTranspose3d. - pub fn new(channels: [usize; 2], kernel_size: [usize; 3]) -> Self { + #[allow(clippy::too_many_arguments)] + pub fn new( + channels: [usize; 2], + kernel_size: [usize; 3], + stride: [usize; 3], + dilation: [usize; 3], + padding: [usize; 3], + padding_out: [usize; 3], + groups: usize, + bias: bool, + ) -> Self { Self { channels, kernel_size, - stride: [1, 1, 1], - dilation: [1, 1, 1], - padding: [0, 0, 0], - padding_out: [0, 0, 0], - groups: 1, - bias: true, + stride, + dilation, + padding, + padding_out, + groups, + bias, } } - - /// Set the stride. - pub fn with_stride(mut self, stride: [usize; 3]) -> Self { - self.stride = stride; - self - } - - /// Set the dilation. - pub fn with_dilation(mut self, dilation: [usize; 3]) -> Self { - self.dilation = dilation; - self - } - - /// Set the padding. - pub fn with_padding(mut self, padding: [usize; 3]) -> Self { - self.padding = padding; - self - } - - /// Set the output padding. - pub fn with_padding_out(mut self, padding_out: [usize; 3]) -> Self { - self.padding_out = padding_out; - self - } - - /// Set the groups. - pub fn with_groups(mut self, groups: usize) -> Self { - self.groups = groups; - self - } - - /// Set whether to use bias. - pub fn with_bias(mut self, bias: bool) -> Self { - self.bias = bias; - self - } } /// Create a ConvTranspose3dConfig from the attributes of the node @@ -135,21 +109,21 @@ pub fn conv_transpose3d_config(curr: &Node) -> ConvTranspose3dConfig { kernel_shape[1] as usize, kernel_shape[2] as usize, ], + [stride[0] as usize, stride[1] as usize, stride[2] as usize], + [ + dilations[0] as usize, + dilations[1] as usize, + dilations[2] as usize, + ], + [pads[0] as usize, pads[1] as usize, pads[2] as usize], + [ + output_padding[0] as usize, + output_padding[1] as usize, + output_padding[2] as usize, + ], + group, + bias, ) - .with_stride([stride[0] as usize, stride[1] as usize, stride[2] as usize]) - .with_padding([pads[0] as usize, pads[1] as usize, pads[2] as usize]) - .with_dilation([ - dilations[0] as usize, - dilations[1] as usize, - dilations[2] as usize, - ]) - .with_padding_out([ - output_padding[0] as usize, - output_padding[1] as usize, - output_padding[2] as usize, - ]) - .with_groups(group) - .with_bias(bias) } #[cfg(test)] diff --git a/crates/onnx-ir/src/op_configuration.rs b/crates/onnx-ir/src/op_configuration.rs index d453d84c9a..430a49a18c 100644 --- a/crates/onnx-ir/src/op_configuration.rs +++ b/crates/onnx-ir/src/op_configuration.rs @@ -2,7 +2,16 @@ // These should be deprecated and eventually removed in favor of using // direct imports from the node modules +pub use crate::node::avg_pool1d::avg_pool1d_config; +pub use crate::node::avg_pool2d::avg_pool2d_config; +pub use crate::node::batch_norm::batch_norm_config; pub use crate::node::clip::clip_config; +pub use crate::node::conv1d::conv1d_config; +pub use crate::node::conv2d::conv2d_config; +pub use crate::node::conv3d::conv3d_config; +pub use crate::node::conv_transpose1d::conv_transpose1d_config; +pub use crate::node::conv_transpose2d::conv_transpose2d_config; +pub use crate::node::conv_transpose3d::conv_transpose3d_config; pub use crate::node::gemm::gemm_config; pub use crate::node::hard_sigmoid::hard_sigmoid_config; pub use crate::node::leaky_relu::leaky_relu_config; From 995b7eb0e7487ddf1c44e179f22885a677e542e7 Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Thu, 1 May 2025 14:23:37 -0500 Subject: [PATCH 19/37] Fix format --- .../burn-import/src/burn/node/avg_pool1d.rs | 6 ++--- .../burn-import/src/burn/node/avg_pool2d.rs | 8 +++---- crates/burn-import/src/burn/node/base.rs | 24 +++++++++---------- .../burn-import/src/burn/node/batch_norm.rs | 2 +- crates/burn-import/src/burn/node/conv1d.rs | 14 +++++------ crates/burn-import/src/burn/node/conv2d.rs | 12 +++++----- crates/burn-import/src/burn/node/conv3d.rs | 14 +++++------ 7 files changed, 40 insertions(+), 40 deletions(-) diff --git a/crates/burn-import/src/burn/node/avg_pool1d.rs b/crates/burn-import/src/burn/node/avg_pool1d.rs index f22dffdceb..0002175c54 100644 --- a/crates/burn-import/src/burn/node/avg_pool1d.rs +++ b/crates/burn-import/src/burn/node/avg_pool1d.rs @@ -106,10 +106,10 @@ mod tests { TensorType::new_float("input", 3), TensorType::new_float("output", 3), AvgPool1dConfig::new( - 3, // kernel_size - 1, // stride + 3, // kernel_size + 1, // stride PaddingConfig1d::Valid, // padding - true, // count_include_pad + true, // count_include_pad ), )); diff --git a/crates/burn-import/src/burn/node/avg_pool2d.rs b/crates/burn-import/src/burn/node/avg_pool2d.rs index ed3b78be5b..af70737e4d 100644 --- a/crates/burn-import/src/burn/node/avg_pool2d.rs +++ b/crates/burn-import/src/burn/node/avg_pool2d.rs @@ -110,10 +110,10 @@ mod tests { TensorType::new_float("input", 4), TensorType::new_float("output", 4), AvgPool2dConfig::new( - [3, 3], // kernel_size - [1, 1], // strides - PaddingConfig2d::Valid, // padding - true, // count_include_pad + [3, 3], // kernel_size + [1, 1], // strides + PaddingConfig2d::Valid, // padding + true, // count_include_pad ), )); diff --git a/crates/burn-import/src/burn/node/base.rs b/crates/burn-import/src/burn/node/base.rs index 71f614f1af..1b47261db0 100644 --- a/crates/burn-import/src/burn/node/base.rs +++ b/crates/burn-import/src/burn/node/base.rs @@ -372,13 +372,13 @@ pub(crate) mod tests { TensorData::from([2f32]), None, Conv2dConfig::new( - [3, 3], // kernel_size - [3, 3], // stride - [1, 1], // dilation + [3, 3], // kernel_size + [3, 3], // stride + [1, 1], // dilation PaddingConfig2d::Valid, // padding - [1, 1], // output_padding - 1, // groups - true // bias + [1, 1], // output_padding + 1, // groups + true, // bias ), )); @@ -453,13 +453,13 @@ pub(crate) mod tests { TensorData::from([2f32]), None, Conv2dConfig::new( - [3, 3], // kernel_size - [3, 3], // stride - [1, 1], // dilation + [3, 3], // kernel_size + [3, 3], // stride + [1, 1], // dilation PaddingConfig2d::Valid, // padding - [1, 1], // output_padding - 1, // groups - true // bias + [1, 1], // output_padding + 1, // groups + true, // bias ), )); graph.register(MatmulNode::new( diff --git a/crates/burn-import/src/burn/node/batch_norm.rs b/crates/burn-import/src/burn/node/batch_norm.rs index ed9221a15f..76bce6a8e6 100644 --- a/crates/burn-import/src/burn/node/batch_norm.rs +++ b/crates/burn-import/src/burn/node/batch_norm.rs @@ -175,7 +175,7 @@ mod tests { BatchNormConfig::new( 128, // num_features 0.00001, // epsilon - 0.1 // momentum + 0.1, // momentum ), )); diff --git a/crates/burn-import/src/burn/node/conv1d.rs b/crates/burn-import/src/burn/node/conv1d.rs index 717ddc7cfc..5b517f103f 100644 --- a/crates/burn-import/src/burn/node/conv1d.rs +++ b/crates/burn-import/src/burn/node/conv1d.rs @@ -150,14 +150,14 @@ mod tests { TensorData::from([2f32]), None, Conv1dConfig::new( - 3, // channels_in - 3, // channels_out - 3, // kernel_size - 1, // stride + 3, // channels_in + 3, // channels_out + 3, // kernel_size + 1, // stride PaddingConfig1d::Valid, // padding - 1, // dilation - 1, // groups - true // bias + 1, // dilation + 1, // groups + true, // bias ), )); diff --git a/crates/burn-import/src/burn/node/conv2d.rs b/crates/burn-import/src/burn/node/conv2d.rs index 039932fa89..9e87a2afcb 100644 --- a/crates/burn-import/src/burn/node/conv2d.rs +++ b/crates/burn-import/src/burn/node/conv2d.rs @@ -149,13 +149,13 @@ mod tests { TensorData::from([2f32]), None, Conv2dConfig::new( - [3, 3], // kernel_size - [3, 3], // stride - [1, 1], // dilation + [3, 3], // kernel_size + [3, 3], // stride + [1, 1], // dilation PaddingConfig2d::Valid, // padding - [1, 1], // output_padding - 1, // groups - true // bias + [1, 1], // output_padding + 1, // groups + true, // bias ), )); diff --git a/crates/burn-import/src/burn/node/conv3d.rs b/crates/burn-import/src/burn/node/conv3d.rs index 50f6fdf384..0bc3766ad2 100644 --- a/crates/burn-import/src/burn/node/conv3d.rs +++ b/crates/burn-import/src/burn/node/conv3d.rs @@ -149,13 +149,13 @@ mod tests { TensorData::from([2f32]), None, Conv3dConfig::new( - [3, 3], // kernel_size - [3, 3, 3], // stride - [1, 1, 1], // dilation - [1, 1, 1], // output_padding - 1, // groups - true, // bias - PaddingConfig3d::Valid // padding + [3, 3], // kernel_size + [3, 3, 3], // stride + [1, 1, 1], // dilation + [1, 1, 1], // output_padding + 1, // groups + true, // bias + PaddingConfig3d::Valid, // padding ), )); From fa67a32f7e5968fa603a5d8434c8607d43dc6cf7 Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Thu, 1 May 2025 14:23:57 -0500 Subject: [PATCH 20/37] remove op_configuration.rs --- crates/burn-import/src/onnx/to_burn.rs | 16 +++++++-------- crates/onnx-ir/src/lib.rs | 1 - crates/onnx-ir/src/op_configuration.rs | 28 -------------------------- 3 files changed, 8 insertions(+), 37 deletions(-) delete mode 100644 crates/onnx-ir/src/op_configuration.rs diff --git a/crates/burn-import/src/onnx/to_burn.rs b/crates/burn-import/src/onnx/to_burn.rs index 14a16d020e..693a1185d3 100644 --- a/crates/burn-import/src/onnx/to_burn.rs +++ b/crates/burn-import/src/onnx/to_burn.rs @@ -81,18 +81,18 @@ use onnx_ir::{ }, node::{ argmax::argmax_config, avg_pool1d::avg_pool1d_config, avg_pool2d::avg_pool2d_config, - batch_norm::batch_norm_config, concat::concat_config, + batch_norm::batch_norm_config, clip::clip_config, concat::concat_config, conv_transpose1d::conv_transpose1d_config, conv_transpose2d::conv_transpose2d_config, conv_transpose3d::conv_transpose3d_config, conv1d::conv1d_config, conv2d::conv2d_config, conv3d::conv3d_config, dropout::dropout_config, flatten::flatten_config, - gather::gather_config, layer_norm::layer_norm_config, linear::linear_config, + gather::gather_config, gemm::gemm_config, hard_sigmoid::hard_sigmoid_config, + layer_norm::layer_norm_config, leaky_relu::leaky_relu_config, linear::linear_config, log_softmax::log_softmax_config, max_pool1d::max_pool1d_config, - max_pool2d::max_pool2d_config, slice::slice_config, softmax::softmax_config, - }, - op_configuration::{ - clip_config, gemm_config, hard_sigmoid_config, leaky_relu_config, one_hot_config, - reduce_max_config, reduce_mean_config, reduce_min_config, reduce_prod_config, - reduce_sum_config, reshape_config, resize_config, squeeze_config, transpose_config, + max_pool2d::max_pool2d_config, one_hot::one_hot_config, reduce_max::reduce_max_config, + reduce_mean::reduce_mean_config, reduce_min::reduce_min_config, + reduce_prod::reduce_prod_config, reduce_sum::reduce_sum_config, reshape::reshape_config, + resize::resize_config, slice::slice_config, softmax::softmax_config, + squeeze::squeeze_config, transpose::transpose_config, }, parse_onnx, util::shape_config, diff --git a/crates/onnx-ir/src/lib.rs b/crates/onnx-ir/src/lib.rs index 0a4d0e847c..b195b88adb 100644 --- a/crates/onnx-ir/src/lib.rs +++ b/crates/onnx-ir/src/lib.rs @@ -3,7 +3,6 @@ mod from_onnx; pub mod ir; pub mod node; mod node_remap; -pub mod op_configuration; mod proto_conversion; mod protos; mod rank_inference; diff --git a/crates/onnx-ir/src/op_configuration.rs b/crates/onnx-ir/src/op_configuration.rs deleted file mode 100644 index 430a49a18c..0000000000 --- a/crates/onnx-ir/src/op_configuration.rs +++ /dev/null @@ -1,28 +0,0 @@ -// Reexport functions from node modules for compatibility -// These should be deprecated and eventually removed in favor of using -// direct imports from the node modules - -pub use crate::node::avg_pool1d::avg_pool1d_config; -pub use crate::node::avg_pool2d::avg_pool2d_config; -pub use crate::node::batch_norm::batch_norm_config; -pub use crate::node::clip::clip_config; -pub use crate::node::conv1d::conv1d_config; -pub use crate::node::conv2d::conv2d_config; -pub use crate::node::conv3d::conv3d_config; -pub use crate::node::conv_transpose1d::conv_transpose1d_config; -pub use crate::node::conv_transpose2d::conv_transpose2d_config; -pub use crate::node::conv_transpose3d::conv_transpose3d_config; -pub use crate::node::gemm::gemm_config; -pub use crate::node::hard_sigmoid::hard_sigmoid_config; -pub use crate::node::leaky_relu::leaky_relu_config; -pub use crate::node::one_hot::one_hot_config; -pub use crate::node::reduce_max::reduce_max_config; -pub use crate::node::reduce_mean::reduce_mean_config; -pub use crate::node::reduce_min::reduce_min_config; -pub use crate::node::reduce_prod::reduce_prod_config; -pub use crate::node::reduce_sum::reduce_sum_config; -pub use crate::node::reshape::reshape_config; -pub use crate::node::resize::resize_config; -pub use crate::node::shape::shape_config; -pub use crate::node::squeeze::squeeze_config; -pub use crate::node::transpose::transpose_config; From 8e8c5a3558e64f5bed9251e46c2b5e223c9212c9 Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Thu, 1 May 2025 19:09:11 -0500 Subject: [PATCH 21/37] Decouple burn-import types from op_configuration --- crates/burn-import/src/burn/node/expand.rs | 68 +++++++----- crates/burn-import/src/burn/node/pad.rs | 12 +- crates/burn-import/src/burn/node/split.rs | 9 +- crates/burn-import/src/burn/node/tile.rs | 11 +- crates/burn-import/src/burn/node/top_k.rs | 12 +- crates/burn-import/src/burn/node/trilu.rs | 12 +- crates/burn-import/src/burn/node/unsqueeze.rs | 11 +- crates/burn-import/src/onnx/mod.rs | 2 +- .../burn-import/src/onnx/op_configuration.rs | 104 +++++++++++++++--- crates/onnx-ir/src/lib.rs | 2 +- 10 files changed, 147 insertions(+), 96 deletions(-) diff --git a/crates/burn-import/src/burn/node/expand.rs b/crates/burn-import/src/burn/node/expand.rs index 0fd56124a7..43df4d37cd 100644 --- a/crates/burn-import/src/burn/node/expand.rs +++ b/crates/burn-import/src/burn/node/expand.rs @@ -1,5 +1,8 @@ use super::{Node, NodeCodegen}; -use crate::burn::{Scope, TensorType, ToTokens, Type}; +use crate::{ + burn::{Scope, TensorType, ToTokens, Type}, + onnx::op_configuration::ExpandShape, +}; use burn::record::PrecisionSettings; use proc_macro2::TokenStream; use quote::quote; @@ -11,12 +14,6 @@ pub struct ExpandNode { pub shape: ExpandShape, } -#[derive(Debug, Clone)] -pub enum ExpandShape { - Static(Vec), - Runtime(Type), -} - impl NodeCodegen for ExpandNode { fn output_types(&self) -> Vec { vec![Type::Tensor(self.output.clone())] @@ -28,10 +25,9 @@ impl NodeCodegen for ExpandNode { // if it is dynamic, the shape will be our 2nd: match &self.shape { ExpandShape::Static(_) => vec![input], - ExpandShape::Runtime(rt_type) => vec![input, rt_type.clone()], + ExpandShape::Runtime(rt_type) => vec![input, Type::from(rt_type)], } } - fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { let input = scope.tensor_use_owned(&self.input, node_position); let output = &self.output.name; @@ -39,22 +35,20 @@ impl NodeCodegen for ExpandNode { let shape = match &self.shape { ExpandShape::Static(static_shape) => static_shape.to_tokens(), - ExpandShape::Runtime(Type::Tensor(shape_tensor)) => { - // Since we don't take ownership of the shape_tensor, `tensor_use_owned` is not needed here. - let tensor_name = &shape_tensor.name; - // The shape of the tensor is statically validated to be rank one during input parsing. - // The tensor must be downloaded from device to CPU for the expand operation. - // Additionally, it needs to be converted to an array for use in BroadcastArgs. - quote! { - TryInto::<[B::IntElem; #output_rank]>::try_into(#tensor_name.to_data().as_slice::().unwrap()).unwrap() + ExpandShape::Runtime(ty) => match Type::from(ty) { + Type::Tensor(shape_tensor) => { + let tensor_name = &shape_tensor.name; + quote! { + TryInto::<[B::IntElem; #output_rank]>::try_into(#tensor_name.to_data().as_slice::().unwrap()).unwrap() + } } - } - ExpandShape::Runtime(Type::Shape(shape)) => { - // Shape implements BroadcastArgs, allowing it to be passed directly to the expand method. - let shape_name = &shape.name; - quote! { #shape_name } - } - _ => panic!("Invalid shape source {:?}", self.shape), + Type::Shape(shape) => { + // Shape implements BroadcastArgs, allowing it to be passed directly to the expand method. + let shape_name = &shape.name; + quote! { #shape_name } + } + b => panic!("Invalid shape source {:?}", b), + }, }; quote! { @@ -70,10 +64,11 @@ impl NodeCodegen for ExpandNode { #[cfg(test)] mod tests { use burn::record::FullPrecisionSettings; + use onnx_ir::{ArgType, Argument, ElementType}; use super::*; use crate::burn::{ - ShapeType, TensorType, + TensorType, graph::BurnGraph, node::{expand::ExpandNode, test::assert_tokens}, }; @@ -121,7 +116,6 @@ mod tests { assert_tokens(graph.codegen(), expected); } - #[test] fn test_codegen_expand_shape() { let mut graph = BurnGraph::::default(); @@ -129,7 +123,12 @@ mod tests { graph.register(ExpandNode::new( TensorType::new_float("tensor1", 4), TensorType::new_float("tensor2", 4), - ExpandShape::Runtime(Type::Shape(ShapeType::new("shape1", 4))), + ExpandShape::Runtime(Argument { + name: "shape1".to_string(), + ty: ArgType::Shape(4), + value: None, + passed: false, + }), )); graph.register_input_output( @@ -177,12 +176,21 @@ mod tests { fn test_codegen_expand_tensor() { let mut graph = BurnGraph::::default(); - let shape_tensor_type = TensorType::new_int("tensor3", 4); + // let shape_tensor_type = TensorType::new_int("tensor3", 1); graph.register(ExpandNode::new( TensorType::new_float("tensor1", 4), TensorType::new_float("tensor2", 4), - ExpandShape::Runtime(Type::Tensor(shape_tensor_type)), + ExpandShape::Runtime(Argument { + name: "tensor3".to_string(), + ty: ArgType::Tensor(onnx_ir::TensorType { + elem_type: ElementType::Int32, + rank: 1, + static_shape: None, + }), + value: None, + passed: false, + }), )); graph.register_input_output( @@ -215,7 +223,7 @@ mod tests { pub fn forward( &self, tensor1: Tensor, - tensor3: Tensor, + tensor3: Tensor, ) -> Tensor { let tensor2 = tensor1.expand( TryInto::<[B::IntElem; 4usize]>::try_into(tensor3.to_data().as_slice::().unwrap()) diff --git a/crates/burn-import/src/burn/node/pad.rs b/crates/burn-import/src/burn/node/pad.rs index 0374a8832d..9fe2f2eca7 100644 --- a/crates/burn-import/src/burn/node/pad.rs +++ b/crates/burn-import/src/burn/node/pad.rs @@ -1,18 +1,14 @@ use std::str::FromStr; use super::{Node, NodeCodegen}; -use crate::burn::{Scope, TensorType, ToTokens, Type}; -use burn::config::Config; +use crate::{ + burn::{Scope, TensorType, ToTokens, Type}, + onnx::op_configuration::PadConfig, +}; use burn::record::PrecisionSettings; use proc_macro2::TokenStream; use quote::quote; -#[derive(Config, Debug)] -pub struct PadConfig { - pub pads: Vec, - pub constant_value: f32, -} - #[derive(Debug, Clone, new)] pub struct PadNode { pub input: TensorType, diff --git a/crates/burn-import/src/burn/node/split.rs b/crates/burn-import/src/burn/node/split.rs index 2980959099..4b00072a2a 100644 --- a/crates/burn-import/src/burn/node/split.rs +++ b/crates/burn-import/src/burn/node/split.rs @@ -1,17 +1,10 @@ use super::{Node, NodeCodegen}; use crate::burn::{Scope, TensorType, ToTokens, Type}; -use burn::config::Config; +use crate::onnx::op_configuration::SplitConfig; use burn::record::PrecisionSettings; use proc_macro2::TokenStream; use quote::quote; -#[derive(Config, Debug)] -pub struct SplitConfig { - pub axis: usize, - pub split_size: Option, - pub split_sizes: Option>, -} - #[derive(Debug, Clone, new)] pub struct SplitNode { pub input: TensorType, diff --git a/crates/burn-import/src/burn/node/tile.rs b/crates/burn-import/src/burn/node/tile.rs index 68d31fae21..900498d253 100644 --- a/crates/burn-import/src/burn/node/tile.rs +++ b/crates/burn-import/src/burn/node/tile.rs @@ -1,15 +1,12 @@ use super::{Node, NodeCodegen}; -use crate::burn::{Scope, TensorType, ToTokens, Type}; -use burn::config::Config; +use crate::{ + burn::{Scope, TensorType, ToTokens, Type}, + onnx::op_configuration::TileConfig, +}; use burn::record::PrecisionSettings; use proc_macro2::TokenStream; use quote::quote; -#[derive(Config, Debug)] -pub struct TileConfig { - pub repeats: Vec, -} - #[derive(Debug, Clone, new)] pub struct TileNode { pub input: TensorType, diff --git a/crates/burn-import/src/burn/node/top_k.rs b/crates/burn-import/src/burn/node/top_k.rs index e79a14c51d..d7efb4c924 100644 --- a/crates/burn-import/src/burn/node/top_k.rs +++ b/crates/burn-import/src/burn/node/top_k.rs @@ -1,16 +1,12 @@ use super::{Node, NodeCodegen}; -use crate::burn::{Scope, TensorType, Type}; -use burn::config::Config; +use crate::{ + burn::{Scope, TensorType, Type}, + onnx::op_configuration::TopKConfig, +}; use burn::record::PrecisionSettings; use proc_macro2::TokenStream; use quote::{ToTokens, quote}; -#[derive(Config, Debug)] -pub struct TopKConfig { - pub axis: usize, - pub k: usize, -} - #[derive(Debug, Clone, new)] pub struct TopKNode { pub input: TensorType, diff --git a/crates/burn-import/src/burn/node/trilu.rs b/crates/burn-import/src/burn/node/trilu.rs index 18c0ac8196..a71911031e 100644 --- a/crates/burn-import/src/burn/node/trilu.rs +++ b/crates/burn-import/src/burn/node/trilu.rs @@ -1,16 +1,12 @@ use super::{Node, NodeCodegen}; -use crate::burn::{Scope, TensorType, ToTokens, Type}; -use burn::config::Config; +use crate::{ + burn::{Scope, TensorType, ToTokens, Type}, + onnx::op_configuration::TriluConfig, +}; use burn::record::PrecisionSettings; use proc_macro2::TokenStream; use quote::quote; -#[derive(Config, Debug)] -pub struct TriluConfig { - pub upper: bool, - pub diagonal: i64, -} - #[derive(Debug, Clone, new)] pub struct TriluNode { pub input: TensorType, diff --git a/crates/burn-import/src/burn/node/unsqueeze.rs b/crates/burn-import/src/burn/node/unsqueeze.rs index afe22cf220..0b0ed0ac85 100644 --- a/crates/burn-import/src/burn/node/unsqueeze.rs +++ b/crates/burn-import/src/burn/node/unsqueeze.rs @@ -1,5 +1,8 @@ use super::{Node, NodeCodegen}; -use crate::burn::{BurnImports, Scope, TensorType, ToTokens, Type}; +use crate::{ + burn::{BurnImports, Scope, TensorType, ToTokens, Type}, + onnx::op_configuration::UnsqueezeAxes, +}; use burn::record::PrecisionSettings; use proc_macro2::TokenStream; use quote::quote; @@ -11,12 +14,6 @@ pub struct UnsqueezeNode { pub axes: UnsqueezeAxes, } -#[derive(Debug, Clone)] -pub enum UnsqueezeAxes { - Static(Vec), - Runtime(Type), -} - impl NodeCodegen for UnsqueezeNode { fn output_types(&self) -> Vec { vec![Type::Tensor(self.output.clone())] diff --git a/crates/burn-import/src/onnx/mod.rs b/crates/burn-import/src/onnx/mod.rs index b0b67d79c3..29a99a6680 100644 --- a/crates/burn-import/src/onnx/mod.rs +++ b/crates/burn-import/src/onnx/mod.rs @@ -1,3 +1,3 @@ -mod op_configuration; +pub mod op_configuration; mod to_burn; pub use to_burn::*; diff --git a/crates/burn-import/src/onnx/op_configuration.rs b/crates/burn-import/src/onnx/op_configuration.rs index 27b15b34ac..0dd7f39fd6 100644 --- a/crates/burn-import/src/onnx/op_configuration.rs +++ b/crates/burn-import/src/onnx/op_configuration.rs @@ -1,12 +1,32 @@ -// TODO Move op_configuration.rs from burn-import to onnx-ir #3091 -// See https://github.com/tracel-ai/burn/issues/3091 +//! This file contains the configuration structs and functions for various ONNX operations. -use crate::burn::node::{ - expand::ExpandShape, pad::PadConfig, split::SplitConfig, tile::TileConfig, top_k::TopKConfig, - trilu::TriluConfig, unsqueeze::UnsqueezeAxes, -}; -use onnx_ir::ir::{ArgType, AttributeValue, Data, ElementType, Node, TensorData}; +// use crate::burn::Type; +use onnx_ir::ir::{ArgType, Argument, AttributeValue, Data, ElementType, Node, TensorData}; +use crate::burn::Type; + +/// Configuration for the Pad operation. +#[derive(Debug, Clone, new)] +pub struct PadConfig { + /// The paddings to be applied to each dimension. + pub pads: Vec, + /// The constant value to fill the padded areas with. + pub constant_value: f32, +} + +/// Shape information for the Expand operation. +#[derive(Debug, Clone)] +pub enum ExpandShape { + /// Static shape information known at compile time. + Static(Vec), + /// Runtime shape that will be determined during execution. + Runtime(Argument), +} + +/// Creates an ExpandShape configuration from the given Node. +/// +/// Extracts shape information from the node's second input to determine +/// whether to use static or runtime shape expansion. pub fn expand_config(node: &Node) -> ExpandShape { match &node.inputs[1].ty { ArgType::Tensor(tensor) => { @@ -29,7 +49,7 @@ pub fn expand_config(node: &Node) -> ExpandShape { }) => ExpandShape::Static(shape.clone()), None => { // we were unable to statically determine the input value, so we'll need to fetch it at runtime - ExpandShape::Runtime(crate::burn::Type::from(&node.inputs[1])) + ExpandShape::Runtime(node.inputs[1].clone()) } _ => panic!( "Shape data type must be int64, is {:?}", @@ -38,7 +58,34 @@ pub fn expand_config(node: &Node) -> ExpandShape { } } -/// Create a TileConfig from the attributes of the node +/// Configuration for the Split operation. +#[derive(Clone, Debug, new)] +pub struct SplitConfig { + /// The axis along which to split the input tensor. + pub axis: usize, + /// The uniform size of each split when splitting evenly. + pub split_size: Option, + /// Custom sizes for each split when splitting unevenly. + pub split_sizes: Option>, +} + +/// Configuration for the Tile operation. +#[derive(Debug, Clone, new)] +pub struct TileConfig { + /// The number of times to repeat each dimension. + pub repeats: Vec, +} + +/// Axes specification for the Unsqueeze operation. +#[derive(Debug, Clone, new)] +pub enum UnsqueezeAxes { + /// Static axes known at compile time. + Static(Vec), + /// Runtime axes that will be determined during execution. + Runtime(Type), +} + +/// Creates a TileConfig from the node attributes and inputs. pub fn tile_config(node: &Node) -> TileConfig { let repeat = node .inputs @@ -58,9 +105,18 @@ pub fn tile_config(node: &Node) -> TileConfig { TileConfig::new(repeat) } -/// Create a TopKConfig from the attributes of the node. +/// Configuration for the TopK operation. +#[derive(Debug, Clone, new)] +pub struct TopKConfig { + /// The axis along which to perform the top-k selection. + pub axis: usize, + /// The number of top elements to select. + pub k: usize, +} + +/// Creates a TopKConfig from the node attributes and inputs. pub fn top_k_config(node: &Node) -> TopKConfig { - // extract the shape of the input data tensor + // Extract the shape of the input data tensor let data_tensor = match node.inputs.first().unwrap().clone().ty { ArgType::Tensor(tensor) => tensor, _ => panic!("Only tensor input is valid"), @@ -86,7 +142,7 @@ pub fn top_k_config(node: &Node) -> TopKConfig { None => -1, }; - // if axis is negative, it is counted from the end + // If axis is negative, it is counted from the end if axis < 0 { axis += data_tensor.rank as i64; } @@ -106,7 +162,16 @@ pub fn top_k_config(node: &Node) -> TopKConfig { TopKConfig::new(axis as usize, k as usize) } -/// Create a TriluConfig from the attributes of the node +/// Configuration for the Trilu operation. +#[derive(Debug, Clone, new)] +pub struct TriluConfig { + /// Whether to return the upper triangular matrix. + pub upper: bool, + /// The diagonal offset. + pub diagonal: i64, +} + +/// Creates a TriluConfig from the node attributes and inputs. pub fn trilu_config(node: &Node) -> TriluConfig { let mut upper = true; let mut diagonal = 0; @@ -129,7 +194,7 @@ pub fn trilu_config(node: &Node) -> TriluConfig { TriluConfig::new(upper, diagonal) } -/// Create a PadConfig from the attributes of the node +/// Creates a PadConfig from the node attributes and inputs. pub fn pad_config(node: &Node) -> PadConfig { fn get_pads_input(node: &Node) -> Vec { match &node.inputs[1].value { @@ -150,7 +215,7 @@ pub fn pad_config(node: &Node) -> PadConfig { _ => panic!("Pad: Only tensor input is valid"), }; - //TODO : handle more possible attributes + // TODO: Handle more possible attributes let mut pads: Vec = get_pads_input(node) .into_iter() .map(|x| x as usize) @@ -215,7 +280,7 @@ pub fn pad_config(node: &Node) -> PadConfig { vec![left, right, top, bottom] } fn get_constant_value(node: &Node) -> f32 { - // TODO: support int, boolean + // TODO: Support int, boolean let mut constant_value = node.inputs .get(2) .and_then(|input| match &input.value.as_ref().expect("Value input must be present").data { @@ -250,8 +315,10 @@ pub fn pad_config(node: &Node) -> PadConfig { PadConfig::new(pads, constant_value) } -//Note this function should only execute if the second input is a constant -//if it wasn't and the output shape was known, unsqueeze has been remapped to reshape +/// Creates UnsqueezeAxes configuration from the node attributes. +/// +/// Note: This function should only execute if the second input is a constant. +/// If it wasn't and the output shape was known, unsqueeze has been remapped to reshape. pub fn unsqueeze_config(node: &Node) -> UnsqueezeAxes { // Check if axes attribute exists for (key, value) in node.attrs.iter() { @@ -285,6 +352,7 @@ pub fn unsqueeze_config(node: &Node) -> UnsqueezeAxes { } } +/// Creates a SplitConfig from the node attributes and inputs. pub fn split_config(node: &Node) -> SplitConfig { // Initialize the axis to split along (default is 0 as per ONNX specification) let mut axis: i64 = 0; diff --git a/crates/onnx-ir/src/lib.rs b/crates/onnx-ir/src/lib.rs index b195b88adb..7c7340482e 100644 --- a/crates/onnx-ir/src/lib.rs +++ b/crates/onnx-ir/src/lib.rs @@ -10,4 +10,4 @@ pub mod util; pub use from_onnx::convert_constant_value; pub use from_onnx::parse_onnx; -pub use ir::OnnxGraph; +pub use ir::*; From 5fbfbec16b2432ea6e5e69f456c997111f62f428 Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Thu, 1 May 2025 20:05:50 -0500 Subject: [PATCH 22/37] Moved remaining configs from burn-import to onnx-ir --- crates/burn-import/src/burn/node/expand.rs | 6 +- crates/burn-import/src/burn/node/pad.rs | 6 +- crates/burn-import/src/burn/node/split.rs | 2 +- crates/burn-import/src/burn/node/tile.rs | 6 +- crates/burn-import/src/burn/node/top_k.rs | 6 +- crates/burn-import/src/burn/node/trilu.rs | 6 +- crates/burn-import/src/burn/node/unsqueeze.rs | 38 +- crates/burn-import/src/onnx/mod.rs | 1 - .../burn-import/src/onnx/op_configuration.rs | 461 ------------------ crates/burn-import/src/onnx/to_burn.rs | 24 +- crates/onnx-ir/src/node/expand.rs | 190 +++++++- crates/onnx-ir/src/node/mod.rs | 3 + crates/onnx-ir/src/node/pad.rs | 400 +++++++++++++++ crates/onnx-ir/src/node/split.rs | 328 ++++++++++++- crates/onnx-ir/src/node/tile.rs | 197 ++++++++ crates/onnx-ir/src/node/topk.rs | 260 +++++++++- crates/onnx-ir/src/node/trilu.rs | 233 +++++++++ crates/onnx-ir/src/node/unsqueeze.rs | 178 ++++++- 18 files changed, 1772 insertions(+), 573 deletions(-) delete mode 100644 crates/burn-import/src/onnx/op_configuration.rs create mode 100644 crates/onnx-ir/src/node/pad.rs create mode 100644 crates/onnx-ir/src/node/tile.rs create mode 100644 crates/onnx-ir/src/node/trilu.rs diff --git a/crates/burn-import/src/burn/node/expand.rs b/crates/burn-import/src/burn/node/expand.rs index 43df4d37cd..c0afecfe6a 100644 --- a/crates/burn-import/src/burn/node/expand.rs +++ b/crates/burn-import/src/burn/node/expand.rs @@ -1,9 +1,7 @@ use super::{Node, NodeCodegen}; -use crate::{ - burn::{Scope, TensorType, ToTokens, Type}, - onnx::op_configuration::ExpandShape, -}; +use crate::burn::{Scope, TensorType, ToTokens, Type}; use burn::record::PrecisionSettings; +use onnx_ir::node::expand::ExpandShape; use proc_macro2::TokenStream; use quote::quote; diff --git a/crates/burn-import/src/burn/node/pad.rs b/crates/burn-import/src/burn/node/pad.rs index 9fe2f2eca7..5b76421d05 100644 --- a/crates/burn-import/src/burn/node/pad.rs +++ b/crates/burn-import/src/burn/node/pad.rs @@ -1,11 +1,9 @@ use std::str::FromStr; use super::{Node, NodeCodegen}; -use crate::{ - burn::{Scope, TensorType, ToTokens, Type}, - onnx::op_configuration::PadConfig, -}; +use crate::burn::{Scope, TensorType, ToTokens, Type}; use burn::record::PrecisionSettings; +use onnx_ir::node::pad::PadConfig; use proc_macro2::TokenStream; use quote::quote; diff --git a/crates/burn-import/src/burn/node/split.rs b/crates/burn-import/src/burn/node/split.rs index 4b00072a2a..5943f94b6b 100644 --- a/crates/burn-import/src/burn/node/split.rs +++ b/crates/burn-import/src/burn/node/split.rs @@ -1,7 +1,7 @@ use super::{Node, NodeCodegen}; use crate::burn::{Scope, TensorType, ToTokens, Type}; -use crate::onnx::op_configuration::SplitConfig; use burn::record::PrecisionSettings; +use onnx_ir::node::split::SplitConfig; use proc_macro2::TokenStream; use quote::quote; diff --git a/crates/burn-import/src/burn/node/tile.rs b/crates/burn-import/src/burn/node/tile.rs index 900498d253..bda7dc9ecc 100644 --- a/crates/burn-import/src/burn/node/tile.rs +++ b/crates/burn-import/src/burn/node/tile.rs @@ -1,9 +1,7 @@ use super::{Node, NodeCodegen}; -use crate::{ - burn::{Scope, TensorType, ToTokens, Type}, - onnx::op_configuration::TileConfig, -}; +use crate::burn::{Scope, TensorType, ToTokens, Type}; use burn::record::PrecisionSettings; +use onnx_ir::node::tile::TileConfig; use proc_macro2::TokenStream; use quote::quote; diff --git a/crates/burn-import/src/burn/node/top_k.rs b/crates/burn-import/src/burn/node/top_k.rs index d7efb4c924..bbd1c74d61 100644 --- a/crates/burn-import/src/burn/node/top_k.rs +++ b/crates/burn-import/src/burn/node/top_k.rs @@ -1,9 +1,7 @@ use super::{Node, NodeCodegen}; -use crate::{ - burn::{Scope, TensorType, Type}, - onnx::op_configuration::TopKConfig, -}; +use crate::burn::{Scope, TensorType, Type}; use burn::record::PrecisionSettings; +use onnx_ir::node::topk::TopKConfig; use proc_macro2::TokenStream; use quote::{ToTokens, quote}; diff --git a/crates/burn-import/src/burn/node/trilu.rs b/crates/burn-import/src/burn/node/trilu.rs index a71911031e..2bb84e9447 100644 --- a/crates/burn-import/src/burn/node/trilu.rs +++ b/crates/burn-import/src/burn/node/trilu.rs @@ -1,9 +1,7 @@ use super::{Node, NodeCodegen}; -use crate::{ - burn::{Scope, TensorType, ToTokens, Type}, - onnx::op_configuration::TriluConfig, -}; +use crate::burn::{Scope, TensorType, ToTokens, Type}; use burn::record::PrecisionSettings; +use onnx_ir::node::trilu::TriluConfig; use proc_macro2::TokenStream; use quote::quote; diff --git a/crates/burn-import/src/burn/node/unsqueeze.rs b/crates/burn-import/src/burn/node/unsqueeze.rs index 0b0ed0ac85..d8e40503c7 100644 --- a/crates/burn-import/src/burn/node/unsqueeze.rs +++ b/crates/burn-import/src/burn/node/unsqueeze.rs @@ -1,9 +1,7 @@ use super::{Node, NodeCodegen}; -use crate::{ - burn::{BurnImports, Scope, TensorType, ToTokens, Type}, - onnx::op_configuration::UnsqueezeAxes, -}; +use crate::burn::{BurnImports, Scope, TensorType, ToTokens, Type}; use burn::record::PrecisionSettings; +use onnx_ir::node::unsqueeze::UnsqueezeConfig; use proc_macro2::TokenStream; use quote::quote; @@ -11,7 +9,7 @@ use quote::quote; pub struct UnsqueezeNode { pub input: Type, pub output: TensorType, - pub axes: UnsqueezeAxes, + pub axes: UnsqueezeConfig, } impl NodeCodegen for UnsqueezeNode { @@ -22,8 +20,8 @@ impl NodeCodegen for UnsqueezeNode { fn input_types(&self) -> Vec { let input = self.input.clone(); match &self.axes { - UnsqueezeAxes::Static(_) => vec![input], - UnsqueezeAxes::Runtime(rt_type) => vec![input, rt_type.clone()], + UnsqueezeConfig::Static(_) => vec![input], + UnsqueezeConfig::Runtime(rt_type) => vec![input, Type::from(rt_type)], } } fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { @@ -31,17 +29,19 @@ impl NodeCodegen for UnsqueezeNode { let output_rank = self.output.rank.to_tokens(); let axes = match &self.axes { - UnsqueezeAxes::Static(static_axes) => static_axes.to_tokens(), - UnsqueezeAxes::Runtime(Type::Tensor(axes_tensor)) => { - let tensor_name = &axes_tensor.name; - quote! { - #tensor_name.to_data().as_slice::().unwrap().iter().map(|&x| x.to_isize()).collect::>() + UnsqueezeConfig::Static(static_axes) => static_axes.to_tokens(), + UnsqueezeConfig::Runtime(arg) => match Type::from(arg) { + Type::Tensor(axes_tensor) => { + let tensor_name = &axes_tensor.name; + quote! { + #tensor_name.to_data().as_slice::().unwrap().iter().map(|&x| x.to_isize()).collect::>() + } } - } - _ => panic!( - "UnsqueezeNode received invalid axes type: expected static axes or tensor but got {:?}", - self.axes - ), + _ => panic!( + "UnsqueezeNode received invalid axes type: expected tensor but got {:?}", + arg + ), + }, }; match &self.input { @@ -76,7 +76,7 @@ impl NodeCodegen for UnsqueezeNode { _ => {} } match &self.axes { - UnsqueezeAxes::Runtime(_) => { + UnsqueezeConfig::Runtime(_) => { imports.register("alloc::vec::Vec"); imports.register("burn::tensor::cast::ToElement"); } @@ -103,7 +103,7 @@ mod tests { graph.register(UnsqueezeNode::new( Type::Tensor(TensorType::new_float("tensor1", 3)), TensorType::new_float("tensor2", 5), - UnsqueezeAxes::Static([0, 4].into()), + UnsqueezeConfig::Static([0, 4].into()), )); graph.register_input_output(vec!["tensor1".to_string()], vec!["tensor2".to_string()]); diff --git a/crates/burn-import/src/onnx/mod.rs b/crates/burn-import/src/onnx/mod.rs index 29a99a6680..e387b0bdb9 100644 --- a/crates/burn-import/src/onnx/mod.rs +++ b/crates/burn-import/src/onnx/mod.rs @@ -1,3 +1,2 @@ -pub mod op_configuration; mod to_burn; pub use to_burn::*; diff --git a/crates/burn-import/src/onnx/op_configuration.rs b/crates/burn-import/src/onnx/op_configuration.rs deleted file mode 100644 index 0dd7f39fd6..0000000000 --- a/crates/burn-import/src/onnx/op_configuration.rs +++ /dev/null @@ -1,461 +0,0 @@ -//! This file contains the configuration structs and functions for various ONNX operations. - -// use crate::burn::Type; -use onnx_ir::ir::{ArgType, Argument, AttributeValue, Data, ElementType, Node, TensorData}; - -use crate::burn::Type; - -/// Configuration for the Pad operation. -#[derive(Debug, Clone, new)] -pub struct PadConfig { - /// The paddings to be applied to each dimension. - pub pads: Vec, - /// The constant value to fill the padded areas with. - pub constant_value: f32, -} - -/// Shape information for the Expand operation. -#[derive(Debug, Clone)] -pub enum ExpandShape { - /// Static shape information known at compile time. - Static(Vec), - /// Runtime shape that will be determined during execution. - Runtime(Argument), -} - -/// Creates an ExpandShape configuration from the given Node. -/// -/// Extracts shape information from the node's second input to determine -/// whether to use static or runtime shape expansion. -pub fn expand_config(node: &Node) -> ExpandShape { - match &node.inputs[1].ty { - ArgType::Tensor(tensor) => { - assert_eq!(tensor.rank, 1, "Expand: shape tensor must be 1D"); - assert!( - matches!(tensor.elem_type, ElementType::Int64), - "Expand: shape tensor must have element type int64" - ); - } - ArgType::Shape(_) => { - // Shapes are always 1-D int64 data, so nothing to assert here - } - _ => panic!("Only tensor input is valid for shape"), - } - - match &node.inputs[1].value { - Some(TensorData { - data: Data::Int64s(shape), - .. - }) => ExpandShape::Static(shape.clone()), - None => { - // we were unable to statically determine the input value, so we'll need to fetch it at runtime - ExpandShape::Runtime(node.inputs[1].clone()) - } - _ => panic!( - "Shape data type must be int64, is {:?}", - &node.inputs[1].value - ), - } -} - -/// Configuration for the Split operation. -#[derive(Clone, Debug, new)] -pub struct SplitConfig { - /// The axis along which to split the input tensor. - pub axis: usize, - /// The uniform size of each split when splitting evenly. - pub split_size: Option, - /// Custom sizes for each split when splitting unevenly. - pub split_sizes: Option>, -} - -/// Configuration for the Tile operation. -#[derive(Debug, Clone, new)] -pub struct TileConfig { - /// The number of times to repeat each dimension. - pub repeats: Vec, -} - -/// Axes specification for the Unsqueeze operation. -#[derive(Debug, Clone, new)] -pub enum UnsqueezeAxes { - /// Static axes known at compile time. - Static(Vec), - /// Runtime axes that will be determined during execution. - Runtime(Type), -} - -/// Creates a TileConfig from the node attributes and inputs. -pub fn tile_config(node: &Node) -> TileConfig { - let repeat = node - .inputs - .get(1) - .map(|input| { - if let Some(TensorData { data, .. }) = &input.value { - data.clone() - .into_i64s() - .iter() - .map(|&x| x as usize) - .collect() - } else { - vec![] - } - }) - .unwrap_or_default(); - TileConfig::new(repeat) -} - -/// Configuration for the TopK operation. -#[derive(Debug, Clone, new)] -pub struct TopKConfig { - /// The axis along which to perform the top-k selection. - pub axis: usize, - /// The number of top elements to select. - pub k: usize, -} - -/// Creates a TopKConfig from the node attributes and inputs. -pub fn top_k_config(node: &Node) -> TopKConfig { - // Extract the shape of the input data tensor - let data_tensor = match node.inputs.first().unwrap().clone().ty { - ArgType::Tensor(tensor) => tensor, - _ => panic!("Only tensor input is valid"), - }; - - let k = match node.inputs.get(1) { - Some(k_tensor) => k_tensor - .clone() - .value - .expect("TopK: only constant 'k' tensor is currently supported") - .data - .into_i64s()[0], - _ => node - .attrs - .get("k") - .expect("TopK: number of top elements 'k' is missing") - .clone() - .into_i64(), - }; - - let mut axis = match node.attrs.get("axis") { - Some(axis) => axis.clone().into_i64(), - None => -1, - }; - - // If axis is negative, it is counted from the end - if axis < 0 { - axis += data_tensor.rank as i64; - } - - if let Some(largest) = node.attrs.get("largest") { - if largest.clone().into_i64() != 1 { - unimplemented!("TopK: only largest elements is supported") - } - }; - - if let Some(sorted) = node.attrs.get("sorted") { - if sorted.clone().into_i64() != 1 { - unimplemented!("TopK: only sorted elements is supported") - } - }; - - TopKConfig::new(axis as usize, k as usize) -} - -/// Configuration for the Trilu operation. -#[derive(Debug, Clone, new)] -pub struct TriluConfig { - /// Whether to return the upper triangular matrix. - pub upper: bool, - /// The diagonal offset. - pub diagonal: i64, -} - -/// Creates a TriluConfig from the node attributes and inputs. -pub fn trilu_config(node: &Node) -> TriluConfig { - let mut upper = true; - let mut diagonal = 0; - for (key, value) in node.attrs.iter() { - match key.as_str() { - "upper" => upper = value.clone().into_i64() != 0, - _ => {} - } - } - // The second input of the Trilu node is the diagonal value, coming from a constant node - if let Some(diagonal_arg) = node.inputs.get(1) { - if let Some(TensorData { - data: Data::Int64(diagonal_val), - .. - }) = &diagonal_arg.value - { - diagonal = *diagonal_val; - } - } - TriluConfig::new(upper, diagonal) -} - -/// Creates a PadConfig from the node attributes and inputs. -pub fn pad_config(node: &Node) -> PadConfig { - fn get_pads_input(node: &Node) -> Vec { - match &node.inputs[1].value { - Some(TensorData { data, .. }) => data.clone().into_i64s(), - _ => Vec::new(), - } - } - fn get_pads(node: &Node) -> Vec { - if node.inputs.is_empty() { - panic!("Pad: must provide data as input") - } - if node.inputs.len() >= 4 { - panic!("Pad: axes input is not supported") - } - - let input_dim = match &node.inputs.first().unwrap().ty { - ArgType::Tensor(tensor) => tensor.rank, - _ => panic!("Pad: Only tensor input is valid"), - }; - - // TODO: Handle more possible attributes - let mut pads: Vec = get_pads_input(node) - .into_iter() - .map(|x| x as usize) - .collect(); - - for (key, value) in node.attrs.iter() { - match key.as_str() { - "pads" => { - pads = value - .clone() - .into_i64s() - .iter() - .map(|&x| { - if x < 0 { - panic!("Pad: Negative pad is not supported"); - } - x as usize - }) - .collect() - } - "mode" => { - let mode = value.clone().into_string(); - if mode != "constant" { - panic!("only constant mode is supported, given mode is {}", mode); - } - } - - _ => {} - } - } - - if pads.is_empty() { - panic!("Pad: pads should be given as attribute or as input"); - } - - if pads.len() != input_dim * 2 { - panic!("Pad: pads should be a 1D tensor of shape [2 * num_axes]"); - } - // TODO: Burn's pad should support 1D tensor - if input_dim < 2 { - panic!("Pad: input tensor should be rank 2 or higher"); - } - - let left_index = input_dim - 1; - let top_index = input_dim - 2; - let right_index = pads.len() - 1; - let bottom_index = pads.len() - 2; - let index_list = [left_index, top_index, right_index, bottom_index]; - - for (index, &item) in pads.iter().enumerate() { - if !index_list.contains(&index) && item != 0 { - panic!( - "Pad: padding will only be applied to the last two dimensions but found non zero padding for other dimensions" - ); - } - } - - let left = pads[left_index]; - let top = pads[top_index]; - let right = pads[right_index]; - let bottom = pads[bottom_index]; - vec![left, right, top, bottom] - } - fn get_constant_value(node: &Node) -> f32 { - // TODO: Support int, boolean - let mut constant_value = node.inputs - .get(2) - .and_then(|input| match &input.value.as_ref().expect("Value input must be present").data { - Data::Float16s(constant_value) => { - constant_value.first().map(|&f| f32::from(f)) - } - Data::Float32s(constant_value) => { - constant_value.first().copied() - } - Data::Float64s(constant_value) => { - constant_value.first().map(|&f| f as f32) - } - Data::Float16(constant_value) => Some(f32::from(*constant_value)), - Data::Float32(constant_value) => Some(*constant_value), - Data::Float64(constant_value) => Some(*constant_value as f32), - _ => panic!("Pad: only float values are currently supported for constant value, submit an issue on github"), - }) - .unwrap_or(0.0); - - if node.attrs.contains_key("value") { - constant_value = node.attrs.get("value").map(|value| match value { - AttributeValue::Float32(value) => *value, - _ => panic!("Pad: only float32 values are currently supported for constant value as attribute, submit an issue on github"), - }).expect("constant_value should have had a value now"); - } - constant_value - } - - let pads = get_pads(node); - let constant_value = get_constant_value(node); - - PadConfig::new(pads, constant_value) -} - -/// Creates UnsqueezeAxes configuration from the node attributes. -/// -/// Note: This function should only execute if the second input is a constant. -/// If it wasn't and the output shape was known, unsqueeze has been remapped to reshape. -pub fn unsqueeze_config(node: &Node) -> UnsqueezeAxes { - // Check if axes attribute exists - for (key, value) in node.attrs.iter() { - match key.as_str() { - "axes" => return UnsqueezeAxes::Static(value.clone().into_i64s()), - _ => {} - } - } - - assert!( - !node.inputs.is_empty(), - "Unsqueeze: axes tensor must be present" - ); - - let input_value = &node.inputs[1]; - - match &node.inputs[1].ty { - ArgType::Tensor(tensor) => { - assert_eq!(tensor.rank, 1, "Unsqueeze: axes tensor must be 1D"); - if let Some(TensorData { - data: Data::Int64s(shape), - .. - }) = input_value.value.as_ref() - { - UnsqueezeAxes::Static(shape.clone()) - } else { - UnsqueezeAxes::Runtime(crate::burn::Type::from(&node.inputs[1])) - } - } - _ => panic!("Arg for unsqueeze must be tensor or scalar"), - } -} - -/// Creates a SplitConfig from the node attributes and inputs. -pub fn split_config(node: &Node) -> SplitConfig { - // Initialize the axis to split along (default is 0 as per ONNX specification) - let mut axis: i64 = 0; - // Holds the uniform split size if calculated or provided - let mut split_size: Option = None; - // Holds the custom split sizes if provided as input - let mut split_sizes: Option> = None; - - // Extract the input tensor type to determine rank and shape - let tensor = match node.inputs.first().unwrap().ty { - ArgType::Tensor(ref tensor) => tensor, - _ => panic!("Split: Input must be a valid tensor"), - }; - - // Optionally store the number of outputs if provided as an attribute - let mut num_outputs: Option = None; - - // Iterate through node attributes to extract relevant values - for (key, value) in node.attrs.iter() { - match key.as_str() { - "axis" => axis = value.clone().into_i64(), - "num_outputs" => num_outputs = Some(value.clone().into_i64() as usize), - _ => {} - } - } - - // Handle the case when num_outputs is provided to calculate uniform split size - if let Some(num_outputs) = num_outputs { - if num_outputs == 0 { - panic!("Split: 'num_outputs' must be a positive value greater than zero"); - } - - let dim_size = tensor - .static_shape - .as_ref() - .expect("Split: Static shape must be known to calculate split size")[axis as usize]; - - // Calculate the split size considering any remainder for non-evenly divisible dimensions - let calculated_split_size = - dim_size / (num_outputs - (dim_size % num_outputs != 0) as usize); - - if calculated_split_size == 0 { - panic!( - "Split: Calculated split size is zero. Please ensure 'num_outputs' is valid for the dimension size" - ); - } - - // Assign the calculated split size - split_size = Some(calculated_split_size); - } - - // Adjust axis if negative to count from the end as per ONNX spec - if axis < 0 { - axis += tensor.rank as i64; - } - - // Check for custom split sizes provided as a second input - if node.inputs.len() > 1 && node.inputs[1].value.is_some() { - let sizes = node.inputs[1] - .value - .as_ref() - .unwrap() - .data - .clone() - .into_usizes(); - - if !sizes.is_empty() { - split_sizes = Some(sizes); - } - } - - // Ensure that only one of 'split_sizes' or 'num_outputs' is specified - if split_sizes.is_some() && split_size.is_some() { - panic!( - "Split: Cannot specify both 'split' input and 'num_outputs' attribute simultaneously" - ); - } - - // Infer split_size if neither custom split_sizes nor split_size is provided - if split_sizes.is_none() && split_size.is_none() { - let num_outputs = node.outputs.len(); - let dim_size = tensor - .static_shape - .as_ref() - .expect("Split: Static shape must be known to infer split size")[axis as usize]; - - // Calculate inferred split size based on number of outputs - let calculated_split_size = - dim_size / (num_outputs - (dim_size % num_outputs != 0) as usize); - - if calculated_split_size == 0 { - panic!( - "Split: Inferred split size is zero. Please ensure the number of outputs is valid for the dimension size" - ); - } - - split_size = Some(calculated_split_size); - } - - // Return the configuration for splitting operation - SplitConfig { - axis: axis as usize, - split_size, - split_sizes, - } -} diff --git a/crates/burn-import/src/onnx/to_burn.rs b/crates/burn-import/src/onnx/to_burn.rs index 693a1185d3..a6e0ddea5d 100644 --- a/crates/burn-import/src/onnx/to_burn.rs +++ b/crates/burn-import/src/onnx/to_burn.rs @@ -69,10 +69,6 @@ use crate::{ logger::init_log, }; -use super::op_configuration::{ - expand_config, pad_config, split_config, tile_config, top_k_config, trilu_config, - unsqueeze_config, -}; use onnx_ir::{ convert_constant_value, ir::{ @@ -84,15 +80,17 @@ use onnx_ir::{ batch_norm::batch_norm_config, clip::clip_config, concat::concat_config, conv_transpose1d::conv_transpose1d_config, conv_transpose2d::conv_transpose2d_config, conv_transpose3d::conv_transpose3d_config, conv1d::conv1d_config, conv2d::conv2d_config, - conv3d::conv3d_config, dropout::dropout_config, flatten::flatten_config, - gather::gather_config, gemm::gemm_config, hard_sigmoid::hard_sigmoid_config, - layer_norm::layer_norm_config, leaky_relu::leaky_relu_config, linear::linear_config, - log_softmax::log_softmax_config, max_pool1d::max_pool1d_config, - max_pool2d::max_pool2d_config, one_hot::one_hot_config, reduce_max::reduce_max_config, - reduce_mean::reduce_mean_config, reduce_min::reduce_min_config, - reduce_prod::reduce_prod_config, reduce_sum::reduce_sum_config, reshape::reshape_config, - resize::resize_config, slice::slice_config, softmax::softmax_config, - squeeze::squeeze_config, transpose::transpose_config, + conv3d::conv3d_config, dropout::dropout_config, expand::expand_config, + flatten::flatten_config, gather::gather_config, gemm::gemm_config, + hard_sigmoid::hard_sigmoid_config, layer_norm::layer_norm_config, + leaky_relu::leaky_relu_config, linear::linear_config, log_softmax::log_softmax_config, + max_pool1d::max_pool1d_config, max_pool2d::max_pool2d_config, one_hot::one_hot_config, + pad::pad_config, reduce_max::reduce_max_config, reduce_mean::reduce_mean_config, + reduce_min::reduce_min_config, reduce_prod::reduce_prod_config, + reduce_sum::reduce_sum_config, reshape::reshape_config, resize::resize_config, + slice::slice_config, softmax::softmax_config, split::split_config, squeeze::squeeze_config, + tile::tile_config, topk::top_k_config, transpose::transpose_config, trilu::trilu_config, + unsqueeze::unsqueeze_config, }, parse_onnx, util::shape_config, diff --git a/crates/onnx-ir/src/node/expand.rs b/crates/onnx-ir/src/node/expand.rs index 7e60f2a22b..32a890fed3 100644 --- a/crates/onnx-ir/src/node/expand.rs +++ b/crates/onnx-ir/src/node/expand.rs @@ -1,4 +1,7 @@ -use crate::ir::{ArgType, Data, Node, TensorType}; +use crate::{ + Argument, ElementType, TensorData, + ir::{ArgType, Data, Node, TensorType}, +}; /// Updates the output rank and shape for the Expand operation based on the provided shape input. /// If the shape is a constant, the rank and static shape of the output are set accordingly. @@ -50,13 +53,61 @@ pub fn expand_update_outputs(node: &mut Node) { } } +/// Shape information for the Expand operation. +#[derive(Debug, Clone)] +pub enum ExpandShape { + /// Static shape information known at compile time. + Static(Vec), + /// Runtime shape that will be determined during execution. + Runtime(Argument), +} + +/// Creates an ExpandShape configuration from the given Node. +/// +/// Extracts shape information from the node's second input to determine +/// whether to use static or runtime shape expansion. +pub fn expand_config(node: &Node) -> ExpandShape { + match &node.inputs[1].ty { + ArgType::Tensor(tensor) => { + assert_eq!(tensor.rank, 1, "Expand: shape tensor must be 1D"); + assert!( + matches!(tensor.elem_type, ElementType::Int64), + "Expand: shape tensor must have element type int64" + ); + } + ArgType::Shape(_) => { + // Shapes are always 1-D int64 data, so nothing to assert here + } + _ => panic!("Only tensor input is valid for shape"), + } + + match &node.inputs[1].value { + Some(TensorData { + data: Data::Int64s(shape), + .. + }) => ExpandShape::Static(shape.clone()), + None => { + // we were unable to statically determine the input value, so we'll need to fetch it at runtime + ExpandShape::Runtime(node.inputs[1].clone()) + } + _ => panic!( + "Shape data type must be int64, is {:?}", + &node.inputs[1].value + ), + } +} + #[cfg(test)] mod tests { use super::*; use crate::ir::{Argument, ElementType, NodeType, TensorData}; use std::collections::HashMap; - fn create_test_node(input_rank: usize, shape_value: Option>) -> Node { + fn create_test_node( + input_rank: usize, + shape_value: Option>, + shape_type: Option, + ) -> Node { let inputs = vec![ Argument { name: "input".to_string(), @@ -70,19 +121,21 @@ mod tests { }, Argument { name: "shape".to_string(), - ty: if shape_value.is_some() { - ArgType::Tensor(TensorType { - elem_type: ElementType::Int64, - rank: 1, - static_shape: Some(vec![shape_value.as_ref().unwrap().len()]), - }) - } else { - ArgType::Tensor(TensorType { - elem_type: ElementType::Int64, - rank: 1, - static_shape: Some(vec![3]), // Example: a shape with 3 dimensions - }) - }, + ty: shape_type.unwrap_or_else(|| { + if shape_value.is_some() { + ArgType::Tensor(TensorType { + elem_type: ElementType::Int64, + rank: 1, + static_shape: Some(vec![shape_value.as_ref().unwrap().len()]), + }) + } else { + ArgType::Tensor(TensorType { + elem_type: ElementType::Int64, + rank: 1, + static_shape: Some(vec![3]), // Example: a shape with 3 dimensions + }) + } + }), value: shape_value.map(|shape| TensorData { shape: vec![shape.len()], data: Data::Int64s(shape), @@ -113,7 +166,7 @@ mod tests { #[test] fn test_expand_with_constant_shape() { - let mut node = create_test_node(2, Some(vec![2, 3, 4])); + let mut node = create_test_node(2, Some(vec![2, 3, 4]), None); expand_update_outputs(&mut node); @@ -129,7 +182,7 @@ mod tests { #[test] fn test_expand_with_dynamic_shape() { - let mut node = create_test_node(2, None); + let mut node = create_test_node(2, None, None); expand_update_outputs(&mut node); @@ -146,9 +199,110 @@ mod tests { #[test] #[should_panic(expected = "Expand operation requires exactly two inputs")] fn test_expand_with_incorrect_inputs() { - let mut node = create_test_node(2, Some(vec![2, 3, 4])); + let mut node = create_test_node(2, Some(vec![2, 3, 4]), None); node.inputs.pop(); // Remove one input expand_update_outputs(&mut node); } + + // Tests for expand_config function + + #[test] + fn test_expand_config_with_static_shape() { + let node = create_test_node(2, Some(vec![2, 3, 4]), None); + let config = expand_config(&node); + + match config { + ExpandShape::Static(shape) => { + assert_eq!(shape, vec![2, 3, 4]); + } + ExpandShape::Runtime(_) => panic!("Expected Static config, got Runtime"), + } + } + + #[test] + fn test_expand_config_with_runtime_shape() { + let node = create_test_node(2, None, None); + let config = expand_config(&node); + + match config { + ExpandShape::Static(_) => panic!("Expected Runtime config, got Static"), + ExpandShape::Runtime(arg) => { + assert_eq!(arg.name, "shape"); + match arg.ty { + ArgType::Tensor(tensor) => { + assert_eq!(tensor.elem_type, ElementType::Int64); + assert_eq!(tensor.rank, 1); + } + _ => panic!("Expected tensor type for runtime shape"), + } + } + } + } + + #[test] + fn test_expand_config_with_shape_type() { + let shape_type = ArgType::Shape(3); + let node = create_test_node(2, None, Some(shape_type)); + let config = expand_config(&node); + + match config { + ExpandShape::Static(_) => panic!("Expected Runtime config, got Static"), + ExpandShape::Runtime(arg) => { + assert_eq!(arg.name, "shape"); + match arg.ty { + ArgType::Shape(rank) => { + assert_eq!(rank, 3); + } + _ => panic!("Expected shape type for runtime shape"), + } + } + } + } + + #[test] + #[should_panic(expected = "Expand: shape tensor must be 1D")] + fn test_expand_config_with_invalid_shape_rank() { + let invalid_shape_type = ArgType::Tensor(TensorType { + elem_type: ElementType::Int64, + rank: 2, // Invalid rank, should be 1 + static_shape: None, + }); + let node = create_test_node(2, None, Some(invalid_shape_type)); + let _ = expand_config(&node); + } + + #[test] + #[should_panic(expected = "Expand: shape tensor must have element type int64")] + fn test_expand_config_with_invalid_shape_type() { + let invalid_shape_type = ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, // Invalid element type, should be Int64 + rank: 1, + static_shape: None, + }); + let node = create_test_node(2, None, Some(invalid_shape_type)); + let _ = expand_config(&node); + } + + #[test] + #[should_panic(expected = "Only tensor input is valid for shape")] + fn test_expand_config_with_invalid_input_type() { + let invalid_shape_type = ArgType::Scalar(ElementType::Int64); + let node = create_test_node(2, None, Some(invalid_shape_type)); + let _ = expand_config(&node); + } + + #[test] + #[should_panic(expected = "Shape data type must be int64")] + fn test_expand_config_with_invalid_value_type() { + let mut node = create_test_node(2, None, None); + + // Replace the value with a non-Int64s value + node.inputs[1].value = Some(TensorData { + shape: vec![1], + data: Data::Float32s(vec![1.0]), // Invalid data type + }); + + let _ = expand_config(&node); + } } diff --git a/crates/onnx-ir/src/node/mod.rs b/crates/onnx-ir/src/node/mod.rs index 4e560c7453..917ae45189 100644 --- a/crates/onnx-ir/src/node/mod.rs +++ b/crates/onnx-ir/src/node/mod.rs @@ -37,6 +37,7 @@ pub mod matmul; pub mod max_pool1d; pub mod max_pool2d; pub mod one_hot; +pub mod pad; pub mod padding; pub mod random; pub mod random_like; @@ -53,7 +54,9 @@ pub mod slice; pub mod softmax; pub mod split; pub mod squeeze; +pub mod tile; pub mod topk; pub mod transpose; +pub mod trilu; pub mod unsqueeze; pub mod where_op; diff --git a/crates/onnx-ir/src/node/pad.rs b/crates/onnx-ir/src/node/pad.rs new file mode 100644 index 0000000000..3335464f40 --- /dev/null +++ b/crates/onnx-ir/src/node/pad.rs @@ -0,0 +1,400 @@ +use crate::ir::{ArgType, AttributeValue, Data, Node, TensorData}; + +/// Configuration for the Pad operation. +#[derive(Debug, Clone, PartialEq)] +pub struct PadConfig { + /// The paddings to be applied to each dimension. + pub pads: Vec, + /// The constant value to fill the padded areas with. + pub constant_value: f32, +} + +impl PadConfig { + pub fn new(pads: Vec, constant_value: f32) -> Self { + PadConfig { + pads, + constant_value, + } + } +} + +/// Creates a PadConfig from the node attributes and inputs. +pub fn pad_config(node: &Node) -> PadConfig { + fn get_pads_input(node: &Node) -> Vec { + if node.inputs.len() <= 1 { + return Vec::new(); + } + + match &node.inputs[1].value { + Some(TensorData { data, .. }) => data.clone().into_i64s(), + _ => Vec::new(), + } + } + fn get_pads(node: &Node) -> Vec { + if node.inputs.is_empty() { + panic!("Pad: must provide data as input") + } + if node.inputs.len() >= 4 { + panic!("Pad: axes input is not supported") + } + + let input_dim = match &node.inputs.first().unwrap().ty { + ArgType::Tensor(tensor) => tensor.rank, + _ => panic!("Pad: Only tensor input is valid"), + }; + + // TODO: Handle more possible attributes + let mut pads: Vec = get_pads_input(node) + .into_iter() + .map(|x| x as usize) + .collect(); + + for (key, value) in node.attrs.iter() { + match key.as_str() { + "pads" => { + pads = value + .clone() + .into_i64s() + .iter() + .map(|&x| { + if x < 0 { + panic!("Pad: Negative pad is not supported"); + } + x as usize + }) + .collect() + } + "mode" => { + let mode = value.clone().into_string(); + if mode != "constant" { + panic!("only constant mode is supported, given mode is {}", mode); + } + } + + _ => {} + } + } + + if pads.is_empty() { + panic!("Pad: pads should be given as attribute or as input"); + } + + if pads.len() != input_dim * 2 { + panic!("Pad: pads should be a 1D tensor of shape [2 * num_axes]"); + } + // TODO: Burn's pad should support 1D tensor + if input_dim < 2 { + panic!("Pad: input tensor should be rank 2 or higher"); + } + + let left_index = input_dim - 1; + let top_index = input_dim - 2; + let right_index = pads.len() - 1; + let bottom_index = pads.len() - 2; + let index_list = [left_index, top_index, right_index, bottom_index]; + + for (index, &item) in pads.iter().enumerate() { + if !index_list.contains(&index) && item != 0 { + panic!( + "Pad: padding will only be applied to the last two dimensions but found non zero padding for other dimensions" + ); + } + } + + let left = pads[left_index]; + let top = pads[top_index]; + let right = pads[right_index]; + let bottom = pads[bottom_index]; + vec![left, right, top, bottom] + } + fn get_constant_value(node: &Node) -> f32 { + // TODO: Support int, boolean + let mut constant_value = node.inputs + .get(2) + .and_then(|input| match &input.value.as_ref().expect("Value input must be present").data { + Data::Float16s(constant_value) => { + constant_value.first().map(|&f| f32::from(f)) + } + Data::Float32s(constant_value) => { + constant_value.first().copied() + } + Data::Float64s(constant_value) => { + constant_value.first().map(|&f| f as f32) + } + Data::Float16(constant_value) => Some(f32::from(*constant_value)), + Data::Float32(constant_value) => Some(*constant_value), + Data::Float64(constant_value) => Some(*constant_value as f32), + _ => panic!("Pad: only float values are currently supported for constant value, submit an issue on github"), + }) + .unwrap_or(0.0); + + if node.attrs.contains_key("value") { + constant_value = node.attrs.get("value").map(|value| match value { + AttributeValue::Float32(value) => *value, + _ => panic!("Pad: only float32 values are currently supported for constant value as attribute, submit an issue on github"), + }).expect("constant_value should have had a value now"); + } + constant_value + } + + let pads = get_pads(node); + let constant_value = get_constant_value(node); + + PadConfig::new(pads, constant_value) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::{ + Argument, AttributeValue, Data, ElementType, NodeType, TensorData, TensorType, + }; + use std::collections::HashMap; + + fn create_test_node( + pad_attrs: Option>, + pad_inputs: Option>, + constant_value_attr: Option, + constant_value_input: Option, + mode: Option<&str>, + rank: usize, + ) -> Node { + let mut inputs = vec![Argument { + name: "data".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank, + static_shape: None, + }), + value: None, + passed: true, + }]; + + // Add pads input if provided + if let Some(pads) = pad_inputs { + inputs.push(Argument { + name: "pads".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Int64, + rank: 1, + static_shape: None, + }), + value: Some(TensorData { + data: Data::Int64s(pads), + shape: vec![], + }), + passed: true, + }); + } + + // Add constant_value input if provided + if let Some(value) = constant_value_input { + inputs.push(Argument { + name: "constant_value".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 0, + static_shape: None, + }), + value: Some(TensorData { + data: Data::Float32(value), + shape: vec![], + }), + passed: true, + }); + } + + let mut attrs = HashMap::new(); + if let Some(pads) = pad_attrs { + attrs.insert("pads".to_string(), AttributeValue::Int64s(pads)); + } + if let Some(value) = constant_value_attr { + attrs.insert("value".to_string(), AttributeValue::Float32(value)); + } + if let Some(mode_val) = mode { + attrs.insert( + "mode".to_string(), + AttributeValue::String(mode_val.to_string()), + ); + } + + Node { + node_type: NodeType::Pad, + name: "test_pad".to_string(), + inputs, + outputs: vec![Argument { + name: "output".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank, + static_shape: None, + }), + value: None, + passed: true, + }], + attrs, + } + } + + #[test] + fn test_pad_config_with_attrs() { + // Test for 2D tensor (rank 2) + let pads = vec![0, 0, 1, 1]; + let node = create_test_node( + Some(pads.clone()), + None, + Some(0.0), + None, + Some("constant"), + 2, + ); + let config = pad_config(&node); + assert_eq!( + config, + PadConfig { + pads: vec![0, 1, 0, 1], + constant_value: 0.0 + } + ); + } + + #[test] + fn test_pad_config_with_inputs() { + // For a 2D tensor, pads should have 4 values (2*rank) + let pads = vec![0, 0, 1, 1]; + let node = create_test_node(None, Some(pads.clone()), None, Some(1.0), None, 2); + let config = pad_config(&node); + assert_eq!( + config, + PadConfig { + pads: vec![0, 1, 0, 1], + constant_value: 1.0 + } + ); + } + + #[test] + fn test_pad_config_with_3d_tensor() { + // For a 3D tensor, pads should have 6 values (2*rank) + let pads = vec![0, 0, 0, 0, 1, 1]; + let node = create_test_node( + Some(pads.clone()), + None, + Some(0.5), + None, + Some("constant"), + 3, + ); + let config = pad_config(&node); + assert_eq!( + config, + PadConfig { + pads: vec![0, 1, 0, 1], + constant_value: 0.5 + } + ); + } + + #[test] + fn test_pad_config_attrs_override_inputs() { + // Attributes should override inputs + let attr_pads = vec![0, 0, 2, 2]; + let input_pads = vec![0, 0, 1, 1]; + let node = create_test_node( + Some(attr_pads.clone()), + Some(input_pads), + Some(0.0), + Some(1.0), + Some("constant"), + 2, + ); + let config = pad_config(&node); + assert_eq!( + config, + PadConfig { + pads: vec![0, 2, 0, 2], + constant_value: 0.0 + } + ); + } + + #[test] + #[should_panic(expected = "Pad: must provide data as input")] + fn test_pad_config_no_inputs() { + let mut node = create_test_node(None, None, None, None, None, 2); + node.inputs = vec![]; + let _ = pad_config(&node); + } + + #[test] + #[should_panic(expected = "Pad: Only tensor input is valid")] + fn test_pad_config_invalid_input_type() { + let mut node = create_test_node(Some(vec![0, 0, 1, 1]), None, None, None, None, 2); + node.inputs[0].ty = ArgType::Scalar(ElementType::Float32); + let _ = pad_config(&node); + } + + #[test] + #[should_panic(expected = "Pad: axes input is not supported")] + fn test_pad_config_with_axes_input() { + // Create node with 4 inputs (including axes) + let mut node = create_test_node(None, Some(vec![0, 0, 1, 1]), None, Some(0.0), None, 2); + node.inputs.push(Argument { + name: "axes".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Int64, + rank: 1, + static_shape: None, + }), + value: Some(TensorData { + data: Data::Int64s(vec![0, 1]), + shape: vec![], + }), + passed: true, + }); + let _ = pad_config(&node); + } + + #[test] + #[should_panic(expected = "Pad: Negative pad is not supported")] + fn test_pad_config_negative_pad() { + let node = create_test_node(Some(vec![0, 0, -1, 1]), None, None, None, None, 2); + let _ = pad_config(&node); + } + + #[test] + #[should_panic(expected = "only constant mode is supported")] + fn test_pad_config_unsupported_mode() { + let node = create_test_node(Some(vec![0, 0, 1, 1]), None, None, None, Some("reflect"), 2); + let _ = pad_config(&node); + } + + #[test] + #[should_panic(expected = "Pad: pads should be given as attribute or as input")] + fn test_pad_config_no_pads() { + let node = create_test_node(None, None, None, None, None, 2); + let _ = pad_config(&node); + } + + #[test] + #[should_panic(expected = "Pad: pads should be a 1D tensor of shape [2 * num_axes]")] + fn test_pad_config_invalid_pads_length() { + let node = create_test_node(Some(vec![0, 0, 1]), None, None, None, None, 2); + let _ = pad_config(&node); + } + + #[test] + #[should_panic(expected = "Pad: input tensor should be rank 2 or higher")] + fn test_pad_config_invalid_tensor_rank() { + let node = create_test_node(Some(vec![0, 1]), None, None, None, None, 1); + let _ = pad_config(&node); + } + + #[test] + #[should_panic(expected = "Pad: padding will only be applied to the last two dimensions")] + fn test_pad_config_non_zero_padding_on_other_dimensions() { + // For a 3D tensor, we try to set non-zero padding on first dimension + let node = create_test_node(Some(vec![1, 0, 0, 0, 1, 1]), None, None, None, None, 3); + let _ = pad_config(&node); + } +} diff --git a/crates/onnx-ir/src/node/split.rs b/crates/onnx-ir/src/node/split.rs index 82514f5057..1ba20c9eaf 100644 --- a/crates/onnx-ir/src/node/split.rs +++ b/crates/onnx-ir/src/node/split.rs @@ -25,24 +25,176 @@ pub fn split_update_outputs(node: &mut Node) { } } +/// Configuration for the Split operation. +#[derive(Clone, Debug)] +pub struct SplitConfig { + /// The axis along which to split the input tensor. + pub axis: usize, + /// The uniform size of each split when splitting evenly. + pub split_size: Option, + /// Custom sizes for each split when splitting unevenly. + pub split_sizes: Option>, +} + +impl SplitConfig { + pub fn new(axis: usize, split_size: Option, split_sizes: Option>) -> Self { + SplitConfig { + axis, + split_size, + split_sizes, + } + } +} + +/// Creates a SplitConfig from the node attributes and inputs. +pub fn split_config(node: &Node) -> SplitConfig { + // Initialize the axis to split along (default is 0 as per ONNX specification) + let mut axis: i64 = 0; + // Holds the uniform split size if calculated or provided + let mut split_size: Option = None; + // Holds the custom split sizes if provided as input + let mut split_sizes: Option> = None; + + // Extract the input tensor type to determine rank and shape + let tensor = match node.inputs.first().unwrap().ty { + ArgType::Tensor(ref tensor) => tensor, + _ => panic!("Split: Input must be a valid tensor"), + }; + + // Optionally store the number of outputs if provided as an attribute + let mut num_outputs: Option = None; + + // Iterate through node attributes to extract relevant values + for (key, value) in node.attrs.iter() { + match key.as_str() { + "axis" => axis = value.clone().into_i64(), + "num_outputs" => num_outputs = Some(value.clone().into_i64() as usize), + _ => {} + } + } + + // Handle the case when num_outputs is provided to calculate uniform split size + if let Some(num_outputs) = num_outputs { + if num_outputs == 0 { + panic!("Split: 'num_outputs' must be a positive value greater than zero"); + } + + let dim_size = tensor + .static_shape + .as_ref() + .expect("Split: Static shape must be known to calculate split size")[axis as usize]; + + // Calculate the split size considering any remainder for non-evenly divisible dimensions + let calculated_split_size = + dim_size / (num_outputs - (dim_size % num_outputs != 0) as usize); + + if calculated_split_size == 0 { + panic!( + "Split: Calculated split size is zero. Please ensure 'num_outputs' is valid for the dimension size" + ); + } + + // Assign the calculated split size + split_size = Some(calculated_split_size); + } + + // Adjust axis if negative to count from the end as per ONNX spec + if axis < 0 { + axis += tensor.rank as i64; + } + + // Check for custom split sizes provided as a second input + if node.inputs.len() > 1 && node.inputs[1].value.is_some() { + let sizes = node.inputs[1] + .value + .as_ref() + .unwrap() + .data + .clone() + .into_usizes(); + + if !sizes.is_empty() { + split_sizes = Some(sizes); + } + } + + // Ensure that only one of 'split_sizes' or 'num_outputs' is specified + if split_sizes.is_some() && split_size.is_some() { + panic!( + "Split: Cannot specify both 'split' input and 'num_outputs' attribute simultaneously" + ); + } + + // Infer split_size if neither custom split_sizes nor split_size is provided + if split_sizes.is_none() && split_size.is_none() { + let num_outputs = node.outputs.len(); + let dim_size = tensor + .static_shape + .as_ref() + .expect("Split: Static shape must be known to infer split size")[axis as usize]; + + // Calculate inferred split size based on number of outputs + let calculated_split_size = + dim_size / (num_outputs - (dim_size % num_outputs != 0) as usize); + + if calculated_split_size == 0 { + panic!( + "Split: Inferred split size is zero. Please ensure the number of outputs is valid for the dimension size" + ); + } + + split_size = Some(calculated_split_size); + } + + // Return the configuration for splitting operation + SplitConfig { + axis: axis as usize, + split_size, + split_sizes, + } +} + #[cfg(test)] mod tests { use super::*; - use crate::ir::{Argument, ElementType, NodeType}; + use crate::ir::{Argument, AttributeValue, Data, ElementType, NodeType, TensorData}; use std::collections::HashMap; - fn create_test_node(input_rank: usize, num_outputs: usize) -> Node { - let inputs = vec![Argument { + fn create_test_node( + input_rank: usize, + num_outputs: usize, + static_shape: Option>, + attrs: Option>, + split_sizes_input: Option>, + ) -> Node { + let mut inputs = vec![Argument { name: "input".to_string(), ty: ArgType::Tensor(TensorType { elem_type: ElementType::Float32, rank: input_rank, - static_shape: None, + static_shape, }), value: None, passed: true, }]; + // Add split sizes input if provided + if let Some(sizes) = split_sizes_input { + inputs.push(Argument { + name: "split".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Int64, + rank: 1, + static_shape: Some(vec![sizes.len()]), + }), + value: Some(TensorData { + shape: vec![sizes.len()], + data: Data::Int64s(sizes), + }), + passed: true, + }); + } + let mut outputs = Vec::new(); for i in 0..num_outputs { outputs.push(Argument { @@ -57,20 +209,18 @@ mod tests { }); } - let attrs = HashMap::new(); - Node { node_type: NodeType::Split, name: "test_split".to_string(), inputs, outputs, - attrs, + attrs: attrs.unwrap_or_default(), } } #[test] fn test_split_single_output() { - let mut node = create_test_node(3, 1); + let mut node = create_test_node(3, 1, None, None, None); split_update_outputs(&mut node); assert_eq!(node.outputs.len(), 1); @@ -85,7 +235,7 @@ mod tests { #[test] fn test_split_multiple_outputs() { - let mut node = create_test_node(4, 3); + let mut node = create_test_node(4, 3, None, None, None); split_update_outputs(&mut node); assert_eq!(node.outputs.len(), 3); @@ -103,8 +253,166 @@ mod tests { #[test] #[should_panic(expected = "Split: Input must be a tensor")] fn test_split_invalid_input() { - let mut node = create_test_node(3, 2); + let mut node = create_test_node(3, 2, None, None, None); node.inputs[0].ty = ArgType::Scalar(ElementType::Float32); split_update_outputs(&mut node); } + + // Tests for split_config function + + #[test] + fn test_split_config_default_axis() { + // Create a node with static shape and 2 outputs + let static_shape = Some(vec![10, 20, 30]); + let node = create_test_node(3, 2, static_shape, None, None); + + let config = split_config(&node); + + // Default axis should be 0, and split_size should be calculated + assert_eq!(config.axis, 0); + assert_eq!(config.split_size, Some(5)); // 10 / 2 = 5 + assert_eq!(config.split_sizes, None); + } + + #[test] + fn test_split_config_specified_axis() { + // Create a node with static shape, 2 outputs, and a specified axis + let static_shape = Some(vec![10, 20, 30]); + let mut attrs = HashMap::new(); + attrs.insert("axis".to_string(), AttributeValue::Int64(1)); // Split along axis 1 + + let node = create_test_node(3, 2, static_shape, Some(attrs), None); + + let config = split_config(&node); + + assert_eq!(config.axis, 1); + assert_eq!(config.split_size, Some(10)); // 20 / 2 = 10 + assert_eq!(config.split_sizes, None); + } + + #[test] + fn test_split_config_negative_axis() { + // Test with negative axis (should count from the end) + let static_shape = Some(vec![10, 20, 30]); + let mut attrs = HashMap::new(); + attrs.insert("axis".to_string(), AttributeValue::Int64(-1)); // Last axis (index 2) + + let node = create_test_node(3, 3, static_shape, Some(attrs), None); + + let config = split_config(&node); + + assert_eq!(config.axis, 2); // -1 should be converted to 2 + assert_eq!(config.split_size, Some(10)); // 30 / 3 = 10 + assert_eq!(config.split_sizes, None); + } + + #[test] + fn test_split_config_num_outputs_attr() { + // Test with explicitly specified num_outputs attribute + let static_shape = Some(vec![12, 24, 36]); + let mut attrs = HashMap::new(); + attrs.insert("num_outputs".to_string(), AttributeValue::Int64(4)); + + let node = create_test_node(3, 4, static_shape, Some(attrs), None); + + let config = split_config(&node); + + assert_eq!(config.axis, 0); + assert_eq!(config.split_size, Some(3)); // 12 / 4 = 3 + assert_eq!(config.split_sizes, None); + } + + #[test] + fn test_split_config_with_split_sizes_input() { + // Test with explicit split sizes provided as second input + let static_shape = Some(vec![10, 20, 30]); + let split_sizes = vec![5, 15]; // Custom split sizes along default axis + + let node = create_test_node(3, 2, static_shape, None, Some(split_sizes.clone())); + + let config = split_config(&node); + + assert_eq!(config.axis, 0); + assert_eq!(config.split_size, None); + assert_eq!(config.split_sizes, Some(vec![5, 15])); + } + + #[test] + #[should_panic( + expected = "Split: Cannot specify both 'split' input and 'num_outputs' attribute simultaneously" + )] + fn test_split_config_both_splits_and_num_outputs() { + // Test with both split sizes input and num_outputs attribute (should panic) + let static_shape = Some(vec![10, 20, 30]); + let mut attrs = HashMap::new(); + attrs.insert("num_outputs".to_string(), AttributeValue::Int64(2)); + let split_sizes = vec![3, 7]; + + let node = create_test_node(3, 2, static_shape, Some(attrs), Some(split_sizes)); + + let _ = split_config(&node); + } + + #[test] + #[should_panic(expected = "Split: 'num_outputs' must be a positive value greater than zero")] + fn test_split_config_zero_num_outputs() { + // Test with num_outputs attribute set to 0 (should panic) + let static_shape = Some(vec![10, 20, 30]); + let mut attrs = HashMap::new(); + attrs.insert("num_outputs".to_string(), AttributeValue::Int64(0)); + + let node = create_test_node(3, 0, static_shape, Some(attrs), None); + + let _ = split_config(&node); + } + + #[test] + #[should_panic(expected = "Split: Calculated split size is zero")] + fn test_split_config_invalid_num_outputs() { + // Test with num_outputs larger than the dimension size (should result in split_size = 0) + let static_shape = Some(vec![5, 10, 15]); + let mut attrs = HashMap::new(); + attrs.insert("num_outputs".to_string(), AttributeValue::Int64(10)); // Larger than dim 0 size + + let node = create_test_node(3, 10, static_shape, Some(attrs), None); + + let _ = split_config(&node); + } + + #[test] + #[should_panic(expected = "Split: Static shape must be known to calculate split size")] + fn test_split_config_no_static_shape() { + // Test with no static shape available + let mut attrs = HashMap::new(); + attrs.insert("num_outputs".to_string(), AttributeValue::Int64(2)); + + let node = create_test_node(3, 2, None, Some(attrs), None); + + let _ = split_config(&node); + } + + #[test] + #[should_panic(expected = "Split: Input must be a valid tensor")] + fn test_split_config_invalid_input_type() { + // Test with invalid input type + let mut node = create_test_node(3, 2, Some(vec![10, 20, 30]), None, None); + node.inputs[0].ty = ArgType::Scalar(ElementType::Float32); + + let _ = split_config(&node); + } + + #[test] + fn test_split_config_non_even_split() { + // Test with non-evenly divisible dimension size + let static_shape = Some(vec![11, 22, 33]); // 11 is not evenly divisible by 3 + let mut attrs = HashMap::new(); + attrs.insert("axis".to_string(), AttributeValue::Int64(0)); + + let node = create_test_node(3, 3, static_shape, Some(attrs), None); + + let config = split_config(&node); + + // 11 / (3-1) = 5, since the dimension is not evenly divisible + assert_eq!(config.split_size, Some(5)); + } } diff --git a/crates/onnx-ir/src/node/tile.rs b/crates/onnx-ir/src/node/tile.rs new file mode 100644 index 0000000000..557e6abfe7 --- /dev/null +++ b/crates/onnx-ir/src/node/tile.rs @@ -0,0 +1,197 @@ +use crate::{Node, TensorData}; + +/// Configuration for the Tile operation. +#[derive(Debug, Clone, PartialEq)] +pub struct TileConfig { + /// The number of times to repeat each dimension. + pub repeats: Vec, +} + +impl TileConfig { + pub fn new(repeats: Vec) -> Self { + TileConfig { repeats } + } +} + +/// Creates a TileConfig from the node attributes and inputs. +pub fn tile_config(node: &Node) -> TileConfig { + let repeat = node + .inputs + .get(1) + .map(|input| { + if let Some(TensorData { data, .. }) = &input.value { + data.clone() + .into_i64s() + .iter() + .map(|&x| x as usize) + .collect() + } else { + vec![] + } + }) + .unwrap_or_default(); + TileConfig::new(repeat) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::{ArgType, Argument, Data, ElementType, NodeType, TensorType}; + use std::collections::HashMap; + + /// Helper function to create test nodes with different repeat values + fn create_test_node(repeats: Option>, input_rank: usize) -> Node { + let mut inputs = vec![ + // First input: the tensor to tile + Argument { + name: "input".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: input_rank, + static_shape: None, + }), + value: None, + passed: true, + }, + ]; + + // Add repeats input if provided + if let Some(reps) = repeats { + inputs.push(Argument { + name: "repeats".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Int64, + rank: 1, + static_shape: Some(vec![reps.len()]), + }), + value: Some(TensorData { + shape: vec![reps.len()], + data: Data::Int64s(reps), + }), + passed: true, + }); + } + + Node { + node_type: NodeType::Tile, + name: "test_tile".to_string(), + inputs, + outputs: vec![Argument { + name: "output".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: input_rank, // Same rank as input initially + static_shape: None, + }), + value: None, + passed: true, + }], + attrs: HashMap::new(), + } + } + + #[test] + fn test_tile_config_with_repeats() { + // Test with normal repeats values + let repeats = vec![2, 3, 4]; + let node = create_test_node(Some(repeats.clone()), 3); + + let config = tile_config(&node); + + // Should extract repeats correctly + assert_eq!(config.repeats, vec![2, 3, 4]); + } + + #[test] + fn test_tile_config_with_single_repeat() { + // Test with single repeat value + let repeats = vec![5]; + let node = create_test_node(Some(repeats.clone()), 1); + + let config = tile_config(&node); + + assert_eq!(config.repeats, vec![5]); + } + + #[test] + fn test_tile_config_with_zero_repeats() { + // Test with repeats including zeros + let repeats = vec![0, 1, 0]; + let node = create_test_node(Some(repeats.clone()), 3); + + let config = tile_config(&node); + + assert_eq!(config.repeats, vec![0, 1, 0]); + } + + #[test] + fn test_tile_config_with_large_repeats() { + // Test with large repeats values + let repeats = vec![100, 200]; + let node = create_test_node(Some(repeats.clone()), 2); + + let config = tile_config(&node); + + assert_eq!(config.repeats, vec![100, 200]); + } + + #[test] + fn test_tile_config_without_repeats_input() { + // Test when repeats input is missing + let node = create_test_node(None, 3); + + let config = tile_config(&node); + + // Should return empty repeats + assert_eq!(config.repeats, vec![]); + } + + #[test] + fn test_tile_config_with_negative_repeats() { + // Test with negative repeats values (will be converted to usize) + let repeats = vec![-1, 2, -3]; + let node = create_test_node(Some(repeats), 3); + + let config = tile_config(&node); + + // Negative values get converted to very large positive values due to usize conversion + // This is expected behavior for this function (though may cause issues elsewhere) + assert!(config.repeats[0] > 0); + assert_eq!(config.repeats[1], 2); + assert!(config.repeats[2] > 0); + } + + #[test] + fn test_tile_config_with_empty_repeats() { + // Test with empty repeats array + let repeats = vec![]; + let node = create_test_node(Some(repeats), 3); + + let config = tile_config(&node); + + assert_eq!(config.repeats, vec![]); + } + + #[test] + fn test_tile_config_with_missing_value() { + // Test with repeats input that has no value + let mut node = create_test_node(None, 3); + + // Add repeats input with no value + node.inputs.push(Argument { + name: "repeats".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Int64, + rank: 1, + static_shape: Some(vec![3]), + }), + value: None, // No value provided + passed: true, + }); + + let config = tile_config(&node); + + // Should return empty repeats + assert_eq!(config.repeats, vec![]); + } +} diff --git a/crates/onnx-ir/src/node/topk.rs b/crates/onnx-ir/src/node/topk.rs index e1a0ba429d..6094d80c2c 100644 --- a/crates/onnx-ir/src/node/topk.rs +++ b/crates/onnx-ir/src/node/topk.rs @@ -28,31 +28,108 @@ pub fn top_k_update_output(node: &mut Node) { ); } +/// Configuration for the TopK operation. +#[derive(Debug, Clone, PartialEq)] +pub struct TopKConfig { + /// The axis along which to perform the top-k selection. + pub axis: usize, + /// The number of top elements to select. + pub k: usize, +} + +impl TopKConfig { + /// Creates a new TopKConfig. + pub fn new(axis: usize, k: usize) -> Self { + Self { axis, k } + } +} + +/// Creates a TopKConfig from the node attributes and inputs. +pub fn top_k_config(node: &Node) -> TopKConfig { + // Extract the shape of the input data tensor + let data_tensor = match node.inputs.first().unwrap().clone().ty { + ArgType::Tensor(tensor) => tensor, + _ => panic!("Only tensor input is valid"), + }; + + let k = match node.inputs.get(1) { + Some(k_tensor) => k_tensor + .clone() + .value + .expect("TopK: only constant 'k' tensor is currently supported") + .data + .into_i64s()[0], + _ => node + .attrs + .get("k") + .expect("TopK: number of top elements 'k' is missing") + .clone() + .into_i64(), + }; + + let mut axis = match node.attrs.get("axis") { + Some(axis) => axis.clone().into_i64(), + None => -1, + }; + + // If axis is negative, it is counted from the end + if axis < 0 { + axis += data_tensor.rank as i64; + } + + if let Some(largest) = node.attrs.get("largest") { + if largest.clone().into_i64() != 1 { + unimplemented!("TopK: only largest elements is supported") + } + }; + + if let Some(sorted) = node.attrs.get("sorted") { + if sorted.clone().into_i64() != 1 { + unimplemented!("TopK: only sorted elements is supported") + } + }; + + TopKConfig::new(axis as usize, k as usize) +} + #[cfg(test)] mod tests { use super::*; - use crate::ir::{Argument, NodeType}; + use crate::ir::{Argument, AttributeValue, Data, NodeType, TensorData}; use std::collections::HashMap; - fn create_test_node(input_rank: usize) -> Node { - let inputs = vec![ - Argument { - name: "X".to_string(), + fn create_test_node( + input_rank: usize, + attrs: Option>, + k_input_value: Option, + ) -> Node { + let mut inputs = vec![Argument { + name: "X".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: input_rank, + static_shape: None, + }), + value: None, + passed: true, + }]; + + // Add K input if provided + if let Some(k) = k_input_value { + inputs.push(Argument { + name: "K".to_string(), ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: input_rank, - static_shape: None, + elem_type: ElementType::Int64, + rank: 0, + static_shape: Some(vec![]), + }), + value: Some(TensorData { + shape: vec![], + data: Data::Int64s(vec![k]), }), - value: None, - passed: true, - }, - Argument { - name: "K".to_string(), - ty: ArgType::Scalar(ElementType::Int64), - value: None, passed: true, - }, - ]; + }); + } let outputs = vec![ Argument { @@ -77,20 +154,21 @@ mod tests { }, ]; - let attrs = HashMap::new(); - Node { node_type: NodeType::TopK, name: "test_topk".to_string(), inputs, outputs, - attrs, + attrs: attrs.unwrap_or_default(), } } #[test] fn test_topk_basic() { - let mut node = create_test_node(3); + let mut node = create_test_node(3, None, None); + // Add K attribute since we didn't provide K input + node.attrs.insert("k".to_string(), AttributeValue::Int64(5)); + top_k_update_output(&mut node); assert_eq!(node.outputs.len(), 2); @@ -117,8 +195,146 @@ mod tests { #[test] #[should_panic(expected = "TopK: invalid input type")] fn test_topk_invalid_input() { - let mut node = create_test_node(3); + let mut node = create_test_node(3, None, None); + node.attrs.insert("k".to_string(), AttributeValue::Int64(5)); node.inputs[0].ty = ArgType::Scalar(ElementType::Float32); top_k_update_output(&mut node); } + + // Tests for top_k_config function + + #[test] + fn test_top_k_config_with_k_attribute() { + // Test when k is provided as an attribute + let mut attrs = HashMap::new(); + attrs.insert("k".to_string(), AttributeValue::Int64(10)); + let node = create_test_node(3, Some(attrs), None); + + let config = top_k_config(&node); + + // Default axis should be -1 which gets converted to rank-1 + assert_eq!(config, TopKConfig { axis: 2, k: 10 }); + } + + #[test] + fn test_top_k_config_with_k_input() { + // Test when k is provided as an input + let node = create_test_node(4, None, Some(5)); + + let config = top_k_config(&node); + + // Default axis should be -1 which gets converted to rank-1 + assert_eq!(config, TopKConfig { axis: 3, k: 5 }); + } + + #[test] + fn test_top_k_config_with_explicit_axis() { + // Test with explicitly specified axis + let mut attrs = HashMap::new(); + attrs.insert("k".to_string(), AttributeValue::Int64(3)); + attrs.insert("axis".to_string(), AttributeValue::Int64(1)); + let node = create_test_node(3, Some(attrs), None); + + let config = top_k_config(&node); + + assert_eq!(config, TopKConfig { axis: 1, k: 3 }); + } + + #[test] + fn test_top_k_config_with_negative_axis() { + // Test with negative axis (counts from the end) + let mut attrs = HashMap::new(); + attrs.insert("k".to_string(), AttributeValue::Int64(5)); + attrs.insert("axis".to_string(), AttributeValue::Int64(-2)); // Second-to-last axis + let node = create_test_node(4, Some(attrs), None); + + let config = top_k_config(&node); + + // For rank 4, axis -2 should be 2 + assert_eq!(config, TopKConfig { axis: 2, k: 5 }); + } + + #[test] + fn test_top_k_config_with_largest_attribute() { + // Test with largest attribute set to 1 (default supported behavior) + let mut attrs = HashMap::new(); + attrs.insert("k".to_string(), AttributeValue::Int64(7)); + attrs.insert("largest".to_string(), AttributeValue::Int64(1)); + let node = create_test_node(2, Some(attrs), None); + + let config = top_k_config(&node); + + assert_eq!(config, TopKConfig { axis: 1, k: 7 }); + } + + #[test] + fn test_top_k_config_with_sorted_attribute() { + // Test with sorted attribute set to 1 (default supported behavior) + let mut attrs = HashMap::new(); + attrs.insert("k".to_string(), AttributeValue::Int64(2)); + attrs.insert("sorted".to_string(), AttributeValue::Int64(1)); + let node = create_test_node(3, Some(attrs), None); + + let config = top_k_config(&node); + + assert_eq!(config, TopKConfig { axis: 2, k: 2 }); + } + + #[test] + #[should_panic(expected = "only largest elements is supported")] + fn test_top_k_config_with_largest_false() { + // Test with largest attribute set to 0 (unsupported) + let mut attrs = HashMap::new(); + attrs.insert("k".to_string(), AttributeValue::Int64(3)); + attrs.insert("largest".to_string(), AttributeValue::Int64(0)); + let node = create_test_node(2, Some(attrs), None); + + let _ = top_k_config(&node); + } + + #[test] + #[should_panic(expected = "only sorted elements is supported")] + fn test_top_k_config_with_sorted_false() { + // Test with sorted attribute set to 0 (unsupported) + let mut attrs = HashMap::new(); + attrs.insert("k".to_string(), AttributeValue::Int64(3)); + attrs.insert("sorted".to_string(), AttributeValue::Int64(0)); + let node = create_test_node(2, Some(attrs), None); + + let _ = top_k_config(&node); + } + + #[test] + #[should_panic(expected = "Only tensor input is valid")] + fn test_top_k_config_with_invalid_input_type() { + // Test with invalid input type + let mut node = create_test_node(2, None, None); + node.attrs.insert("k".to_string(), AttributeValue::Int64(3)); + node.inputs[0].ty = ArgType::Scalar(ElementType::Float32); + + let _ = top_k_config(&node); + } + + #[test] + #[should_panic(expected = "TopK: number of top elements 'k' is missing")] + fn test_top_k_config_without_k() { + // Test when k is neither provided as input nor attribute + let node = create_test_node(3, None, None); + + let _ = top_k_config(&node); + } + + #[test] + fn test_top_k_config_with_both_k_input_and_attribute() { + // Test when k is provided both as input and attribute + // Input should take precedence + let mut attrs = HashMap::new(); + attrs.insert("k".to_string(), AttributeValue::Int64(10)); + let node = create_test_node(3, Some(attrs), Some(5)); + + let config = top_k_config(&node); + + // K from input should be used (5), not from attribute (10) + assert_eq!(config, TopKConfig { axis: 2, k: 5 }); + } } diff --git a/crates/onnx-ir/src/node/trilu.rs b/crates/onnx-ir/src/node/trilu.rs new file mode 100644 index 0000000000..d4ac189021 --- /dev/null +++ b/crates/onnx-ir/src/node/trilu.rs @@ -0,0 +1,233 @@ +use crate::{Data, Node, TensorData}; + +/// Configuration for the Trilu operation. +#[derive(Debug, Clone, PartialEq)] +pub struct TriluConfig { + /// Whether to return the upper triangular matrix. + pub upper: bool, + /// The diagonal offset. + pub diagonal: i64, +} + +impl TriluConfig { + /// Creates a TriluConfig from the node attributes and inputs. + pub fn new(upper: bool, diagonal: i64) -> Self { + Self { upper, diagonal } + } +} + +/// Creates a TriluConfig from the node attributes and inputs. +pub fn trilu_config(node: &Node) -> TriluConfig { + let mut upper = true; + let mut diagonal = 0; + for (key, value) in node.attrs.iter() { + if key.as_str() == "upper" { + upper = value.clone().into_i64() != 0 + } + } + // The second input of the Trilu node is the diagonal value, coming from a constant node + if let Some(diagonal_arg) = node.inputs.get(1) { + if let Some(TensorData { + data: Data::Int64(diagonal_val), + .. + }) = &diagonal_arg.value + { + diagonal = *diagonal_val; + } + } + TriluConfig::new(upper, diagonal) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::{ArgType, Argument, AttributeValue, ElementType, NodeType, TensorType}; + use std::collections::HashMap; + + /// Helper function to create test nodes for Trilu tests + fn create_test_node(upper_attr: Option, diagonal_input: Option) -> Node { + let mut inputs = vec![ + // First input: the tensor to process + Argument { + name: "X".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 2, // Typically a matrix + static_shape: None, + }), + value: None, + passed: true, + }, + ]; + + // Add diagonal input if provided + if let Some(diag) = diagonal_input { + inputs.push(Argument { + name: "k".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Int64, + rank: 0, + static_shape: Some(vec![]), + }), + value: Some(TensorData { + shape: vec![], + data: Data::Int64(diag), + }), + passed: true, + }); + } + + // Create attributes map + let mut attrs = HashMap::new(); + if let Some(upper) = upper_attr { + attrs.insert("upper".to_string(), AttributeValue::Int64(upper)); + } + + Node { + node_type: NodeType::Trilu, + name: "test_trilu".to_string(), + inputs, + outputs: vec![Argument { + name: "Y".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 2, + static_shape: None, + }), + value: None, + passed: true, + }], + attrs, + } + } + + #[test] + fn test_trilu_config_default() { + // Test with no attributes or inputs - should use defaults (upper=true, diagonal=0) + let node = create_test_node(None, None); + + let config = trilu_config(&node); + + assert_eq!( + config, + TriluConfig { + upper: true, + diagonal: 0 + } + ); + } + + #[test] + fn test_trilu_config_upper_true() { + // Test with upper=1 attribute + let node = create_test_node(Some(1), None); + + let config = trilu_config(&node); + + assert_eq!( + config, + TriluConfig { + upper: true, + diagonal: 0 + } + ); + } + + #[test] + fn test_trilu_config_upper_false() { + // Test with upper=0 attribute (lower triangular) + let node = create_test_node(Some(0), None); + + let config = trilu_config(&node); + + assert_eq!( + config, + TriluConfig { + upper: false, + diagonal: 0 + } + ); + } + + #[test] + fn test_trilu_config_with_diagonal() { + // Test with diagonal=2 input (offset 2 above main diagonal) + let node = create_test_node(None, Some(2)); + + let config = trilu_config(&node); + + assert_eq!( + config, + TriluConfig { + upper: true, + diagonal: 2 + } + ); + } + + #[test] + fn test_trilu_config_with_negative_diagonal() { + // Test with diagonal=-3 input (offset 3 below main diagonal) + let node = create_test_node(None, Some(-3)); + + let config = trilu_config(&node); + + assert_eq!( + config, + TriluConfig { + upper: true, + diagonal: -3 + } + ); + } + + #[test] + fn test_trilu_config_both_params() { + // Test with both upper attribute and diagonal input + let node = create_test_node(Some(0), Some(1)); + + let config = trilu_config(&node); + + assert_eq!( + config, + TriluConfig { + upper: false, + diagonal: 1 + } + ); + } + + #[test] + fn test_trilu_config_non_binary_upper() { + // Test with non-binary values for the upper attribute + // Any non-zero value should be treated as true + let node = create_test_node(Some(42), None); + + let config = trilu_config(&node); + + assert_eq!( + config, + TriluConfig { + upper: true, + diagonal: 0 + } + ); + } + + #[test] + fn test_trilu_config_negative_non_binary_upper() { + // Test with negative values for the upper attribute + // Any non-zero value should be treated as true + let node = create_test_node(Some(-5), None); + + let config = trilu_config(&node); + + assert_eq!( + config, + TriluConfig { + upper: true, + diagonal: 0 + } + ); + } +} diff --git a/crates/onnx-ir/src/node/unsqueeze.rs b/crates/onnx-ir/src/node/unsqueeze.rs index 8ea66adad6..f7044e6a37 100644 --- a/crates/onnx-ir/src/node/unsqueeze.rs +++ b/crates/onnx-ir/src/node/unsqueeze.rs @@ -1,4 +1,7 @@ -use crate::ir::{ArgType, Data, Node, TensorType}; +use crate::{ + Argument, TensorData, + ir::{ArgType, Data, Node, TensorType}, +}; /// Update output rank for Unsqueeze based on axes. /// Update the output tensor dimension based on the "axes" attribute or the second input @@ -61,12 +64,68 @@ pub fn unsqueeze_update_output(node: &mut Node) { log::debug!("Unsqueeze output rank for {}: {}", node.name, output_rank); } +/// Axes specification for the Unsqueeze operation. +#[derive(Debug, Clone)] +pub enum UnsqueezeConfig { + /// Static axes known at compile time. + Static(Vec), + /// Runtime axes that will be determined during execution. + Runtime(Argument), +} + +/// Creates UnsqueezeAxes configuration from the node attributes. +/// +/// Note: This function should only execute if the second input is a constant. +/// If it wasn't and the output shape was known, unsqueeze has been remapped to reshape. +pub fn unsqueeze_config(node: &Node) -> UnsqueezeConfig { + // Check if axes attribute exists + for (key, value) in node.attrs.iter() { + if key.as_str() == "axes" { + return UnsqueezeConfig::Static(value.clone().into_i64s()); + } + } + + assert!( + !node.inputs.is_empty(), + "Unsqueeze: axes tensor must be present" + ); + + let input_value = &node.inputs[1]; + + match &node.inputs[1].ty { + ArgType::Tensor(tensor) => { + assert_eq!(tensor.rank, 1, "Unsqueeze: axes tensor must be 1D"); + if let Some(TensorData { + data: Data::Int64s(shape), + .. + }) = input_value.value.as_ref() + { + UnsqueezeConfig::Static(shape.clone()) + } else { + UnsqueezeConfig::Runtime(node.inputs[1].clone()) + } + } + _ => panic!("Arg for unsqueeze must be tensor or scalar"), + } +} + #[cfg(test)] mod tests { use super::*; use crate::ir::{Argument, AttributeValue, ElementType, NodeType, TensorData}; use std::collections::HashMap; + // Implement custom equality for UnsqueezeConfig to make testing easier + impl PartialEq for UnsqueezeConfig { + fn eq(&self, other: &UnsqueezeConfig) -> bool { + match (self, other) { + (UnsqueezeConfig::Static(a), UnsqueezeConfig::Static(b)) => a == b, + (UnsqueezeConfig::Runtime(a), UnsqueezeConfig::Runtime(b)) => a.name == b.name, + _ => false, + } + } + } + fn create_test_node_with_attr(input_rank: usize, axes: Vec) -> Node { let inputs = vec![Argument { name: "X".to_string(), @@ -102,7 +161,8 @@ mod tests { } } - fn create_test_node_with_input(input_rank: usize, axes: Vec) -> Node { + fn create_test_node_with_input(input_rank: usize, axes: Vec, with_value: bool) -> Node { + let axes_len = axes.len(); let inputs = vec![ Argument { name: "X".to_string(), @@ -119,12 +179,16 @@ mod tests { ty: ArgType::Tensor(TensorType { elem_type: ElementType::Int64, rank: 1, - static_shape: Some(vec![axes.len()]), - }), - value: Some(TensorData { - data: Data::Int64s(axes), - shape: vec![1], + static_shape: Some(vec![axes_len]), }), + value: if with_value { + Some(TensorData { + data: Data::Int64s(axes.clone()), + shape: vec![axes_len], + }) + } else { + None + }, passed: true, }, ]; @@ -149,6 +213,8 @@ mod tests { } } + // Tests for unsqueeze_update_output function + #[test] fn test_unsqueeze_with_attr() { let mut node = create_test_node_with_attr(2, vec![0, 3]); @@ -165,7 +231,7 @@ mod tests { #[test] fn test_unsqueeze_with_input() { - let mut node = create_test_node_with_input(3, vec![1, 2, 4]); + let mut node = create_test_node_with_input(3, vec![1, 2, 4], true); unsqueeze_update_output(&mut node); match &node.outputs[0].ty { @@ -199,4 +265,100 @@ mod tests { node.inputs[0].ty = ArgType::Shape(1); unsqueeze_update_output(&mut node); } + + // Tests for unsqueeze_config function + + #[test] + fn test_unsqueeze_config_with_attr() { + // Test with axes provided as attribute + let axes = vec![0, 2, 4]; + let node = create_test_node_with_attr(3, axes.clone()); + + let config = unsqueeze_config(&node); + + assert_eq!(config, UnsqueezeConfig::Static(axes)); + } + + #[test] + fn test_unsqueeze_config_with_static_input() { + // Test with axes provided as input tensor with static value + let axes = vec![1, 3]; + let node = create_test_node_with_input(2, axes.clone(), true); + + let config = unsqueeze_config(&node); + + assert_eq!(config, UnsqueezeConfig::Static(axes)); + } + + #[test] + fn test_unsqueeze_config_with_runtime_input() { + // Test with axes provided as input tensor but without static value + let axes = vec![0, 2]; + let node = create_test_node_with_input(2, axes.clone(), false); + + let config = unsqueeze_config(&node); + + // Should return a Runtime config since the axes are only known at runtime + match config { + UnsqueezeConfig::Static(_) => panic!("Expected Runtime config"), + UnsqueezeConfig::Runtime(arg) => { + assert_eq!(arg.name, "axes"); + } + } + } + + #[test] + fn test_unsqueeze_config_negative_axes() { + // Test with negative axes (should be handled by the caller) + let axes = vec![-1, -3]; + let node = create_test_node_with_attr(3, axes.clone()); + + let config = unsqueeze_config(&node); + + assert_eq!(config, UnsqueezeConfig::Static(axes)); + } + + #[test] + fn test_unsqueeze_config_empty_axes() { + // Test with empty axes array (edge case) + let axes = vec![]; + let node = create_test_node_with_attr(2, axes.clone()); + + let config = unsqueeze_config(&node); + + assert_eq!(config, UnsqueezeConfig::Static(axes)); + } + + #[test] + #[should_panic(expected = "index out of bounds")] + fn test_unsqueeze_config_missing_axes() { + // Test with neither axes attribute nor input + let mut node = create_test_node_with_attr(2, vec![0]); + node.attrs.clear(); // Remove the axes attribute + node.inputs = vec![node.inputs[0].clone()]; // Remove the axes input + + let _ = unsqueeze_config(&node); + } + + #[test] + #[should_panic(expected = "Unsqueeze: axes tensor must be 1D")] + fn test_unsqueeze_config_invalid_axes_rank() { + // Test with axes tensor that is not 1D + let mut node = create_test_node_with_input(2, vec![0, 1], true); + if let ArgType::Tensor(ref mut tensor) = node.inputs[1].ty { + tensor.rank = 2; // Invalid rank for axes + } + + let _ = unsqueeze_config(&node); + } + + #[test] + #[should_panic(expected = "Arg for unsqueeze must be tensor or scalar")] + fn test_unsqueeze_config_invalid_axes_type() { + // Test with axes input that is not a tensor + let mut node = create_test_node_with_input(2, vec![0], false); + node.inputs[1].ty = ArgType::Shape(1); // Invalid type for axes + + let _ = unsqueeze_config(&node); + } } From 9efe0a08d66cf1e61d1dbb86f49ee412a45cf559 Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Thu, 1 May 2025 22:21:03 -0500 Subject: [PATCH 23/37] Update the documentation --- contributor-book/src/SUMMARY.md | 2 +- .../guides/onnx-to-burn-conversion-tool.md | 399 +++++++++--------- crates/burn-import/DEVELOPMENT.md | 80 ---- crates/burn-import/README.md | 77 +++- crates/burn-import/onnx-tests/README.md | 202 +++++++-- crates/onnx-ir/README.md | 89 +++- 6 files changed, 522 insertions(+), 327 deletions(-) delete mode 100644 crates/burn-import/DEVELOPMENT.md diff --git a/contributor-book/src/SUMMARY.md b/contributor-book/src/SUMMARY.md index 885a9cb5ad..0f6292f51a 100644 --- a/contributor-book/src/SUMMARY.md +++ b/contributor-book/src/SUMMARY.md @@ -10,7 +10,7 @@ - [Tensor](./project-architecture/tensor.md) - [Backend](./project-architecture/backend.md) - [Guides for Contributors](./guides/README.md) - - [Onnx To Burn Conversion Tool: A Development Guide](./guides/onnx-to-burn-conversion-tool.md) + - [ONNX to Burn: Development Guide](./guides/onnx-to-burn-conversion-tool.md) - [Adding a New Operation to Burn](./guides/adding-a-new-operation-to-burn.md) - [Frequently Encountered Issues](./frequently-encountered-issues/README.md) - [Issues Related To Adding Operators](./frequently-encountered-issues/issues-while-adding-ops.md) diff --git a/contributor-book/src/guides/onnx-to-burn-conversion-tool.md b/contributor-book/src/guides/onnx-to-burn-conversion-tool.md index 175decf425..3c93f07726 100644 --- a/contributor-book/src/guides/onnx-to-burn-conversion-tool.md +++ b/contributor-book/src/guides/onnx-to-burn-conversion-tool.md @@ -1,4 +1,4 @@ -# ONNX to Burn Conversion Tool: Development Guide +# ONNX to Burn: Development Guide This guide offers in-depth design insights and step-by-step procedures for developers working on the ONNX to Burn conversion tool. This tool allows the importation of ONNX models into the Burn deep @@ -8,28 +8,6 @@ weights to Burn state files. For an introduction to ONNX import in Burn, see [this section of the Burn book](https://burn.dev/burn-book/import/onnx-model.html). -## Table of Contents - -- [ONNX to Burn Conversion Tool: Development Guide](#onnx-to-burn-conversion-tool-development-guide) - - [Table of Contents](#table-of-contents) - - [Design Overview](#design-overview) - - [Design Goals](#design-goals) - - [Design Decisions](#design-decisions) - - [Adding New Operators](#adding-new-operators) - - [Implementing a New Operator](#implementing-a-new-operator) - - [Step 1: Visibility](#step-1-visibility) - - [Step 2: Node Implementation](#step-2-node-implementation) - - [Within Onnx-IR](#within-onnx-ir) - - [Within burn-import](#within-burn-import) - - [Step 3: Registering New Operations](#step-3-registering-new-operations) - - [Step 4: Create a Config Function](#step-4-create-a-config-function) - - [Step 5: Dimension Inference](#step-5-dimension-inference) - - [Step 6: Integrate into the Graph Building Process](#step-6-integrate-into-the-graph-building-process) - - [Step 7: Add Newly Supported Op!](#step-7-add-newly-supported-op) - - [Misc:](#misc) - - [Testing](#testing) - - [Resources](#resources) - ## Design Overview ### Design Goals @@ -66,8 +44,7 @@ To extend `burn-import` with support for new ONNX operators, follow these steps: model contains the expected operators. 4. **Generate IR and Burn Graph**: Navigate to - [crates/burn-import/](https://github.com/tracel-ai/burn/tree/925716f89d0249cbc6bd14f85f40967bd7ef80a8/crates/burn-import) - and run: + [crates/burn-import/](https://github.com/tracel-ai/burn/tree/main/crates/burn-import) and run: ``` cargo r -- ./onnx-tests/tests//.onnx ./out @@ -81,225 +58,261 @@ To extend `burn-import` with support for new ONNX operators, follow these steps: the Burn model in Rust code, and `my-model.json` includes the model data. 7. **Add End-to-End Test**: Include the test in - [crates/burn-import/onnx-tests/tests/onnx_tests.rs](https://github.com/tracel-ai/burn/blob/925716f89d0249cbc6bd14f85f40967bd7ef80a8/crates/burn-import/onnx-tests/tests/test_onnx.rs). + [crates/burn-import/onnx-tests/tests/test_onnx.rs](https://github.com/tracel-ai/burn/blob/main/crates/burn-import/onnx-tests/tests/test_onnx.rs). Further details can be found in the - [onnx-tests README](https://github.com/tracel-ai/burn/blob/925716f89d0249cbc6bd14f85f40967bd7ef80a8/crates/burn-import/onnx-tests/README.md). + [onnx-tests README](https://github.com/tracel-ai/burn/blob/main/crates/burn-import/onnx-tests/README.md). ## Implementing a New Operator To extend the capabilities of the Burn library by supporting new operations imported from ONNX graphs, developers must go through a few systematic steps. Here, we detail the process, using the implementation of the `Squeeze` operation to illustrate points as needed. All file/directory paths -are relative to `burn/crates/burn-import/`. +are relative to the root of the burn repository. ### Step 1: Visibility -To make a new operation accessible to the rest of the Burn project, you need to declare the module -within the -[`mod.rs` file](https://github.com/tracel-ai/burn/blob/925716f89d0249cbc6bd14f85f40967bd7ef80a8/crates/burn-import/src/burn/node/mod.rs#L43) -located in the `src/burn/node/` directory. +To make a new operation accessible, there are two key modules to update: + +1. In `crates/onnx-ir/src/node/mod.rs`, add your new operation module to make it visible within the + IR +2. In `crates/burn-import/src/burn/node/mod.rs`, make the corresponding node type visible within + burn-import ### Step 2: Node Implementation -#### Within Onnx-IR +#### Within onnx-ir + +The `onnx-ir` crate handles the Intermediate Representation (IR) of ONNX models. For each operation: + +1. Add the operation to the `NodeType` enum in `crates/onnx-ir/src/ir.rs`. + +2. Create a new module file in `crates/onnx-ir/src/node/.rs`. This file should + include: + + - A `_config` function to extract operation parameters + - A `_update_output` function for dimension inference + +3. If the operation might work with constants, add it to the list of node types checked for + constants in `crates/onnx-ir/src/from_onnx.rs`. -If the node type does not exist within the -[`NodeType` enum](https://github.com/tracel-ai/burn/blob/925716f89d0249cbc6bd14f85f40967bd7ef80a8/crates/onnx-ir/src/ir.rs#L273), -it will need to be added (support for custom operators is planned). If the node might be provided an -input which is a constant or the output of an identity node, it will need to be added to the list of -nodeTypes -[checked for constants](https://github.com/tracel-ai/burn/blob/925716f89d0249cbc6bd14f85f40967bd7ef80a8/crates/onnx-ir/src/from_onnx.rs#L21). -The node will need to be added to `rank_inference`, and in most cases the work parsing side will be -done. If a node requires extra parsing (such as handling an edge case like potentially remapping an -unsqueeze to a reshape) the best place for that is after check constants and prior to rank_inference -in -[`OnnxGraphBuilder::Build`](https://github.com/tracel-ai/burn/blob/925716f89d0249cbc6bd14f85f40967bd7ef80a8/crates/onnx-ir/src/from_onnx.rs#L222) +For example, the squeeze operation is defined in `crates/onnx-ir/src/node/squeeze.rs` and contains: + +- A `squeeze_config` function that extracts axes from node attributes +- A `squeeze_update_output` function that updates output dimensions by reducing input rank #### Within burn-import -Create a new file named `.rs` in the `src/burn/node/` directory. -This file will define the structure and functionality of your new operation. By convention, the -necessary information for carrying out an operation is encapsulated within a struct named -`Node`. For the `Squeeze` operation, we defined a -[struct called `SqueezeNode`](https://github.com/tracel-ai/burn/blob/925716f89d0249cbc6bd14f85f40967bd7ef80a8/crates/burn-import/src/burn/node/squeeze.rs#L8) -that holds necessary information about the input tensor, output tensor, and axes for the operation. -**If implementing a unary or binary operation, please see note below.** - -The core of integrating a new operation involves implementing the `NodeCodegen` trait for your node. -This trait defines how the node generates code during the graph compilation process. The -implementation must provide methods to define input and output types, to generate the forward pass -code, and to encapsulate the node into the more general `Node` structure. Specifically: - -- `output_types` and `input_types` return the tensor (or element) types for the output and inputs of - the node, respectively. -- `forward` generates the Rust code that performs the operation during the execution phase. The - `quote!` macro is used to generate rust code. Ensure that this is syntactically correct using Burn - code. -- `into_node` wraps the specific node in a general `Node` type, facilitating its inclusion in the - broader Burn graph structure. - -This file is also where you would put `test_codegen_nodes()`, to make sure that the generated code -works within the Burn library. +1. Create a new file named `.rs` in the `crates/burn-import/src/burn/node/` + directory. This file will define the structure and functionality of your new operation. By + convention, the necessary information for carrying out an operation is encapsulated within a + struct named `Node`. For the `Squeeze` operation, we defined a struct called + `SqueezeNode` that holds necessary information about the input tensor, output tensor, and axes + for the operation. **If implementing a unary or binary operation, please see note below.** + +2. The core of integrating a new operation involves implementing the `NodeCodegen` trait for your + node. This trait defines how the node generates code during the graph compilation process. The + implementation must provide methods to define input and output types, to generate the forward + pass code, and to encapsulate the node into the more general `Node` structure. Specifically: + + - `output_types` and `input_types` return the tensor (or element) types for the output and inputs + of the node, respectively. + - `forward` generates the Rust code that performs the operation during the execution phase. The + `quote!` macro is used to generate rust code. Ensure that this is syntactically correct using + Burn code. + - `into_node` wraps the specific node in a general `Node` type, facilitating its inclusion in the + broader Burn graph structure. + +3. This file is also where you would put `test_codegen_nodes()`, to make sure that the generated + code works within the Burn library. **For unary and binary operations:** The implementation of `NodeCodegen` is mostly implemented in -[`binary.rs`](https://github.com/tracel-ai/burn/blob/925716f89d0249cbc6bd14f85f40967bd7ef80a8/crates/burn-import/src/burn/node/binary.rs#L9) -and -[`unary.rs`](https://github.com/tracel-ai/burn/blob/76fe0ed881b3965782f78896433f8bb5e2f13a1b/crates/burn-import/src/burn/node/unary.rs#L13), -so each new operation only has to define a method to execute the function on the input(s) token -stream. +binary.rs and unary.rs, so each new operation only has to define a method to execute the function on +the input(s) token stream. ### Step 3: Registering New Operations -[Register the `NodeType::`](https://github.com/tracel-ai/burn/blob/925716f89d0249cbc6bd14f85f40967bd7ef80a8/crates/burn-import/src/onnx/to_burn.rs#L353) -and -[create an `_conversion(node: Node)` function](https://github.com/tracel-ai/burn/blob/925716f89d0249cbc6bd14f85f40967bd7ef80a8/crates/burn-import/src/onnx/to_burn.rs#L1263), -both in `src/onnx/to_burn.rs`. - -**Registering new operations in the ONNX -> Burn Conversion** -To integrate new operations from an ONNX graph into the Burn framework, each operation must be -registered within the ONNX graph conversion process. This is done in the `src/onnx/to_burn.rs` file, -where the conversion from ONNX nodes to Burn nodes is orchestrated. - -In the `into_burn()` method of the `OnnxGraph` struct, operations are matched with their -corresponding conversion functions. This method iterates over each node in the ONNX graph and, -depending on the node type, calls a specific conversion function that translates the ONNX node into -a corresponding Burn node. +1. In `crates/burn-import/src/onnx/to_burn.rs`, add the operation to the match statement in the + `into_burn()` method: ```rust -impl OnnxGraph { +impl ParsedOnnxGraph { pub fn into_burn(self) -> BurnGraph { - let mut graph = BurnGraph::::default(); - let mut unsupported_ops = vec![]; - - for node in self.nodes { + // ... + for node in self.0.nodes { match node.node_type { - NodeType::Add => graph.register(Self::add_conversion(node)), - // Other operations... + // ... NodeType::Squeeze => graph.register(Self::squeeze_conversion(node)), - // Add new operations here + // Add your new operation here } } } } ``` -Here, the `NodeType::Squeeze` matches the ONNX node type with the `squeeze_conversion()` function -that you define to handle the specific attributes and settings of a Squeeze operation. +2. Create a conversion function that creates an instance of your Burn node: -**Define the Conversion Function** -Each operation conversion function extracts necessary information from the ONNX node and constructs -a corresponding Burn node. The structure of these functions generally includes: +```rust +fn squeeze_conversion(node: Node) -> SqueezeNode { + let input = TensorType::from(node.inputs.first().unwrap()); + let output = TensorType::from(node.outputs.first().unwrap()); + let axes = squeeze_config(&node); -1. Extracting input and output tensors from the node. -2. Retrieving and processing operation-specific configurations. -3. Calling `_config()` to parse ONNX node configurations. -4. Creating an instance of the appropriate Burn node - ([defined in step 2](#step-2-node-implementation)) using this information. + SqueezeNode::new(input, output, axes) +} +``` + +This function extracts the necessary information from the ONNX node and passes it to your node's +constructor. ### Step 4: Create a Config Function -[Create an `_config(curr: &Node)`](https://github.com/tracel-ai/burn/blob/925716f89d0249cbc6bd14f85f40967bd7ef80a8/crates/burn-import/src/onnx/op_configuration.rs#L1847) -in `src/onnx/op_configuration.rs`. - -The `squeeze_conversion()` function in `src/onnx/to_burn.rs` from the previous step calls the -`squeeze_config()` function in `src/onnx/op_configuration.rs` in order the parse the ONNX node's -attributes to extract parameters specific to the Squeeze operation. In this case, the axes along -which the squeeze operation is performed. - -> 📘 Info: Understanding Generic `config` Patterns -> -> The `_config()` functions follow a similar pattern: -> -> 1. Extract tensor or scalar types for inputs and outputs. -> 2. Validate the input structure and types for each node, ensuring they conform to expected formats -> (panicking if not). -> 3. Parse and convert configurations or parameters specific to each operation. -> 4. Create and return a node specific to the operation, initialized with extracted values and -> configurations. -> -> For example, config functions handle specific settings like kernel size for pooling or handling -> different tensor and scalar types for power operations. - -These functions translate the more varied and flexible structure of ONNX nodes into the more -structured and type-safe environment of Rust and the Burn framework. Spec compliance is dealt with -here. - -### Step 5: Dimension Inference - -If needed, -[create a rank inference function](https://github.com/tracel-ai/burn/blob/925716f89d0249cbc6bd14f85f40967bd7ef80a8/crates/onnx-ir/src/rank_inference.rs#L410), -called `_update_output(node: &mut Node)` in `src/rank_inference.rs`. If dimensions remain -unchanged, use the `same_as_input()` function, for example -`NodeType::AveragePool1d => same_as_input(node)`. Match the `NodeType` to the function in the -`rank_inference()` match block. - -Dimension inference is an important step in the conversion process where Burn determines the -dimensions of each output tensor based on the operation. -[The `rank_inference()`](https://github.com/tracel-ai/burn/blob/925716f89d0249cbc6bd14f85f40967bd7ef80a8/crates/onnx-ir/src/rank_inference.rs#L14) -function is responsible for determining the dimensions of the output tensors for each node in the -graph. It does this by: - -1. **Matching the Node Type**: The function uses a `match` statement on the `node_type` of each node - to apply the correct dimension inference logic depending on the operation. -2. **Applying Operation Specific Logic**: For each operation, a specific inference function is - called that encapsulates the rules for how output dimensions should be derived from the inputs. - -For the Squeeze operation, the dimension inference is handled by the `squeeze_update_output()` -function, which is specifically tailored to handle the nuances of the squeeze operation, which is -currently not that nuanced. The output tensor should be (dimensions of input tensor) - 1. - -> 📘 Info: How `squeeze_update_output()` Works -> -> 1. Validation of axes input: We first check if the second input of the node contains a list of -> integers, which represent the axes along which the squeeze operation is applied. The function -> also validates that only one axis is specified for squeezing, ensuring that the operation's -> requirements within Burn are followed. -> 2. Extracting input dimensions: The input tensor's dimension is extracted from the first input. -> 3. Configuring output dimensions: The output tensor's dimensions are then set to be one less than -> the input tensor’s dimensions, reflecting the reduction in dimensions caused by the squeeze -> operation. -> 4. The function includes several checks that throw errors (panics) if the inputs do not meet the -> expected types or configurations, such as when the axes are not provided as an integer list or -> if the input type is not a tensor. - -By invoking this function within the `rank_inference()` match block, the output dimensions of each -node are updated before the graph is finalized. This ensures that all subsequent operations within -the graph can rely on correct tensor sizes, which is critical for both compiling the graph and for -runtime execution efficiency. - -If something is amiss (ie weird panics are happening), after doing this step and the dimensions of -your output tensor differs from the dimensions of your input, see the warning at the very end. +In `crates/onnx-ir/src/node/.rs`, create a config function that extracts +operation-specific parameters from the ONNX node: + +```rust +pub fn squeeze_config(curr: &Node) -> Vec { + let axes = curr + .attrs + .iter() + .filter_map(|(key, value)| { + if key == "axes" { + Some(value.clone().into_i64s()) + } else { + None + } + }) + .next() + .unwrap_or_else(Vec::new); + + match curr.inputs.first().unwrap().clone().ty { + ArgType::Tensor(tensor) => tensor, + _ => panic!("Only tensor input is valid"), + }; + + axes +} +``` + +This config function is responsible for parsing the ONNX node attributes and extracting +operation-specific parameters. In this case, it extracts the "axes" attribute from the squeeze +operation. + +### Step 5: Rank Inference + +In `crates/onnx-ir/src/node/.rs`, implement a rank inference function that updates +the output rank based on the operation: + +```rust +pub fn squeeze_update_output(node: &mut Node) { + // Extract axes information + let axes = /* ... */; + let input_rank = /* ... */; + let output_rank = input_rank - axes.len(); + + // Update output rank + node.outputs[0].ty = ArgType::Tensor(TensorType { + elem_type: node.inputs[0].ty.elem_type().clone(), + rank: output_rank, + static_shape: None, + }); +} +``` + +Then register this function in `crates/onnx-ir/src/rank_inference.rs` by adding it to the match +statement: + +```rust +pub fn rank_inference(node: &mut Node) { + match node.node_type { + // ... + NodeType::Squeeze => squeeze_update_output(node), + // Add your new operation here + } +} +``` + +The `rank_inference.rs` file is responsible for determining the output tensor rank for each node in +the graph. + +If the rank remains unchanged, you can use helper functions like `same_as_input()` or +`same_as_input_broadcast()` instead of writing a custom update function. ### Step 6: Integrate into the Graph Building Process -When a new node type is introduced, it must be added to the -[`Node` enum](https://github.com/tracel-ai/burn/blob/925716f89d0249cbc6bd14f85f40967bd7ef80a8/crates/burn-import/src/burn/node/base.rs#L85) -and -[`match_all!` macro](https://github.com/tracel-ai/burn/blob/925716f89d0249cbc6bd14f85f40967bd7ef80a8/crates/burn-import/src/burn/node/base.rs#L138) -in `src/burn/node/base.rs`. +When a new node type is introduced, it must be added to the `Node` enum in +`crates/burn-import/src/burn/node/base.rs` and the `match_all!` macro in the same file. The `Node` enum abstracts over different types of operations (nodes) within a network graph. Each -variant of the enum corresponds to a specific type of operation, and it encapsulates the -operation-specific data structures (like `SqueezeNode1`) that was -[defined in step 2](#step-2-node-implementation). +variant of the enum corresponds to a specific type of operation and encapsulates the +operation-specific data structures (like `SqueezeNode`) that were defined in step 2. ### Step 7: Add Newly Supported Op! -As a reward, add an extra check to -[SUPPORTED-ONNX-OPS.md](https://github.com/tracel-ai/burn/blob/925716f89d0249cbc6bd14f85f40967bd7ef80a8/crates/burn-import/SUPPORTED-ONNX-OPS.md)! +As a reward, add an extra check to `crates/burn-import/SUPPORTED-ONNX-OPS.md`! + +### Lifting Constant Nodes + +If your operation takes inputs from constant nodes (such as weights in Conv1d, shape tensors in +Reshape, etc.), you need to add your operation's `NodeType` to the `LIFT_CONSTANTS_FOR_NODE_TYPES` +array in `crates/onnx-ir/src/from_onnx.rs`. + +```rust +const LIFT_CONSTANTS_FOR_NODE_TYPES: [NodeType; 16] = [ + NodeType::BatchNormalization, + // other operations... + NodeType::Squeeze, + NodeType::Unsqueeze, + // Add your operation here if it needs constants to be processed +]; +``` + +"Lifting" constants means converting Constant nodes into direct input values. This is similar to how +ONNX initializers work. For example, instead of having a separate Constant node providing weights to +a Convolution operation, the weights are directly embedded as values in the Convolution node's +inputs. + +This transformation makes it easier to: -### Misc: +1. Access the constant values during node configuration +2. Process operations like Conv1d that expect weights as direct inputs +3. Handle shape-defining inputs needed for operations like Reshape -> 🚧 **Warning**: Dimension Changes -> -> If your operation changes the dimensions of the input tensor, you may need to modify the -> [`LIFT_CONSTANTS_FOR_NODE_TYPES` enum](https://github.com/tracel-ai/burn/blob/925716f89d0249cbc6bd14f85f40967bd7ef80a8/crates/onnx-ir/src/from_onnx.rs#L21) -> in `src/from_onnx.rs` by adding the `NodeType` of your operation to it. +Without this, operations that need to extract configuration from constant inputs (such as shapes, +weights, or other parameters) would not work correctly because they wouldn't have direct access to +those constant values. ## Testing -- Unit tests for the Burn graph to Rust source code conversion are mandatory. -- End-to-end tests should include a test ONNX model and its expected output for each operator. +When implementing a new operator, there are several levels of testing to consider: + +### Unit Testing + +- **Node Configuration**: Write unit tests for the `_config` function in + `crates/onnx-ir/src/node/.rs` to verify that it correctly extracts parameters from + ONNX nodes. + +- **Rank Inference**: Test the `_update_output` function to ensure it correctly + computes output ranks. + +- **Code Generation**: Test the Node implementation in `burn-import` to verify that it generates + correct Rust code. + +### Integration Testing + +- Create small ONNX models that use your operator and test the end-to-end conversion process +- Ensure the generated Rust code compiles and produces the expected outputs +- Add these tests to `crates/burn-import/onnx-tests/tests/test_onnx.rs` + +### End-to-End Testing + +- Test with realistic ONNX models that use your operator in conjunction with others +- Verify that inputs and outputs match between the original ONNX model and the converted Burn model +- Include models that test edge cases (e.g., different input shapes, parameter combinations) + +Testing both the rank inference and node configuration is particularly important as these components +directly affect the correctness of the conversion process. Incorrect rank inference can lead to +mismatched tensor shapes, while incorrect configuration can cause runtime errors or incorrect +results. ## Resources diff --git a/crates/burn-import/DEVELOPMENT.md b/crates/burn-import/DEVELOPMENT.md deleted file mode 100644 index 2d122a042b..0000000000 --- a/crates/burn-import/DEVELOPMENT.md +++ /dev/null @@ -1,80 +0,0 @@ -# ONNX to Burn Conversion Tool: Development Guide - -This guide offers in-depth design insights and step-by-step procedures for developers working on the -ONNX to Burn conversion tool. This tool allows the importation of ONNX models into the Burn deep -learning framework written in Rust. It converts both ONNX models to Rust source code and model -weights to Burn state files. - -## Table of Contents - -1. [Design Overview](#Design-Overview) - 1. [Design Goals](#Design-Goals) - 2. [Design Decisions](#Design-Decisions) -2. [Adding New Operators](#Adding-New-Operators) -3. [Testing](#Testing) -4. [Resources](#Resources) - -## Design Overview - -### Design Goals - -- Perform best-effort conversion of ONNX models to Rust source code via Burn APIs. -- Convert ONNX model weights to Burn state files. -- Support ONNX models generated by PyTorch (ONNX Opset 16). -- Produce easy-to-understand and modifiable models. -- Ensure the generated models are trainable using Burn APIs. - -### Design Decisions - -- Limit interaction with ONNX to the Intermediate Representation (IR) stage to simplify the process. -- Ensure operator behavior consistency across different OpSet versions. -- Exclude any ONNX/Protobuf-specific logic from the Burn graph. - -The conversion process involves three main stages: - -1. Convert ONNX model to Intermediate Representation (IR). -2. Translate IR to a Burn graph. -3. Generate Rust source code from the Burn graph. - -## Adding New Operators - -To extend `burn-import` with support for new ONNX operators, follow these steps: - -1. **Create PyTorch Script**: Place a PyTorch script using the new operator under - `./burn-import/onnx-tests/tests//.py`. Make sure to print both input and output tensors - for end-to-end testing. - -2. **Generate ONNX Model**: Run the PyTorch script to produce an ONNX model. - -3. **Visualize ONNX Model**: Use [Netron](https://github.com/lutzroeder/netron) to verify the ONNX - model contains the expected operators. - -4. **Generate IR and Burn Graph**: Navigate to `./burn-import/` and run: - - ``` - cargo r -- ./onnx-tests/tests//.onnx ./out - ``` - -5. **Implement Missing Operators**: If you encounter an error stating that an operator is - unsupported, implement it. The `./out/my-model.graph.txt` should provide relevant information. - -6. **Inspect Generated Files**: The `my-model.graph.txt` contains IR details, `my-model.rs` holds - the Burn model in Rust code, and `my-model.json` includes the model data. - -7. **Add End-to-End Test**: Include the test in `./burn-import/onnx-tests/tests/onnx_tests.rs`. - Further details can be found in the [onnx-tests README](./onnx-tests/README.md). - -## Testing - -- Unit tests for the Burn graph to Rust source code conversion are mandatory. -- End-to-end tests should include a test ONNX model and its expected output for each operator. - -## Resources - -1. [PyTorch to ONNX](https://pytorch.org/docs/stable/onnx.html) -2. [ONNX to PyTorch](https://github.com/ENOT-AutoDL/onnx2torch) -3. [ONNX Introduction](https://onnx.ai/onnx/intro/) -4. [ONNX Operators](https://onnx.ai/onnx/operators/index.html) -5. [ONNX Protos](https://onnx.ai/onnx/api/classes.html) -6. [ONNX Optimizer](https://github.com/onnx/optimizer) -7. [Netron](https://github.com/lutzroeder/netron) diff --git a/crates/burn-import/README.md b/crates/burn-import/README.md index d6a898eccd..7c784c2659 100644 --- a/crates/burn-import/README.md +++ b/crates/burn-import/README.md @@ -1,15 +1,72 @@ -# Importing Models +# Burn Import -The Burn project supports the import of models from various frameworks, emphasizing efficiency and -compatibility. Currently, it handles two primary model formats: +The `burn-import` crate provides tools for importing models from other machine learning frameworks +into the Burn ecosystem. It allows users to leverage pre-trained models from popular frameworks +while benefiting from Burn's performance and Rust integration. -1. [ONNX](https://burn.dev/burn-book/import/onnx-model.html): Facilitates direct import, ensuring the - model's performance and structure are maintained. +## Supported Formats -2. [PyTorch](https://burn.dev/burn-book/import/pytorch-model.html): Enables the loading of PyTorch model - weights into Burn’s native model architecture, ensuring seamless integration. +### ONNX -## Contribution +[ONNX](https://onnx.ai/) (Open Neural Network Exchange) is an open standard for machine learning +interoperability. Burn supports importing ONNX models with opset version 16 or higher, converting +them to native Burn code and model weights. -Interested in contributing to `burn-import`? Check out our [development guide](DEVELOPMENT.md) for -more information. +- **Convert models from**: PyTorch, TensorFlow, Keras, scikit-learn, and other ONNX-compatible + frameworks +- **Full code generation**: Generates Rust source code that matches the original model's + architecture +- **Complete state handling**: Extracts and converts model weights to Burn's format + +See the [ONNX import documentation](https://burn.dev/burn-book/import/onnx-model.html) for usage +details. + +### PyTorch + +Burn supports direct import of PyTorch model weights (.pt/.pth files) into Burn model architectures: + +- **Direct weight loading**: Map PyTorch layer weights to equivalent Burn layers +- **Efficient conversion**: No need for ONNX as an intermediate format +- **Compatible with**: Common PyTorch architectures and custom models + +See the [PyTorch import documentation](https://burn.dev/burn-book/import/pytorch-model.html) for +usage details. + +## Extending Support + +### Adding New ONNX Operators + +The crate is designed to be extensible. To add support for new ONNX operators: + +1. Implement the operator in the `onnx-ir` crate +2. Add the operator conversion logic in `src/onnx/to_burn.rs` +3. Register the operator in the conversion pipeline + +See our +[ONNX to Burn conversion guide](https://github.com/tracel-ai/burn/blob/main/contributor-book/src/guides/onnx-to-burn-conversion-tool.md) +for detailed instructions. + +### Adding New Import Formats + +To add support for a new model format: + +1. Create a new module under `src/` for the format +2. Implement the parsing and conversion logic +3. Add CLI support for the new format + +## Testing + +The `onnx-tests` subcrate contains comprehensive tests for the ONNX import functionality: + +- **Unit tests**: Verify specific operator conversions +- **End-to-end tests**: Ensure complete models are correctly imported +- **Comparison tests**: Validate that imported models produce the same outputs as original models + +See the +[ONNX tests README](https://github.com/tracel-ai/burn/blob/main/crates/burn-import/onnx-tests/README.md) +for details on testing. + +## Supported ONNX Operators + +For a complete list of supported ONNX operators, see the +[Supported ONNX Operators table](https://github.com/tracel-ai/burn/blob/main/crates/burn-import/SUPPORTED-ONNX-OPS.md). diff --git a/crates/burn-import/onnx-tests/README.md b/crates/burn-import/onnx-tests/README.md index a2f9fe28ff..ef19bd8cde 100644 --- a/crates/burn-import/onnx-tests/README.md +++ b/crates/burn-import/onnx-tests/README.md @@ -1,49 +1,187 @@ # ONNX Tests -This crate contains ONNX models that are utilized in testing the conversion of ONNX to Burn source -code through the `burn-import` crate. The tests are designed as end-to-end tests, ensuring that ONNX -models are accurately converted into Burn source code. Of utmost importance is verifying that the -converted Burn source code compiles without errors and produces the same output as the original ONNX -model. +This crate contains ONNX models used for testing the conversion process from ONNX to Burn source +code through the `burn-import` crate. These tests are designed as end-to-end tests, ensuring that +ONNX models are accurately converted into Burn source code that compiles without errors and produces +the same outputs as the original ONNX model. -Here is the directory structure of this crate: +## Directory Structure -- `tests/`: This directory contains the ONNX model and the Python script to generate it. -- `tests//.onnx`: The ONNX model is generated by the script. -- `tests//.py`: This is the Python script responsible for generating the ONNX model - using PyTorch. -- `tests/test_onnx.rs`: This is the main test file, where all the tests are contained. -- `build.rs`: This build script generates the ONNX models and is executed by `cargo test` before - running the actual tests. +- `tests//`: Each operator or model has its own directory +- `tests//.py`: Python script that generates the ONNX model +- `tests//.onnx`: Generated ONNX model +- `tests/test_onnx.rs`: Main test file containing all end-to-end tests +- `build.rs`: Build script that generates ONNX models before running tests -## Setting up your python environment +## Setting Up Your Python Environment -## With rye +### Using uv (Recommended) -You can use [`uv`](https://docs.astral.sh/uv/) to set up a Python environment with the necessary dependencies. To do so, cd into the `onnx-tests` directory and run `uv sync`. Assuming you are in the top-level `burn` directory, you can run the following command: +You can use [`uv`](https://docs.astral.sh/uv/) to set up a Python environment with the necessary +dependencies: ```sh cd crates/burn-import/onnx-tests uv sync # or uv sync -f ``` -This will create a .venv in the `onnx-tests` directory. +This will create a `.venv` directory with all the required dependencies. -You need to install `onnx==1.15.0` and `torch==2.1.1` in your python environment to add a new test +### Manual Setup -## Adding new tests +If you prefer to set up manually, you need to install the following packages: -Here are the steps to add a new test: +```sh +pip install onnx==1.15.0 torch==2.1.1 +``` + +Additional dependencies are specified in `requirements.lock`. + +## Creating a Test for a New Operator + +### 1. Create the Python Script + +Create a new directory and Python script: + +```sh +mkdir -p tests/my_new_op +touch tests/my_new_op/my_new_op.py +``` + +Your script should: + +- Import the necessary PyTorch modules +- Define a model that uses your operator +- Generate test inputs +- Export the model to ONNX format +- Run the model on test inputs and print the output + +Example structure: + +```python +import torch +import torch.nn as nn +import torch.onnx + +# Define a simple model that uses your operator +class MyModel(nn.Module): + def __init__(self): + super().__init__() + # ... + + def forward(self, x): + # Use your operator here + return my_operation(x) + +# Create an instance of the model +model = MyModel() + +# Generate test input +input_tensor = torch.randn(1, 3, 224, 224) + +# Export the model to ONNX +torch.onnx.export( + model, + input_tensor, + "tests/my_new_op/my_new_op.onnx", + opset_version=16, + input_names=["input"], + output_names=["output"], + do_constant_folding=False # Set to False if you want to preserve specific operators +) + +# Run the model with the test input and print output for test verification +output = model(input_tensor) +print("Input:", input_tensor) +print("Output:", output) +``` + +### 2. Add the Build Step + +Update `build.rs` to include your new model. + +### 3. Add the Test in test_onnx.rs + +First, add your model to the include_models! macro: + +```rust +include_models! { + // Other models... + my_new_op, +} +``` + +Then add a test function. + +## Best Practices for ONNX Testing + +### Model Generation + +1. **Keep Models Simple**: Focus on testing a single operator or a small group of related operators. + +2. **Control Randomness**: Use fixed seeds in your Python scripts to ensure reproducible results: + + ```python + torch.manual_seed(42) + ``` + +3. **Print Test Values**: Always print your input and output tensors in the Python script for + reference. + +4. **Verify Operators**: Use [Netron](https://github.com/lutzroeder/netron) to verify your ONNX + model contains the expected operators. + +5. **Handle Constant Folding**: If PyTorch is optimizing away your operators, use: + ```python + torch.onnx.export(..., do_constant_folding=False) + ``` + +### Test Implementation + +1. **Test Multiple Cases**: Include tests for different input shapes, data types, and parameter + combinations. + +2. **Edge Cases**: Test edge cases like empty tensors, single-element tensors, or very large + tensors. + +3. **Parameter Variations**: If your operator has configurable parameters, test different parameter + values. + +4. **Numerical Precision**: Use appropriate tolerance levels based on operation sensitivity. + +5. **Error Cases**: Test that invalid inputs are properly handled and appropriate errors are raised. + +## Running Tests + +Run all tests with: + +```sh +cargo test +``` + +Run a specific test with: + +```sh +cargo test test_my_new_op +``` + +## Debugging Failed Tests + +If a test fails, you can: + +1. **Inspect ONNX Model**: Use Netron to visualize the model structure. + +2. **Check Intermediate Values**: Add print statements in your Python script to see intermediate + tensor values. + +3. **Generate Rust Code**: Use the `burn-import` CLI to generate Rust code and inspect it: + + ```sh + cargo run -p burn-import -- tests/my_new_op/my_new_op.onnx ./out + ``` + +4. **Trace Through Conversion**: Add debug logging in your implementation to see where things might + be going wrong. -1. Add your Python script to the `tests/` directory. Refer to existing scripts for examples. -2. Run your Python script to generate the ONNX model and inspect the output of the model with the - test data. Use the inputs and outputs in your test. -3. Make sure the ONNX output contains the desired operators by verifying with the - [Netron](https://github.com/lutzroeder/netron) app. Sometimes PyTorch will optimize the model and - remove operators that are not necessary for the model to run. If this happens, you can disable - optimization by setting `torch.onnx.export(..., do_constant_folding=False)`. -4. Add an entry to the `build.rs` file to account for the generation of the new ONNX model. -5. Add an entry to `include_models!` in `tests/test_onnx.rs` to include the new ONNX model in the - tests. -6. Include a test in `tests/test_onnx.rs` to test the new ONNX model. -7. Run `cargo test` to ensure your test passes. +5. **Numerical Issues**: If values are close but not equal, it might be a numerical precision issue. + Try adjusting tolerance. diff --git a/crates/onnx-ir/README.md b/crates/onnx-ir/README.md index 69e2cf8085..be6feb69c5 100644 --- a/crates/onnx-ir/README.md +++ b/crates/onnx-ir/README.md @@ -1,13 +1,60 @@ # ONNX-IR -ONNX-IR is a pure Rust library for parsing ONNX models into an intermediate representation that can -be used to generate code for various ML/DL frameworks. It's part of the Burn project, with key -features including ONNX model parsing, rank inference, and node remapping. The crate supports -converting ONNX models to Burn graphs and includes utilities for handling constants and graph -transformations. +ONNX-IR is a pure Rust library for parsing ONNX models into an intermediate representation (IR) that +can be used to generate code for various ML/DL frameworks. It's a core component of the Burn model +import system, providing a clean abstraction layer between ONNX protobuf structures and Burn's +tensor operations. -For a full list of currently supported operators, please check -[here](https://github.com/tracel-ai/burn/blob/main/crates/burn-import/SUPPORTED-ONNX-OPS.md) +## Features + +- **ONNX Model Parsing**: Converts ONNX protobuf definitions to a clean Rust IR +- **Rank Inference**: Automatically infers tensor ranks throughout the computational graph +- **Constant Folding**: Lifts constants from separate nodes into direct node inputs +- **Node Remapping**: Maps generic ONNX operations to dimension-specific implementations +- **Framework Independence**: The IR is designed to be independent of any specific ML framework + +## Architecture + +The ONNX-IR crate is designed with the following components: + +- **IR Core** (`ir.rs`): Defines the core data structures such as `Node`, `NodeType`, `Argument`, + etc. +- **Protocol Conversion** (`proto_conversion.rs`): Converts ONNX protobuf structures to IR +- **ONNX Parsing** (`from_onnx.rs`): Handles the parsing of ONNX models into the IR +- **Rank Inference** (`rank_inference.rs`): Computes output tensor ranks for each operation +- **Node Implementations** (`node/`): Contains operation-specific configurations and rank inference + functions +- **Node Remapping** (`node_remap.rs`): Maps generic ONNX operations to dimension-specific + alternatives + +## Usage + +ONNX-IR is typically used through the `burn-import` crate, but can also be used standalone: + +```rust +use onnx_ir::{parse_onnx, OnnxGraph}; +use std::path::Path; + +// Parse an ONNX model into the IR +let graph: OnnxGraph = parse_onnx(Path::new("path/to/model.onnx")); + +// Work with the IR +for node in &graph.nodes { + println!("Node: {}, Type: {:?}", node.name, node.node_type); + + // Access inputs and outputs + for input in &node.inputs { + println!(" Input: {}", input.name); + } + + for output in &node.outputs { + println!(" Output: {}", output.name); + } +} + +// Convert to another framework's representation +// (This is typically done by burn-import or another conversion layer) +``` ## ONNX Compatibility @@ -35,8 +82,28 @@ inferred_model = shape_inference.infer_shapes(upgraded_model) onnx.save(inferred_model, 'upgraded_model.onnx') ``` -For a full list of currently supported operators, please check -[here](https://github.com/tracel-ai/burn/blob/main/crates/burn-import/SUPPORTED-ONNX-OPS.md) +## Adding New Operators + +To add support for a new ONNX operator: + +1. Add a new NodeType variant in `ir.rs` +2. Create a new module in `node/.rs` with: + - A configuration function that extracts parameters from ONNX nodes + - A rank inference function that updates output tensor ranks +3. Register the rank inference function in `rank_inference.rs` +4. If the operation works with constants, add it to `LIFT_CONSTANTS_FOR_NODE_TYPES` in + `from_onnx.rs` + +For detailed implementation guidance, see the +[ONNX to Burn conversion guide](https://github.com/tracel-ai/burn/blob/main/contributor-book/src/guides/onnx-to-burn-conversion-tool.md). + +## Supported Operators + +For a full list of currently supported ONNX operators, please see the +[Supported ONNX Operators table](https://github.com/tracel-ai/burn/blob/main/crates/burn-import/SUPPORTED-ONNX-OPS.md). + +## Integration with Burn -To see how to use this for generating burn graphs, see -[here](crates/burn-import/src/onnx/to_burn.rs). +ONNX-IR serves as the foundation for the ONNX import support in Burn. The conversion from ONNX-IR to +Burn graphs is implemented in +[`burn-import/src/onnx/to_burn.rs`](https://github.com/tracel-ai/burn/blob/main/crates/burn-import/src/onnx/to_burn.rs). From d571dd69014897aea1a4feb3deb27c2e6cf539c0 Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Thu, 8 May 2025 13:41:11 -0500 Subject: [PATCH 24/37] Remove "Features" section --- crates/onnx-ir/README.md | 8 -------- 1 file changed, 8 deletions(-) diff --git a/crates/onnx-ir/README.md b/crates/onnx-ir/README.md index be6feb69c5..9cc248a8a6 100644 --- a/crates/onnx-ir/README.md +++ b/crates/onnx-ir/README.md @@ -5,14 +5,6 @@ can be used to generate code for various ML/DL frameworks. It's a core component import system, providing a clean abstraction layer between ONNX protobuf structures and Burn's tensor operations. -## Features - -- **ONNX Model Parsing**: Converts ONNX protobuf definitions to a clean Rust IR -- **Rank Inference**: Automatically infers tensor ranks throughout the computational graph -- **Constant Folding**: Lifts constants from separate nodes into direct node inputs -- **Node Remapping**: Maps generic ONNX operations to dimension-specific implementations -- **Framework Independence**: The IR is designed to be independent of any specific ML framework - ## Architecture The ONNX-IR crate is designed with the following components: From 67d0067eeab681ffad71a73d2015ceb797fe976a Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Thu, 8 May 2025 13:48:04 -0500 Subject: [PATCH 25/37] Converted links to resources --- crates/onnx-ir/README.md | 32 ++++++++++---------------------- 1 file changed, 10 insertions(+), 22 deletions(-) diff --git a/crates/onnx-ir/README.md b/crates/onnx-ir/README.md index 9cc248a8a6..de80a3beb9 100644 --- a/crates/onnx-ir/README.md +++ b/crates/onnx-ir/README.md @@ -74,28 +74,16 @@ inferred_model = shape_inference.infer_shapes(upgraded_model) onnx.save(inferred_model, 'upgraded_model.onnx') ``` -## Adding New Operators +## Resources -To add support for a new ONNX operator: +- **ONNX to Burn Conversion Guide**: For detailed implementation guidance on adding new operators, + see the + [ONNX to Burn conversion guide](https://github.com/tracel-ai/burn/blob/main/contributor-book/src/guides/onnx-to-burn-conversion-tool.md). -1. Add a new NodeType variant in `ir.rs` -2. Create a new module in `node/.rs` with: - - A configuration function that extracts parameters from ONNX nodes - - A rank inference function that updates output tensor ranks -3. Register the rank inference function in `rank_inference.rs` -4. If the operation works with constants, add it to `LIFT_CONSTANTS_FOR_NODE_TYPES` in - `from_onnx.rs` +- **Supported ONNX Operators**: For a full list of currently supported ONNX operators, please see + the + [Supported ONNX Operators table](https://github.com/tracel-ai/burn/blob/main/crates/burn-import/SUPPORTED-ONNX-OPS.md). -For detailed implementation guidance, see the -[ONNX to Burn conversion guide](https://github.com/tracel-ai/burn/blob/main/contributor-book/src/guides/onnx-to-burn-conversion-tool.md). - -## Supported Operators - -For a full list of currently supported ONNX operators, please see the -[Supported ONNX Operators table](https://github.com/tracel-ai/burn/blob/main/crates/burn-import/SUPPORTED-ONNX-OPS.md). - -## Integration with Burn - -ONNX-IR serves as the foundation for the ONNX import support in Burn. The conversion from ONNX-IR to -Burn graphs is implemented in -[`burn-import/src/onnx/to_burn.rs`](https://github.com/tracel-ai/burn/blob/main/crates/burn-import/src/onnx/to_burn.rs). +- **Burn Integration**: ONNX-IR serves as the foundation for the ONNX import support in Burn. The + conversion from ONNX-IR to Burn graphs is implemented in + [`burn-import/src/onnx/to_burn.rs`](https://github.com/tracel-ai/burn/blob/main/crates/burn-import/src/onnx/to_burn.rs). From 2ee93176d57427fe1f7f21a6e1e450f3f9c3c42e Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Thu, 8 May 2025 13:50:02 -0500 Subject: [PATCH 26/37] Remove deadcode --- crates/burn-import/src/burn/node/expand.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/crates/burn-import/src/burn/node/expand.rs b/crates/burn-import/src/burn/node/expand.rs index c0afecfe6a..3b3143c5dc 100644 --- a/crates/burn-import/src/burn/node/expand.rs +++ b/crates/burn-import/src/burn/node/expand.rs @@ -174,8 +174,6 @@ mod tests { fn test_codegen_expand_tensor() { let mut graph = BurnGraph::::default(); - // let shape_tensor_type = TensorType::new_int("tensor3", 1); - graph.register(ExpandNode::new( TensorType::new_float("tensor1", 4), TensorType::new_float("tensor2", 4), From f27645e3354b31bb440d4bd8d58bbd49cf14c427 Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Thu, 8 May 2025 14:15:09 -0500 Subject: [PATCH 27/37] Shorten burn-import readme --- crates/burn-import/README.md | 82 ++++++++---------------------------- 1 file changed, 17 insertions(+), 65 deletions(-) diff --git a/crates/burn-import/README.md b/crates/burn-import/README.md index 7c784c2659..7ce4beff84 100644 --- a/crates/burn-import/README.md +++ b/crates/burn-import/README.md @@ -1,72 +1,24 @@ # Burn Import -The `burn-import` crate provides tools for importing models from other machine learning frameworks -into the Burn ecosystem. It allows users to leverage pre-trained models from popular frameworks -while benefiting from Burn's performance and Rust integration. +The `burn-import` crate enables seamless integration of pre-trained models from popular machine +learning frameworks into the Burn ecosystem. This functionality allows you to leverage existing +models while benefiting from Burn's performance optimizations and native Rust integration. -## Supported Formats +## Supported Import Formats -### ONNX +Burn currently supports three primary model import formats, each serving different use cases: -[ONNX](https://onnx.ai/) (Open Neural Network Exchange) is an open standard for machine learning -interoperability. Burn supports importing ONNX models with opset version 16 or higher, converting -them to native Burn code and model weights. +| Format | Description | Use Case | +| ----------------------------------------------------------------------------------- | ----------------------------------------- | ------------------------------------------------------------------------------------------------------ | +| [**ONNX** (Guide)](https://burn.dev/burn-book/import/onnx-model.html) | Open Neural Network Exchange format | Direct import of complete model architectures and weights from any framework that supports ONNX export | +| [**PyTorch** (Guide)](https://burn.dev/burn-book/import/pytorch-model.html) | PyTorch weights (.pt, .pth) | Loading weights from PyTorch models into a matching Burn architecture | +| [**Safetensors** (Guide)](https://burn.dev/burn-book/import/safetensors-model.html) | Hugging Face's model serialization format | Loading a model's tensor weights into a matching Burn architecture | -- **Convert models from**: PyTorch, TensorFlow, Keras, scikit-learn, and other ONNX-compatible - frameworks -- **Full code generation**: Generates Rust source code that matches the original model's - architecture -- **Complete state handling**: Extracts and converts model weights to Burn's format +## ONNX Contributor Resources -See the [ONNX import documentation](https://burn.dev/burn-book/import/onnx-model.html) for usage -details. - -### PyTorch - -Burn supports direct import of PyTorch model weights (.pt/.pth files) into Burn model architectures: - -- **Direct weight loading**: Map PyTorch layer weights to equivalent Burn layers -- **Efficient conversion**: No need for ONNX as an intermediate format -- **Compatible with**: Common PyTorch architectures and custom models - -See the [PyTorch import documentation](https://burn.dev/burn-book/import/pytorch-model.html) for -usage details. - -## Extending Support - -### Adding New ONNX Operators - -The crate is designed to be extensible. To add support for new ONNX operators: - -1. Implement the operator in the `onnx-ir` crate -2. Add the operator conversion logic in `src/onnx/to_burn.rs` -3. Register the operator in the conversion pipeline - -See our -[ONNX to Burn conversion guide](https://github.com/tracel-ai/burn/blob/main/contributor-book/src/guides/onnx-to-burn-conversion-tool.md) -for detailed instructions. - -### Adding New Import Formats - -To add support for a new model format: - -1. Create a new module under `src/` for the format -2. Implement the parsing and conversion logic -3. Add CLI support for the new format - -## Testing - -The `onnx-tests` subcrate contains comprehensive tests for the ONNX import functionality: - -- **Unit tests**: Verify specific operator conversions -- **End-to-end tests**: Ensure complete models are correctly imported -- **Comparison tests**: Validate that imported models produce the same outputs as original models - -See the -[ONNX tests README](https://github.com/tracel-ai/burn/blob/main/crates/burn-import/onnx-tests/README.md) -for details on testing. - -## Supported ONNX Operators - -For a complete list of supported ONNX operators, see the -[Supported ONNX Operators table](https://github.com/tracel-ai/burn/blob/main/crates/burn-import/SUPPORTED-ONNX-OPS.md). +- [ONNX to Burn conversion guide](https://burn.dev/contributor-book/guides/onnx-to-burn-conversion-tool.html) - + Instructions for adding support for additional ONNX operators +- [ONNX tests README](https://github.com/tracel-ai/burn/blob/main/crates/burn-import/onnx-tests/README.md) - + Testing procedures for ONNX operators +- [Supported ONNX Operators table](https://github.com/tracel-ai/burn/blob/main/crates/burn-import/SUPPORTED-ONNX-OPS.md) - + Complete list of currently supported ONNX operators From 5ef940113fadcae59990cecb4aa9409ded3dd1b6 Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Thu, 8 May 2025 14:27:28 -0500 Subject: [PATCH 28/37] Remove inline comments --- crates/burn-import/src/burn/node/avg_pool1d.rs | 7 +------ crates/burn-import/src/burn/node/avg_pool2d.rs | 7 +------ crates/burn-import/src/burn/node/batch_norm.rs | 6 +----- crates/burn-import/src/burn/node/conv1d.rs | 11 +---------- crates/burn-import/src/burn/node/conv2d.rs | 14 +++++++------- crates/burn-import/src/burn/node/conv3d.rs | 14 +++++++------- .../src/burn/node/conv_transpose_2d.rs | 11 +---------- .../src/burn/node/conv_transpose_3d.rs | 16 ++++++++-------- 8 files changed, 27 insertions(+), 59 deletions(-) diff --git a/crates/burn-import/src/burn/node/avg_pool1d.rs b/crates/burn-import/src/burn/node/avg_pool1d.rs index 0002175c54..0c4b1a1693 100644 --- a/crates/burn-import/src/burn/node/avg_pool1d.rs +++ b/crates/burn-import/src/burn/node/avg_pool1d.rs @@ -105,12 +105,7 @@ mod tests { "avg_pool1d", TensorType::new_float("input", 3), TensorType::new_float("output", 3), - AvgPool1dConfig::new( - 3, // kernel_size - 1, // stride - PaddingConfig1d::Valid, // padding - true, // count_include_pad - ), + AvgPool1dConfig::new(3, 1, PaddingConfig1d::Valid, true), )); graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]); diff --git a/crates/burn-import/src/burn/node/avg_pool2d.rs b/crates/burn-import/src/burn/node/avg_pool2d.rs index af70737e4d..8ab1af2ee2 100644 --- a/crates/burn-import/src/burn/node/avg_pool2d.rs +++ b/crates/burn-import/src/burn/node/avg_pool2d.rs @@ -109,12 +109,7 @@ mod tests { "avg_pool2d", TensorType::new_float("input", 4), TensorType::new_float("output", 4), - AvgPool2dConfig::new( - [3, 3], // kernel_size - [1, 1], // strides - PaddingConfig2d::Valid, // padding - true, // count_include_pad - ), + AvgPool2dConfig::new([3, 3], [1, 1], PaddingConfig2d::Valid, true), )); graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]); diff --git a/crates/burn-import/src/burn/node/batch_norm.rs b/crates/burn-import/src/burn/node/batch_norm.rs index 76bce6a8e6..ae0147f09c 100644 --- a/crates/burn-import/src/burn/node/batch_norm.rs +++ b/crates/burn-import/src/burn/node/batch_norm.rs @@ -172,11 +172,7 @@ mod tests { TensorData::from([2f32]), TensorData::from([2f32]), TensorData::from([2f32]), - BatchNormConfig::new( - 128, // num_features - 0.00001, // epsilon - 0.1, // momentum - ), + BatchNormConfig::new(128, 0.00001, 0.1), )); graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]); diff --git a/crates/burn-import/src/burn/node/conv1d.rs b/crates/burn-import/src/burn/node/conv1d.rs index 5b517f103f..3eb0974769 100644 --- a/crates/burn-import/src/burn/node/conv1d.rs +++ b/crates/burn-import/src/burn/node/conv1d.rs @@ -149,16 +149,7 @@ mod tests { TensorType::new_float("output", 4), TensorData::from([2f32]), None, - Conv1dConfig::new( - 3, // channels_in - 3, // channels_out - 3, // kernel_size - 1, // stride - PaddingConfig1d::Valid, // padding - 1, // dilation - 1, // groups - true, // bias - ), + Conv1dConfig::new(3, 3, 3, 1, PaddingConfig1d::Valid, 1, 1, true), )); graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]); diff --git a/crates/burn-import/src/burn/node/conv2d.rs b/crates/burn-import/src/burn/node/conv2d.rs index 9e87a2afcb..1e2336b923 100644 --- a/crates/burn-import/src/burn/node/conv2d.rs +++ b/crates/burn-import/src/burn/node/conv2d.rs @@ -149,13 +149,13 @@ mod tests { TensorData::from([2f32]), None, Conv2dConfig::new( - [3, 3], // kernel_size - [3, 3], // stride - [1, 1], // dilation - PaddingConfig2d::Valid, // padding - [1, 1], // output_padding - 1, // groups - true, // bias + [3, 3], + [3, 3], + [1, 1], + PaddingConfig2d::Valid, + [1, 1], + 1, + true, ), )); diff --git a/crates/burn-import/src/burn/node/conv3d.rs b/crates/burn-import/src/burn/node/conv3d.rs index 0bc3766ad2..694ec1d2d7 100644 --- a/crates/burn-import/src/burn/node/conv3d.rs +++ b/crates/burn-import/src/burn/node/conv3d.rs @@ -149,13 +149,13 @@ mod tests { TensorData::from([2f32]), None, Conv3dConfig::new( - [3, 3], // kernel_size - [3, 3, 3], // stride - [1, 1, 1], // dilation - [1, 1, 1], // output_padding - 1, // groups - true, // bias - PaddingConfig3d::Valid, // padding + [3, 3], + [3, 3, 3], + [1, 1, 1], + [1, 1, 1], + 1, + true, + PaddingConfig3d::Valid, ), )); diff --git a/crates/burn-import/src/burn/node/conv_transpose_2d.rs b/crates/burn-import/src/burn/node/conv_transpose_2d.rs index e1880c6127..ad6932c7c2 100644 --- a/crates/burn-import/src/burn/node/conv_transpose_2d.rs +++ b/crates/burn-import/src/burn/node/conv_transpose_2d.rs @@ -150,16 +150,7 @@ mod tests { TensorType::new_float("output", 4), TensorData::from([2f32]), None, - ConvTranspose2dConfig::new( - [3, 3], // kernel_size - [1, 1], // stride - [0, 0], // dilation - [0, 0], // padding - [0, 0], // output_padding - [0, 0], // padding_out - 1, // groups - true, // bias - ), + ConvTranspose2dConfig::new([3, 3], [1, 1], [0, 0], [0, 0], [0, 0], [0, 0], 1, true), )); graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]); diff --git a/crates/burn-import/src/burn/node/conv_transpose_3d.rs b/crates/burn-import/src/burn/node/conv_transpose_3d.rs index ae56aa689d..41dc2c83f6 100644 --- a/crates/burn-import/src/burn/node/conv_transpose_3d.rs +++ b/crates/burn-import/src/burn/node/conv_transpose_3d.rs @@ -151,14 +151,14 @@ mod tests { TensorData::from([2f32]), None, ConvTranspose3dConfig::new( - [3, 3], // kernel_size - [1, 1, 1], // stride - [0, 0, 0], // dilation - [0, 0, 0], // padding - [0, 0, 0], // output_padding - [0, 0, 0], // output_padding additional - 1, // groups - true, // bias + [3, 3], + [1, 1, 1], + [0, 0, 0], + [0, 0, 0], + [0, 0, 0], + [0, 0, 0], + 1, + true, ), )); From 92890e785bc4507f385a1f6af307da742eed7f56 Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Thu, 8 May 2025 15:05:33 -0500 Subject: [PATCH 29/37] Add NodeBuilder and refactor test code --- crates/onnx-ir/src/node/argmax.rs | 46 +-- crates/onnx-ir/src/node/cast.rs | 59 ++-- crates/onnx-ir/src/node/conv2d.rs | 89 ++---- crates/onnx-ir/src/node/gemm.rs | 71 +---- crates/onnx-ir/src/node/matmul.rs | 50 +-- crates/onnx-ir/src/node/mod.rs | 3 + crates/onnx-ir/src/node/reshape.rs | 59 +--- crates/onnx-ir/src/node/resize.rs | 106 ++----- crates/onnx-ir/src/node/softmax.rs | 39 +-- crates/onnx-ir/src/node/test_utils.rs | 432 ++++++++++++++++++++++++++ crates/onnx-ir/src/node/where_op.rs | 61 +--- 11 files changed, 569 insertions(+), 446 deletions(-) create mode 100644 crates/onnx-ir/src/node/test_utils.rs diff --git a/crates/onnx-ir/src/node/argmax.rs b/crates/onnx-ir/src/node/argmax.rs index 8657859681..153eeac7b8 100644 --- a/crates/onnx-ir/src/node/argmax.rs +++ b/crates/onnx-ir/src/node/argmax.rs @@ -79,45 +79,17 @@ pub fn argmax_update_outputs(node: &mut Node) { #[cfg(test)] mod tests { use super::*; - use crate::ir::{Argument, AttributeValue, ElementType, NodeType, TensorType}; - use std::collections::HashMap; + use crate::ir::{Argument, ElementType, NodeType}; + use crate::node::test_utils::NodeBuilder; fn create_test_node(axis: i64, select_last_index: i64, keepdims: i64) -> Node { - let inputs = vec![Argument { - name: "data".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: 3, - static_shape: None, - }), - value: None, - passed: true, - }]; - - let mut attrs = HashMap::new(); - attrs.insert("axis".to_string(), AttributeValue::Int64(axis)); - attrs.insert( - "select_last_index".to_string(), - AttributeValue::Int64(select_last_index), - ); - attrs.insert("keepdims".to_string(), AttributeValue::Int64(keepdims)); - - Node { - node_type: NodeType::ArgMax, - name: "test_argmax".to_string(), - inputs, - outputs: vec![Argument { - name: "output".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Int64, - rank: 3, - static_shape: None, - }), - value: None, - passed: true, - }], - attrs, - } + NodeBuilder::new(NodeType::ArgMax, "test_argmax") + .input_tensor_f32("data", 3, None) + .output_tensor_i64("output", 3, None) + .attr_int("axis", axis) + .attr_int("select_last_index", select_last_index) + .attr_int("keepdims", keepdims) + .build() } #[test] diff --git a/crates/onnx-ir/src/node/cast.rs b/crates/onnx-ir/src/node/cast.rs index afb805c74a..b3dd289de1 100644 --- a/crates/onnx-ir/src/node/cast.rs +++ b/crates/onnx-ir/src/node/cast.rs @@ -48,40 +48,24 @@ pub fn cast_update_outputs(node: &mut Node) { #[cfg(test)] mod tests { use super::*; - use crate::ir::{Argument, NodeType}; - use std::collections::HashMap; + use crate::ir::{Argument, NodeType, TensorType}; + use crate::node::test_utils::NodeBuilder; fn create_test_node(input_rank: usize, to_type: i64) -> Node { - let mut attrs = HashMap::new(); - attrs.insert("to".to_string(), AttributeValue::Int64(to_type)); - - let inputs = vec![Argument { - name: "X".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: input_rank, - static_shape: None, - }), - value: None, - passed: true, - }]; + NodeBuilder::new(NodeType::Cast, "test_cast") + .input_tensor_f32("X", input_rank, None) + .output_tensor_f32("Y", input_rank, None) // Element type will be overwritten + .attr_int("to", to_type) + .build() + } - Node { - node_type: NodeType::Cast, - name: "test_cast".to_string(), - inputs, - outputs: vec![Argument { - name: "Y".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, // This will be overwritten - rank: 0, - static_shape: None, - }), - value: None, - passed: true, - }], - attrs, - } + // Additional test function to demonstrate scalar inputs + fn create_scalar_test_node(to_type: i64) -> Node { + NodeBuilder::new(NodeType::Cast, "test_cast") + .input_scalar_f32("X") + .output_scalar_f32("Y") // Element type will be overwritten + .attr_int("to", to_type) + .build() } #[test] @@ -134,4 +118,17 @@ mod tests { }); cast_update_outputs(&mut node); } + + #[test] + fn test_cast_scalar_to_bool() { + let mut node = create_scalar_test_node(DataType::BOOL.value() as i64); + cast_update_outputs(&mut node); + + match &node.outputs[0].ty { + ArgType::Scalar(elem_type) => { + assert_eq!(*elem_type, ElementType::Bool); + } + _ => panic!("Expected scalar output"), + } + } } diff --git a/crates/onnx-ir/src/node/conv2d.rs b/crates/onnx-ir/src/node/conv2d.rs index 1899546932..961053502e 100644 --- a/crates/onnx-ir/src/node/conv2d.rs +++ b/crates/onnx-ir/src/node/conv2d.rs @@ -92,10 +92,8 @@ pub fn conv2d_config(curr: &Node) -> Conv2dConfig { #[cfg(test)] mod tests { use super::*; - use crate::ir::{ - ArgType, Argument, AttributeValue, Data, ElementType, NodeType, TensorData, TensorType, - }; - use std::collections::HashMap; + use crate::ir::NodeType; + use crate::node::test_utils::NodeBuilder; fn create_test_node( kernel_shape: Vec, @@ -105,73 +103,26 @@ mod tests { group: i64, has_bias: bool, ) -> Node { - let weight_tensor = TensorData { - data: Data::Float32s(vec![0.0; 16]), // Not important for the test - shape: vec![4, 2, 2, 2], // [output_channels, input_channels/groups, k_h, k_w] - }; - - let mut inputs = vec![ - Argument { - name: "data".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: 4, - static_shape: None, - }), - value: None, - passed: true, - }, - Argument { - name: "weight".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: 4, - static_shape: None, - }), - value: Some(weight_tensor), - passed: true, - }, - ]; - + // Weight tensor data - not important for the test + let weight_data = vec![0.0; 16]; + // [output_channels, input_channels/groups, k_h, k_w] + let weight_shape = vec![4, 2, 2, 2]; + + let mut builder = NodeBuilder::new(NodeType::Conv2d, "test_conv2d") + .input_tensor_f32("data", 4, None) + .input_tensor_f32_data("weight", weight_data, weight_shape) + .output_tensor_f32("output", 4, None) + .attr_ints("kernel_shape", kernel_shape) + .attr_ints("strides", strides) + .attr_ints("pads", pads) + .attr_ints("dilations", dilations) + .attr_int("group", group); + if has_bias { - inputs.push(Argument { - name: "bias".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: 1, - static_shape: None, - }), - value: None, - passed: true, - }); - } - - let mut attrs = HashMap::new(); - attrs.insert( - "kernel_shape".to_string(), - AttributeValue::Int64s(kernel_shape), - ); - attrs.insert("strides".to_string(), AttributeValue::Int64s(strides)); - attrs.insert("pads".to_string(), AttributeValue::Int64s(pads)); - attrs.insert("dilations".to_string(), AttributeValue::Int64s(dilations)); - attrs.insert("group".to_string(), AttributeValue::Int64(group)); - - Node { - node_type: NodeType::Conv2d, - name: "test_conv2d".to_string(), - inputs, - outputs: vec![Argument { - name: "output".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: 4, - static_shape: None, - }), - value: None, - passed: true, - }], - attrs, + builder = builder.input_tensor_f32("bias", 1, None); } + + builder.build() } #[test] diff --git a/crates/onnx-ir/src/node/gemm.rs b/crates/onnx-ir/src/node/gemm.rs index 4424b66d16..f939a8f111 100644 --- a/crates/onnx-ir/src/node/gemm.rs +++ b/crates/onnx-ir/src/node/gemm.rs @@ -62,8 +62,8 @@ pub fn gemm_config(curr: &Node) -> (f32, f32, i64, i64) { #[cfg(test)] mod tests { use super::*; - use crate::ir::{ArgType, Argument, AttributeValue, ElementType, NodeType, TensorType}; - use std::collections::HashMap; + use crate::ir::NodeType; + use crate::node::test_utils::NodeBuilder; fn create_test_node( alpha: Option, @@ -71,69 +71,26 @@ mod tests { trans_a: Option, trans_b: Option, ) -> Node { - let inputs = vec![ - Argument { - name: "A".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: 2, - static_shape: None, - }), - value: None, - passed: true, - }, - Argument { - name: "B".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: 2, - static_shape: None, - }), - value: None, - passed: true, - }, - Argument { - name: "C".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: 2, - static_shape: None, - }), - value: None, - passed: true, - }, - ]; - - let mut attrs = HashMap::new(); + let mut builder = NodeBuilder::new(NodeType::Gemm, "test_gemm") + .input_tensor_f32("A", 2, None) + .input_tensor_f32("B", 2, None) + .input_tensor_f32("C", 2, None) + .output_tensor_f32("Y", 2, None); + if let Some(alpha_val) = alpha { - attrs.insert("alpha".to_string(), AttributeValue::Float32(alpha_val)); + builder = builder.attr_float("alpha", alpha_val); } if let Some(beta_val) = beta { - attrs.insert("beta".to_string(), AttributeValue::Float32(beta_val)); + builder = builder.attr_float("beta", beta_val); } if let Some(trans_a_val) = trans_a { - attrs.insert("transA".to_string(), AttributeValue::Int64(trans_a_val)); + builder = builder.attr_int("transA", trans_a_val); } if let Some(trans_b_val) = trans_b { - attrs.insert("transB".to_string(), AttributeValue::Int64(trans_b_val)); - } - - Node { - node_type: NodeType::Gemm, - name: "test_gemm".to_string(), - inputs, - outputs: vec![Argument { - name: "Y".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: 2, - static_shape: None, - }), - value: None, - passed: true, - }], - attrs, + builder = builder.attr_int("transB", trans_b_val); } + + builder.build() } #[test] diff --git a/crates/onnx-ir/src/node/matmul.rs b/crates/onnx-ir/src/node/matmul.rs index 3cfc7de62f..403f5d4942 100644 --- a/crates/onnx-ir/src/node/matmul.rs +++ b/crates/onnx-ir/src/node/matmul.rs @@ -38,51 +38,15 @@ pub fn matmul_update_outputs(node: &mut Node) { #[cfg(test)] mod tests { use super::*; - use crate::ir::{Argument, ElementType, NodeType}; - use std::collections::HashMap; + use crate::ir::{ElementType, NodeType}; + use crate::node::test_utils::NodeBuilder; fn create_test_node(a_rank: usize, b_rank: usize) -> Node { - let inputs = vec![ - Argument { - name: "A".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: a_rank, - static_shape: None, - }), - value: None, - passed: true, - }, - Argument { - name: "B".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: b_rank, - static_shape: None, - }), - value: None, - passed: true, - }, - ]; - - let outputs = vec![Argument { - name: "C".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: 0, // Will be updated - static_shape: None, - }), - value: None, - passed: true, - }]; - - Node { - node_type: NodeType::MatMul, - name: "test_matmul".to_string(), - inputs, - outputs, - attrs: HashMap::new(), - } + NodeBuilder::new(NodeType::MatMul, "test_matmul") + .input_tensor_f32("A", a_rank, None) + .input_tensor_f32("B", b_rank, None) + .output_tensor_f32("C", 0, None) // Rank will be updated + .build() } #[test] diff --git a/crates/onnx-ir/src/node/mod.rs b/crates/onnx-ir/src/node/mod.rs index 917ae45189..f47a06d9a6 100644 --- a/crates/onnx-ir/src/node/mod.rs +++ b/crates/onnx-ir/src/node/mod.rs @@ -7,6 +7,9 @@ //! This modular structure allows for clean separation of operation implementations //! and facilitates easier maintenance and extension of the ONNX operation set. +#[cfg(test)] +pub mod test_utils; + pub mod argmax; pub mod avg_pool1d; pub mod avg_pool2d; diff --git a/crates/onnx-ir/src/node/reshape.rs b/crates/onnx-ir/src/node/reshape.rs index 142ee40dc7..e4d8eda39c 100644 --- a/crates/onnx-ir/src/node/reshape.rs +++ b/crates/onnx-ir/src/node/reshape.rs @@ -82,59 +82,20 @@ pub fn reshape_config(node: &Node) -> Vec { #[cfg(test)] mod tests { use super::*; - use crate::ir::{ - ArgType, Argument, AttributeValue, Data, ElementType, NodeType, TensorData, TensorType, - }; - use std::collections::HashMap; + use crate::ir::{NodeType}; + use crate::node::test_utils::NodeBuilder; fn create_test_node(allowzero: i64, shape_vec: Vec) -> Node { - let inputs = vec![ - Argument { - name: "data".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: 4, - static_shape: None, - }), - value: None, - passed: true, - }, - Argument { - name: "shape".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Int64, - rank: 1, - static_shape: None, - }), - value: Some(TensorData { - data: Data::Int64s(shape_vec), - shape: vec![2], - }), - passed: true, - }, - ]; - - let mut attrs = HashMap::new(); + let mut builder = NodeBuilder::new(NodeType::Reshape, "test_reshape") + .input_tensor_f32("data", 4, None) + .input_tensor_i64_data("shape", shape_vec.clone(), vec![shape_vec.len()]) + .output_tensor_f32("reshaped", 2, None); + if allowzero != 0 { - attrs.insert("allowzero".to_string(), AttributeValue::Int64(allowzero)); - } - - Node { - node_type: NodeType::Reshape, - name: "test_reshape".to_string(), - inputs, - outputs: vec![Argument { - name: "reshaped".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: 2, - static_shape: None, - }), - value: None, - passed: true, - }], - attrs, + builder = builder.attr_int("allowzero", allowzero); } + + builder.build() } #[test] diff --git a/crates/onnx-ir/src/node/resize.rs b/crates/onnx-ir/src/node/resize.rs index 51180dddaa..a1884b707a 100644 --- a/crates/onnx-ir/src/node/resize.rs +++ b/crates/onnx-ir/src/node/resize.rs @@ -131,10 +131,8 @@ pub fn resize_config(node: &Node) -> (String, Vec, Vec) { #[cfg(test)] mod tests { use super::*; - use crate::ir::{ - Argument, AttributeValue, Data, ElementType, NodeType, TensorData, TensorType, - }; - use std::collections::HashMap; + use crate::ir::NodeType; + use crate::node::test_utils::NodeBuilder; fn create_test_node( mode: &str, @@ -142,81 +140,39 @@ mod tests { sizes: Option>, roi: Option>, ) -> Node { - let mut inputs = vec![Argument { - name: "X".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: 4, // N,C,H,W format - static_shape: None, - }), - value: None, - passed: true, - }]; - + let mut builder = NodeBuilder::new(NodeType::Resize, "test_resize") + .input_tensor_f32("X", 4, None) // N,C,H,W format + .output_tensor_f32("Y", 4, None) + .attr_string("mode", mode); + // Add ROI input if provided - inputs.push(Argument { - name: "roi".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: 1, - static_shape: None, - }), - value: roi.map(|data| TensorData { - data: Data::Float32s(data), - shape: vec![8], // For 4D input (start x, start y, end x, end y) - }), - passed: true, - }); - + if let Some(roi_data) = roi { + builder = builder.input_tensor_f32_data("roi", roi_data, vec![8]); + // For 4D input (start x, start y, end x, end y) + } else { + // Empty ROI still needs to be added as a placeholder + builder = builder.input_tensor_f32("roi", 1, None); + } + // Add scales input if provided - inputs.push(Argument { - name: "scales".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: 1, - static_shape: None, - }), - value: scales.map(|data| TensorData { - data: Data::Float32s(data), - shape: vec![4], // N,C,H,W scales - }), - passed: true, - }); - + if let Some(scales_data) = scales { + builder = builder.input_tensor_f32_data("scales", scales_data, vec![4]); + // N,C,H,W scales + } else { + // Empty scales still needs to be added as a placeholder + builder = builder.input_tensor_f32("scales", 1, None); + } + // Add sizes input if provided - inputs.push(Argument { - name: "sizes".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Int64, - rank: 1, - static_shape: None, - }), - value: sizes.map(|data| TensorData { - data: Data::Int64s(data), - shape: vec![4], // N,C,H,W sizes - }), - passed: true, - }); - - let mut attrs = HashMap::new(); - attrs.insert("mode".to_string(), AttributeValue::String(mode.to_string())); - - Node { - node_type: NodeType::Resize, - name: "test_resize".to_string(), - inputs, - outputs: vec![Argument { - name: "Y".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: 4, - static_shape: None, - }), - value: None, - passed: true, - }], - attrs, + if let Some(sizes_data) = sizes { + builder = builder.input_tensor_i64_data("sizes", sizes_data, vec![4]); + // N,C,H,W sizes + } else { + // Empty sizes still needs to be added as a placeholder + builder = builder.input_tensor_i64("sizes", 1, None); } + + builder.build() } #[test] diff --git a/crates/onnx-ir/src/node/softmax.rs b/crates/onnx-ir/src/node/softmax.rs index 14f3a1c37a..bad8bcdec9 100644 --- a/crates/onnx-ir/src/node/softmax.rs +++ b/crates/onnx-ir/src/node/softmax.rs @@ -37,40 +37,15 @@ pub fn softmax_config(node: &Node) -> usize { #[cfg(test)] mod tests { use super::*; - use crate::ir::{Argument, AttributeValue, ElementType, NodeType, TensorType}; - use std::collections::HashMap; + use crate::ir::{Argument, ElementType, NodeType, TensorType}; + use crate::node::test_utils::NodeBuilder; fn create_test_node(axis: i64, input_rank: usize) -> Node { - let inputs = vec![Argument { - name: "data".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: input_rank, - static_shape: None, - }), - value: None, - passed: true, - }]; - - let mut attrs = HashMap::new(); - attrs.insert("axis".to_string(), AttributeValue::Int64(axis)); - - Node { - node_type: NodeType::Softmax, - name: "test_softmax".to_string(), - inputs, - outputs: vec![Argument { - name: "output".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: input_rank, - static_shape: None, - }), - value: None, - passed: true, - }], - attrs, - } + NodeBuilder::new(NodeType::Softmax, "test_softmax") + .input_tensor_f32("data", input_rank, None) + .output_tensor_f32("output", input_rank, None) + .attr_int("axis", axis) + .build() } #[test] diff --git a/crates/onnx-ir/src/node/test_utils.rs b/crates/onnx-ir/src/node/test_utils.rs new file mode 100644 index 0000000000..b24e4b98d3 --- /dev/null +++ b/crates/onnx-ir/src/node/test_utils.rs @@ -0,0 +1,432 @@ +use crate::ir::{ + ArgType, Argument, AttributeValue, Data, ElementType, Node, NodeType, TensorData, TensorType, +}; +use std::collections::HashMap; + +/// Builder for creating test node instances with convenient defaults and simple API. +pub struct NodeBuilder { + node_type: NodeType, + name: String, + inputs: Vec, + outputs: Vec, + attrs: HashMap, +} + +impl NodeBuilder { + /// Create a new builder with the specified node type and name + pub fn new(node_type: NodeType, name: &str) -> Self { + Self { + node_type, + name: name.to_string(), + inputs: Vec::new(), + outputs: Vec::new(), + attrs: HashMap::new(), + } + } + + /// Add a generic input with the given name and type + /// + /// Note: Prefer using the specialized methods like `input_tensor_f32`, + /// `input_scalar_f32`, etc. for better readability and type safety. + #[doc(hidden)] + pub fn add_input(mut self, name: &str, ty: ArgType) -> Self { + self.inputs.push(Argument { + name: name.to_string(), + ty, + value: None, + passed: true, + }); + self + } + + /// Add a float32 tensor input with the given name and rank + pub fn input_tensor_f32( + self, + name: &str, + rank: usize, + static_shape: Option>, + ) -> Self { + self.add_input( + name, + ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank, + static_shape, + }), + ) + } + + /// Add a float64 tensor input with the given name and rank + pub fn input_tensor_f64( + self, + name: &str, + rank: usize, + static_shape: Option>, + ) -> Self { + self.add_input( + name, + ArgType::Tensor(TensorType { + elem_type: ElementType::Float64, + rank, + static_shape, + }), + ) + } + + /// Add an int32 tensor input with the given name and rank + pub fn input_tensor_i32( + self, + name: &str, + rank: usize, + static_shape: Option>, + ) -> Self { + self.add_input( + name, + ArgType::Tensor(TensorType { + elem_type: ElementType::Int32, + rank, + static_shape, + }), + ) + } + + /// Add an int64 tensor input with the given name and rank + pub fn input_tensor_i64( + self, + name: &str, + rank: usize, + static_shape: Option>, + ) -> Self { + self.add_input( + name, + ArgType::Tensor(TensorType { + elem_type: ElementType::Int64, + rank, + static_shape, + }), + ) + } + + /// Add a bool tensor input with the given name and rank + pub fn input_tensor_bool( + self, + name: &str, + rank: usize, + static_shape: Option>, + ) -> Self { + self.add_input( + name, + ArgType::Tensor(TensorType { + elem_type: ElementType::Bool, + rank, + static_shape, + }), + ) + } + + /// Add a float16 tensor input with the given name and rank + pub fn input_tensor_f16( + self, + name: &str, + rank: usize, + static_shape: Option>, + ) -> Self { + self.add_input( + name, + ArgType::Tensor(TensorType { + elem_type: ElementType::Float16, + rank, + static_shape, + }), + ) + } + + /// Add a string tensor input with the given name and rank + pub fn input_tensor_string( + self, + name: &str, + rank: usize, + static_shape: Option>, + ) -> Self { + self.add_input( + name, + ArgType::Tensor(TensorType { + elem_type: ElementType::String, + rank, + static_shape, + }), + ) + } + + /// Add a scalar input with the given name and element type + pub fn input_scalar(self, name: &str, elem_type: ElementType) -> Self { + self.add_input(name, ArgType::Scalar(elem_type)) + } + + /// Add a float32 scalar input with the given name + pub fn input_scalar_f32(self, name: &str) -> Self { + self.input_scalar(name, ElementType::Float32) + } + + /// Add an int64 scalar input with the given name + pub fn input_scalar_i64(self, name: &str) -> Self { + self.input_scalar(name, ElementType::Int64) + } + + /// Add a shape input with the given name and rank + pub fn input_shape(self, name: &str, rank: usize) -> Self { + self.add_input(name, ArgType::Shape(rank)) + } + + /// Add a tensor input with data value + pub fn input_tensor_with_data( + mut self, + name: &str, + elem_type: ElementType, + rank: usize, + data: Data, + shape: Vec, + ) -> Self { + let arg = Argument { + name: name.to_string(), + ty: ArgType::Tensor(TensorType { + elem_type, + rank, + static_shape: None, + }), + value: Some(TensorData { data, shape }), + passed: true, + }; + self.inputs.push(arg); + self + } + + /// Add a float32 tensor input with data values + pub fn input_tensor_f32_data(self, name: &str, data: Vec, shape: Vec) -> Self { + self.input_tensor_with_data( + name, + ElementType::Float32, + shape.len(), + Data::Float32s(data), + shape, + ) + } + + /// Add an int64 tensor input with data values + pub fn input_tensor_i64_data(self, name: &str, data: Vec, shape: Vec) -> Self { + self.input_tensor_with_data( + name, + ElementType::Int64, + shape.len(), + Data::Int64s(data), + shape, + ) + } + + /// Add a generic output with the given name and type + /// + /// Note: Prefer using the specialized methods like `output_tensor_f32`, + /// `output_scalar_f32`, etc. for better readability and type safety. + #[doc(hidden)] + pub fn add_output(mut self, name: &str, ty: ArgType) -> Self { + self.outputs.push(Argument { + name: name.to_string(), + ty, + value: None, + passed: true, + }); + self + } + + /// Add a float32 tensor output with the given name and rank + pub fn output_tensor_f32( + self, + name: &str, + rank: usize, + static_shape: Option>, + ) -> Self { + self.add_output( + name, + ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank, + static_shape, + }), + ) + } + + /// Add a float64 tensor output with the given name and rank + pub fn output_tensor_f64( + self, + name: &str, + rank: usize, + static_shape: Option>, + ) -> Self { + self.add_output( + name, + ArgType::Tensor(TensorType { + elem_type: ElementType::Float64, + rank, + static_shape, + }), + ) + } + + /// Add an int32 tensor output with the given name and rank + pub fn output_tensor_i32( + self, + name: &str, + rank: usize, + static_shape: Option>, + ) -> Self { + self.add_output( + name, + ArgType::Tensor(TensorType { + elem_type: ElementType::Int32, + rank, + static_shape, + }), + ) + } + + /// Add an int64 tensor output with the given name and rank + pub fn output_tensor_i64( + self, + name: &str, + rank: usize, + static_shape: Option>, + ) -> Self { + self.add_output( + name, + ArgType::Tensor(TensorType { + elem_type: ElementType::Int64, + rank, + static_shape, + }), + ) + } + + /// Add a bool tensor output with the given name and rank + pub fn output_tensor_bool( + self, + name: &str, + rank: usize, + static_shape: Option>, + ) -> Self { + self.add_output( + name, + ArgType::Tensor(TensorType { + elem_type: ElementType::Bool, + rank, + static_shape, + }), + ) + } + + /// Add a float16 tensor output with the given name and rank + pub fn output_tensor_f16( + self, + name: &str, + rank: usize, + static_shape: Option>, + ) -> Self { + self.add_output( + name, + ArgType::Tensor(TensorType { + elem_type: ElementType::Float16, + rank, + static_shape, + }), + ) + } + + /// Add a string tensor output with the given name and rank + pub fn output_tensor_string( + self, + name: &str, + rank: usize, + static_shape: Option>, + ) -> Self { + self.add_output( + name, + ArgType::Tensor(TensorType { + elem_type: ElementType::String, + rank, + static_shape, + }), + ) + } + + /// Add a scalar output with the given name and element type + pub fn output_scalar(self, name: &str, elem_type: ElementType) -> Self { + self.add_output(name, ArgType::Scalar(elem_type)) + } + + /// Add a float32 scalar output with the given name + pub fn output_scalar_f32(self, name: &str) -> Self { + self.output_scalar(name, ElementType::Float32) + } + + /// Add an int64 scalar output with the given name + pub fn output_scalar_i64(self, name: &str) -> Self { + self.output_scalar(name, ElementType::Int64) + } + + /// Add a shape output with the given name and rank + pub fn output_shape(self, name: &str, rank: usize) -> Self { + self.add_output(name, ArgType::Shape(rank)) + } + + /// Add an integer attribute + pub fn attr_int(mut self, name: &str, value: i64) -> Self { + self.attrs + .insert(name.to_string(), AttributeValue::Int64(value)); + self + } + + /// Add a float attribute + pub fn attr_float(mut self, name: &str, value: f32) -> Self { + self.attrs + .insert(name.to_string(), AttributeValue::Float32(value)); + self + } + + /// Add a string attribute + pub fn attr_string(mut self, name: &str, value: &str) -> Self { + self.attrs + .insert(name.to_string(), AttributeValue::String(value.to_string())); + self + } + + /// Add an integer array attribute + pub fn attr_ints(mut self, name: &str, values: Vec) -> Self { + self.attrs + .insert(name.to_string(), AttributeValue::Int64s(values)); + self + } + + /// Add a float array attribute + pub fn attr_floats(mut self, name: &str, values: Vec) -> Self { + self.attrs + .insert(name.to_string(), AttributeValue::Float32s(values)); + self + } + + /// Add a string array attribute + pub fn attr_strings(mut self, name: &str, values: Vec) -> Self { + self.attrs + .insert(name.to_string(), AttributeValue::Strings(values)); + self + } + + /// Build the node + pub fn build(self) -> Node { + Node { + node_type: self.node_type, + name: self.name, + inputs: self.inputs, + outputs: self.outputs, + attrs: self.attrs, + } + } +} diff --git a/crates/onnx-ir/src/node/where_op.rs b/crates/onnx-ir/src/node/where_op.rs index 4a9b67767a..0f480ab822 100644 --- a/crates/onnx-ir/src/node/where_op.rs +++ b/crates/onnx-ir/src/node/where_op.rs @@ -51,61 +51,16 @@ pub fn where_update_outputs(node: &mut Node) { #[cfg(test)] mod tests { use super::*; - use crate::ir::{Argument, NodeType}; - use std::collections::HashMap; + use crate::ir::{NodeType, TensorType}; + use crate::node::test_utils::NodeBuilder; fn create_test_node(condition_rank: usize, x_rank: usize, y_rank: usize) -> Node { - let inputs = vec![ - Argument { - name: "condition".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Bool, - rank: condition_rank, - static_shape: None, - }), - value: None, - passed: true, - }, - Argument { - name: "X".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: x_rank, - static_shape: None, - }), - value: None, - passed: true, - }, - Argument { - name: "Y".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: y_rank, - static_shape: None, - }), - value: None, - passed: true, - }, - ]; - - let outputs = vec![Argument { - name: "output".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: 0, // Will be updated - static_shape: None, - }), - value: None, - passed: true, - }]; - - Node { - node_type: NodeType::Where, - name: "test_where".to_string(), - inputs, - outputs, - attrs: HashMap::new(), - } + NodeBuilder::new(NodeType::Where, "test_where") + .input_tensor_bool("condition", condition_rank, None) + .input_tensor_f32("X", x_rank, None) + .input_tensor_f32("Y", y_rank, None) + .output_tensor_f32("output", 0, None) // Rank will be updated + .build() } #[test] From 83112debaecd4e3cdb84ce207d0fe846e3f49705 Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Thu, 8 May 2025 15:32:12 -0500 Subject: [PATCH 30/37] Refactor tests to use NodeBuilder --- crates/onnx-ir/src/node/constant.rs | 28 ++------ crates/onnx-ir/src/node/constant_of_shape.rs | 36 ++--------- crates/onnx-ir/src/node/reduce_sum.rs | 68 +++++--------------- crates/onnx-ir/src/node/test_utils.rs | 7 ++ 4 files changed, 33 insertions(+), 106 deletions(-) diff --git a/crates/onnx-ir/src/node/constant.rs b/crates/onnx-ir/src/node/constant.rs index 3bd48c0c18..2af9aabb04 100644 --- a/crates/onnx-ir/src/node/constant.rs +++ b/crates/onnx-ir/src/node/constant.rs @@ -69,31 +69,13 @@ pub fn constant_update_outputs(node: &mut Node) { #[cfg(test)] mod tests { use super::*; - use crate::ir::{Argument, NodeType, TensorData}; - use std::collections::HashMap; + use crate::ir::{NodeType, TensorData}; + use crate::node::test_utils::NodeBuilder; fn create_test_node() -> Node { - let inputs = vec![]; - - let attrs = HashMap::new(); - // Empty attrs initially - - Node { - node_type: NodeType::Constant, - name: "test_constant".to_string(), - inputs, - outputs: vec![Argument { - name: "output".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, // This will be overwritten - rank: 0, - static_shape: None, - }), - value: None, - passed: true, - }], - attrs, - } + NodeBuilder::new(NodeType::Constant, "test_constant") + .output_tensor_f32("output", 0, None) // This will be overwritten + .build() } #[test] diff --git a/crates/onnx-ir/src/node/constant_of_shape.rs b/crates/onnx-ir/src/node/constant_of_shape.rs index a97c807397..723a788562 100644 --- a/crates/onnx-ir/src/node/constant_of_shape.rs +++ b/crates/onnx-ir/src/node/constant_of_shape.rs @@ -63,38 +63,14 @@ pub fn constant_of_shape_update_output(node: &mut Node) { #[cfg(test)] mod tests { use super::*; - use crate::ir::{Argument, AttributeValue, Data, NodeType, TensorData}; - use std::collections::HashMap; + use crate::ir::{AttributeValue, Data, NodeType, TensorData}; + use crate::node::test_utils::NodeBuilder; fn create_test_node(input_ty: ArgType) -> Node { - let inputs = vec![Argument { - name: "shape".to_string(), - ty: input_ty, - value: None, - passed: true, - }]; - - let outputs = vec![Argument { - name: "output".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, // Will be updated - rank: 0, // Will be updated - static_shape: None, - }), - value: None, - passed: true, - }]; - - let attrs = HashMap::new(); - // Default value attribute not set initially - - Node { - node_type: NodeType::ConstantOfShape, - name: "test_constantofshape".to_string(), - inputs, - outputs, - attrs, - } + NodeBuilder::new(NodeType::ConstantOfShape, "test_constantofshape") + .add_input("shape", input_ty) + .output_tensor_f32("output", 0, None) // Will be updated + .build() } #[test] diff --git a/crates/onnx-ir/src/node/reduce_sum.rs b/crates/onnx-ir/src/node/reduce_sum.rs index 6182d22c09..2fa43cae8e 100644 --- a/crates/onnx-ir/src/node/reduce_sum.rs +++ b/crates/onnx-ir/src/node/reduce_sum.rs @@ -94,72 +94,34 @@ pub fn reduce_sum_update_outputs(node: &mut Node) { #[cfg(test)] mod tests { use super::*; - use crate::ir::{ - Argument, AttributeValue, Data, ElementType, NodeType, TensorData, TensorType, - }; - use std::collections::HashMap; + use crate::ir::{NodeType}; + use crate::node::test_utils::NodeBuilder; fn create_test_node( axes: Option>, keepdims: Option, with_axes_input: bool, ) -> Node { - let mut inputs = vec![Argument { - name: "data".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: 3, - static_shape: None, - }), - value: None, - passed: true, - }]; - + let mut builder = NodeBuilder::new(NodeType::ReduceSum, "test_reduce_sum") + .input_tensor_f32("data", 3, None) + .output_tensor_f32("reduced", 3, None); + // Add axes input if requested if with_axes_input && axes.is_some() { - let axes_clone = axes.clone().unwrap(); - inputs.push(Argument { - name: "axes".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Int64, - rank: 1, - static_shape: None, - }), - value: Some(TensorData { - data: Data::Int64s(axes_clone.clone()), - shape: vec![axes_clone.len()], - }), - passed: true, - }); + let axes_vec = axes.clone().unwrap(); + builder = builder.input_tensor_i64_data("axes", axes_vec.clone(), vec![axes_vec.len()]); } - - let mut attrs = HashMap::new(); + + // Add attributes if !with_axes_input && axes.is_some() { - attrs.insert( - "axes".to_string(), - AttributeValue::Int64s(axes.clone().unwrap()), - ); + builder = builder.attr_ints("axes", axes.clone().unwrap()); } + if let Some(kd) = keepdims { - attrs.insert("keepdims".to_string(), AttributeValue::Int64(kd)); - } - - Node { - node_type: NodeType::ReduceSum, - name: "test_reduce_sum".to_string(), - inputs, - outputs: vec![Argument { - name: "reduced".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: 3, - static_shape: None, - }), - value: None, - passed: true, - }], - attrs, + builder = builder.attr_int("keepdims", kd); } + + builder.build() } #[test] diff --git a/crates/onnx-ir/src/node/test_utils.rs b/crates/onnx-ir/src/node/test_utils.rs index b24e4b98d3..811b0d56b1 100644 --- a/crates/onnx-ir/src/node/test_utils.rs +++ b/crates/onnx-ir/src/node/test_utils.rs @@ -418,6 +418,13 @@ impl NodeBuilder { .insert(name.to_string(), AttributeValue::Strings(values)); self } + + /// Add a tensor attribute + pub fn attr_tensor(mut self, name: &str, tensor: TensorData) -> Self { + self.attrs + .insert(name.to_string(), AttributeValue::Tensor(tensor)); + self + } /// Build the node pub fn build(self) -> Node { From 6dfe2c00ffa68a19f516111e95058d83d3545f0d Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Thu, 8 May 2025 16:03:16 -0500 Subject: [PATCH 31/37] Refactored tests to use NodeBuilder --- crates/onnx-ir/src/node/avg_pool1d.rs | 53 ++------- crates/onnx-ir/src/node/avg_pool2d.rs | 53 ++------- crates/onnx-ir/src/node/batch_norm.rs | 112 +++--------------- crates/onnx-ir/src/node/clip.rs | 109 +++--------------- crates/onnx-ir/src/node/comparison.rs | 50 ++------ crates/onnx-ir/src/node/concat.rs | 44 ++----- crates/onnx-ir/src/node/conv2d.rs | 2 +- crates/onnx-ir/src/node/max_pool1d.rs | 48 ++------ crates/onnx-ir/src/node/max_pool2d.rs | 48 ++------ crates/onnx-ir/src/node/one_hot.rs | 73 ++---------- crates/onnx-ir/src/node/pad.rs | 91 ++++----------- crates/onnx-ir/src/node/resize.rs | 6 +- crates/onnx-ir/src/node/slice.rs | 159 ++++---------------------- crates/onnx-ir/src/node/split.rs | 70 +++++------- crates/onnx-ir/src/node/test_utils.rs | 73 ++++++++++++ crates/onnx-ir/src/node/tile.rs | 59 +++------- crates/onnx-ir/src/node/unsqueeze.rs | 109 +++++------------- 17 files changed, 287 insertions(+), 872 deletions(-) diff --git a/crates/onnx-ir/src/node/avg_pool1d.rs b/crates/onnx-ir/src/node/avg_pool1d.rs index af5bc6dabe..99f58e49bb 100644 --- a/crates/onnx-ir/src/node/avg_pool1d.rs +++ b/crates/onnx-ir/src/node/avg_pool1d.rs @@ -75,8 +75,8 @@ pub fn avg_pool1d_config(curr: &Node) -> AvgPool1dConfig { #[cfg(test)] mod tests { use super::*; - use crate::ir::{ArgType, Argument, AttributeValue, ElementType, NodeType, TensorType}; - use std::collections::HashMap; + use crate::ir::NodeType; + use crate::node::test_utils::NodeBuilder; fn create_test_node( kernel_shape: Vec, @@ -85,46 +85,15 @@ mod tests { count_include_pad: i64, ceil_mode: i64, ) -> Node { - let inputs = vec![Argument { - name: "data".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: 3, - static_shape: None, - }), - value: None, - passed: true, - }]; - - let mut attrs = HashMap::new(); - attrs.insert( - "kernel_shape".to_string(), - AttributeValue::Int64s(kernel_shape), - ); - attrs.insert("strides".to_string(), AttributeValue::Int64s(strides)); - attrs.insert("pads".to_string(), AttributeValue::Int64s(pads)); - attrs.insert( - "count_include_pad".to_string(), - AttributeValue::Int64(count_include_pad), - ); - attrs.insert("ceil_mode".to_string(), AttributeValue::Int64(ceil_mode)); - - Node { - node_type: NodeType::AveragePool1d, - name: "test_avgpool1d".to_string(), - inputs, - outputs: vec![Argument { - name: "output".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: 3, - static_shape: None, - }), - value: None, - passed: true, - }], - attrs, - } + NodeBuilder::new(NodeType::AveragePool1d, "test_avgpool1d") + .input_tensor_f32("data", 3, None) + .output_tensor_f32("output", 3, None) + .attr_ints("kernel_shape", kernel_shape) + .attr_ints("strides", strides) + .attr_ints("pads", pads) + .attr_int("count_include_pad", count_include_pad) + .attr_int("ceil_mode", ceil_mode) + .build() } #[test] diff --git a/crates/onnx-ir/src/node/avg_pool2d.rs b/crates/onnx-ir/src/node/avg_pool2d.rs index 4e95cc8228..58644ca237 100644 --- a/crates/onnx-ir/src/node/avg_pool2d.rs +++ b/crates/onnx-ir/src/node/avg_pool2d.rs @@ -67,8 +67,8 @@ pub fn avg_pool2d_config(curr: &Node) -> AvgPool2dConfig { #[cfg(test)] mod tests { use super::*; - use crate::ir::{ArgType, Argument, AttributeValue, ElementType, NodeType, TensorType}; - use std::collections::HashMap; + use crate::ir::NodeType; + use crate::node::test_utils::NodeBuilder; fn create_test_node( kernel_shape: Vec, @@ -77,46 +77,15 @@ mod tests { count_include_pad: i64, ceil_mode: i64, ) -> Node { - let inputs = vec![Argument { - name: "data".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: 4, - static_shape: None, - }), - value: None, - passed: true, - }]; - - let mut attrs = HashMap::new(); - attrs.insert( - "kernel_shape".to_string(), - AttributeValue::Int64s(kernel_shape), - ); - attrs.insert("strides".to_string(), AttributeValue::Int64s(strides)); - attrs.insert("pads".to_string(), AttributeValue::Int64s(pads)); - attrs.insert( - "count_include_pad".to_string(), - AttributeValue::Int64(count_include_pad), - ); - attrs.insert("ceil_mode".to_string(), AttributeValue::Int64(ceil_mode)); - - Node { - node_type: NodeType::AveragePool2d, - name: "test_avgpool2d".to_string(), - inputs, - outputs: vec![Argument { - name: "output".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: 4, - static_shape: None, - }), - value: None, - passed: true, - }], - attrs, - } + NodeBuilder::new(NodeType::AveragePool2d, "test_avgpool2d") + .input_tensor_f32("data", 4, None) + .output_tensor_f32("output", 4, None) + .attr_ints("kernel_shape", kernel_shape) + .attr_ints("strides", strides) + .attr_ints("pads", pads) + .attr_int("count_include_pad", count_include_pad) + .attr_int("ceil_mode", ceil_mode) + .build() } #[test] diff --git a/crates/onnx-ir/src/node/batch_norm.rs b/crates/onnx-ir/src/node/batch_norm.rs index c3e83ff3d4..0cef6c14e5 100644 --- a/crates/onnx-ir/src/node/batch_norm.rs +++ b/crates/onnx-ir/src/node/batch_norm.rs @@ -50,105 +50,23 @@ pub fn batch_norm_config(node: &Node) -> BatchNormConfig { #[cfg(test)] mod tests { use super::*; - use crate::ir::{ - ArgType, Argument, AttributeValue, Data, ElementType, NodeType, TensorData, TensorType, - }; - use std::collections::HashMap; + use crate::ir::NodeType; + use crate::node::test_utils::NodeBuilder; fn create_test_node(epsilon: f32, momentum: f32, num_features: usize) -> Node { - let weight_tensor = TensorData { - data: Data::Float32s(vec![1.0; num_features]), // Not important for the test - shape: vec![num_features], - }; - - let bias_tensor = TensorData { - data: Data::Float32s(vec![0.0; num_features]), // Not important for the test - shape: vec![num_features], - }; - - let mean_tensor = TensorData { - data: Data::Float32s(vec![0.0; num_features]), // Not important for the test - shape: vec![num_features], - }; - - let var_tensor = TensorData { - data: Data::Float32s(vec![1.0; num_features]), // Not important for the test - shape: vec![num_features], - }; - - let inputs = vec![ - Argument { - name: "X".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: 4, // NCHW format - static_shape: None, - }), - value: None, - passed: true, - }, - Argument { - name: "scale".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: 1, - static_shape: None, - }), - value: Some(weight_tensor), - passed: true, - }, - Argument { - name: "bias".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: 1, - static_shape: None, - }), - value: Some(bias_tensor), - passed: true, - }, - Argument { - name: "mean".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: 1, - static_shape: None, - }), - value: Some(mean_tensor), - passed: true, - }, - Argument { - name: "var".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: 1, - static_shape: None, - }), - value: Some(var_tensor), - passed: true, - }, - ]; - - let mut attrs = HashMap::new(); - attrs.insert("epsilon".to_string(), AttributeValue::Float32(epsilon)); - attrs.insert("momentum".to_string(), AttributeValue::Float32(momentum)); - - Node { - node_type: NodeType::BatchNormalization, - name: "test_batchnorm".to_string(), - inputs, - outputs: vec![Argument { - name: "output".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: 4, - static_shape: None, - }), - value: None, - passed: true, - }], - attrs, - } + let ones = vec![1.0; num_features]; + let zeros = vec![0.0; num_features]; + + NodeBuilder::new(NodeType::BatchNormalization, "test_batchnorm") + .input_tensor_f32("X", 4, None) // NCHW format + .input_tensor_f32_data("scale", ones.clone(), vec![num_features]) + .input_tensor_f32_data("bias", zeros.clone(), vec![num_features]) + .input_tensor_f32_data("mean", zeros.clone(), vec![num_features]) + .input_tensor_f32_data("var", ones.clone(), vec![num_features]) + .output_tensor_f32("output", 4, None) + .attr_float("epsilon", epsilon) + .attr_float("momentum", momentum) + .build() } #[test] diff --git a/crates/onnx-ir/src/node/clip.rs b/crates/onnx-ir/src/node/clip.rs index c6f010578c..2f5bbedbe1 100644 --- a/crates/onnx-ir/src/node/clip.rs +++ b/crates/onnx-ir/src/node/clip.rs @@ -56,107 +56,32 @@ pub fn clip_config(node: &Node) -> (Option, Option) { #[cfg(test)] mod tests { use super::*; - use crate::ir::{ - ArgType, Argument, AttributeValue, Data, ElementType, NodeType, TensorData, TensorType, - }; - use std::collections::HashMap; + use crate::ir::NodeType; + use crate::node::test_utils::NodeBuilder; fn create_test_node_with_attributes(min: Option, max: Option) -> Node { - let inputs = vec![Argument { - name: "X".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: 4, - static_shape: None, - }), - value: None, - passed: true, - }]; - - let mut attrs = HashMap::new(); + let mut builder = NodeBuilder::new(NodeType::Clip, "test_clip") + .input_tensor_f32("X", 4, None) + .output_tensor_f32("Y", 4, None); + if let Some(min_val) = min { - attrs.insert("min".to_string(), AttributeValue::Float32(min_val)); + builder = builder.attr_float("min", min_val); } + if let Some(max_val) = max { - attrs.insert("max".to_string(), AttributeValue::Float32(max_val)); - } - - Node { - node_type: NodeType::Clip, - name: "test_clip".to_string(), - inputs, - outputs: vec![Argument { - name: "Y".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: 4, - static_shape: None, - }), - value: None, - passed: true, - }], - attrs, + builder = builder.attr_float("max", max_val); } + + builder.build() } fn create_test_node_with_inputs(min: Option, max: Option) -> Node { - let mut inputs = vec![Argument { - name: "X".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: 4, - static_shape: None, - }), - value: None, - passed: true, - }]; - - // Add min input - inputs.push(Argument { - name: "min".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: 0, - static_shape: None, - }), - value: min.map(|val| TensorData { - data: Data::Float32(val), - shape: vec![], - }), - passed: true, - }); - - // Add max input - inputs.push(Argument { - name: "max".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: 0, - static_shape: None, - }), - value: max.map(|val| TensorData { - data: Data::Float32(val), - shape: vec![], - }), - passed: true, - }); - - Node { - node_type: NodeType::Clip, - name: "test_clip".to_string(), - inputs, - outputs: vec![Argument { - name: "Y".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: 4, - static_shape: None, - }), - value: None, - passed: true, - }], - attrs: HashMap::new(), - } + NodeBuilder::new(NodeType::Clip, "test_clip") + .input_tensor_f32("X", 4, None) + .input_scalar_tensor_f32("min", min) + .input_scalar_tensor_f32("max", max) + .output_tensor_f32("Y", 4, None) + .build() } #[test] diff --git a/crates/onnx-ir/src/node/comparison.rs b/crates/onnx-ir/src/node/comparison.rs index 37bbb91bfc..fca4cf0621 100644 --- a/crates/onnx-ir/src/node/comparison.rs +++ b/crates/onnx-ir/src/node/comparison.rs @@ -32,51 +32,15 @@ pub fn elementwise_comparison_outputs(node: &mut Node) { #[cfg(test)] mod tests { use super::*; - use crate::ir::{Argument, NodeType}; - use std::collections::HashMap; + use crate::ir::NodeType; + use crate::node::test_utils::NodeBuilder; fn create_test_node(input1_rank: usize, input2_rank: usize) -> Node { - let inputs = vec![ - Argument { - name: "A".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: input1_rank, - static_shape: None, - }), - value: None, - passed: true, - }, - Argument { - name: "B".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: input2_rank, - static_shape: None, - }), - value: None, - passed: true, - }, - ]; - - let outputs = vec![Argument { - name: "result".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Bool, - rank: 0, // Will be updated - static_shape: None, - }), - value: None, - passed: true, - }]; - - Node { - node_type: NodeType::Equal, - name: "test_comparison".to_string(), - inputs, - outputs, - attrs: HashMap::new(), - } + NodeBuilder::new(NodeType::Equal, "test_comparison") + .input_tensor_f32("A", input1_rank, None) + .input_tensor_f32("B", input2_rank, None) + .output_tensor_bool("result", 0, None) // rank will be updated + .build() } #[test] diff --git a/crates/onnx-ir/src/node/concat.rs b/crates/onnx-ir/src/node/concat.rs index b2b4c4b73e..71a4ee328c 100644 --- a/crates/onnx-ir/src/node/concat.rs +++ b/crates/onnx-ir/src/node/concat.rs @@ -53,45 +53,15 @@ pub fn concat_config(node: &Node) -> usize { #[cfg(test)] mod tests { use super::*; - use crate::ir::{Argument, AttributeValue, ElementType, NodeType, TensorType}; - use std::collections::HashMap; + use crate::ir::NodeType; + use crate::node::test_utils::NodeBuilder; fn create_test_node(axis: i64, input_rank: usize, num_inputs: usize) -> Node { - let mut inputs = Vec::new(); - - // Create multiple inputs for concat - for i in 0..num_inputs { - inputs.push(Argument { - name: format!("data_{}", i), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: input_rank, - static_shape: None, - }), - value: None, - passed: true, - }); - } - - let mut attrs = HashMap::new(); - attrs.insert("axis".to_string(), AttributeValue::Int64(axis)); - - Node { - node_type: NodeType::Concat, - name: "test_concat".to_string(), - inputs, - outputs: vec![Argument { - name: "output".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: input_rank, - static_shape: None, - }), - value: None, - passed: true, - }], - attrs, - } + NodeBuilder::new(NodeType::Concat, "test_concat") + .input_tensors_f32::>("data", num_inputs, input_rank, None) + .output_tensor_f32("output", input_rank, None) + .attr_int("axis", axis) + .build() } #[test] diff --git a/crates/onnx-ir/src/node/conv2d.rs b/crates/onnx-ir/src/node/conv2d.rs index 961053502e..e783c7d973 100644 --- a/crates/onnx-ir/src/node/conv2d.rs +++ b/crates/onnx-ir/src/node/conv2d.rs @@ -110,7 +110,7 @@ mod tests { let mut builder = NodeBuilder::new(NodeType::Conv2d, "test_conv2d") .input_tensor_f32("data", 4, None) - .input_tensor_f32_data("weight", weight_data, weight_shape) + .input_tensor_f32_data("weight", weight_data.clone(), weight_shape) .output_tensor_f32("output", 4, None) .attr_ints("kernel_shape", kernel_shape) .attr_ints("strides", strides) diff --git a/crates/onnx-ir/src/node/max_pool1d.rs b/crates/onnx-ir/src/node/max_pool1d.rs index 3523cf91f9..5538aa2fc6 100644 --- a/crates/onnx-ir/src/node/max_pool1d.rs +++ b/crates/onnx-ir/src/node/max_pool1d.rs @@ -84,10 +84,10 @@ pub fn max_pool1d_config(curr: &Node) -> MaxPool1dConfig { mod tests { use super::*; use crate::{ - ir::{ArgType, Argument, AttributeValue, ElementType, NodeType, TensorType}, + ir::NodeType, node::padding::PaddingConfig1d, + node::test_utils::NodeBuilder, }; - use std::collections::HashMap; fn create_test_node( kernel_shape: Vec, @@ -95,42 +95,14 @@ mod tests { pads: Vec, dilation: Vec, ) -> Node { - let inputs = vec![Argument { - name: "data".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: 3, - static_shape: None, - }), - value: None, - passed: true, - }]; - - let mut attrs = HashMap::new(); - attrs.insert( - "kernel_shape".to_string(), - AttributeValue::Int64s(kernel_shape), - ); - attrs.insert("strides".to_string(), AttributeValue::Int64s(stride)); - attrs.insert("pads".to_string(), AttributeValue::Int64s(pads)); - attrs.insert("dilations".to_string(), AttributeValue::Int64s(dilation)); - - Node { - node_type: NodeType::MaxPool1d, - name: "test_maxpool1d".to_string(), - inputs, - outputs: vec![Argument { - name: "output".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: 3, - static_shape: None, - }), - value: None, - passed: true, - }], - attrs, - } + NodeBuilder::new(NodeType::MaxPool1d, "test_maxpool1d") + .input_tensor_f32("data", 3, None) + .output_tensor_f32("output", 3, None) + .attr_ints("kernel_shape", kernel_shape) + .attr_ints("strides", stride) + .attr_ints("pads", pads) + .attr_ints("dilations", dilation) + .build() } #[test] diff --git a/crates/onnx-ir/src/node/max_pool2d.rs b/crates/onnx-ir/src/node/max_pool2d.rs index 62c3c71971..f77cae0c83 100644 --- a/crates/onnx-ir/src/node/max_pool2d.rs +++ b/crates/onnx-ir/src/node/max_pool2d.rs @@ -72,8 +72,8 @@ pub fn max_pool2d_config(curr: &Node) -> MaxPool2dConfig { #[cfg(test)] mod tests { use super::*; - use crate::ir::{ArgType, Argument, AttributeValue, ElementType, NodeType, TensorType}; - use std::collections::HashMap; + use crate::ir::NodeType; + use crate::node::test_utils::NodeBuilder; fn create_test_node( kernel_shape: Vec, @@ -81,42 +81,14 @@ mod tests { pads: Vec, dilations: Vec, ) -> Node { - let inputs = vec![Argument { - name: "data".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: 4, - static_shape: None, - }), - value: None, - passed: true, - }]; - - let mut attrs = HashMap::new(); - attrs.insert( - "kernel_shape".to_string(), - AttributeValue::Int64s(kernel_shape), - ); - attrs.insert("strides".to_string(), AttributeValue::Int64s(strides)); - attrs.insert("pads".to_string(), AttributeValue::Int64s(pads)); - attrs.insert("dilations".to_string(), AttributeValue::Int64s(dilations)); - - Node { - node_type: NodeType::MaxPool2d, - name: "test_maxpool2d".to_string(), - inputs, - outputs: vec![Argument { - name: "output".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: 4, - static_shape: None, - }), - value: None, - passed: true, - }], - attrs, - } + NodeBuilder::new(NodeType::MaxPool2d, "test_maxpool2d") + .input_tensor_f32("data", 4, None) + .output_tensor_f32("output", 4, None) + .attr_ints("kernel_shape", kernel_shape) + .attr_ints("strides", strides) + .attr_ints("pads", pads) + .attr_ints("dilations", dilations) + .build() } #[test] diff --git a/crates/onnx-ir/src/node/one_hot.rs b/crates/onnx-ir/src/node/one_hot.rs index 2b01bee530..598122e075 100644 --- a/crates/onnx-ir/src/node/one_hot.rs +++ b/crates/onnx-ir/src/node/one_hot.rs @@ -47,72 +47,21 @@ pub fn one_hot_output_shape(node: &mut Node) { #[cfg(test)] mod tests { use super::*; - use crate::ir::{ - ArgType, Argument, AttributeValue, Data, ElementType, NodeType, TensorData, TensorType, - }; - use std::collections::HashMap; + use crate::ir::NodeType; + use crate::node::test_utils::NodeBuilder; fn create_test_node(depth: i64, values: Vec, axis: Option) -> Node { - let inputs = vec![ - Argument { - name: "indices".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Int64, - rank: 2, - static_shape: None, - }), - value: None, - passed: true, - }, - Argument { - name: "depth".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Int64, - rank: 0, - static_shape: None, - }), - value: Some(TensorData { - data: Data::Int64(depth), - shape: vec![], - }), - passed: true, - }, - Argument { - name: "values".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: 1, - static_shape: None, - }), - value: Some(TensorData { - data: Data::Float32s(values), - shape: vec![2], // always [off_value, on_value] - }), - passed: true, - }, - ]; - - let mut attrs = HashMap::new(); + let mut builder = NodeBuilder::new(NodeType::OneHot, "test_one_hot") + .input_tensor_i64("indices", 2, None) + .input_scalar_tensor_i64("depth", depth) + .input_tensor_f32_data("values", values.clone(), vec![2]) // always [off_value, on_value] + .output_tensor_f32("output", 3, None); // rank increases by 1 + if let Some(axis_val) = axis { - attrs.insert("axis".to_string(), AttributeValue::Int64(axis_val)); - } - - Node { - node_type: NodeType::OneHot, - name: "test_one_hot".to_string(), - inputs, - outputs: vec![Argument { - name: "output".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: 3, // rank increases by 1 - static_shape: None, - }), - value: None, - passed: true, - }], - attrs, + builder = builder.attr_int("axis", axis_val); } + + builder.build() } #[test] diff --git a/crates/onnx-ir/src/node/pad.rs b/crates/onnx-ir/src/node/pad.rs index 3335464f40..2a72125f02 100644 --- a/crates/onnx-ir/src/node/pad.rs +++ b/crates/onnx-ir/src/node/pad.rs @@ -146,10 +146,8 @@ pub fn pad_config(node: &Node) -> PadConfig { #[cfg(test)] mod tests { use super::*; - use crate::ir::{ - Argument, AttributeValue, Data, ElementType, NodeType, TensorData, TensorType, - }; - use std::collections::HashMap; + use crate::ir::{ArgType, Argument, Data, ElementType, NodeType, TensorData, TensorType}; + use crate::node::test_utils::NodeBuilder; fn create_test_node( pad_attrs: Option>, @@ -159,81 +157,34 @@ mod tests { mode: Option<&str>, rank: usize, ) -> Node { - let mut inputs = vec![Argument { - name: "data".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank, - static_shape: None, - }), - value: None, - passed: true, - }]; - - // Add pads input if provided - if let Some(pads) = pad_inputs { - inputs.push(Argument { - name: "pads".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Int64, - rank: 1, - static_shape: None, - }), - value: Some(TensorData { - data: Data::Int64s(pads), - shape: vec![], - }), - passed: true, - }); + let mut builder = NodeBuilder::new(NodeType::Pad, "test_pad") + .input_tensor_f32("data", rank, None) + .output_tensor_f32("output", rank, None); + + // Add pad inputs if provided + if let Some(pads) = pad_inputs.clone() { + builder = builder.input_tensor_i64_data("pads", pads, vec![]); } - - // Add constant_value input if provided + + // Add constant value input if provided if let Some(value) = constant_value_input { - inputs.push(Argument { - name: "constant_value".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: 0, - static_shape: None, - }), - value: Some(TensorData { - data: Data::Float32(value), - shape: vec![], - }), - passed: true, - }); + builder = builder.input_scalar_tensor_f32("constant_value", Some(value)); } - - let mut attrs = HashMap::new(); + + // Add attributes if provided if let Some(pads) = pad_attrs { - attrs.insert("pads".to_string(), AttributeValue::Int64s(pads)); + builder = builder.attr_ints("pads", pads); } + if let Some(value) = constant_value_attr { - attrs.insert("value".to_string(), AttributeValue::Float32(value)); + builder = builder.attr_float("value", value); } + if let Some(mode_val) = mode { - attrs.insert( - "mode".to_string(), - AttributeValue::String(mode_val.to_string()), - ); - } - - Node { - node_type: NodeType::Pad, - name: "test_pad".to_string(), - inputs, - outputs: vec![Argument { - name: "output".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank, - static_shape: None, - }), - value: None, - passed: true, - }], - attrs, + builder = builder.attr_string("mode", mode_val); } + + builder.build() } #[test] diff --git a/crates/onnx-ir/src/node/resize.rs b/crates/onnx-ir/src/node/resize.rs index a1884b707a..4d46081be0 100644 --- a/crates/onnx-ir/src/node/resize.rs +++ b/crates/onnx-ir/src/node/resize.rs @@ -147,7 +147,7 @@ mod tests { // Add ROI input if provided if let Some(roi_data) = roi { - builder = builder.input_tensor_f32_data("roi", roi_data, vec![8]); + builder = builder.input_tensor_f32_data("roi", roi_data.clone(), vec![8]); // For 4D input (start x, start y, end x, end y) } else { // Empty ROI still needs to be added as a placeholder @@ -156,7 +156,7 @@ mod tests { // Add scales input if provided if let Some(scales_data) = scales { - builder = builder.input_tensor_f32_data("scales", scales_data, vec![4]); + builder = builder.input_tensor_f32_data("scales", scales_data.clone(), vec![4]); // N,C,H,W scales } else { // Empty scales still needs to be added as a placeholder @@ -165,7 +165,7 @@ mod tests { // Add sizes input if provided if let Some(sizes_data) = sizes { - builder = builder.input_tensor_i64_data("sizes", sizes_data, vec![4]); + builder = builder.input_tensor_i64_data("sizes", sizes_data.clone(), vec![4]); // N,C,H,W sizes } else { // Empty sizes still needs to be added as a placeholder diff --git a/crates/onnx-ir/src/node/slice.rs b/crates/onnx-ir/src/node/slice.rs index 1424095be7..f9169c72a9 100644 --- a/crates/onnx-ir/src/node/slice.rs +++ b/crates/onnx-ir/src/node/slice.rs @@ -131,9 +131,8 @@ pub fn slice_update_output_rank(node: &mut Node) { #[cfg(test)] mod tests { - use std::collections::HashMap; - - use crate::ir::{Argument, AttributeValue, ElementType, NodeType, TensorType}; + use crate::ir::{ElementType, NodeType}; + use crate::node::test_utils::NodeBuilder; use super::*; @@ -143,151 +142,39 @@ mod tests { axes: Option>, use_attrs: bool, ) -> Node { - let mut inputs = vec![Argument { - name: "data".to_string(), - ty: crate::ir::ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: 3, - static_shape: None, - }), - value: None, - passed: true, - }]; + let mut builder = NodeBuilder::new(NodeType::Slice, "test_slice") + .input_tensor_f32("data", 3, None) + .output_default("output"); if !use_attrs { // Add inputs as tensors - inputs.push(Argument { - name: "starts".to_string(), - ty: crate::ir::ArgType::Tensor(TensorType { - elem_type: ElementType::Int64, - rank: 1, - static_shape: Some(vec![starts.len()]), - }), - value: Some(TensorData { - data: Data::Int64s(starts.clone()), - shape: vec![starts.len()], - }), - passed: true, - }); - - inputs.push(Argument { - name: "ends".to_string(), - ty: crate::ir::ArgType::Tensor(TensorType { - elem_type: ElementType::Int64, - rank: 1, - static_shape: Some(vec![ends.len()]), - }), - value: Some(TensorData { - data: Data::Int64s(ends.clone()), - shape: vec![ends.len()], - }), - passed: true, - }); + builder = builder.input_tensor_i64_data("starts", starts.clone(), vec![starts.len()]); + builder = builder.input_tensor_i64_data("ends", ends.clone(), vec![ends.len()]); - if let Some(axes_vec) = &axes { - inputs.push(Argument { - name: "axes".to_string(), - ty: crate::ir::ArgType::Tensor(TensorType { - elem_type: ElementType::Int64, - rank: 1, - static_shape: Some(vec![axes_vec.len()]), - }), - value: Some(TensorData { - data: Data::Int64s(axes_vec.clone()), - shape: vec![axes_vec.len()], - }), - passed: true, - }); + if let Some(axes_vec) = axes.clone() { + builder = builder.input_tensor_i64_data("axes", axes_vec.clone(), vec![axes_vec.len()]); } - } - - let mut attrs = HashMap::new(); - if use_attrs { - attrs.insert("starts".to_string(), AttributeValue::Int64s(starts)); - attrs.insert("ends".to_string(), AttributeValue::Int64s(ends)); + } else { + // Add attributes + builder = builder.attr_ints("starts", starts); + builder = builder.attr_ints("ends", ends); + if let Some(axes_vec) = axes { - attrs.insert("axes".to_string(), AttributeValue::Int64s(axes_vec)); + builder = builder.attr_ints("axes", axes_vec); } } - Node { - node_type: NodeType::Slice, - name: "test_slice".to_string(), - inputs, - outputs: vec![Argument { - name: "output".to_string(), - ty: ArgType::default(), - value: None, - passed: true, - }], - attrs, - } + builder.build() } fn create_shape_input_node(start: i64, end: i64) -> Node { - let mut node = Node { - node_type: NodeType::Slice, - name: "test_slice_shape".to_string(), - inputs: vec![Argument { - name: "data".to_string(), - ty: ArgType::Shape(5), // 1-dimensional shape (important: matches what the tests expect) - value: None, - passed: true, - }], - outputs: vec![Argument { - name: "output".to_string(), - ty: ArgType::default(), - value: None, - passed: true, - }], - attrs: HashMap::new(), - }; - - // Add starts and ends as tensors - node.inputs.push(Argument { - name: "starts".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Int64, - rank: 1, - static_shape: Some(vec![1]), - }), - value: Some(TensorData { - data: Data::Int64s(vec![start]), - shape: vec![1], - }), - passed: true, - }); - - node.inputs.push(Argument { - name: "ends".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Int64, - rank: 1, - static_shape: Some(vec![1]), - }), - value: Some(TensorData { - data: Data::Int64s(vec![end]), - shape: vec![1], - }), - passed: true, - }); - - // Add axes tensor to specify dimension 0 - node.inputs.push(Argument { - name: "axes".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Int64, - rank: 1, - static_shape: Some(vec![1]), - }), - value: Some(TensorData { - data: Data::Int64s(vec![0]), - shape: vec![1], - }), - passed: true, - }); - - node + NodeBuilder::new(NodeType::Slice, "test_slice_shape") + .input_shape("data", 5) + .input_tensor_i64_data("starts", vec![start], vec![1]) + .input_tensor_i64_data("ends", vec![end], vec![1]) + .input_tensor_i64_data("axes", vec![0], vec![1]) + .output_default("output") + .build() } #[test] diff --git a/crates/onnx-ir/src/node/split.rs b/crates/onnx-ir/src/node/split.rs index 1ba20c9eaf..221686abac 100644 --- a/crates/onnx-ir/src/node/split.rs +++ b/crates/onnx-ir/src/node/split.rs @@ -157,7 +157,8 @@ pub fn split_config(node: &Node) -> SplitConfig { #[cfg(test)] mod tests { use super::*; - use crate::ir::{Argument, AttributeValue, Data, ElementType, NodeType, TensorData}; + use crate::ir::{AttributeValue, ElementType, NodeType, ArgType}; + use crate::node::test_utils::NodeBuilder; use std::collections::HashMap; fn create_test_node( @@ -167,55 +168,36 @@ mod tests { attrs: Option>, split_sizes_input: Option>, ) -> Node { - let mut inputs = vec![Argument { - name: "input".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: input_rank, - static_shape, - }), - value: None, - passed: true, - }]; - + // Start with input tensor + let mut builder = NodeBuilder::new(NodeType::Split, "test_split") + .input_tensor_f32("input", input_rank, static_shape); + // Add split sizes input if provided if let Some(sizes) = split_sizes_input { - inputs.push(Argument { - name: "split".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Int64, - rank: 1, - static_shape: Some(vec![sizes.len()]), - }), - value: Some(TensorData { - shape: vec![sizes.len()], - data: Data::Int64s(sizes), - }), - passed: true, - }); + builder = builder.input_tensor_i64_data( + "split", + sizes.clone(), + vec![sizes.len()] + ); } - - let mut outputs = Vec::new(); + + // Add output tensors for i in 0..num_outputs { - outputs.push(Argument { - name: format!("output_{}", i), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: 0, // Will be updated - static_shape: None, - }), - value: None, - passed: true, - }); + builder = builder.output_tensor_f32( + &format!("output_{}", i), + 0, // Will be updated + None + ); } - - Node { - node_type: NodeType::Split, - name: "test_split".to_string(), - inputs, - outputs, - attrs: attrs.unwrap_or_default(), + + // Add attributes if provided + let mut node = builder.build(); + + if let Some(attributes) = attrs { + node.attrs = attributes; } + + node } #[test] diff --git a/crates/onnx-ir/src/node/test_utils.rs b/crates/onnx-ir/src/node/test_utils.rs index 811b0d56b1..a046dddc20 100644 --- a/crates/onnx-ir/src/node/test_utils.rs +++ b/crates/onnx-ir/src/node/test_utils.rs @@ -222,6 +222,68 @@ impl NodeBuilder { shape, ) } + + + /// Add a float32 scalar tensor input (rank 0) + pub fn input_scalar_tensor_f32( + mut self, + name: &str, + value: Option, + ) -> Self { + let arg = Argument { + name: name.to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 0, + static_shape: None, + }), + value: value.map(|val| TensorData { + data: Data::Float32(val), + shape: vec![], + }), + passed: true, + }; + self.inputs.push(arg); + self + } + + /// Add an int64 scalar tensor input (rank 0) + pub fn input_scalar_tensor_i64( + mut self, + name: &str, + value: i64, + ) -> Self { + let arg = Argument { + name: name.to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Int64, + rank: 0, + static_shape: None, + }), + value: Some(TensorData { + data: Data::Int64(value), + shape: vec![], + }), + passed: true, + }; + self.inputs.push(arg); + self + } + + + /// Add multiple tensor inputs with the same type but different names + pub fn input_tensors_f32( + mut self, + name_prefix: &str, + count: usize, + rank: usize, + static_shape: Option>, + ) -> Self { + for i in 0..count { + self = self.input_tensor_f32(&format!("{}_{}", name_prefix, i), rank, static_shape.clone()); + } + self + } /// Add a generic output with the given name and type /// @@ -426,6 +488,17 @@ impl NodeBuilder { self } + /// Add a default output with the given name + pub fn output_default(mut self, name: &str) -> Self { + self.outputs.push(Argument { + name: name.to_string(), + ty: ArgType::default(), + value: None, + passed: true, + }); + self + } + /// Build the node pub fn build(self) -> Node { Node { diff --git a/crates/onnx-ir/src/node/tile.rs b/crates/onnx-ir/src/node/tile.rs index 557e6abfe7..e1f664af98 100644 --- a/crates/onnx-ir/src/node/tile.rs +++ b/crates/onnx-ir/src/node/tile.rs @@ -36,58 +36,25 @@ pub fn tile_config(node: &Node) -> TileConfig { #[cfg(test)] mod tests { use super::*; - use crate::ir::{ArgType, Argument, Data, ElementType, NodeType, TensorType}; - use std::collections::HashMap; + use crate::ir::{Argument, ArgType, NodeType, TensorType, ElementType}; + use crate::node::test_utils::NodeBuilder; /// Helper function to create test nodes with different repeat values fn create_test_node(repeats: Option>, input_rank: usize) -> Node { - let mut inputs = vec![ - // First input: the tensor to tile - Argument { - name: "input".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: input_rank, - static_shape: None, - }), - value: None, - passed: true, - }, - ]; - + let mut builder = NodeBuilder::new(NodeType::Tile, "test_tile") + .input_tensor_f32("input", input_rank, None) + .output_tensor_f32("output", input_rank, None); // Same rank as input initially + // Add repeats input if provided if let Some(reps) = repeats { - inputs.push(Argument { - name: "repeats".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Int64, - rank: 1, - static_shape: Some(vec![reps.len()]), - }), - value: Some(TensorData { - shape: vec![reps.len()], - data: Data::Int64s(reps), - }), - passed: true, - }); - } - - Node { - node_type: NodeType::Tile, - name: "test_tile".to_string(), - inputs, - outputs: vec![Argument { - name: "output".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: input_rank, // Same rank as input initially - static_shape: None, - }), - value: None, - passed: true, - }], - attrs: HashMap::new(), + builder = builder.input_tensor_i64_data( + "repeats", + reps.clone(), + vec![reps.len()] + ); } + + builder.build() } #[test] diff --git a/crates/onnx-ir/src/node/unsqueeze.rs b/crates/onnx-ir/src/node/unsqueeze.rs index f7044e6a37..42e9810922 100644 --- a/crates/onnx-ir/src/node/unsqueeze.rs +++ b/crates/onnx-ir/src/node/unsqueeze.rs @@ -112,8 +112,8 @@ pub fn unsqueeze_config(node: &Node) -> UnsqueezeConfig { #[cfg(test)] mod tests { use super::*; - use crate::ir::{Argument, AttributeValue, ElementType, NodeType, TensorData}; - use std::collections::HashMap; + use crate::ir::{ElementType, NodeType}; + use crate::node::test_utils::NodeBuilder; // Implement custom equality for UnsqueezeConfig to make testing easier impl PartialEq for UnsqueezeConfig { @@ -127,90 +127,37 @@ mod tests { } fn create_test_node_with_attr(input_rank: usize, axes: Vec) -> Node { - let inputs = vec![Argument { - name: "X".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: input_rank, - static_shape: None, - }), - value: None, - passed: true, - }]; - - let outputs = vec![Argument { - name: "Y".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: 0, // Will be updated - static_shape: None, - }), - value: None, - passed: true, - }]; - - let mut attrs = HashMap::new(); - attrs.insert("axes".to_string(), AttributeValue::Int64s(axes.clone())); - - Node { - node_type: NodeType::Unsqueeze, - name: "test_unsqueeze".to_string(), - inputs, - outputs, - attrs, - } + let builder = NodeBuilder::new(NodeType::Unsqueeze, "test_unsqueeze") + .input_tensor_f32("X", input_rank, None) + .output_tensor_f32("Y", 0, None) // Will be updated + .attr_ints("axes", axes); + + builder.build() } fn create_test_node_with_input(input_rank: usize, axes: Vec, with_value: bool) -> Node { let axes_len = axes.len(); - let inputs = vec![ - Argument { - name: "X".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: input_rank, - static_shape: None, - }), - value: None, - passed: true, - }, - Argument { - name: "axes".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Int64, - rank: 1, - static_shape: Some(vec![axes_len]), - }), - value: if with_value { - Some(TensorData { - data: Data::Int64s(axes.clone()), - shape: vec![axes_len], - }) - } else { - None - }, - passed: true, - }, - ]; - - let outputs = vec![Argument { - name: "Y".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: 0, // Will be updated - static_shape: None, - }), - value: None, - passed: true, - }]; - - Node { - node_type: NodeType::Unsqueeze, - name: "test_unsqueeze".to_string(), - inputs, - outputs, - attrs: HashMap::new(), + let mut builder = NodeBuilder::new(NodeType::Unsqueeze, "test_unsqueeze") + .input_tensor_f32("X", input_rank, None) + .output_tensor_f32("Y", 0, None); // Will be updated + + // Add axes input with or without value + if with_value { + builder = builder.input_tensor_i64_data( + "axes", + axes.clone(), + vec![axes_len] + ); + } else { + // Input without value + builder = builder.input_tensor_i64( + "axes", + 1, + Some(vec![axes_len]) + ); } + + builder.build() } // Tests for unsqueeze_update_output function From a4592a3a2075b1be2340941678846fe7e66e7ec7 Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Thu, 8 May 2025 16:13:11 -0500 Subject: [PATCH 32/37] Refactored tests to use NodeBuilder --- crates/onnx-ir/src/node/batch_norm.rs | 2 +- crates/onnx-ir/src/node/clip.rs | 6 +- crates/onnx-ir/src/node/conv2d.rs | 6 +- crates/onnx-ir/src/node/gemm.rs | 4 +- crates/onnx-ir/src/node/hard_sigmoid.rs | 41 +++---------- crates/onnx-ir/src/node/layer_norm.rs | 82 +++++-------------------- crates/onnx-ir/src/node/leaky_relu.rs | 39 +++--------- crates/onnx-ir/src/node/log_softmax.rs | 46 ++++---------- crates/onnx-ir/src/node/max_pool1d.rs | 6 +- crates/onnx-ir/src/node/one_hot.rs | 4 +- crates/onnx-ir/src/node/pad.rs | 12 ++-- crates/onnx-ir/src/node/reduce_max.rs | 41 +++---------- crates/onnx-ir/src/node/reduce_mean.rs | 41 +++---------- crates/onnx-ir/src/node/reduce_min.rs | 41 +++---------- crates/onnx-ir/src/node/reduce_prod.rs | 41 +++---------- crates/onnx-ir/src/node/reduce_sum.rs | 10 +-- crates/onnx-ir/src/node/reshape.rs | 6 +- crates/onnx-ir/src/node/resize.rs | 14 ++--- crates/onnx-ir/src/node/shape.rs | 47 ++++---------- crates/onnx-ir/src/node/slice.rs | 5 +- crates/onnx-ir/src/node/split.rs | 31 +++++----- crates/onnx-ir/src/node/squeeze.rs | 43 ++++--------- crates/onnx-ir/src/node/test_utils.rs | 28 ++++----- crates/onnx-ir/src/node/tile.rs | 12 ++-- crates/onnx-ir/src/node/topk.rs | 75 +++++++--------------- crates/onnx-ir/src/node/transpose.rs | 44 ++++--------- crates/onnx-ir/src/node/trilu.rs | 56 +++-------------- crates/onnx-ir/src/node/unsqueeze.rs | 18 ++---- 28 files changed, 213 insertions(+), 588 deletions(-) diff --git a/crates/onnx-ir/src/node/batch_norm.rs b/crates/onnx-ir/src/node/batch_norm.rs index 0cef6c14e5..aaa0165c8f 100644 --- a/crates/onnx-ir/src/node/batch_norm.rs +++ b/crates/onnx-ir/src/node/batch_norm.rs @@ -56,7 +56,7 @@ mod tests { fn create_test_node(epsilon: f32, momentum: f32, num_features: usize) -> Node { let ones = vec![1.0; num_features]; let zeros = vec![0.0; num_features]; - + NodeBuilder::new(NodeType::BatchNormalization, "test_batchnorm") .input_tensor_f32("X", 4, None) // NCHW format .input_tensor_f32_data("scale", ones.clone(), vec![num_features]) diff --git a/crates/onnx-ir/src/node/clip.rs b/crates/onnx-ir/src/node/clip.rs index 2f5bbedbe1..e33bf381f0 100644 --- a/crates/onnx-ir/src/node/clip.rs +++ b/crates/onnx-ir/src/node/clip.rs @@ -63,15 +63,15 @@ mod tests { let mut builder = NodeBuilder::new(NodeType::Clip, "test_clip") .input_tensor_f32("X", 4, None) .output_tensor_f32("Y", 4, None); - + if let Some(min_val) = min { builder = builder.attr_float("min", min_val); } - + if let Some(max_val) = max { builder = builder.attr_float("max", max_val); } - + builder.build() } diff --git a/crates/onnx-ir/src/node/conv2d.rs b/crates/onnx-ir/src/node/conv2d.rs index e783c7d973..3048bdfefe 100644 --- a/crates/onnx-ir/src/node/conv2d.rs +++ b/crates/onnx-ir/src/node/conv2d.rs @@ -107,7 +107,7 @@ mod tests { let weight_data = vec![0.0; 16]; // [output_channels, input_channels/groups, k_h, k_w] let weight_shape = vec![4, 2, 2, 2]; - + let mut builder = NodeBuilder::new(NodeType::Conv2d, "test_conv2d") .input_tensor_f32("data", 4, None) .input_tensor_f32_data("weight", weight_data.clone(), weight_shape) @@ -117,11 +117,11 @@ mod tests { .attr_ints("pads", pads) .attr_ints("dilations", dilations) .attr_int("group", group); - + if has_bias { builder = builder.input_tensor_f32("bias", 1, None); } - + builder.build() } diff --git a/crates/onnx-ir/src/node/gemm.rs b/crates/onnx-ir/src/node/gemm.rs index f939a8f111..803da4970d 100644 --- a/crates/onnx-ir/src/node/gemm.rs +++ b/crates/onnx-ir/src/node/gemm.rs @@ -76,7 +76,7 @@ mod tests { .input_tensor_f32("B", 2, None) .input_tensor_f32("C", 2, None) .output_tensor_f32("Y", 2, None); - + if let Some(alpha_val) = alpha { builder = builder.attr_float("alpha", alpha_val); } @@ -89,7 +89,7 @@ mod tests { if let Some(trans_b_val) = trans_b { builder = builder.attr_int("transB", trans_b_val); } - + builder.build() } diff --git a/crates/onnx-ir/src/node/hard_sigmoid.rs b/crates/onnx-ir/src/node/hard_sigmoid.rs index 636dbc48a6..7b1a517f00 100644 --- a/crates/onnx-ir/src/node/hard_sigmoid.rs +++ b/crates/onnx-ir/src/node/hard_sigmoid.rs @@ -19,41 +19,16 @@ pub fn hard_sigmoid_config(node: &Node) -> (f64, f64) { #[cfg(test)] mod tests { use super::*; - use crate::ir::{ArgType, Argument, AttributeValue, ElementType, NodeType, TensorType}; - use std::collections::HashMap; + use crate::ir::NodeType; + use crate::node::test_utils::NodeBuilder; fn create_test_node(alpha: f32, beta: f32) -> Node { - let inputs = vec![Argument { - name: "X".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: 4, - static_shape: None, - }), - value: None, - passed: true, - }]; - - let mut attrs = HashMap::new(); - attrs.insert("alpha".to_string(), AttributeValue::Float32(alpha)); - attrs.insert("beta".to_string(), AttributeValue::Float32(beta)); - - Node { - node_type: NodeType::HardSigmoid, - name: "test_hard_sigmoid".to_string(), - inputs, - outputs: vec![Argument { - name: "Y".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: 4, - static_shape: None, - }), - value: None, - passed: true, - }], - attrs, - } + NodeBuilder::new(NodeType::HardSigmoid, "test_hard_sigmoid") + .input_tensor_f32("X", 4, None) + .output_tensor_f32("Y", 4, None) + .attr_float("alpha", alpha) + .attr_float("beta", beta) + .build() } #[test] diff --git a/crates/onnx-ir/src/node/layer_norm.rs b/crates/onnx-ir/src/node/layer_norm.rs index 9911f9b6c1..8158adcefc 100644 --- a/crates/onnx-ir/src/node/layer_norm.rs +++ b/crates/onnx-ir/src/node/layer_norm.rs @@ -64,76 +64,22 @@ pub fn layer_norm_config(node: &Node) -> (LayerNormConfig, bool) { #[cfg(test)] mod tests { use super::*; - use crate::ir::{ - ArgType, Argument, AttributeValue, Data, ElementType, NodeType, TensorData, TensorType, - }; - use std::collections::HashMap; + use crate::ir::NodeType; + use crate::node::test_utils::NodeBuilder; fn create_test_node(epsilon: f32, axis: i64, stash_type: i64, num_features: usize) -> Node { - let weight_tensor = TensorData { - data: Data::Float32s(vec![1.0; num_features]), // Not important for the test - shape: vec![num_features], - }; - - let bias_tensor = TensorData { - data: Data::Float32s(vec![0.0; num_features]), // Not important for the test - shape: vec![num_features], - }; - - let inputs = vec![ - Argument { - name: "X".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: 3, - static_shape: None, - }), - value: None, - passed: true, - }, - Argument { - name: "scale".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: 1, - static_shape: None, - }), - value: Some(weight_tensor), - passed: true, - }, - Argument { - name: "bias".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: 1, - static_shape: None, - }), - value: Some(bias_tensor), - passed: true, - }, - ]; - - let mut attrs = HashMap::new(); - attrs.insert("epsilon".to_string(), AttributeValue::Float32(epsilon)); - attrs.insert("axis".to_string(), AttributeValue::Int64(axis)); - attrs.insert("stash_type".to_string(), AttributeValue::Int64(stash_type)); - - Node { - node_type: NodeType::LayerNormalization, - name: "test_layernorm".to_string(), - inputs, - outputs: vec![Argument { - name: "output".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: 3, - static_shape: None, - }), - value: None, - passed: true, - }], - attrs, - } + let weight_data = vec![1.0; num_features]; // Not important for the test + let bias_data = vec![0.0; num_features]; // Not important for the test + + NodeBuilder::new(NodeType::LayerNormalization, "test_layernorm") + .input_tensor_f32("X", 3, None) + .input_tensor_f32_data("scale", weight_data, vec![num_features]) + .input_tensor_f32_data("bias", bias_data, vec![num_features]) + .output_tensor_f32("output", 3, None) + .attr_float("epsilon", epsilon) + .attr_int("axis", axis) + .attr_int("stash_type", stash_type) + .build() } #[test] diff --git a/crates/onnx-ir/src/node/leaky_relu.rs b/crates/onnx-ir/src/node/leaky_relu.rs index 6dcad47fc9..ebeff2464a 100644 --- a/crates/onnx-ir/src/node/leaky_relu.rs +++ b/crates/onnx-ir/src/node/leaky_relu.rs @@ -16,40 +16,15 @@ pub fn leaky_relu_config(node: &Node) -> f64 { #[cfg(test)] mod tests { use super::*; - use crate::ir::{ArgType, Argument, AttributeValue, ElementType, NodeType, TensorType}; - use std::collections::HashMap; + use crate::ir::NodeType; + use crate::node::test_utils::NodeBuilder; fn create_test_node(alpha: f32) -> Node { - let inputs = vec![Argument { - name: "X".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: 4, - static_shape: None, - }), - value: None, - passed: true, - }]; - - let mut attrs = HashMap::new(); - attrs.insert("alpha".to_string(), AttributeValue::Float32(alpha)); - - Node { - node_type: NodeType::LeakyRelu, - name: "test_leaky_relu".to_string(), - inputs, - outputs: vec![Argument { - name: "Y".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: 4, - static_shape: None, - }), - value: None, - passed: true, - }], - attrs, - } + NodeBuilder::new(NodeType::LeakyRelu, "test_leaky_relu") + .input_tensor_f32("X", 4, None) + .output_tensor_f32("Y", 4, None) + .attr_float("alpha", alpha) + .build() } #[test] diff --git a/crates/onnx-ir/src/node/log_softmax.rs b/crates/onnx-ir/src/node/log_softmax.rs index 61f0619df5..5a6ce66248 100644 --- a/crates/onnx-ir/src/node/log_softmax.rs +++ b/crates/onnx-ir/src/node/log_softmax.rs @@ -37,40 +37,15 @@ pub fn log_softmax_config(node: &Node) -> usize { #[cfg(test)] mod tests { use super::*; - use crate::ir::{Argument, AttributeValue, ElementType, NodeType, TensorType}; - use std::collections::HashMap; + use crate::ir::NodeType; + use crate::node::test_utils::NodeBuilder; fn create_test_node(axis: i64, input_rank: usize) -> Node { - let inputs = vec![Argument { - name: "data".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: input_rank, - static_shape: None, - }), - value: None, - passed: true, - }]; - - let mut attrs = HashMap::new(); - attrs.insert("axis".to_string(), AttributeValue::Int64(axis)); - - Node { - node_type: NodeType::LogSoftmax, - name: "test_log_softmax".to_string(), - inputs, - outputs: vec![Argument { - name: "output".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: input_rank, - static_shape: None, - }), - value: None, - passed: true, - }], - attrs, - } + NodeBuilder::new(NodeType::LogSoftmax, "test_log_softmax") + .input_tensor_f32("data", input_rank, None) + .output_tensor_f32("output", input_rank, None) + .attr_int("axis", axis) + .build() } #[test] @@ -91,10 +66,11 @@ mod tests { #[should_panic(expected = "LogSoftmax: multiple inputs are not supported")] fn test_log_softmax_config_multiple_inputs() { let mut node = create_test_node(1, 3); - node.inputs.push(Argument { + // Add an extra input to cause the expected panic + node.inputs.push(crate::ir::Argument { name: "extra".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, + ty: crate::ir::ArgType::Tensor(crate::ir::TensorType { + elem_type: crate::ir::ElementType::Float32, rank: 1, static_shape: None, }), diff --git a/crates/onnx-ir/src/node/max_pool1d.rs b/crates/onnx-ir/src/node/max_pool1d.rs index 5538aa2fc6..6446c6794f 100644 --- a/crates/onnx-ir/src/node/max_pool1d.rs +++ b/crates/onnx-ir/src/node/max_pool1d.rs @@ -83,11 +83,7 @@ pub fn max_pool1d_config(curr: &Node) -> MaxPool1dConfig { #[cfg(test)] mod tests { use super::*; - use crate::{ - ir::NodeType, - node::padding::PaddingConfig1d, - node::test_utils::NodeBuilder, - }; + use crate::{ir::NodeType, node::padding::PaddingConfig1d, node::test_utils::NodeBuilder}; fn create_test_node( kernel_shape: Vec, diff --git a/crates/onnx-ir/src/node/one_hot.rs b/crates/onnx-ir/src/node/one_hot.rs index 598122e075..1ccff14d45 100644 --- a/crates/onnx-ir/src/node/one_hot.rs +++ b/crates/onnx-ir/src/node/one_hot.rs @@ -56,11 +56,11 @@ mod tests { .input_scalar_tensor_i64("depth", depth) .input_tensor_f32_data("values", values.clone(), vec![2]) // always [off_value, on_value] .output_tensor_f32("output", 3, None); // rank increases by 1 - + if let Some(axis_val) = axis { builder = builder.attr_int("axis", axis_val); } - + builder.build() } diff --git a/crates/onnx-ir/src/node/pad.rs b/crates/onnx-ir/src/node/pad.rs index 2a72125f02..c1eceda76f 100644 --- a/crates/onnx-ir/src/node/pad.rs +++ b/crates/onnx-ir/src/node/pad.rs @@ -160,30 +160,30 @@ mod tests { let mut builder = NodeBuilder::new(NodeType::Pad, "test_pad") .input_tensor_f32("data", rank, None) .output_tensor_f32("output", rank, None); - + // Add pad inputs if provided if let Some(pads) = pad_inputs.clone() { builder = builder.input_tensor_i64_data("pads", pads, vec![]); } - + // Add constant value input if provided if let Some(value) = constant_value_input { builder = builder.input_scalar_tensor_f32("constant_value", Some(value)); } - + // Add attributes if provided if let Some(pads) = pad_attrs { builder = builder.attr_ints("pads", pads); } - + if let Some(value) = constant_value_attr { builder = builder.attr_float("value", value); } - + if let Some(mode_val) = mode { builder = builder.attr_string("mode", mode_val); } - + builder.build() } diff --git a/crates/onnx-ir/src/node/reduce_max.rs b/crates/onnx-ir/src/node/reduce_max.rs index 883914f1b3..2cae7dfa99 100644 --- a/crates/onnx-ir/src/node/reduce_max.rs +++ b/crates/onnx-ir/src/node/reduce_max.rs @@ -80,45 +80,22 @@ pub fn reduce_max_update_outputs(node: &mut Node) { #[cfg(test)] mod tests { use super::*; - use crate::ir::{Argument, AttributeValue, ElementType, NodeType, TensorType}; - use std::collections::HashMap; + use crate::ir::NodeType; + use crate::node::test_utils::NodeBuilder; fn create_test_node(axes: Option>, keepdims: Option) -> Node { - let inputs = vec![Argument { - name: "data".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: 3, - static_shape: None, - }), - value: None, - passed: true, - }]; - - let mut attrs = HashMap::new(); + let mut builder = NodeBuilder::new(NodeType::ReduceMax, "test_reduce_max") + .input_tensor_f32("data", 3, None) + .output_tensor_f32("reduced", 3, None); + if let Some(axes_val) = axes { - attrs.insert("axes".to_string(), AttributeValue::Int64s(axes_val.clone())); + builder = builder.attr_ints("axes", axes_val); } if let Some(kd) = keepdims { - attrs.insert("keepdims".to_string(), AttributeValue::Int64(kd)); + builder = builder.attr_int("keepdims", kd); } - Node { - node_type: NodeType::ReduceMax, - name: "test_reduce_max".to_string(), - inputs, - outputs: vec![Argument { - name: "reduced".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: 3, - static_shape: None, - }), - value: None, - passed: true, - }], - attrs, - } + builder.build() } #[test] diff --git a/crates/onnx-ir/src/node/reduce_mean.rs b/crates/onnx-ir/src/node/reduce_mean.rs index d3c0775226..81849d2504 100644 --- a/crates/onnx-ir/src/node/reduce_mean.rs +++ b/crates/onnx-ir/src/node/reduce_mean.rs @@ -79,45 +79,22 @@ pub fn reduce_mean_update_outputs(node: &mut Node) { #[cfg(test)] mod tests { use super::*; - use crate::ir::{Argument, AttributeValue, ElementType, NodeType, TensorType}; - use std::collections::HashMap; + use crate::ir::NodeType; + use crate::node::test_utils::NodeBuilder; fn create_test_node(axes: Option>, keepdims: Option) -> Node { - let inputs = vec![Argument { - name: "data".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: 3, - static_shape: None, - }), - value: None, - passed: true, - }]; - - let mut attrs = HashMap::new(); + let mut builder = NodeBuilder::new(NodeType::ReduceMean, "test_reduce_mean") + .input_tensor_f32("data", 3, None) + .output_tensor_f32("reduced", 3, None); + if let Some(axes_val) = axes { - attrs.insert("axes".to_string(), AttributeValue::Int64s(axes_val.clone())); + builder = builder.attr_ints("axes", axes_val); } if let Some(kd) = keepdims { - attrs.insert("keepdims".to_string(), AttributeValue::Int64(kd)); + builder = builder.attr_int("keepdims", kd); } - Node { - node_type: NodeType::ReduceMean, - name: "test_reduce_mean".to_string(), - inputs, - outputs: vec![Argument { - name: "reduced".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: 3, - static_shape: None, - }), - value: None, - passed: true, - }], - attrs, - } + builder.build() } #[test] diff --git a/crates/onnx-ir/src/node/reduce_min.rs b/crates/onnx-ir/src/node/reduce_min.rs index ccf28acb66..494079454f 100644 --- a/crates/onnx-ir/src/node/reduce_min.rs +++ b/crates/onnx-ir/src/node/reduce_min.rs @@ -78,45 +78,22 @@ pub fn reduce_min_update_outputs(node: &mut Node) { #[cfg(test)] mod tests { use super::*; - use crate::ir::{Argument, AttributeValue, ElementType, NodeType, TensorType}; - use std::collections::HashMap; + use crate::ir::NodeType; + use crate::node::test_utils::NodeBuilder; fn create_test_node(axes: Option>, keepdims: Option) -> Node { - let inputs = vec![Argument { - name: "data".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: 3, - static_shape: None, - }), - value: None, - passed: true, - }]; - - let mut attrs = HashMap::new(); + let mut builder = NodeBuilder::new(NodeType::ReduceMin, "test_reduce_min") + .input_tensor_f32("data", 3, None) + .output_tensor_f32("reduced", 3, None); + if let Some(axes_val) = axes { - attrs.insert("axes".to_string(), AttributeValue::Int64s(axes_val.clone())); + builder = builder.attr_ints("axes", axes_val); } if let Some(kd) = keepdims { - attrs.insert("keepdims".to_string(), AttributeValue::Int64(kd)); + builder = builder.attr_int("keepdims", kd); } - Node { - node_type: NodeType::ReduceMin, - name: "test_reduce_min".to_string(), - inputs, - outputs: vec![Argument { - name: "reduced".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: 3, - static_shape: None, - }), - value: None, - passed: true, - }], - attrs, - } + builder.build() } #[test] diff --git a/crates/onnx-ir/src/node/reduce_prod.rs b/crates/onnx-ir/src/node/reduce_prod.rs index 2f4df96b84..62196a7b4b 100644 --- a/crates/onnx-ir/src/node/reduce_prod.rs +++ b/crates/onnx-ir/src/node/reduce_prod.rs @@ -81,45 +81,22 @@ pub fn reduce_prod_update_outputs(node: &mut Node) { #[cfg(test)] mod tests { use super::*; - use crate::ir::{Argument, AttributeValue, ElementType, NodeType, TensorType}; - use std::collections::HashMap; + use crate::ir::NodeType; + use crate::node::test_utils::NodeBuilder; fn create_test_node(axes: Option>, keepdims: Option) -> Node { - let inputs = vec![Argument { - name: "data".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: 3, - static_shape: None, - }), - value: None, - passed: true, - }]; - - let mut attrs = HashMap::new(); + let mut builder = NodeBuilder::new(NodeType::ReduceProd, "test_reduce_prod") + .input_tensor_f32("data", 3, None) + .output_tensor_f32("reduced", 3, None); + if let Some(axes_val) = axes { - attrs.insert("axes".to_string(), AttributeValue::Int64s(axes_val.clone())); + builder = builder.attr_ints("axes", axes_val); } if let Some(kd) = keepdims { - attrs.insert("keepdims".to_string(), AttributeValue::Int64(kd)); + builder = builder.attr_int("keepdims", kd); } - Node { - node_type: NodeType::ReduceProd, - name: "test_reduce_prod".to_string(), - inputs, - outputs: vec![Argument { - name: "reduced".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: 3, - static_shape: None, - }), - value: None, - passed: true, - }], - attrs, - } + builder.build() } #[test] diff --git a/crates/onnx-ir/src/node/reduce_sum.rs b/crates/onnx-ir/src/node/reduce_sum.rs index 2fa43cae8e..6b2fae1283 100644 --- a/crates/onnx-ir/src/node/reduce_sum.rs +++ b/crates/onnx-ir/src/node/reduce_sum.rs @@ -94,7 +94,7 @@ pub fn reduce_sum_update_outputs(node: &mut Node) { #[cfg(test)] mod tests { use super::*; - use crate::ir::{NodeType}; + use crate::ir::NodeType; use crate::node::test_utils::NodeBuilder; fn create_test_node( @@ -105,22 +105,22 @@ mod tests { let mut builder = NodeBuilder::new(NodeType::ReduceSum, "test_reduce_sum") .input_tensor_f32("data", 3, None) .output_tensor_f32("reduced", 3, None); - + // Add axes input if requested if with_axes_input && axes.is_some() { let axes_vec = axes.clone().unwrap(); builder = builder.input_tensor_i64_data("axes", axes_vec.clone(), vec![axes_vec.len()]); } - + // Add attributes if !with_axes_input && axes.is_some() { builder = builder.attr_ints("axes", axes.clone().unwrap()); } - + if let Some(kd) = keepdims { builder = builder.attr_int("keepdims", kd); } - + builder.build() } diff --git a/crates/onnx-ir/src/node/reshape.rs b/crates/onnx-ir/src/node/reshape.rs index e4d8eda39c..9d64e5baf1 100644 --- a/crates/onnx-ir/src/node/reshape.rs +++ b/crates/onnx-ir/src/node/reshape.rs @@ -82,7 +82,7 @@ pub fn reshape_config(node: &Node) -> Vec { #[cfg(test)] mod tests { use super::*; - use crate::ir::{NodeType}; + use crate::ir::NodeType; use crate::node::test_utils::NodeBuilder; fn create_test_node(allowzero: i64, shape_vec: Vec) -> Node { @@ -90,11 +90,11 @@ mod tests { .input_tensor_f32("data", 4, None) .input_tensor_i64_data("shape", shape_vec.clone(), vec![shape_vec.len()]) .output_tensor_f32("reshaped", 2, None); - + if allowzero != 0 { builder = builder.attr_int("allowzero", allowzero); } - + builder.build() } diff --git a/crates/onnx-ir/src/node/resize.rs b/crates/onnx-ir/src/node/resize.rs index 4d46081be0..4d43e58824 100644 --- a/crates/onnx-ir/src/node/resize.rs +++ b/crates/onnx-ir/src/node/resize.rs @@ -144,34 +144,34 @@ mod tests { .input_tensor_f32("X", 4, None) // N,C,H,W format .output_tensor_f32("Y", 4, None) .attr_string("mode", mode); - + // Add ROI input if provided if let Some(roi_data) = roi { - builder = builder.input_tensor_f32_data("roi", roi_data.clone(), vec![8]); + builder = builder.input_tensor_f32_data("roi", roi_data.clone(), vec![8]); // For 4D input (start x, start y, end x, end y) } else { // Empty ROI still needs to be added as a placeholder builder = builder.input_tensor_f32("roi", 1, None); } - + // Add scales input if provided if let Some(scales_data) = scales { - builder = builder.input_tensor_f32_data("scales", scales_data.clone(), vec![4]); + builder = builder.input_tensor_f32_data("scales", scales_data.clone(), vec![4]); // N,C,H,W scales } else { // Empty scales still needs to be added as a placeholder builder = builder.input_tensor_f32("scales", 1, None); } - + // Add sizes input if provided if let Some(sizes_data) = sizes { - builder = builder.input_tensor_i64_data("sizes", sizes_data.clone(), vec![4]); + builder = builder.input_tensor_i64_data("sizes", sizes_data.clone(), vec![4]); // N,C,H,W sizes } else { // Empty sizes still needs to be added as a placeholder builder = builder.input_tensor_i64("sizes", 1, None); } - + builder.build() } diff --git a/crates/onnx-ir/src/node/shape.rs b/crates/onnx-ir/src/node/shape.rs index 694fc05436..1d4a9770e1 100644 --- a/crates/onnx-ir/src/node/shape.rs +++ b/crates/onnx-ir/src/node/shape.rs @@ -58,45 +58,23 @@ pub fn shape_update_outputs(node: &mut Node) { #[cfg(test)] mod tests { use super::*; - use crate::ir::{Argument, AttributeValue, ElementType, NodeType, TensorType}; - use std::collections::HashMap; + use crate::ir::NodeType; + use crate::node::test_utils::NodeBuilder; fn create_test_node(start: Option, end: Option, rank: usize) -> Node { - let inputs = vec![Argument { - name: "data".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank, - static_shape: None, - }), - value: None, - passed: true, - }]; + let mut builder = NodeBuilder::new(NodeType::Shape, "test_shape") + .input_tensor_f32("data", rank, None) + .output_tensor_i64("shape", 1, None); - let mut attrs = HashMap::new(); if let Some(start_val) = start { - attrs.insert("start".to_string(), AttributeValue::Int64(start_val)); + builder = builder.attr_int("start", start_val); } + if let Some(end_val) = end { - attrs.insert("end".to_string(), AttributeValue::Int64(end_val)); + builder = builder.attr_int("end", end_val); } - Node { - node_type: NodeType::Shape, - name: "test_shape".to_string(), - inputs, - outputs: vec![Argument { - name: "shape".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Int64, - rank: 1, - static_shape: None, - }), - value: None, - passed: true, - }], - attrs, - } + builder.build() } #[test] @@ -143,10 +121,11 @@ mod tests { #[should_panic(expected = "Shape: multiple inputs are not supported")] fn test_shape_config_multiple_inputs() { let mut node = create_test_node(None, None, 4); - node.inputs.push(Argument { + // Add an extra input to cause the expected panic + node.inputs.push(crate::ir::Argument { name: "extra".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, + ty: crate::ir::ArgType::Tensor(crate::ir::TensorType { + elem_type: crate::ir::ElementType::Float32, rank: 4, static_shape: None, }), diff --git a/crates/onnx-ir/src/node/slice.rs b/crates/onnx-ir/src/node/slice.rs index f9169c72a9..f1441db485 100644 --- a/crates/onnx-ir/src/node/slice.rs +++ b/crates/onnx-ir/src/node/slice.rs @@ -152,13 +152,14 @@ mod tests { builder = builder.input_tensor_i64_data("ends", ends.clone(), vec![ends.len()]); if let Some(axes_vec) = axes.clone() { - builder = builder.input_tensor_i64_data("axes", axes_vec.clone(), vec![axes_vec.len()]); + builder = + builder.input_tensor_i64_data("axes", axes_vec.clone(), vec![axes_vec.len()]); } } else { // Add attributes builder = builder.attr_ints("starts", starts); builder = builder.attr_ints("ends", ends); - + if let Some(axes_vec) = axes { builder = builder.attr_ints("axes", axes_vec); } diff --git a/crates/onnx-ir/src/node/split.rs b/crates/onnx-ir/src/node/split.rs index 221686abac..8f5f70af9a 100644 --- a/crates/onnx-ir/src/node/split.rs +++ b/crates/onnx-ir/src/node/split.rs @@ -157,7 +157,7 @@ pub fn split_config(node: &Node) -> SplitConfig { #[cfg(test)] mod tests { use super::*; - use crate::ir::{AttributeValue, ElementType, NodeType, ArgType}; + use crate::ir::{ArgType, AttributeValue, ElementType, NodeType}; use crate::node::test_utils::NodeBuilder; use std::collections::HashMap; @@ -169,34 +169,33 @@ mod tests { split_sizes_input: Option>, ) -> Node { // Start with input tensor - let mut builder = NodeBuilder::new(NodeType::Split, "test_split") - .input_tensor_f32("input", input_rank, static_shape); - + let mut builder = NodeBuilder::new(NodeType::Split, "test_split").input_tensor_f32( + "input", + input_rank, + static_shape, + ); + // Add split sizes input if provided if let Some(sizes) = split_sizes_input { - builder = builder.input_tensor_i64_data( - "split", - sizes.clone(), - vec![sizes.len()] - ); + builder = builder.input_tensor_i64_data("split", sizes.clone(), vec![sizes.len()]); } - + // Add output tensors for i in 0..num_outputs { builder = builder.output_tensor_f32( - &format!("output_{}", i), - 0, // Will be updated - None + &format!("output_{}", i), + 0, // Will be updated + None, ); } - + // Add attributes if provided let mut node = builder.build(); - + if let Some(attributes) = attrs { node.attrs = attributes; } - + node } diff --git a/crates/onnx-ir/src/node/squeeze.rs b/crates/onnx-ir/src/node/squeeze.rs index 9c528cd29e..cdb4c9c8c4 100644 --- a/crates/onnx-ir/src/node/squeeze.rs +++ b/crates/onnx-ir/src/node/squeeze.rs @@ -61,42 +61,21 @@ pub fn squeeze_update_output(node: &mut Node) { #[cfg(test)] mod tests { use super::*; - use crate::ir::{Argument, AttributeValue, ElementType, NodeType, TensorType}; - use std::collections::HashMap; + use crate::ir::NodeType; + use crate::node::test_utils::NodeBuilder; fn create_test_node(axes: Option>, rank: usize) -> Node { - let inputs = vec![Argument { - name: "data".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank, - static_shape: None, - }), - value: None, - passed: true, - }]; - - let mut attrs = HashMap::new(); - if let Some(ref axes_val) = axes { - attrs.insert("axes".to_string(), AttributeValue::Int64s(axes_val.clone())); - } + let output_rank = rank - (axes.as_ref().map_or(0, |a| a.len())); + + let mut builder = NodeBuilder::new(NodeType::Squeeze, "test_squeeze") + .input_tensor_f32("data", rank, None) + .output_tensor_f32("squeezed", output_rank, None); - Node { - node_type: NodeType::Squeeze, - name: "test_squeeze".to_string(), - inputs, - outputs: vec![Argument { - name: "squeezed".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: rank - (axes.as_ref().map_or(0, |a| a.len())), - static_shape: None, - }), - value: None, - passed: true, - }], - attrs, + if let Some(axes_val) = axes { + builder = builder.attr_ints("axes", axes_val); } + + builder.build() } #[test] diff --git a/crates/onnx-ir/src/node/test_utils.rs b/crates/onnx-ir/src/node/test_utils.rs index a046dddc20..d6df87f50f 100644 --- a/crates/onnx-ir/src/node/test_utils.rs +++ b/crates/onnx-ir/src/node/test_utils.rs @@ -222,14 +222,9 @@ impl NodeBuilder { shape, ) } - - + /// Add a float32 scalar tensor input (rank 0) - pub fn input_scalar_tensor_f32( - mut self, - name: &str, - value: Option, - ) -> Self { + pub fn input_scalar_tensor_f32(mut self, name: &str, value: Option) -> Self { let arg = Argument { name: name.to_string(), ty: ArgType::Tensor(TensorType { @@ -246,13 +241,9 @@ impl NodeBuilder { self.inputs.push(arg); self } - + /// Add an int64 scalar tensor input (rank 0) - pub fn input_scalar_tensor_i64( - mut self, - name: &str, - value: i64, - ) -> Self { + pub fn input_scalar_tensor_i64(mut self, name: &str, value: i64) -> Self { let arg = Argument { name: name.to_string(), ty: ArgType::Tensor(TensorType { @@ -269,8 +260,7 @@ impl NodeBuilder { self.inputs.push(arg); self } - - + /// Add multiple tensor inputs with the same type but different names pub fn input_tensors_f32( mut self, @@ -280,7 +270,11 @@ impl NodeBuilder { static_shape: Option>, ) -> Self { for i in 0..count { - self = self.input_tensor_f32(&format!("{}_{}", name_prefix, i), rank, static_shape.clone()); + self = self.input_tensor_f32( + &format!("{}_{}", name_prefix, i), + rank, + static_shape.clone(), + ); } self } @@ -480,7 +474,7 @@ impl NodeBuilder { .insert(name.to_string(), AttributeValue::Strings(values)); self } - + /// Add a tensor attribute pub fn attr_tensor(mut self, name: &str, tensor: TensorData) -> Self { self.attrs diff --git a/crates/onnx-ir/src/node/tile.rs b/crates/onnx-ir/src/node/tile.rs index e1f664af98..10ada204ef 100644 --- a/crates/onnx-ir/src/node/tile.rs +++ b/crates/onnx-ir/src/node/tile.rs @@ -36,7 +36,7 @@ pub fn tile_config(node: &Node) -> TileConfig { #[cfg(test)] mod tests { use super::*; - use crate::ir::{Argument, ArgType, NodeType, TensorType, ElementType}; + use crate::ir::{ArgType, Argument, ElementType, NodeType, TensorType}; use crate::node::test_utils::NodeBuilder; /// Helper function to create test nodes with different repeat values @@ -44,16 +44,12 @@ mod tests { let mut builder = NodeBuilder::new(NodeType::Tile, "test_tile") .input_tensor_f32("input", input_rank, None) .output_tensor_f32("output", input_rank, None); // Same rank as input initially - + // Add repeats input if provided if let Some(reps) = repeats { - builder = builder.input_tensor_i64_data( - "repeats", - reps.clone(), - vec![reps.len()] - ); + builder = builder.input_tensor_i64_data("repeats", reps.clone(), vec![reps.len()]); } - + builder.build() } diff --git a/crates/onnx-ir/src/node/topk.rs b/crates/onnx-ir/src/node/topk.rs index 6094d80c2c..c4a7219156 100644 --- a/crates/onnx-ir/src/node/topk.rs +++ b/crates/onnx-ir/src/node/topk.rs @@ -95,7 +95,8 @@ pub fn top_k_config(node: &Node) -> TopKConfig { #[cfg(test)] mod tests { use super::*; - use crate::ir::{Argument, AttributeValue, Data, NodeType, TensorData}; + use crate::ir::{AttributeValue, NodeType}; + use crate::node::test_utils::NodeBuilder; use std::collections::HashMap; fn create_test_node( @@ -103,64 +104,32 @@ mod tests { attrs: Option>, k_input_value: Option, ) -> Node { - let mut inputs = vec![Argument { - name: "X".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: input_rank, - static_shape: None, - }), - value: None, - passed: true, - }]; + let mut builder = NodeBuilder::new(NodeType::TopK, "test_topk") + .input_tensor_f32("X", input_rank, None) + .output_tensor_f32("Values", 0, None) // Rank will be updated + .output_tensor_i64("Indices", 0, None); // Rank will be updated // Add K input if provided if let Some(k) = k_input_value { - inputs.push(Argument { - name: "K".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Int64, - rank: 0, - static_shape: Some(vec![]), - }), - value: Some(TensorData { - shape: vec![], - data: Data::Int64s(vec![k]), - }), - passed: true, - }); + builder = builder.input_tensor_i64_data("K", vec![k], vec![]); } - let outputs = vec![ - Argument { - name: "Values".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: 0, // Will be updated - static_shape: None, - }), - value: None, - passed: true, - }, - Argument { - name: "Indices".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Int64, - rank: 0, // Will be updated - static_shape: None, - }), - value: None, - passed: true, - }, - ]; - - Node { - node_type: NodeType::TopK, - name: "test_topk".to_string(), - inputs, - outputs, - attrs: attrs.unwrap_or_default(), + // Add attributes if provided + if let Some(attr_map) = attrs { + for (key, value) in attr_map { + match value { + AttributeValue::Int64(val) => builder = builder.attr_int(&key, val), + AttributeValue::Int64s(vals) => builder = builder.attr_ints(&key, vals), + AttributeValue::Float32(val) => builder = builder.attr_float(&key, val), + AttributeValue::Float32s(vals) => builder = builder.attr_floats(&key, vals), + AttributeValue::String(val) => builder = builder.attr_string(&key, &val), + AttributeValue::Strings(vals) => builder = builder.attr_strings(&key, vals), + _ => panic!("Unsupported attribute type"), + } + } } + + builder.build() } #[test] diff --git a/crates/onnx-ir/src/node/transpose.rs b/crates/onnx-ir/src/node/transpose.rs index 4b32f4316b..fbff00c2c9 100644 --- a/crates/onnx-ir/src/node/transpose.rs +++ b/crates/onnx-ir/src/node/transpose.rs @@ -27,42 +27,19 @@ pub fn transpose_config(curr: &Node) -> Vec { #[cfg(test)] mod tests { use super::*; - use crate::ir::{Argument, AttributeValue, ElementType, NodeType, TensorType}; - use std::collections::HashMap; + use crate::ir::NodeType; + use crate::node::test_utils::NodeBuilder; fn create_test_node(perm: Option>, rank: usize) -> Node { - let inputs = vec![Argument { - name: "data".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank, - static_shape: None, - }), - value: None, - passed: true, - }]; + let mut builder = NodeBuilder::new(NodeType::Transpose, "test_transpose") + .input_tensor_f32("data", rank, None) + .output_tensor_f32("transposed", rank, None); - let mut attrs = HashMap::new(); if let Some(perm_val) = perm { - attrs.insert("perm".to_string(), AttributeValue::Int64s(perm_val)); + builder = builder.attr_ints("perm", perm_val); } - Node { - node_type: NodeType::Transpose, - name: "test_transpose".to_string(), - inputs, - outputs: vec![Argument { - name: "transposed".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank, - static_shape: None, - }), - value: None, - passed: true, - }], - attrs, - } + builder.build() } #[test] @@ -83,10 +60,11 @@ mod tests { #[should_panic(expected = "Transpose: multiple inputs are not supported")] fn test_transpose_config_multiple_inputs() { let mut node = create_test_node(None, 3); - node.inputs.push(Argument { + // Add an extra input to cause the expected panic + node.inputs.push(crate::ir::Argument { name: "extra".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, + ty: crate::ir::ArgType::Tensor(crate::ir::TensorType { + elem_type: crate::ir::ElementType::Float32, rank: 3, static_shape: None, }), diff --git a/crates/onnx-ir/src/node/trilu.rs b/crates/onnx-ir/src/node/trilu.rs index d4ac189021..41a022f777 100644 --- a/crates/onnx-ir/src/node/trilu.rs +++ b/crates/onnx-ir/src/node/trilu.rs @@ -41,64 +41,26 @@ pub fn trilu_config(node: &Node) -> TriluConfig { #[cfg(test)] mod tests { use super::*; - use crate::ir::{ArgType, Argument, AttributeValue, ElementType, NodeType, TensorType}; - use std::collections::HashMap; + use crate::ir::NodeType; + use crate::node::test_utils::NodeBuilder; /// Helper function to create test nodes for Trilu tests fn create_test_node(upper_attr: Option, diagonal_input: Option) -> Node { - let mut inputs = vec![ - // First input: the tensor to process - Argument { - name: "X".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: 2, // Typically a matrix - static_shape: None, - }), - value: None, - passed: true, - }, - ]; + let mut builder = NodeBuilder::new(NodeType::Trilu, "test_trilu") + .input_tensor_f32("X", 2, None) // Typically a matrix + .output_tensor_f32("Y", 2, None); // Add diagonal input if provided if let Some(diag) = diagonal_input { - inputs.push(Argument { - name: "k".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Int64, - rank: 0, - static_shape: Some(vec![]), - }), - value: Some(TensorData { - shape: vec![], - data: Data::Int64(diag), - }), - passed: true, - }); + builder = builder.input_scalar_tensor_i64("k", diag); } - // Create attributes map - let mut attrs = HashMap::new(); + // Add upper attribute if provided if let Some(upper) = upper_attr { - attrs.insert("upper".to_string(), AttributeValue::Int64(upper)); + builder = builder.attr_int("upper", upper); } - Node { - node_type: NodeType::Trilu, - name: "test_trilu".to_string(), - inputs, - outputs: vec![Argument { - name: "Y".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: 2, - static_shape: None, - }), - value: None, - passed: true, - }], - attrs, - } + builder.build() } #[test] diff --git a/crates/onnx-ir/src/node/unsqueeze.rs b/crates/onnx-ir/src/node/unsqueeze.rs index 42e9810922..38f58a5700 100644 --- a/crates/onnx-ir/src/node/unsqueeze.rs +++ b/crates/onnx-ir/src/node/unsqueeze.rs @@ -131,7 +131,7 @@ mod tests { .input_tensor_f32("X", input_rank, None) .output_tensor_f32("Y", 0, None) // Will be updated .attr_ints("axes", axes); - + builder.build() } @@ -140,23 +140,15 @@ mod tests { let mut builder = NodeBuilder::new(NodeType::Unsqueeze, "test_unsqueeze") .input_tensor_f32("X", input_rank, None) .output_tensor_f32("Y", 0, None); // Will be updated - + // Add axes input with or without value if with_value { - builder = builder.input_tensor_i64_data( - "axes", - axes.clone(), - vec![axes_len] - ); + builder = builder.input_tensor_i64_data("axes", axes.clone(), vec![axes_len]); } else { // Input without value - builder = builder.input_tensor_i64( - "axes", - 1, - Some(vec![axes_len]) - ); + builder = builder.input_tensor_i64("axes", 1, Some(vec![axes_len])); } - + builder.build() } From 52ed266984459eaf05a082c71de136cb4eef1263 Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Thu, 8 May 2025 16:28:45 -0500 Subject: [PATCH 33/37] Refactored tests to use NodeBuilder --- crates/onnx-ir/src/node/flatten.rs | 70 +++++++++----------------- crates/onnx-ir/src/node/log_softmax.rs | 19 +++---- crates/onnx-ir/src/node/random.rs | 33 ++++-------- crates/onnx-ir/src/node/random_like.rs | 42 +++------------- crates/onnx-ir/src/node/softmax.rs | 20 ++++---- 5 files changed, 57 insertions(+), 127 deletions(-) diff --git a/crates/onnx-ir/src/node/flatten.rs b/crates/onnx-ir/src/node/flatten.rs index 99073c9b98..394f01dc17 100644 --- a/crates/onnx-ir/src/node/flatten.rs +++ b/crates/onnx-ir/src/node/flatten.rs @@ -66,40 +66,15 @@ pub fn flatten_config(curr: &Node) -> usize { #[cfg(test)] mod tests { use super::*; - use crate::ir::{Argument, AttributeValue, ElementType, NodeType, TensorType}; - use std::collections::HashMap; + use crate::ir::NodeType; + use crate::node::test_utils::NodeBuilder; fn create_test_node(axis: i64) -> Node { - let inputs = vec![Argument { - name: "data".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: 4, - static_shape: None, - }), - value: None, - passed: true, - }]; - - let mut attrs = HashMap::new(); - attrs.insert("axis".to_string(), AttributeValue::Int64(axis)); - - Node { - node_type: NodeType::Flatten, - name: "test_flatten".to_string(), - inputs, - outputs: vec![Argument { - name: "output".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: 2, - static_shape: None, - }), - value: None, - passed: true, - }], - attrs, - } + NodeBuilder::new(NodeType::Flatten, "test_flatten") + .input_tensor_f32("data", 4, None) + .output_tensor_f32("output", 2, None) + .attr_int("axis", axis) + .build() } #[test] @@ -120,11 +95,14 @@ mod tests { #[should_panic(expected = "Flatten: input tensor must have at least 2 dimensions")] fn test_flatten_config_with_low_rank() { let mut node = create_test_node(1); - node.inputs[0].ty = ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: 1, - static_shape: None, - }); + // Replace the input with one that has lower rank + let input = NodeBuilder::new(NodeType::Identity, "temp") + .input_tensor_f32("x", 1, None) + .build() + .inputs + .pop() + .unwrap(); + node.inputs[0] = input; let _ = flatten_config(&node); } @@ -132,16 +110,14 @@ mod tests { #[should_panic(expected = "Flatten: multiple inputs are not supported")] fn test_flatten_config_with_multiple_inputs() { let mut node = create_test_node(1); - node.inputs.push(Argument { - name: "extra".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: 1, - static_shape: None, - }), - value: None, - passed: true, - }); + // Add an extra input + let extra_input = NodeBuilder::new(NodeType::Identity, "temp") + .input_tensor_f32("extra", 1, None) + .build() + .inputs + .pop() + .unwrap(); + node.inputs.push(extra_input); let _ = flatten_config(&node); } } diff --git a/crates/onnx-ir/src/node/log_softmax.rs b/crates/onnx-ir/src/node/log_softmax.rs index 5a6ce66248..c06e4b74e1 100644 --- a/crates/onnx-ir/src/node/log_softmax.rs +++ b/crates/onnx-ir/src/node/log_softmax.rs @@ -66,17 +66,14 @@ mod tests { #[should_panic(expected = "LogSoftmax: multiple inputs are not supported")] fn test_log_softmax_config_multiple_inputs() { let mut node = create_test_node(1, 3); - // Add an extra input to cause the expected panic - node.inputs.push(crate::ir::Argument { - name: "extra".to_string(), - ty: crate::ir::ArgType::Tensor(crate::ir::TensorType { - elem_type: crate::ir::ElementType::Float32, - rank: 1, - static_shape: None, - }), - value: None, - passed: true, - }); + // Add an extra input + let extra_input = NodeBuilder::new(NodeType::Identity, "temp") + .input_tensor_f32("extra", 1, None) + .build() + .inputs + .pop() + .unwrap(); + node.inputs.push(extra_input); let _ = log_softmax_config(&node); } } diff --git a/crates/onnx-ir/src/node/random.rs b/crates/onnx-ir/src/node/random.rs index 32ada33208..c441cea478 100644 --- a/crates/onnx-ir/src/node/random.rs +++ b/crates/onnx-ir/src/node/random.rs @@ -40,32 +40,16 @@ pub fn random_update_output(node: &mut Node) { #[cfg(test)] mod tests { use super::*; - use crate::ir::{Argument, AttributeValue, NodeType}; - use std::collections::HashMap; + use crate::ir::NodeType; + use crate::node::test_utils::NodeBuilder; + use crate::protos::tensor_proto::DataType; fn create_test_node(dtype: i32, shape: Vec) -> Node { - let mut attrs = HashMap::new(); - attrs.insert("dtype".to_string(), AttributeValue::Int64(dtype as i64)); - attrs.insert("shape".to_string(), AttributeValue::Int64s(shape.clone())); - - let outputs = vec![Argument { - name: "output".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, // Will be updated - rank: 0, // Will be updated - static_shape: None, - }), - value: None, - passed: true, - }]; - - Node { - node_type: NodeType::RandomNormal, - name: "test_random".to_string(), - inputs: vec![], - outputs, - attrs, - } + NodeBuilder::new(NodeType::RandomNormal, "test_random") + .output_tensor_f32("output", 0, None) // Rank 0 will be updated + .attr_int("dtype", dtype as i64) + .attr_ints("shape", shape) + .build() } #[test] @@ -99,6 +83,7 @@ mod tests { #[test] #[should_panic(expected = "required shape attribute missing")] fn test_random_normal_missing_shape() { + // Create node and then manually remove the shape attribute let mut node = create_test_node(DataType::FLOAT.value(), vec![2, 3]); node.attrs.remove("shape"); random_update_output(&mut node); diff --git a/crates/onnx-ir/src/node/random_like.rs b/crates/onnx-ir/src/node/random_like.rs index c9ccb01acc..25a09b8d70 100644 --- a/crates/onnx-ir/src/node/random_like.rs +++ b/crates/onnx-ir/src/node/random_like.rs @@ -38,42 +38,16 @@ pub fn random_like_update_output(node: &mut Node) { #[cfg(test)] mod tests { use super::*; - use crate::ir::{Argument, AttributeValue, NodeType}; - use std::collections::HashMap; + use crate::ir::NodeType; + use crate::node::test_utils::NodeBuilder; + use crate::protos::tensor_proto::DataType; fn create_test_node(dtype: i32, input_rank: usize, static_shape: Option>) -> Node { - let mut attrs = HashMap::new(); - attrs.insert("dtype".to_string(), AttributeValue::Int64(dtype as i64)); - - let inputs = vec![Argument { - name: "input".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: input_rank, - static_shape, - }), - value: None, - passed: true, - }]; - - let outputs = vec![Argument { - name: "output".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, // Will be updated - rank: 0, // Will be updated - static_shape: None, - }), - value: None, - passed: true, - }]; - - Node { - node_type: NodeType::RandomNormalLike, - name: "test_random_like".to_string(), - inputs, - outputs, - attrs, - } + NodeBuilder::new(NodeType::RandomNormalLike, "test_random_like") + .input_tensor_f32("input", input_rank, static_shape) + .output_tensor_f32("output", 0, None) // Rank 0 will be updated + .attr_int("dtype", dtype as i64) + .build() } #[test] diff --git a/crates/onnx-ir/src/node/softmax.rs b/crates/onnx-ir/src/node/softmax.rs index bad8bcdec9..60f018a2ae 100644 --- a/crates/onnx-ir/src/node/softmax.rs +++ b/crates/onnx-ir/src/node/softmax.rs @@ -37,7 +37,7 @@ pub fn softmax_config(node: &Node) -> usize { #[cfg(test)] mod tests { use super::*; - use crate::ir::{Argument, ElementType, NodeType, TensorType}; + use crate::ir::NodeType; use crate::node::test_utils::NodeBuilder; fn create_test_node(axis: i64, input_rank: usize) -> Node { @@ -66,16 +66,14 @@ mod tests { #[should_panic(expected = "Softmax: multiple inputs are not supported")] fn test_softmax_config_multiple_inputs() { let mut node = create_test_node(1, 3); - node.inputs.push(Argument { - name: "extra".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: 1, - static_shape: None, - }), - value: None, - passed: true, - }); + // Add an extra input + let extra_input = NodeBuilder::new(NodeType::Identity, "temp") + .input_tensor_f32("extra", 1, None) + .build() + .inputs + .pop() + .unwrap(); + node.inputs.push(extra_input); let _ = softmax_config(&node); } } From eca4a8aa2a4e213f3789db2b6ea87212453717cf Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Thu, 8 May 2025 16:36:19 -0500 Subject: [PATCH 34/37] Refactored tests to use NodeBuilder --- crates/onnx-ir/src/node/tile.rs | 20 ++++++++---------- crates/onnx-ir/src/node/where_op.rs | 32 +++++++++++++++++++---------- 2 files changed, 30 insertions(+), 22 deletions(-) diff --git a/crates/onnx-ir/src/node/tile.rs b/crates/onnx-ir/src/node/tile.rs index 10ada204ef..9538fd635c 100644 --- a/crates/onnx-ir/src/node/tile.rs +++ b/crates/onnx-ir/src/node/tile.rs @@ -36,7 +36,7 @@ pub fn tile_config(node: &Node) -> TileConfig { #[cfg(test)] mod tests { use super::*; - use crate::ir::{ArgType, Argument, ElementType, NodeType, TensorType}; + use crate::ir::NodeType; use crate::node::test_utils::NodeBuilder; /// Helper function to create test nodes with different repeat values @@ -141,16 +141,14 @@ mod tests { let mut node = create_test_node(None, 3); // Add repeats input with no value - node.inputs.push(Argument { - name: "repeats".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Int64, - rank: 1, - static_shape: Some(vec![3]), - }), - value: None, // No value provided - passed: true, - }); + node.inputs.push( + NodeBuilder::new(NodeType::Identity, "temp") + .input_tensor_i64("repeats", 1, Some(vec![3])) + .build() + .inputs + .pop() + .unwrap(), + ); let config = tile_config(&node); diff --git a/crates/onnx-ir/src/node/where_op.rs b/crates/onnx-ir/src/node/where_op.rs index 0f480ab822..18a0bde986 100644 --- a/crates/onnx-ir/src/node/where_op.rs +++ b/crates/onnx-ir/src/node/where_op.rs @@ -51,7 +51,7 @@ pub fn where_update_outputs(node: &mut Node) { #[cfg(test)] mod tests { use super::*; - use crate::ir::{NodeType, TensorType}; + use crate::ir::NodeType; use crate::node::test_utils::NodeBuilder; fn create_test_node(condition_rank: usize, x_rank: usize, y_rank: usize) -> Node { @@ -94,11 +94,16 @@ mod tests { #[should_panic(expected = "Where condition must be boolean!")] fn test_where_invalid_condition() { let mut node = create_test_node(2, 2, 2); - node.inputs[0].ty = ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, // Not boolean - rank: 2, - static_shape: None, - }); + + // Replace condition with non-boolean tensor + let non_bool_input = NodeBuilder::new(NodeType::Identity, "temp") + .input_tensor_f32("x", 2, None) + .build() + .inputs + .pop() + .unwrap(); + + node.inputs[0] = non_bool_input; where_update_outputs(&mut node); } @@ -106,11 +111,16 @@ mod tests { #[should_panic(expected = "Where x and y have different element types!")] fn test_where_mismatched_types() { let mut node = create_test_node(2, 2, 2); - node.inputs[2].ty = ArgType::Tensor(TensorType { - elem_type: ElementType::Int64, // Different from X - rank: 2, - static_shape: None, - }); + + // Replace Y with int64 tensor (different from X's float32) + let int64_input = NodeBuilder::new(NodeType::Identity, "temp") + .input_tensor_i64("y", 2, None) + .build() + .inputs + .pop() + .unwrap(); + + node.inputs[2] = int64_input; where_update_outputs(&mut node); } } From 689ba268f3145df00ebfce883f6c4f447a9252bb Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Thu, 8 May 2025 16:54:55 -0500 Subject: [PATCH 35/37] Refactored tests to use NodeBuilder --- crates/onnx-ir/src/node/avg_pool1d.rs | 4 +- crates/onnx-ir/src/node/avg_pool2d.rs | 4 +- crates/onnx-ir/src/node/conv1d.rs | 90 ++------ crates/onnx-ir/src/node/conv2d.rs | 2 +- crates/onnx-ir/src/node/conv3d.rs | 87 ++----- crates/onnx-ir/src/node/conv_transpose1d.rs | 243 ++++++++------------ crates/onnx-ir/src/node/conv_transpose2d.rs | 147 ++++-------- crates/onnx-ir/src/node/conv_transpose3d.rs | 141 ++++-------- crates/onnx-ir/src/node/gemm.rs | 34 ++- crates/onnx-ir/src/node/layer_norm.rs | 2 +- crates/onnx-ir/src/node/max_pool1d.rs | 4 +- crates/onnx-ir/src/node/max_pool2d.rs | 4 +- crates/onnx-ir/src/node/reshape.rs | 6 +- 13 files changed, 262 insertions(+), 506 deletions(-) diff --git a/crates/onnx-ir/src/node/avg_pool1d.rs b/crates/onnx-ir/src/node/avg_pool1d.rs index 99f58e49bb..73d6d70291 100644 --- a/crates/onnx-ir/src/node/avg_pool1d.rs +++ b/crates/onnx-ir/src/node/avg_pool1d.rs @@ -47,7 +47,9 @@ pub fn avg_pool1d_config(curr: &Node) -> AvgPool1dConfig { "pads" => pads = value.clone().into_i64s(), "count_include_pad" => count_include_pad = value.clone().into_i64(), "ceil_mode" => ceil_mode = value.clone().into_i64(), - _ => {} + // These are attributes that are allowed but not used in this implementation + "auto_pad" | "storage_order" => {} + _ => panic!("Unexpected attribute for AvgPool1d: {key}"), } } diff --git a/crates/onnx-ir/src/node/avg_pool2d.rs b/crates/onnx-ir/src/node/avg_pool2d.rs index 58644ca237..bc6e5a0f73 100644 --- a/crates/onnx-ir/src/node/avg_pool2d.rs +++ b/crates/onnx-ir/src/node/avg_pool2d.rs @@ -46,7 +46,9 @@ pub fn avg_pool2d_config(curr: &Node) -> AvgPool2dConfig { "pads" => pads = value.clone().into_i64s(), "count_include_pad" => count_include_pad = value.clone().into_i64(), "ceil_mode" => ceil_mode = value.clone().into_i64(), - _ => {} + // These are attributes that are allowed but not used in this implementation + "auto_pad" | "storage_order" => {} + _ => panic!("Unexpected attribute for AvgPool2d: {key}"), } } diff --git a/crates/onnx-ir/src/node/conv1d.rs b/crates/onnx-ir/src/node/conv1d.rs index cc8420faf6..3ab2dcf1ef 100644 --- a/crates/onnx-ir/src/node/conv1d.rs +++ b/crates/onnx-ir/src/node/conv1d.rs @@ -99,10 +99,8 @@ pub fn conv1d_config(curr: &Node) -> Conv1dConfig { #[cfg(test)] mod tests { use super::*; - use crate::ir::{ - ArgType, Argument, AttributeValue, Data, ElementType, NodeType, TensorData, TensorType, - }; - use std::collections::HashMap; + use crate::ir::NodeType; + use crate::node::test_utils::NodeBuilder; fn create_test_node( kernel_shape: Vec, @@ -112,75 +110,33 @@ mod tests { group: i64, has_bias: bool, ) -> Node { - let mut inputs = vec![Argument { - name: "data".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: 3, - static_shape: None, - }), - value: None, - passed: true, - }]; - - // Add weight tensor - inputs.push(Argument { - name: "weight".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: 3, - static_shape: None, - }), - value: Some(TensorData { - data: Data::Float32s(vec![0.1; 16]), - shape: vec![2, 2, 4], // [out_channels, in_channels, kernel_size] - }), - passed: true, - }); + // Create weight tensor data + let weight_data = vec![0.1; 16]; + + // Start building the node with input and weight + let mut builder = NodeBuilder::new(NodeType::Conv1d, "test_conv1d") + .input_tensor_f32("data", 3, None) + .input_tensor_f32_data( + "weight", + weight_data, + vec![2, 2, 4], // [out_channels, in_channels, kernel_size] + ) + .output_tensor_f32("output", 3, None); // Add bias if needed if has_bias { - inputs.push(Argument { - name: "bias".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: 1, - static_shape: None, - }), - value: Some(TensorData { - data: Data::Float32s(vec![0.1, 0.2]), - shape: vec![2], - }), - passed: true, - }); + builder = builder.input_tensor_f32_data("bias", vec![0.1, 0.2], vec![2]); } - let mut attrs = HashMap::new(); - attrs.insert( - "kernel_shape".to_string(), - AttributeValue::Int64s(kernel_shape), - ); - attrs.insert("strides".to_string(), AttributeValue::Int64s(strides)); - attrs.insert("pads".to_string(), AttributeValue::Int64s(pads)); - attrs.insert("dilations".to_string(), AttributeValue::Int64s(dilations)); - attrs.insert("group".to_string(), AttributeValue::Int64(group)); + // Add attributes + builder = builder + .attr_ints("kernel_shape", kernel_shape) + .attr_ints("strides", strides) + .attr_ints("pads", pads) + .attr_ints("dilations", dilations) + .attr_int("group", group); - Node { - node_type: NodeType::Conv1d, - name: "test_conv1d".to_string(), - inputs, - outputs: vec![Argument { - name: "output".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: 3, - static_shape: None, - }), - value: None, - passed: true, - }], - attrs, - } + builder.build() } #[test] diff --git a/crates/onnx-ir/src/node/conv2d.rs b/crates/onnx-ir/src/node/conv2d.rs index 3048bdfefe..6715069e26 100644 --- a/crates/onnx-ir/src/node/conv2d.rs +++ b/crates/onnx-ir/src/node/conv2d.rs @@ -68,7 +68,7 @@ pub fn conv2d_config(curr: &Node) -> Conv2dConfig { "pads" => pads = value.clone().into_i64s(), "dilations" => dilations = value.clone().into_i64s(), "group" => group = value.clone().into_i64() as usize, - _ => {} + _ => panic!("Unexpected attribute for Conv2d: {key}"), } } diff --git a/crates/onnx-ir/src/node/conv3d.rs b/crates/onnx-ir/src/node/conv3d.rs index 97f71951f2..043ec6a5d6 100644 --- a/crates/onnx-ir/src/node/conv3d.rs +++ b/crates/onnx-ir/src/node/conv3d.rs @@ -68,7 +68,7 @@ pub fn conv3d_config(curr: &Node) -> Conv3dConfig { "pads" => pads = value.clone().into_i64s(), "dilations" => dilations = value.clone().into_i64s(), "group" => group = value.clone().into_i64() as usize, - _ => {} + _ => panic!("Unexpected attribute for Conv3d: {key}"), } } @@ -104,10 +104,8 @@ pub fn conv3d_config(curr: &Node) -> Conv3dConfig { #[cfg(test)] mod tests { use super::*; - use crate::ir::{ - ArgType, Argument, AttributeValue, Data, ElementType, NodeType, TensorData, TensorType, - }; - use std::collections::HashMap; + use crate::ir::NodeType; + use crate::node::test_utils::NodeBuilder; fn create_test_node( kernel_shape: Vec, @@ -117,73 +115,30 @@ mod tests { group: i64, has_bias: bool, ) -> Node { - let weight_tensor = TensorData { - data: Data::Float32s(vec![0.0; 32]), // Not important for the test - shape: vec![4, 2, 2, 2, 2], // [output_channels, input_channels/groups, k_d, k_h, k_w] - }; + // Create weight tensor data (not important for the test) + let weight_data = vec![0.0; 32]; + let weight_shape = vec![4, 2, 2, 2, 2]; // [output_channels, input_channels/groups, k_d, k_h, k_w] - let mut inputs = vec![ - Argument { - name: "data".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: 5, - static_shape: None, - }), - value: None, - passed: true, - }, - Argument { - name: "weight".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: 5, - static_shape: None, - }), - value: Some(weight_tensor), - passed: true, - }, - ]; + // Start building the node with input and weight + let mut builder = NodeBuilder::new(NodeType::Conv3d, "test_conv3d") + .input_tensor_f32("data", 5, None) + .input_tensor_f32_data("weight", weight_data, weight_shape) + .output_tensor_f32("output", 5, None); + // Add bias if needed if has_bias { - inputs.push(Argument { - name: "bias".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: 1, - static_shape: None, - }), - value: None, - passed: true, - }); + builder = builder.input_tensor_f32("bias", 1, None); } - let mut attrs = HashMap::new(); - attrs.insert( - "kernel_shape".to_string(), - AttributeValue::Int64s(kernel_shape), - ); - attrs.insert("strides".to_string(), AttributeValue::Int64s(strides)); - attrs.insert("pads".to_string(), AttributeValue::Int64s(pads)); - attrs.insert("dilations".to_string(), AttributeValue::Int64s(dilations)); - attrs.insert("group".to_string(), AttributeValue::Int64(group)); + // Add attributes + builder = builder + .attr_ints("kernel_shape", kernel_shape) + .attr_ints("strides", strides) + .attr_ints("pads", pads) + .attr_ints("dilations", dilations) + .attr_int("group", group); - Node { - node_type: NodeType::Conv3d, - name: "test_conv3d".to_string(), - inputs, - outputs: vec![Argument { - name: "output".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: 5, - static_shape: None, - }), - value: None, - passed: true, - }], - attrs, - } + builder.build() } #[test] diff --git a/crates/onnx-ir/src/node/conv_transpose1d.rs b/crates/onnx-ir/src/node/conv_transpose1d.rs index e82b3d8132..961de340cc 100644 --- a/crates/onnx-ir/src/node/conv_transpose1d.rs +++ b/crates/onnx-ir/src/node/conv_transpose1d.rs @@ -1,4 +1,4 @@ -use crate::ir::{AttributeValue, Node}; +use crate::ir::Node; /// Configuration for ConvTranspose1d operations extracted from ONNX nodes #[derive(Debug, Clone)] @@ -24,99 +24,95 @@ pub struct ConvTranspose1dConfig { } impl ConvTranspose1dConfig { - /// Create a new ConvTranspose1dConfig from the attributes of the node - pub fn new(curr: &Node) -> Self { - let mut attrs = curr.attrs.clone(); - - // Extract kernel_shape, default to an empty vector if not present - let kernel_shape = attrs - .remove("kernel_shape") - .map(AttributeValue::into_i64s) - .unwrap_or_default(); - - // Extract strides, default to 1 if not present - let stride = attrs - .remove("strides") - .map(AttributeValue::into_i64s) - .unwrap_or_else(|| vec![1]); - - // Extract padding, default to 0 if not present - let pads = attrs - .remove("pads") - .map(AttributeValue::into_i64s) - .unwrap_or_else(|| vec![0, 0]); - - // Extract dilations, default to 1 if not present - let dilations = attrs - .remove("dilations") - .map(AttributeValue::into_i64s) - .unwrap_or_else(|| vec![1]); - - // Extract group attribute, default to 1 - let group = attrs - .remove("group") - .map(AttributeValue::into_i64) - .unwrap_or(1) as usize; - - // Extract output_padding, default to 0 if not present - let output_padding = attrs - .remove("output_padding") - .map(AttributeValue::into_i64s) - .unwrap_or_else(|| vec![0]); - - // Ensure no unused attributes remain - if !attrs.is_empty() { - panic!("Not all attributes are used: {attrs:?}"); - } - - // Check the pads are symmetric - if pads.len() != 2 || pads[0] != pads[1] { - panic!( - "Asymmetric padding is not supported for ConvTranspose1d: {:?}", - pads - ); - } - - let weight_shape = curr.inputs[1] - .value - .as_ref() - .expect("ConvTranspose1d: weight tensor must be present") - .shape - .clone(); - - // Check if bias is present (third input) - let bias = curr.inputs.len() == 3; - - // Extract channels from the weight tensor shape [out_channels, in_channels] - let channels_in = weight_shape[1] * group; - let channels_out = weight_shape[0]; - + /// Create a new ConvTranspose1dConfig + #[allow(clippy::too_many_arguments)] + pub fn new( + channels_in: usize, + channels_out: usize, + kernel_size: usize, + stride: usize, + padding: usize, + dilation: usize, + groups: usize, + bias: bool, + padding_out: usize, + ) -> Self { Self { channels_in, channels_out, - kernel_size: kernel_shape[0] as usize, - stride: stride[0] as usize, - padding: pads[0] as usize, - dilation: dilations[0] as usize, - padding_out: output_padding[0] as usize, - groups: group, + kernel_size, + stride, + padding, + dilation, + groups, bias, + padding_out, } } } /// Create a ConvTranspose1dConfig from the attributes of the node pub fn conv_transpose1d_config(curr: &Node) -> ConvTranspose1dConfig { - ConvTranspose1dConfig::new(curr) + let mut kernel_shape = Vec::new(); // Default to empty vector + let mut stride = vec![1]; // Default stride to 1 + let mut pads = vec![0, 0]; // Default padding to 0 + let mut dilations = vec![1]; // Default dilation to 1 + let mut group: usize = 1; // Default group to 1 + let mut output_padding = vec![0]; // Default output padding to 0 + + // Extract attributes + for (key, value) in curr.attrs.iter() { + match key.as_str() { + "kernel_shape" => kernel_shape = value.clone().into_i64s(), + "strides" => stride = value.clone().into_i64s(), + "pads" => pads = value.clone().into_i64s(), + "dilations" => dilations = value.clone().into_i64s(), + "group" => group = value.clone().into_i64() as usize, + "output_padding" => output_padding = value.clone().into_i64s(), + _ => panic!("Unexpected attribute for ConvTranspose1d: {key}"), + } + } + + // Check the pads are symmetric + if pads.len() != 2 || pads[0] != pads[1] { + panic!( + "Asymmetric padding is not supported for ConvTranspose1d: {:?}", + pads + ); + } + + let weight_shape = curr.inputs[1] + .value + .as_ref() + .expect("ConvTranspose1d: weight tensor must be present") + .shape + .clone(); + + // Check if bias is present (third input) + let bias = curr.inputs.len() == 3; + + // Extract channels from the weight tensor shape [out_channels, in_channels] + let channels_in = weight_shape[1] * group; + let channels_out = weight_shape[0]; + + ConvTranspose1dConfig { + channels_in, + channels_out, + kernel_size: kernel_shape[0] as usize, + stride: stride[0] as usize, + padding: pads[0] as usize, + dilation: dilations[0] as usize, + padding_out: output_padding[0] as usize, + groups: group, + bias, + } } #[cfg(test)] mod tests { use super::*; - use crate::ir::{ - ArgType, Argument, AttributeValue, Data, ElementType, NodeType, TensorData, TensorType, - }; - use std::collections::HashMap; + use crate::ir::NodeType; + use crate::node::test_utils::NodeBuilder; fn create_test_node( kernel_shape: Vec, @@ -127,79 +123,34 @@ mod tests { output_padding: Vec, has_bias: bool, ) -> Node { - let mut inputs = vec![Argument { - name: "data".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: 3, - static_shape: None, - }), - value: None, - passed: true, - }]; - - // Add weight tensor - inputs.push(Argument { - name: "weight".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: 3, - static_shape: None, - }), - value: Some(TensorData { - data: Data::Float32s(vec![0.1; 16]), - shape: vec![2, 2, 4], // [out_channels, in_channels, kernel_size] - }), - passed: true, - }); + // Create weight tensor data + let weight_data = vec![0.1; 16]; + + // Start building the node with input and weight + let mut builder = NodeBuilder::new(NodeType::ConvTranspose1d, "test_conv_transpose1d") + .input_tensor_f32("data", 3, None) + .input_tensor_f32_data( + "weight", + weight_data, + vec![2, 2, 4], // [out_channels, in_channels, kernel_size] + ) + .output_tensor_f32("output", 3, None); // Add bias if needed if has_bias { - inputs.push(Argument { - name: "bias".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: 1, - static_shape: None, - }), - value: Some(TensorData { - data: Data::Float32s(vec![0.1, 0.2]), - shape: vec![2], - }), - passed: true, - }); + builder = builder.input_tensor_f32_data("bias", vec![0.1, 0.2], vec![2]); } - let mut attrs = HashMap::new(); - attrs.insert( - "kernel_shape".to_string(), - AttributeValue::Int64s(kernel_shape), - ); - attrs.insert("strides".to_string(), AttributeValue::Int64s(stride)); - attrs.insert("pads".to_string(), AttributeValue::Int64s(pads)); - attrs.insert("dilations".to_string(), AttributeValue::Int64s(dilations)); - attrs.insert("group".to_string(), AttributeValue::Int64(group)); - attrs.insert( - "output_padding".to_string(), - AttributeValue::Int64s(output_padding), - ); + // Add attributes + builder = builder + .attr_ints("kernel_shape", kernel_shape) + .attr_ints("strides", stride) + .attr_ints("pads", pads) + .attr_ints("dilations", dilations) + .attr_int("group", group) + .attr_ints("output_padding", output_padding); - Node { - node_type: NodeType::ConvTranspose1d, - name: "test_conv_transpose1d".to_string(), - inputs, - outputs: vec![Argument { - name: "output".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: 3, - static_shape: None, - }), - value: None, - passed: true, - }], - attrs, - } + builder.build() } #[test] diff --git a/crates/onnx-ir/src/node/conv_transpose2d.rs b/crates/onnx-ir/src/node/conv_transpose2d.rs index 568a459b06..f813ee9136 100644 --- a/crates/onnx-ir/src/node/conv_transpose2d.rs +++ b/crates/onnx-ir/src/node/conv_transpose2d.rs @@ -1,4 +1,4 @@ -use crate::ir::{AttributeValue, Node}; +use crate::ir::Node; /// Configuration for ConvTranspose2d operations. #[derive(Debug, Clone, PartialEq, Eq)] @@ -49,36 +49,26 @@ impl ConvTranspose2dConfig { /// Create a ConvTranspose2dConfig from the attributes of the node pub fn conv_transpose2d_config(curr: &Node) -> ConvTranspose2dConfig { - let mut attrs = curr.attrs.clone(); - let kernel_shape = attrs - .remove("kernel_shape") - .map(AttributeValue::into_i64s) - .unwrap_or_default(); - let stride = attrs - .remove("strides") - .map(AttributeValue::into_i64s) - .unwrap_or_else(|| vec![1, 1]); - let pads = attrs - .remove("pads") - .map(AttributeValue::into_i64s) - .unwrap_or_else(|| vec![0, 0, 0, 0]); - let dilations = attrs - .remove("dilations") - .map(AttributeValue::into_i64s) - .unwrap_or_else(|| vec![1, 1]); - let group = attrs - .remove("group") - .map(AttributeValue::into_i64) - .unwrap_or(1) as usize; - let output_padding = attrs - .remove("output_padding") - .map(AttributeValue::into_i64s) - .unwrap_or_else(|| vec![0, 0]); - - // Trick with remove + empty check is simplest way to not forget some attribute for runtime: - if !attrs.is_empty() { - panic!("Not all attributes are used: {attrs:?}"); + let mut kernel_shape = Vec::new(); // Default to empty vector + let mut stride = vec![1, 1]; // Default stride to 1 + let mut pads = vec![0, 0, 0, 0]; // Default padding to 0 + let mut dilations = vec![1, 1]; // Default dilation to 1 + let mut group: usize = 1; // Default group to 1 + let mut output_padding = vec![0, 0]; // Default output padding to 0 + + // Extract attributes + for (key, value) in curr.attrs.iter() { + match key.as_str() { + "kernel_shape" => kernel_shape = value.clone().into_i64s(), + "strides" => stride = value.clone().into_i64s(), + "pads" => pads = value.clone().into_i64s(), + "dilations" => dilations = value.clone().into_i64s(), + "group" => group = value.clone().into_i64() as usize, + "output_padding" => output_padding = value.clone().into_i64s(), + _ => panic!("Unexpected attribute for ConvTranspose2d: {key}"), + } } + // Check the pads are symmetric. let [left, top, right, bottom] = [pads[0], pads[1], pads[2], pads[3]]; if left < 0 || top < 0 || right < 0 || bottom < 0 { @@ -115,10 +105,8 @@ pub fn conv_transpose2d_config(curr: &Node) -> ConvTranspose2dConfig { #[cfg(test)] mod tests { use super::*; - use crate::ir::{ - ArgType, Argument, AttributeValue, Data, ElementType, NodeType, TensorData, TensorType, - }; - use std::collections::HashMap; + use crate::ir::NodeType; + use crate::node::test_utils::NodeBuilder; fn create_test_node( kernel_shape: Vec, @@ -129,77 +117,34 @@ mod tests { group: i64, has_bias: bool, ) -> Node { - let weight_tensor = TensorData { - data: Data::Float32s(vec![0.0; 16]), // Not important for the test - shape: vec![2, 4, 2, 2], // [input_channels, output_channels/groups, k_h, k_w] - }; - - let mut inputs = vec![ - Argument { - name: "data".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: 4, - static_shape: None, - }), - value: None, - passed: true, - }, - Argument { - name: "weight".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: 4, - static_shape: None, - }), - value: Some(weight_tensor), - passed: true, - }, - ]; - + // Create weight tensor data + let weight_data = vec![0.0; 16]; // Not important for the test + + // Start building the node with input and weight + let mut builder = NodeBuilder::new(NodeType::ConvTranspose2d, "test_convtranspose2d") + .input_tensor_f32("data", 4, None) + .input_tensor_f32_data( + "weight", + weight_data, + vec![2, 4, 2, 2], // [out_channels, in_channels, k_h, k_w] + ) + .output_tensor_f32("output", 4, None); + + // Add bias if needed if has_bias { - inputs.push(Argument { - name: "bias".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: 1, - static_shape: None, - }), - value: None, - passed: true, - }); + builder = builder.input_tensor_f32("bias", 1, None); } - let mut attrs = HashMap::new(); - attrs.insert( - "kernel_shape".to_string(), - AttributeValue::Int64s(kernel_shape), - ); - attrs.insert("strides".to_string(), AttributeValue::Int64s(strides)); - attrs.insert("pads".to_string(), AttributeValue::Int64s(pads)); - attrs.insert("dilations".to_string(), AttributeValue::Int64s(dilations)); - attrs.insert( - "output_padding".to_string(), - AttributeValue::Int64s(output_padding), - ); - attrs.insert("group".to_string(), AttributeValue::Int64(group)); + // Add attributes + builder = builder + .attr_ints("kernel_shape", kernel_shape) + .attr_ints("strides", strides) + .attr_ints("pads", pads) + .attr_ints("dilations", dilations) + .attr_ints("output_padding", output_padding) + .attr_int("group", group); - Node { - node_type: NodeType::ConvTranspose2d, - name: "test_convtranspose2d".to_string(), - inputs, - outputs: vec![Argument { - name: "output".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: 4, - static_shape: None, - }), - value: None, - passed: true, - }], - attrs, - } + builder.build() } #[test] diff --git a/crates/onnx-ir/src/node/conv_transpose3d.rs b/crates/onnx-ir/src/node/conv_transpose3d.rs index 0bac8cd1fa..288800776e 100644 --- a/crates/onnx-ir/src/node/conv_transpose3d.rs +++ b/crates/onnx-ir/src/node/conv_transpose3d.rs @@ -1,4 +1,4 @@ -use crate::ir::{AttributeValue, Node}; +use crate::ir::Node; /// Configuration for ConvTranspose3d operations. #[derive(Debug, Clone, PartialEq, Eq)] @@ -49,36 +49,26 @@ impl ConvTranspose3dConfig { /// Create a ConvTranspose3dConfig from the attributes of the node pub fn conv_transpose3d_config(curr: &Node) -> ConvTranspose3dConfig { - let mut attrs = curr.attrs.clone(); - let kernel_shape = attrs - .remove("kernel_shape") - .map(AttributeValue::into_i64s) - .unwrap_or_default(); - let stride = attrs - .remove("strides") - .map(AttributeValue::into_i64s) - .unwrap_or_else(|| vec![1, 1, 1]); - let pads = attrs - .remove("pads") - .map(AttributeValue::into_i64s) - .unwrap_or_else(|| vec![0, 0, 0, 0, 0, 0]); - let dilations = attrs - .remove("dilations") - .map(AttributeValue::into_i64s) - .unwrap_or_else(|| vec![1, 1, 1]); - let group = attrs - .remove("group") - .map(AttributeValue::into_i64) - .unwrap_or(1) as usize; - let output_padding = attrs - .remove("output_padding") - .map(AttributeValue::into_i64s) - .unwrap_or_else(|| vec![0, 0, 0]); + let mut kernel_shape = Vec::new(); // Default to empty vector + let mut stride = vec![1, 1, 1]; // Default stride to 1 + let mut pads = vec![0, 0, 0, 0, 0, 0]; // Default padding to 0 + let mut dilations = vec![1, 1, 1]; // Default dilation to 1 + let mut group: usize = 1; // Default group to 1 + let mut output_padding = vec![0, 0, 0]; // Default output padding to 0 - // Trick with remove + empty check is simplest way to not forget some attribute for runtime: - if !attrs.is_empty() { - panic!("Not all attributes are used: {attrs:?}"); + // Extract attributes + for (key, value) in curr.attrs.iter() { + match key.as_str() { + "kernel_shape" => kernel_shape = value.clone().into_i64s(), + "strides" => stride = value.clone().into_i64s(), + "pads" => pads = value.clone().into_i64s(), + "dilations" => dilations = value.clone().into_i64s(), + "group" => group = value.clone().into_i64() as usize, + "output_padding" => output_padding = value.clone().into_i64s(), + _ => panic!("Unexpected attribute for ConvTranspose3d: {key}"), + } } + // Check the pads are symmetric. let [left, top, front, right, bottom, back] = [pads[0], pads[1], pads[2], pads[3], pads[4], pads[5]]; @@ -129,10 +119,8 @@ pub fn conv_transpose3d_config(curr: &Node) -> ConvTranspose3dConfig { #[cfg(test)] mod tests { use super::*; - use crate::ir::{ - ArgType, Argument, AttributeValue, Data, ElementType, NodeType, TensorData, TensorType, - }; - use std::collections::HashMap; + use crate::ir::NodeType; + use crate::node::test_utils::NodeBuilder; fn create_test_node( kernel_shape: Vec, @@ -143,77 +131,34 @@ mod tests { group: i64, has_bias: bool, ) -> Node { - let weight_tensor = TensorData { - data: Data::Float32s(vec![0.0; 32]), // Not important for the test - shape: vec![2, 4, 2, 2, 2], // [input_channels, output_channels/groups, k_d, k_h, k_w] - }; + // Create weight tensor data + let weight_data = vec![0.0; 32]; // Not important for the test - let mut inputs = vec![ - Argument { - name: "data".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: 5, - static_shape: None, - }), - value: None, - passed: true, - }, - Argument { - name: "weight".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: 5, - static_shape: None, - }), - value: Some(weight_tensor), - passed: true, - }, - ]; + // Start building the node with input and weight + let mut builder = NodeBuilder::new(NodeType::ConvTranspose3d, "test_convtranspose3d") + .input_tensor_f32("data", 5, None) + .input_tensor_f32_data( + "weight", + weight_data, + vec![2, 4, 2, 2, 2], // [out_channels, in_channels, k_d, k_h, k_w] + ) + .output_tensor_f32("output", 5, None); + // Add bias if needed if has_bias { - inputs.push(Argument { - name: "bias".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: 1, - static_shape: None, - }), - value: None, - passed: true, - }); + builder = builder.input_tensor_f32("bias", 1, None); } - let mut attrs = HashMap::new(); - attrs.insert( - "kernel_shape".to_string(), - AttributeValue::Int64s(kernel_shape), - ); - attrs.insert("strides".to_string(), AttributeValue::Int64s(strides)); - attrs.insert("pads".to_string(), AttributeValue::Int64s(pads)); - attrs.insert("dilations".to_string(), AttributeValue::Int64s(dilations)); - attrs.insert( - "output_padding".to_string(), - AttributeValue::Int64s(output_padding), - ); - attrs.insert("group".to_string(), AttributeValue::Int64(group)); + // Add attributes + builder = builder + .attr_ints("kernel_shape", kernel_shape) + .attr_ints("strides", strides) + .attr_ints("pads", pads) + .attr_ints("dilations", dilations) + .attr_ints("output_padding", output_padding) + .attr_int("group", group); - Node { - node_type: NodeType::ConvTranspose3d, - name: "test_convtranspose3d".to_string(), - inputs, - outputs: vec![Argument { - name: "output".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: 5, - static_shape: None, - }), - value: None, - passed: true, - }], - attrs, - } + builder.build() } #[test] diff --git a/crates/onnx-ir/src/node/gemm.rs b/crates/onnx-ir/src/node/gemm.rs index 803da4970d..b81d5ce436 100644 --- a/crates/onnx-ir/src/node/gemm.rs +++ b/crates/onnx-ir/src/node/gemm.rs @@ -35,26 +35,20 @@ pub fn gemm_output_shape(node: &mut Node) { } pub fn gemm_config(curr: &Node) -> (f32, f32, i64, i64) { - let alpha = curr - .attrs - .get("alpha") - .map(|val| val.clone().into_f32()) - .unwrap_or(1.0); - let beta = curr - .attrs - .get("beta") - .map(|val| val.clone().into_f32()) - .unwrap_or(1.0); - let trans_a = curr - .attrs - .get("transA") - .map(|val| val.clone().into_i64()) - .unwrap_or(0); - let trans_b = curr - .attrs - .get("transB") - .map(|val| val.clone().into_i64()) - .unwrap_or(0); + let mut alpha: f32 = 1.0; + let mut beta: f32 = 1.0; + let mut trans_a: i64 = 0; + let mut trans_b: i64 = 0; + + for (key, value) in curr.attrs.iter() { + match key.as_str() { + "alpha" => alpha = value.clone().into_f32(), + "beta" => beta = value.clone().into_f32(), + "transA" => trans_a = value.clone().into_i64(), + "transB" => trans_b = value.clone().into_i64(), + _ => panic!("Unexpected attribute for Gemm: {key}"), + } + } (alpha, beta, trans_a, trans_b) } diff --git a/crates/onnx-ir/src/node/layer_norm.rs b/crates/onnx-ir/src/node/layer_norm.rs index 8158adcefc..e97e0d085a 100644 --- a/crates/onnx-ir/src/node/layer_norm.rs +++ b/crates/onnx-ir/src/node/layer_norm.rs @@ -47,7 +47,7 @@ pub fn layer_norm_config(node: &Node) -> (LayerNormConfig, bool) { "axis" => axis = value.clone().into_i64(), "epsilon" => epsilon = value.clone().into_f32(), "stash_type" => stash_type = value.clone().into_i64(), - _ => {} + _ => panic!("Unexpected attribute for LayerNorm: {key}"), } } diff --git a/crates/onnx-ir/src/node/max_pool1d.rs b/crates/onnx-ir/src/node/max_pool1d.rs index 6446c6794f..6112e7fba1 100644 --- a/crates/onnx-ir/src/node/max_pool1d.rs +++ b/crates/onnx-ir/src/node/max_pool1d.rs @@ -58,7 +58,9 @@ pub fn max_pool1d_config(curr: &Node) -> MaxPool1dConfig { "strides" => stride = value.clone().into_i64s(), "pads" => pads = value.clone().into_i64s(), "dilations" => dilation = value.clone().into_i64s(), - _ => {} + // These are attributes that are allowed but not used in this implementation + "auto_pad" | "ceil_mode" | "storage_order" => {} + _ => panic!("Unexpected attribute for MaxPool1d: {key}"), } } diff --git a/crates/onnx-ir/src/node/max_pool2d.rs b/crates/onnx-ir/src/node/max_pool2d.rs index f77cae0c83..9883f86b6a 100644 --- a/crates/onnx-ir/src/node/max_pool2d.rs +++ b/crates/onnx-ir/src/node/max_pool2d.rs @@ -57,7 +57,9 @@ pub fn max_pool2d_config(curr: &Node) -> MaxPool2dConfig { "strides" => strides = value.clone().into_i64s(), "pads" => pads = value.clone().into_i64s(), "dilations" => dilations = value.clone().into_i64s(), - _ => {} + // These are attributes that are allowed but not used in this implementation + "auto_pad" | "ceil_mode" | "storage_order" => {} + _ => panic!("Unexpected attribute for MaxPool2d: {key}"), } } diff --git a/crates/onnx-ir/src/node/reshape.rs b/crates/onnx-ir/src/node/reshape.rs index 9d64e5baf1..b48b4176b4 100644 --- a/crates/onnx-ir/src/node/reshape.rs +++ b/crates/onnx-ir/src/node/reshape.rs @@ -54,8 +54,10 @@ pub fn reshape_config(node: &Node) -> Vec { let mut allowzero = 0; for (key, value) in node.attrs.iter() { - if key.as_str() == "allowzero" { - allowzero = value.clone().into_i64() + match key.as_str() { + "allowzero" => allowzero = value.clone().into_i64(), + "shape" => {} // This can be used when shape is not provided as input - handled elsewhere + _ => panic!("Unexpected attribute for Reshape: {key}"), } } From a89840419ab54912b64470b2bf60f82f9b491a22 Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Thu, 8 May 2025 17:28:09 -0500 Subject: [PATCH 36/37] Refactored tests to use NodeBuilder --- crates/onnx-ir/src/node/dropout.rs | 94 +++++------------------------- crates/onnx-ir/src/node/expand.rs | 71 +++++----------------- crates/onnx-ir/src/node/gather.rs | 63 +++++--------------- crates/onnx-ir/src/node/linear.rs | 76 +++++------------------- crates/onnx-ir/src/node/range.rs | 49 +++------------- 5 files changed, 66 insertions(+), 287 deletions(-) diff --git a/crates/onnx-ir/src/node/dropout.rs b/crates/onnx-ir/src/node/dropout.rs index ef5820c0c1..05b3edb2d2 100644 --- a/crates/onnx-ir/src/node/dropout.rs +++ b/crates/onnx-ir/src/node/dropout.rs @@ -46,91 +46,23 @@ pub fn dropout_config(node: &Node) -> DropoutConfig { #[cfg(test)] mod tests { use super::*; - use crate::ir::{ - ArgType, Argument, AttributeValue, Data, ElementType, NodeType, TensorData, TensorType, - }; - use std::collections::HashMap; + use crate::ir::NodeType; + use crate::node::test_utils::NodeBuilder; fn create_test_node_with_attr(ratio: f32) -> Node { - let inputs = vec![Argument { - name: "data".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: 3, - static_shape: None, - }), - value: None, - passed: true, - }]; - - let mut attrs = HashMap::new(); - attrs.insert("ratio".to_string(), AttributeValue::Float32(ratio)); - - Node { - node_type: NodeType::Dropout, - name: "test_dropout".to_string(), - inputs, - outputs: vec![Argument { - name: "output".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: 3, - static_shape: None, - }), - value: None, - passed: true, - }], - attrs, - } + NodeBuilder::new(NodeType::Dropout, "test_dropout") + .input_tensor_f32("data", 3, None) + .output_tensor_f32("output", 3, None) + .attr_float("ratio", ratio) + .build() } fn create_test_node_with_input(ratio: f32) -> Node { - let ratio_tensor = TensorData { - data: Data::Float32(ratio), - shape: vec![], - }; - - let inputs = vec![ - Argument { - name: "data".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: 3, - static_shape: None, - }), - value: None, - passed: true, - }, - Argument { - name: "ratio".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: 0, - static_shape: None, - }), - value: Some(ratio_tensor), - passed: true, - }, - ]; - - let attrs = HashMap::new(); - - Node { - node_type: NodeType::Dropout, - name: "test_dropout".to_string(), - inputs, - outputs: vec![Argument { - name: "output".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: 3, - static_shape: None, - }), - value: None, - passed: true, - }], - attrs, - } + NodeBuilder::new(NodeType::Dropout, "test_dropout") + .input_tensor_f32("data", 3, None) + .input_scalar_tensor_f32("ratio", Some(ratio)) + .output_tensor_f32("output", 3, None) + .build() } #[test] @@ -151,7 +83,7 @@ mod tests { #[should_panic(expected = "Dropout configuration must have at least 2 inputs")] fn test_dropout_config_missing_input() { let mut node = create_test_node_with_input(0.5); - node.attrs = HashMap::new(); // Remove attributes + node.attrs.clear(); // Remove attributes node.inputs.remove(1); // Remove ratio input let _ = dropout_config(&node); } diff --git a/crates/onnx-ir/src/node/expand.rs b/crates/onnx-ir/src/node/expand.rs index 32a890fed3..a78377de4d 100644 --- a/crates/onnx-ir/src/node/expand.rs +++ b/crates/onnx-ir/src/node/expand.rs @@ -100,68 +100,29 @@ pub fn expand_config(node: &Node) -> ExpandShape { #[cfg(test)] mod tests { use super::*; - use crate::ir::{Argument, ElementType, NodeType, TensorData}; - use std::collections::HashMap; + use crate::ir::{ElementType, NodeType, TensorData}; + use crate::node::test_utils::NodeBuilder; fn create_test_node( input_rank: usize, shape_value: Option>, shape_type: Option, ) -> Node { - let inputs = vec![ - Argument { - name: "input".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: input_rank, - static_shape: None, - }), - value: None, - passed: true, - }, - Argument { - name: "shape".to_string(), - ty: shape_type.unwrap_or_else(|| { - if shape_value.is_some() { - ArgType::Tensor(TensorType { - elem_type: ElementType::Int64, - rank: 1, - static_shape: Some(vec![shape_value.as_ref().unwrap().len()]), - }) - } else { - ArgType::Tensor(TensorType { - elem_type: ElementType::Int64, - rank: 1, - static_shape: Some(vec![3]), // Example: a shape with 3 dimensions - }) - } - }), - value: shape_value.map(|shape| TensorData { - shape: vec![shape.len()], - data: Data::Int64s(shape), - }), - passed: true, - }, - ]; - - let outputs = vec![Argument { - name: "output".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: 0, // Will be updated - static_shape: None, - }), - value: None, - passed: true, - }]; - - Node { - node_type: NodeType::Expand, - name: "test_expand".to_string(), - inputs, - outputs, - attrs: HashMap::new(), + let mut builder = NodeBuilder::new(NodeType::Expand, "test_expand") + .input_tensor_f32("input", input_rank, None) + .output_tensor_f32("output", 0, None); // Rank 0 will be updated + + if let Some(shape) = shape_value { + builder = builder.input_tensor_i64_data("shape", shape.clone(), vec![shape.len()]); + } else if let Some(st) = shape_type { + // Use the provided custom shape type + builder = builder.add_input("shape", st); + } else { + // Default case with dynamic shape + builder = builder.input_tensor_i64("shape", 1, Some(vec![3])); } + + builder.build() } #[test] diff --git a/crates/onnx-ir/src/node/gather.rs b/crates/onnx-ir/src/node/gather.rs index cf80da01d6..d785f35652 100644 --- a/crates/onnx-ir/src/node/gather.rs +++ b/crates/onnx-ir/src/node/gather.rs @@ -108,58 +108,25 @@ pub fn gather_config(curr: &Node) -> usize { #[cfg(test)] mod tests { use super::*; - use crate::ir::{Argument, AttributeValue, ElementType, NodeType, TensorType}; - use std::collections::HashMap; + use crate::ir::NodeType; + use crate::node::test_utils::NodeBuilder; fn create_test_node(axis: i64, input_rank: usize, is_shape: bool) -> Node { - let input_ty = if is_shape { - ArgType::Shape(1) + // Start building the node with the appropriate input type + let mut builder = NodeBuilder::new(NodeType::Gather, "test_gather").attr_int("axis", axis); + + if is_shape { + builder = builder.add_input("data", ArgType::Shape(1)); } else { - ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: input_rank, - static_shape: None, - }) - }; - - let inputs = vec![ - Argument { - name: "data".to_string(), - ty: input_ty, - value: None, - passed: true, - }, - Argument { - name: "indices".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Int64, - rank: 1, - static_shape: None, - }), - value: None, - passed: true, - }, - ]; - - let mut attrs = HashMap::new(); - attrs.insert("axis".to_string(), AttributeValue::Int64(axis)); - - Node { - node_type: NodeType::Gather, - name: "test_gather".to_string(), - inputs, - outputs: vec![Argument { - name: "output".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: input_rank, - static_shape: None, - }), - value: None, - passed: true, - }], - attrs, + builder = builder.input_tensor_f32("data", input_rank, None); } + + // Add indices and output + builder = builder + .input_tensor_i64("indices", 1, None) + .output_tensor_f32("output", input_rank, None); + + builder.build() } #[test] diff --git a/crates/onnx-ir/src/node/linear.rs b/crates/onnx-ir/src/node/linear.rs index ecff7b9a25..a9afa761ac 100644 --- a/crates/onnx-ir/src/node/linear.rs +++ b/crates/onnx-ir/src/node/linear.rs @@ -79,74 +79,26 @@ pub fn linear_config(node: &Node) -> LinearConfig { #[cfg(test)] mod tests { use super::*; - use crate::ir::{ArgType, Argument, Data, ElementType, NodeType, TensorData, TensorType}; - use std::collections::HashMap; + use crate::ir::NodeType; + use crate::node::test_utils::NodeBuilder; fn create_test_node(has_bias: bool, weight_dims: Vec) -> Node { - let weight_tensor = TensorData { - data: Data::Float32s(vec![0.0; weight_dims.iter().product()]), // Not important for the test - shape: weight_dims.clone(), - }; - - let mut inputs = vec![ - Argument { - name: "input".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: 2, - static_shape: None, - }), - value: None, - passed: true, - }, - Argument { - name: "weight".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: weight_dims.len(), - static_shape: None, - }), - value: Some(weight_tensor), - passed: true, - }, - ]; + // Create weight tensor data + let weight_data = vec![0.0; weight_dims.iter().product()]; // Not important for the test + // Start building the node with input and weight + let mut builder = NodeBuilder::new(NodeType::Gemm, "test_linear") + .input_tensor_f32("input", 2, None) + .input_tensor_f32_data("weight", weight_data, weight_dims.clone()) + .output_tensor_f32("output", 2, None); + + // Add bias if needed if has_bias { - let bias_tensor = TensorData { - data: Data::Float32s(vec![0.0; weight_dims[1]]), // bias size equals output size - shape: vec![weight_dims[1]], - }; - - inputs.push(Argument { - name: "bias".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: 1, - static_shape: None, - }), - value: Some(bias_tensor), - passed: true, - }); + let bias_data = vec![0.0; weight_dims[1]]; // bias size equals output size + builder = builder.input_tensor_f32_data("bias", bias_data, vec![weight_dims[1]]); } - let attrs = HashMap::new(); - - Node { - node_type: NodeType::Gemm, - name: "test_linear".to_string(), - inputs, - outputs: vec![Argument { - name: "output".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: 2, - static_shape: None, - }), - value: None, - passed: true, - }], - attrs, - } + builder.build() } #[test] diff --git a/crates/onnx-ir/src/node/range.rs b/crates/onnx-ir/src/node/range.rs index 8fe97e43f5..8a21ca5086 100644 --- a/crates/onnx-ir/src/node/range.rs +++ b/crates/onnx-ir/src/node/range.rs @@ -24,49 +24,16 @@ pub fn range_update_outputs(node: &mut Node) { #[cfg(test)] mod tests { use super::*; - use crate::ir::{Argument, NodeType}; - use std::collections::HashMap; + use crate::ir::NodeType; + use crate::node::test_utils::NodeBuilder; fn create_test_node() -> Node { - let inputs = vec![ - Argument { - name: "start".to_string(), - ty: ArgType::Scalar(ElementType::Int64), - value: None, - passed: true, - }, - Argument { - name: "limit".to_string(), - ty: ArgType::Scalar(ElementType::Int64), - value: None, - passed: true, - }, - Argument { - name: "delta".to_string(), - ty: ArgType::Scalar(ElementType::Int64), - value: None, - passed: true, - }, - ]; - - let outputs = vec![Argument { - name: "output".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Int64, - rank: 0, // Will be updated - static_shape: None, - }), - value: None, - passed: true, - }]; - - Node { - node_type: NodeType::Range, - name: "test_range".to_string(), - inputs, - outputs, - attrs: HashMap::new(), - } + NodeBuilder::new(NodeType::Range, "test_range") + .input_scalar_i64("start") + .input_scalar_i64("limit") + .input_scalar_i64("delta") + .output_tensor_i64("output", 0, None) // Rank 0 will be updated + .build() } #[test] From 86d8c726a0048bc7a7866c3060564b50ee11e59c Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Thu, 8 May 2025 17:35:15 -0500 Subject: [PATCH 37/37] Remove inline comments --- crates/burn-import/src/burn/node/base.rs | 28 ++++++++++++------------ 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/crates/burn-import/src/burn/node/base.rs b/crates/burn-import/src/burn/node/base.rs index 1b47261db0..c39b008bcb 100644 --- a/crates/burn-import/src/burn/node/base.rs +++ b/crates/burn-import/src/burn/node/base.rs @@ -372,13 +372,13 @@ pub(crate) mod tests { TensorData::from([2f32]), None, Conv2dConfig::new( - [3, 3], // kernel_size - [3, 3], // stride - [1, 1], // dilation - PaddingConfig2d::Valid, // padding - [1, 1], // output_padding - 1, // groups - true, // bias + [3, 3], + [3, 3], + [1, 1], + PaddingConfig2d::Valid, + [1, 1], + 1, + true, ), )); @@ -453,13 +453,13 @@ pub(crate) mod tests { TensorData::from([2f32]), None, Conv2dConfig::new( - [3, 3], // kernel_size - [3, 3], // stride - [1, 1], // dilation - PaddingConfig2d::Valid, // padding - [1, 1], // output_padding - 1, // groups - true, // bias + [3, 3], + [3, 3], + [1, 1], + PaddingConfig2d::Valid, + [1, 1], + 1, + true, ), )); graph.register(MatmulNode::new(