diff --git a/contributor-book/src/SUMMARY.md b/contributor-book/src/SUMMARY.md index b86475efa3..d3e2c7549c 100644 --- a/contributor-book/src/SUMMARY.md +++ b/contributor-book/src/SUMMARY.md @@ -10,7 +10,7 @@ - [Tensor](./project-architecture/tensor.md) - [Backend](./project-architecture/backend.md) - [Guides for Contributors](./guides/README.md) - - [Onnx To Burn Conversion Tool: A Development Guide](./guides/onnx-to-burn-conversion-tool.md) + - [ONNX to Burn: Development Guide](./guides/onnx-to-burn-conversion-tool.md) - [Adding a New Operation to Burn](./guides/adding-a-new-operation-to-burn.md) - [Submitting Examples to Burn](./guides/submitting-examples.md) - [Frequently Encountered Issues](./frequently-encountered-issues/README.md) diff --git a/contributor-book/src/guides/onnx-to-burn-conversion-tool.md b/contributor-book/src/guides/onnx-to-burn-conversion-tool.md index 175decf425..3c93f07726 100644 --- a/contributor-book/src/guides/onnx-to-burn-conversion-tool.md +++ b/contributor-book/src/guides/onnx-to-burn-conversion-tool.md @@ -1,4 +1,4 @@ -# ONNX to Burn Conversion Tool: Development Guide +# ONNX to Burn: Development Guide This guide offers in-depth design insights and step-by-step procedures for developers working on the ONNX to Burn conversion tool. This tool allows the importation of ONNX models into the Burn deep @@ -8,28 +8,6 @@ weights to Burn state files. For an introduction to ONNX import in Burn, see [this section of the Burn book](https://burn.dev/burn-book/import/onnx-model.html). -## Table of Contents - -- [ONNX to Burn Conversion Tool: Development Guide](#onnx-to-burn-conversion-tool-development-guide) - - [Table of Contents](#table-of-contents) - - [Design Overview](#design-overview) - - [Design Goals](#design-goals) - - [Design Decisions](#design-decisions) - - [Adding New Operators](#adding-new-operators) - - [Implementing a New Operator](#implementing-a-new-operator) - - [Step 1: Visibility](#step-1-visibility) - - [Step 2: Node Implementation](#step-2-node-implementation) - - [Within Onnx-IR](#within-onnx-ir) - - [Within burn-import](#within-burn-import) - - [Step 3: Registering New Operations](#step-3-registering-new-operations) - - [Step 4: Create a Config Function](#step-4-create-a-config-function) - - [Step 5: Dimension Inference](#step-5-dimension-inference) - - [Step 6: Integrate into the Graph Building Process](#step-6-integrate-into-the-graph-building-process) - - [Step 7: Add Newly Supported Op!](#step-7-add-newly-supported-op) - - [Misc:](#misc) - - [Testing](#testing) - - [Resources](#resources) - ## Design Overview ### Design Goals @@ -66,8 +44,7 @@ To extend `burn-import` with support for new ONNX operators, follow these steps: model contains the expected operators. 4. **Generate IR and Burn Graph**: Navigate to - [crates/burn-import/](https://github.com/tracel-ai/burn/tree/925716f89d0249cbc6bd14f85f40967bd7ef80a8/crates/burn-import) - and run: + [crates/burn-import/](https://github.com/tracel-ai/burn/tree/main/crates/burn-import) and run: ``` cargo r -- ./onnx-tests/tests//.onnx ./out @@ -81,225 +58,261 @@ To extend `burn-import` with support for new ONNX operators, follow these steps: the Burn model in Rust code, and `my-model.json` includes the model data. 7. **Add End-to-End Test**: Include the test in - [crates/burn-import/onnx-tests/tests/onnx_tests.rs](https://github.com/tracel-ai/burn/blob/925716f89d0249cbc6bd14f85f40967bd7ef80a8/crates/burn-import/onnx-tests/tests/test_onnx.rs). + [crates/burn-import/onnx-tests/tests/test_onnx.rs](https://github.com/tracel-ai/burn/blob/main/crates/burn-import/onnx-tests/tests/test_onnx.rs). Further details can be found in the - [onnx-tests README](https://github.com/tracel-ai/burn/blob/925716f89d0249cbc6bd14f85f40967bd7ef80a8/crates/burn-import/onnx-tests/README.md). + [onnx-tests README](https://github.com/tracel-ai/burn/blob/main/crates/burn-import/onnx-tests/README.md). ## Implementing a New Operator To extend the capabilities of the Burn library by supporting new operations imported from ONNX graphs, developers must go through a few systematic steps. Here, we detail the process, using the implementation of the `Squeeze` operation to illustrate points as needed. All file/directory paths -are relative to `burn/crates/burn-import/`. +are relative to the root of the burn repository. ### Step 1: Visibility -To make a new operation accessible to the rest of the Burn project, you need to declare the module -within the -[`mod.rs` file](https://github.com/tracel-ai/burn/blob/925716f89d0249cbc6bd14f85f40967bd7ef80a8/crates/burn-import/src/burn/node/mod.rs#L43) -located in the `src/burn/node/` directory. +To make a new operation accessible, there are two key modules to update: + +1. In `crates/onnx-ir/src/node/mod.rs`, add your new operation module to make it visible within the + IR +2. In `crates/burn-import/src/burn/node/mod.rs`, make the corresponding node type visible within + burn-import ### Step 2: Node Implementation -#### Within Onnx-IR +#### Within onnx-ir + +The `onnx-ir` crate handles the Intermediate Representation (IR) of ONNX models. For each operation: + +1. Add the operation to the `NodeType` enum in `crates/onnx-ir/src/ir.rs`. + +2. Create a new module file in `crates/onnx-ir/src/node/.rs`. This file should + include: + + - A `_config` function to extract operation parameters + - A `_update_output` function for dimension inference + +3. If the operation might work with constants, add it to the list of node types checked for + constants in `crates/onnx-ir/src/from_onnx.rs`. -If the node type does not exist within the -[`NodeType` enum](https://github.com/tracel-ai/burn/blob/925716f89d0249cbc6bd14f85f40967bd7ef80a8/crates/onnx-ir/src/ir.rs#L273), -it will need to be added (support for custom operators is planned). If the node might be provided an -input which is a constant or the output of an identity node, it will need to be added to the list of -nodeTypes -[checked for constants](https://github.com/tracel-ai/burn/blob/925716f89d0249cbc6bd14f85f40967bd7ef80a8/crates/onnx-ir/src/from_onnx.rs#L21). -The node will need to be added to `rank_inference`, and in most cases the work parsing side will be -done. If a node requires extra parsing (such as handling an edge case like potentially remapping an -unsqueeze to a reshape) the best place for that is after check constants and prior to rank_inference -in -[`OnnxGraphBuilder::Build`](https://github.com/tracel-ai/burn/blob/925716f89d0249cbc6bd14f85f40967bd7ef80a8/crates/onnx-ir/src/from_onnx.rs#L222) +For example, the squeeze operation is defined in `crates/onnx-ir/src/node/squeeze.rs` and contains: + +- A `squeeze_config` function that extracts axes from node attributes +- A `squeeze_update_output` function that updates output dimensions by reducing input rank #### Within burn-import -Create a new file named `.rs` in the `src/burn/node/` directory. -This file will define the structure and functionality of your new operation. By convention, the -necessary information for carrying out an operation is encapsulated within a struct named -`Node`. For the `Squeeze` operation, we defined a -[struct called `SqueezeNode`](https://github.com/tracel-ai/burn/blob/925716f89d0249cbc6bd14f85f40967bd7ef80a8/crates/burn-import/src/burn/node/squeeze.rs#L8) -that holds necessary information about the input tensor, output tensor, and axes for the operation. -**If implementing a unary or binary operation, please see note below.** - -The core of integrating a new operation involves implementing the `NodeCodegen` trait for your node. -This trait defines how the node generates code during the graph compilation process. The -implementation must provide methods to define input and output types, to generate the forward pass -code, and to encapsulate the node into the more general `Node` structure. Specifically: - -- `output_types` and `input_types` return the tensor (or element) types for the output and inputs of - the node, respectively. -- `forward` generates the Rust code that performs the operation during the execution phase. The - `quote!` macro is used to generate rust code. Ensure that this is syntactically correct using Burn - code. -- `into_node` wraps the specific node in a general `Node` type, facilitating its inclusion in the - broader Burn graph structure. - -This file is also where you would put `test_codegen_nodes()`, to make sure that the generated code -works within the Burn library. +1. Create a new file named `.rs` in the `crates/burn-import/src/burn/node/` + directory. This file will define the structure and functionality of your new operation. By + convention, the necessary information for carrying out an operation is encapsulated within a + struct named `Node`. For the `Squeeze` operation, we defined a struct called + `SqueezeNode` that holds necessary information about the input tensor, output tensor, and axes + for the operation. **If implementing a unary or binary operation, please see note below.** + +2. The core of integrating a new operation involves implementing the `NodeCodegen` trait for your + node. This trait defines how the node generates code during the graph compilation process. The + implementation must provide methods to define input and output types, to generate the forward + pass code, and to encapsulate the node into the more general `Node` structure. Specifically: + + - `output_types` and `input_types` return the tensor (or element) types for the output and inputs + of the node, respectively. + - `forward` generates the Rust code that performs the operation during the execution phase. The + `quote!` macro is used to generate rust code. Ensure that this is syntactically correct using + Burn code. + - `into_node` wraps the specific node in a general `Node` type, facilitating its inclusion in the + broader Burn graph structure. + +3. This file is also where you would put `test_codegen_nodes()`, to make sure that the generated + code works within the Burn library. **For unary and binary operations:** The implementation of `NodeCodegen` is mostly implemented in -[`binary.rs`](https://github.com/tracel-ai/burn/blob/925716f89d0249cbc6bd14f85f40967bd7ef80a8/crates/burn-import/src/burn/node/binary.rs#L9) -and -[`unary.rs`](https://github.com/tracel-ai/burn/blob/76fe0ed881b3965782f78896433f8bb5e2f13a1b/crates/burn-import/src/burn/node/unary.rs#L13), -so each new operation only has to define a method to execute the function on the input(s) token -stream. +binary.rs and unary.rs, so each new operation only has to define a method to execute the function on +the input(s) token stream. ### Step 3: Registering New Operations -[Register the `NodeType::`](https://github.com/tracel-ai/burn/blob/925716f89d0249cbc6bd14f85f40967bd7ef80a8/crates/burn-import/src/onnx/to_burn.rs#L353) -and -[create an `_conversion(node: Node)` function](https://github.com/tracel-ai/burn/blob/925716f89d0249cbc6bd14f85f40967bd7ef80a8/crates/burn-import/src/onnx/to_burn.rs#L1263), -both in `src/onnx/to_burn.rs`. - -**Registering new operations in the ONNX -> Burn Conversion** -To integrate new operations from an ONNX graph into the Burn framework, each operation must be -registered within the ONNX graph conversion process. This is done in the `src/onnx/to_burn.rs` file, -where the conversion from ONNX nodes to Burn nodes is orchestrated. - -In the `into_burn()` method of the `OnnxGraph` struct, operations are matched with their -corresponding conversion functions. This method iterates over each node in the ONNX graph and, -depending on the node type, calls a specific conversion function that translates the ONNX node into -a corresponding Burn node. +1. In `crates/burn-import/src/onnx/to_burn.rs`, add the operation to the match statement in the + `into_burn()` method: ```rust -impl OnnxGraph { +impl ParsedOnnxGraph { pub fn into_burn(self) -> BurnGraph { - let mut graph = BurnGraph::::default(); - let mut unsupported_ops = vec![]; - - for node in self.nodes { + // ... + for node in self.0.nodes { match node.node_type { - NodeType::Add => graph.register(Self::add_conversion(node)), - // Other operations... + // ... NodeType::Squeeze => graph.register(Self::squeeze_conversion(node)), - // Add new operations here + // Add your new operation here } } } } ``` -Here, the `NodeType::Squeeze` matches the ONNX node type with the `squeeze_conversion()` function -that you define to handle the specific attributes and settings of a Squeeze operation. +2. Create a conversion function that creates an instance of your Burn node: -**Define the Conversion Function** -Each operation conversion function extracts necessary information from the ONNX node and constructs -a corresponding Burn node. The structure of these functions generally includes: +```rust +fn squeeze_conversion(node: Node) -> SqueezeNode { + let input = TensorType::from(node.inputs.first().unwrap()); + let output = TensorType::from(node.outputs.first().unwrap()); + let axes = squeeze_config(&node); -1. Extracting input and output tensors from the node. -2. Retrieving and processing operation-specific configurations. -3. Calling `_config()` to parse ONNX node configurations. -4. Creating an instance of the appropriate Burn node - ([defined in step 2](#step-2-node-implementation)) using this information. + SqueezeNode::new(input, output, axes) +} +``` + +This function extracts the necessary information from the ONNX node and passes it to your node's +constructor. ### Step 4: Create a Config Function -[Create an `_config(curr: &Node)`](https://github.com/tracel-ai/burn/blob/925716f89d0249cbc6bd14f85f40967bd7ef80a8/crates/burn-import/src/onnx/op_configuration.rs#L1847) -in `src/onnx/op_configuration.rs`. - -The `squeeze_conversion()` function in `src/onnx/to_burn.rs` from the previous step calls the -`squeeze_config()` function in `src/onnx/op_configuration.rs` in order the parse the ONNX node's -attributes to extract parameters specific to the Squeeze operation. In this case, the axes along -which the squeeze operation is performed. - -> 📘 Info: Understanding Generic `config` Patterns -> -> The `_config()` functions follow a similar pattern: -> -> 1. Extract tensor or scalar types for inputs and outputs. -> 2. Validate the input structure and types for each node, ensuring they conform to expected formats -> (panicking if not). -> 3. Parse and convert configurations or parameters specific to each operation. -> 4. Create and return a node specific to the operation, initialized with extracted values and -> configurations. -> -> For example, config functions handle specific settings like kernel size for pooling or handling -> different tensor and scalar types for power operations. - -These functions translate the more varied and flexible structure of ONNX nodes into the more -structured and type-safe environment of Rust and the Burn framework. Spec compliance is dealt with -here. - -### Step 5: Dimension Inference - -If needed, -[create a rank inference function](https://github.com/tracel-ai/burn/blob/925716f89d0249cbc6bd14f85f40967bd7ef80a8/crates/onnx-ir/src/rank_inference.rs#L410), -called `_update_output(node: &mut Node)` in `src/rank_inference.rs`. If dimensions remain -unchanged, use the `same_as_input()` function, for example -`NodeType::AveragePool1d => same_as_input(node)`. Match the `NodeType` to the function in the -`rank_inference()` match block. - -Dimension inference is an important step in the conversion process where Burn determines the -dimensions of each output tensor based on the operation. -[The `rank_inference()`](https://github.com/tracel-ai/burn/blob/925716f89d0249cbc6bd14f85f40967bd7ef80a8/crates/onnx-ir/src/rank_inference.rs#L14) -function is responsible for determining the dimensions of the output tensors for each node in the -graph. It does this by: - -1. **Matching the Node Type**: The function uses a `match` statement on the `node_type` of each node - to apply the correct dimension inference logic depending on the operation. -2. **Applying Operation Specific Logic**: For each operation, a specific inference function is - called that encapsulates the rules for how output dimensions should be derived from the inputs. - -For the Squeeze operation, the dimension inference is handled by the `squeeze_update_output()` -function, which is specifically tailored to handle the nuances of the squeeze operation, which is -currently not that nuanced. The output tensor should be (dimensions of input tensor) - 1. - -> 📘 Info: How `squeeze_update_output()` Works -> -> 1. Validation of axes input: We first check if the second input of the node contains a list of -> integers, which represent the axes along which the squeeze operation is applied. The function -> also validates that only one axis is specified for squeezing, ensuring that the operation's -> requirements within Burn are followed. -> 2. Extracting input dimensions: The input tensor's dimension is extracted from the first input. -> 3. Configuring output dimensions: The output tensor's dimensions are then set to be one less than -> the input tensor’s dimensions, reflecting the reduction in dimensions caused by the squeeze -> operation. -> 4. The function includes several checks that throw errors (panics) if the inputs do not meet the -> expected types or configurations, such as when the axes are not provided as an integer list or -> if the input type is not a tensor. - -By invoking this function within the `rank_inference()` match block, the output dimensions of each -node are updated before the graph is finalized. This ensures that all subsequent operations within -the graph can rely on correct tensor sizes, which is critical for both compiling the graph and for -runtime execution efficiency. - -If something is amiss (ie weird panics are happening), after doing this step and the dimensions of -your output tensor differs from the dimensions of your input, see the warning at the very end. +In `crates/onnx-ir/src/node/.rs`, create a config function that extracts +operation-specific parameters from the ONNX node: + +```rust +pub fn squeeze_config(curr: &Node) -> Vec { + let axes = curr + .attrs + .iter() + .filter_map(|(key, value)| { + if key == "axes" { + Some(value.clone().into_i64s()) + } else { + None + } + }) + .next() + .unwrap_or_else(Vec::new); + + match curr.inputs.first().unwrap().clone().ty { + ArgType::Tensor(tensor) => tensor, + _ => panic!("Only tensor input is valid"), + }; + + axes +} +``` + +This config function is responsible for parsing the ONNX node attributes and extracting +operation-specific parameters. In this case, it extracts the "axes" attribute from the squeeze +operation. + +### Step 5: Rank Inference + +In `crates/onnx-ir/src/node/.rs`, implement a rank inference function that updates +the output rank based on the operation: + +```rust +pub fn squeeze_update_output(node: &mut Node) { + // Extract axes information + let axes = /* ... */; + let input_rank = /* ... */; + let output_rank = input_rank - axes.len(); + + // Update output rank + node.outputs[0].ty = ArgType::Tensor(TensorType { + elem_type: node.inputs[0].ty.elem_type().clone(), + rank: output_rank, + static_shape: None, + }); +} +``` + +Then register this function in `crates/onnx-ir/src/rank_inference.rs` by adding it to the match +statement: + +```rust +pub fn rank_inference(node: &mut Node) { + match node.node_type { + // ... + NodeType::Squeeze => squeeze_update_output(node), + // Add your new operation here + } +} +``` + +The `rank_inference.rs` file is responsible for determining the output tensor rank for each node in +the graph. + +If the rank remains unchanged, you can use helper functions like `same_as_input()` or +`same_as_input_broadcast()` instead of writing a custom update function. ### Step 6: Integrate into the Graph Building Process -When a new node type is introduced, it must be added to the -[`Node` enum](https://github.com/tracel-ai/burn/blob/925716f89d0249cbc6bd14f85f40967bd7ef80a8/crates/burn-import/src/burn/node/base.rs#L85) -and -[`match_all!` macro](https://github.com/tracel-ai/burn/blob/925716f89d0249cbc6bd14f85f40967bd7ef80a8/crates/burn-import/src/burn/node/base.rs#L138) -in `src/burn/node/base.rs`. +When a new node type is introduced, it must be added to the `Node` enum in +`crates/burn-import/src/burn/node/base.rs` and the `match_all!` macro in the same file. The `Node` enum abstracts over different types of operations (nodes) within a network graph. Each -variant of the enum corresponds to a specific type of operation, and it encapsulates the -operation-specific data structures (like `SqueezeNode1`) that was -[defined in step 2](#step-2-node-implementation). +variant of the enum corresponds to a specific type of operation and encapsulates the +operation-specific data structures (like `SqueezeNode`) that were defined in step 2. ### Step 7: Add Newly Supported Op! -As a reward, add an extra check to -[SUPPORTED-ONNX-OPS.md](https://github.com/tracel-ai/burn/blob/925716f89d0249cbc6bd14f85f40967bd7ef80a8/crates/burn-import/SUPPORTED-ONNX-OPS.md)! +As a reward, add an extra check to `crates/burn-import/SUPPORTED-ONNX-OPS.md`! + +### Lifting Constant Nodes + +If your operation takes inputs from constant nodes (such as weights in Conv1d, shape tensors in +Reshape, etc.), you need to add your operation's `NodeType` to the `LIFT_CONSTANTS_FOR_NODE_TYPES` +array in `crates/onnx-ir/src/from_onnx.rs`. + +```rust +const LIFT_CONSTANTS_FOR_NODE_TYPES: [NodeType; 16] = [ + NodeType::BatchNormalization, + // other operations... + NodeType::Squeeze, + NodeType::Unsqueeze, + // Add your operation here if it needs constants to be processed +]; +``` + +"Lifting" constants means converting Constant nodes into direct input values. This is similar to how +ONNX initializers work. For example, instead of having a separate Constant node providing weights to +a Convolution operation, the weights are directly embedded as values in the Convolution node's +inputs. + +This transformation makes it easier to: -### Misc: +1. Access the constant values during node configuration +2. Process operations like Conv1d that expect weights as direct inputs +3. Handle shape-defining inputs needed for operations like Reshape -> 🚧 **Warning**: Dimension Changes -> -> If your operation changes the dimensions of the input tensor, you may need to modify the -> [`LIFT_CONSTANTS_FOR_NODE_TYPES` enum](https://github.com/tracel-ai/burn/blob/925716f89d0249cbc6bd14f85f40967bd7ef80a8/crates/onnx-ir/src/from_onnx.rs#L21) -> in `src/from_onnx.rs` by adding the `NodeType` of your operation to it. +Without this, operations that need to extract configuration from constant inputs (such as shapes, +weights, or other parameters) would not work correctly because they wouldn't have direct access to +those constant values. ## Testing -- Unit tests for the Burn graph to Rust source code conversion are mandatory. -- End-to-end tests should include a test ONNX model and its expected output for each operator. +When implementing a new operator, there are several levels of testing to consider: + +### Unit Testing + +- **Node Configuration**: Write unit tests for the `_config` function in + `crates/onnx-ir/src/node/.rs` to verify that it correctly extracts parameters from + ONNX nodes. + +- **Rank Inference**: Test the `_update_output` function to ensure it correctly + computes output ranks. + +- **Code Generation**: Test the Node implementation in `burn-import` to verify that it generates + correct Rust code. + +### Integration Testing + +- Create small ONNX models that use your operator and test the end-to-end conversion process +- Ensure the generated Rust code compiles and produces the expected outputs +- Add these tests to `crates/burn-import/onnx-tests/tests/test_onnx.rs` + +### End-to-End Testing + +- Test with realistic ONNX models that use your operator in conjunction with others +- Verify that inputs and outputs match between the original ONNX model and the converted Burn model +- Include models that test edge cases (e.g., different input shapes, parameter combinations) + +Testing both the rank inference and node configuration is particularly important as these components +directly affect the correctness of the conversion process. Incorrect rank inference can lead to +mismatched tensor shapes, while incorrect configuration can cause runtime errors or incorrect +results. ## Resources diff --git a/crates/burn-core/src/nn/padding.rs b/crates/burn-core/src/nn/padding.rs index 0e847c8532..5c61322073 100644 --- a/crates/burn-core/src/nn/padding.rs +++ b/crates/burn-core/src/nn/padding.rs @@ -7,12 +7,11 @@ use crate::config::Config; /// Padding configuration for 1D operators. #[derive(Config, Debug, PartialEq)] pub enum PaddingConfig1d { - /// Dynamically calculate the amount of padding necessary to ensure that the output size will be - /// the same as the input. + /// Dynamically calculates padding to ensure output size matches input size. Same, - /// Same as no padding. + /// No padding applied. Valid, - /// Applies the specified amount of padding to all inputs. + /// Applies a specific amount of padding to all inputs. Explicit(usize), } @@ -35,12 +34,11 @@ impl PaddingConfig1d { /// Padding configuration for 2D operators. #[derive(Config, Debug, PartialEq)] pub enum PaddingConfig2d { - /// Dynamically calculate the amount of padding necessary to ensure that the output size will be - /// the same as the input. + /// Dynamically calculates padding to preserve input dimensions in output. Same, - /// Same as no padding. + /// No padding applied. Valid, - /// Applies the specified amount of padding to all inputs. + /// Applies specified padding values to height and width dimensions. Explicit(usize, usize), } @@ -70,12 +68,11 @@ impl PaddingConfig2d { /// Padding configuration for 3D operators. #[derive(Config, Debug, PartialEq)] pub enum PaddingConfig3d { - /// Dynamically calculate the amount of padding necessary to ensure that the output size will be - /// the same as the input. + /// Dynamically calculates padding to preserve input dimensions in output. Same, - /// Same as no padding. + /// No padding applied. Valid, - /// Applies the specified amount of padding to all inputs. + /// Applies specified padding values to depth, height, and width dimensions. Explicit(usize, usize, usize), } diff --git a/crates/burn-import/DEVELOPMENT.md b/crates/burn-import/DEVELOPMENT.md deleted file mode 100644 index 2d122a042b..0000000000 --- a/crates/burn-import/DEVELOPMENT.md +++ /dev/null @@ -1,80 +0,0 @@ -# ONNX to Burn Conversion Tool: Development Guide - -This guide offers in-depth design insights and step-by-step procedures for developers working on the -ONNX to Burn conversion tool. This tool allows the importation of ONNX models into the Burn deep -learning framework written in Rust. It converts both ONNX models to Rust source code and model -weights to Burn state files. - -## Table of Contents - -1. [Design Overview](#Design-Overview) - 1. [Design Goals](#Design-Goals) - 2. [Design Decisions](#Design-Decisions) -2. [Adding New Operators](#Adding-New-Operators) -3. [Testing](#Testing) -4. [Resources](#Resources) - -## Design Overview - -### Design Goals - -- Perform best-effort conversion of ONNX models to Rust source code via Burn APIs. -- Convert ONNX model weights to Burn state files. -- Support ONNX models generated by PyTorch (ONNX Opset 16). -- Produce easy-to-understand and modifiable models. -- Ensure the generated models are trainable using Burn APIs. - -### Design Decisions - -- Limit interaction with ONNX to the Intermediate Representation (IR) stage to simplify the process. -- Ensure operator behavior consistency across different OpSet versions. -- Exclude any ONNX/Protobuf-specific logic from the Burn graph. - -The conversion process involves three main stages: - -1. Convert ONNX model to Intermediate Representation (IR). -2. Translate IR to a Burn graph. -3. Generate Rust source code from the Burn graph. - -## Adding New Operators - -To extend `burn-import` with support for new ONNX operators, follow these steps: - -1. **Create PyTorch Script**: Place a PyTorch script using the new operator under - `./burn-import/onnx-tests/tests//.py`. Make sure to print both input and output tensors - for end-to-end testing. - -2. **Generate ONNX Model**: Run the PyTorch script to produce an ONNX model. - -3. **Visualize ONNX Model**: Use [Netron](https://github.com/lutzroeder/netron) to verify the ONNX - model contains the expected operators. - -4. **Generate IR and Burn Graph**: Navigate to `./burn-import/` and run: - - ``` - cargo r -- ./onnx-tests/tests//.onnx ./out - ``` - -5. **Implement Missing Operators**: If you encounter an error stating that an operator is - unsupported, implement it. The `./out/my-model.graph.txt` should provide relevant information. - -6. **Inspect Generated Files**: The `my-model.graph.txt` contains IR details, `my-model.rs` holds - the Burn model in Rust code, and `my-model.json` includes the model data. - -7. **Add End-to-End Test**: Include the test in `./burn-import/onnx-tests/tests/onnx_tests.rs`. - Further details can be found in the [onnx-tests README](./onnx-tests/README.md). - -## Testing - -- Unit tests for the Burn graph to Rust source code conversion are mandatory. -- End-to-end tests should include a test ONNX model and its expected output for each operator. - -## Resources - -1. [PyTorch to ONNX](https://pytorch.org/docs/stable/onnx.html) -2. [ONNX to PyTorch](https://github.com/ENOT-AutoDL/onnx2torch) -3. [ONNX Introduction](https://onnx.ai/onnx/intro/) -4. [ONNX Operators](https://onnx.ai/onnx/operators/index.html) -5. [ONNX Protos](https://onnx.ai/onnx/api/classes.html) -6. [ONNX Optimizer](https://github.com/onnx/optimizer) -7. [Netron](https://github.com/lutzroeder/netron) diff --git a/crates/burn-import/README.md b/crates/burn-import/README.md index d6a898eccd..7ce4beff84 100644 --- a/crates/burn-import/README.md +++ b/crates/burn-import/README.md @@ -1,15 +1,24 @@ -# Importing Models +# Burn Import -The Burn project supports the import of models from various frameworks, emphasizing efficiency and -compatibility. Currently, it handles two primary model formats: +The `burn-import` crate enables seamless integration of pre-trained models from popular machine +learning frameworks into the Burn ecosystem. This functionality allows you to leverage existing +models while benefiting from Burn's performance optimizations and native Rust integration. -1. [ONNX](https://burn.dev/burn-book/import/onnx-model.html): Facilitates direct import, ensuring the - model's performance and structure are maintained. +## Supported Import Formats -2. [PyTorch](https://burn.dev/burn-book/import/pytorch-model.html): Enables the loading of PyTorch model - weights into Burn’s native model architecture, ensuring seamless integration. +Burn currently supports three primary model import formats, each serving different use cases: -## Contribution +| Format | Description | Use Case | +| ----------------------------------------------------------------------------------- | ----------------------------------------- | ------------------------------------------------------------------------------------------------------ | +| [**ONNX** (Guide)](https://burn.dev/burn-book/import/onnx-model.html) | Open Neural Network Exchange format | Direct import of complete model architectures and weights from any framework that supports ONNX export | +| [**PyTorch** (Guide)](https://burn.dev/burn-book/import/pytorch-model.html) | PyTorch weights (.pt, .pth) | Loading weights from PyTorch models into a matching Burn architecture | +| [**Safetensors** (Guide)](https://burn.dev/burn-book/import/safetensors-model.html) | Hugging Face's model serialization format | Loading a model's tensor weights into a matching Burn architecture | -Interested in contributing to `burn-import`? Check out our [development guide](DEVELOPMENT.md) for -more information. +## ONNX Contributor Resources + +- [ONNX to Burn conversion guide](https://burn.dev/contributor-book/guides/onnx-to-burn-conversion-tool.html) - + Instructions for adding support for additional ONNX operators +- [ONNX tests README](https://github.com/tracel-ai/burn/blob/main/crates/burn-import/onnx-tests/README.md) - + Testing procedures for ONNX operators +- [Supported ONNX Operators table](https://github.com/tracel-ai/burn/blob/main/crates/burn-import/SUPPORTED-ONNX-OPS.md) - + Complete list of currently supported ONNX operators diff --git a/crates/burn-import/src/burn/codegen.rs b/crates/burn-import/src/burn/codegen.rs index 7f511dafd4..dff8ff4875 100644 --- a/crates/burn-import/src/burn/codegen.rs +++ b/crates/burn-import/src/burn/codegen.rs @@ -1,10 +1,7 @@ +use onnx_ir::node::padding::{PaddingConfig1d, PaddingConfig2d, PaddingConfig3d}; use proc_macro2::TokenStream; use quote::quote; -use burn::nn::PaddingConfig1d; -use burn::nn::PaddingConfig2d; -use burn::nn::PaddingConfig3d; - fn convert_primitive(primitive: T) -> TokenStream { let value = format!("{:?}", primitive); @@ -76,7 +73,6 @@ impl ToTokens for f32 { impl ToTokens for PaddingConfig1d { fn to_tokens(&self) -> TokenStream { match self { - Self::Same => quote! { PaddingConfig1d::Same }, Self::Valid => quote! { PaddingConfig1d::Valid }, Self::Explicit(padding) => { let padding = padding.to_tokens(); @@ -90,7 +86,6 @@ impl ToTokens for PaddingConfig1d { impl ToTokens for PaddingConfig2d { fn to_tokens(&self) -> TokenStream { match self { - Self::Same => quote! { PaddingConfig2d::Same }, Self::Valid => quote! { PaddingConfig2d::Valid }, Self::Explicit(padding1, padding2) => { let padding1 = padding1.to_tokens(); @@ -105,7 +100,6 @@ impl ToTokens for PaddingConfig2d { impl ToTokens for PaddingConfig3d { fn to_tokens(&self) -> TokenStream { match self { - Self::Same => quote! { PaddingConfig3d::Same }, Self::Valid => quote! { PaddingConfig3d::Valid }, Self::Explicit(padding1, padding2, padding3) => { let padding1 = padding1.to_tokens(); diff --git a/crates/burn-import/src/burn/node/avg_pool1d.rs b/crates/burn-import/src/burn/node/avg_pool1d.rs index 2a8310bb9c..0c4b1a1693 100644 --- a/crates/burn-import/src/burn/node/avg_pool1d.rs +++ b/crates/burn-import/src/burn/node/avg_pool1d.rs @@ -1,7 +1,8 @@ +use onnx_ir::node::avg_pool1d::AvgPool1dConfig; use proc_macro2::TokenStream; use quote::quote; -use burn::{nn::pool::AvgPool1dConfig, record::PrecisionSettings}; +use burn::record::PrecisionSettings; use super::{Node, NodeCodegen}; use crate::burn::{BurnImports, OtherType, Scope, TensorType, ToTokens, Type}; @@ -93,7 +94,8 @@ impl NodeCodegen for AvgPool1dNode { mod tests { use super::*; use crate::burn::{TensorType, graph::BurnGraph, node::test::assert_tokens}; - use burn::{nn::PaddingConfig1d, record::FullPrecisionSettings}; + use burn::record::FullPrecisionSettings; + use onnx_ir::node::padding::PaddingConfig1d; #[test] fn test_codegen() { @@ -103,9 +105,7 @@ mod tests { "avg_pool1d", TensorType::new_float("input", 3), TensorType::new_float("output", 3), - AvgPool1dConfig::new(3) - .with_stride(1) - .with_padding(PaddingConfig1d::Valid), + AvgPool1dConfig::new(3, 1, PaddingConfig1d::Valid, true), )); graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]); diff --git a/crates/burn-import/src/burn/node/avg_pool2d.rs b/crates/burn-import/src/burn/node/avg_pool2d.rs index 2e84a5170b..8ab1af2ee2 100644 --- a/crates/burn-import/src/burn/node/avg_pool2d.rs +++ b/crates/burn-import/src/burn/node/avg_pool2d.rs @@ -1,7 +1,8 @@ +use onnx_ir::node::avg_pool2d::AvgPool2dConfig; use proc_macro2::TokenStream; use quote::quote; -use burn::{nn::pool::AvgPool2dConfig, record::PrecisionSettings}; +use burn::record::PrecisionSettings; use super::{Node, NodeCodegen}; use crate::burn::{BurnImports, OtherType, Scope, TensorType, ToTokens, Type}; @@ -97,7 +98,8 @@ mod tests { graph::BurnGraph, node::{avg_pool2d::AvgPool2dNode, test::assert_tokens}, }; - use burn::{nn::PaddingConfig2d, nn::pool::AvgPool2dConfig, record::FullPrecisionSettings}; + use burn::record::FullPrecisionSettings; + use onnx_ir::node::padding::PaddingConfig2d; #[test] fn test_codegen() { @@ -107,9 +109,7 @@ mod tests { "avg_pool2d", TensorType::new_float("input", 4), TensorType::new_float("output", 4), - AvgPool2dConfig::new([3, 3]) - .with_strides([1, 1]) - .with_padding(PaddingConfig2d::Valid), + AvgPool2dConfig::new([3, 3], [1, 1], PaddingConfig2d::Valid, true), )); graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]); diff --git a/crates/burn-import/src/burn/node/base.rs b/crates/burn-import/src/burn/node/base.rs index d77df6172f..c39b008bcb 100644 --- a/crates/burn-import/src/burn/node/base.rs +++ b/crates/burn-import/src/burn/node/base.rs @@ -308,10 +308,8 @@ pub(crate) mod tests { graph::BurnGraph, node::{NodeCodegen, conv2d::Conv2dNode, matmul::MatmulNode, test::assert_tokens}, }; - use burn::{ - nn::PaddingConfig2d, nn::conv::Conv2dConfig, record::FullPrecisionSettings, - tensor::TensorData, - }; + use burn::{record::FullPrecisionSettings, tensor::TensorData}; + use onnx_ir::node::{conv2d::Conv2dConfig, padding::PaddingConfig2d}; use proc_macro2::TokenStream; use quote::quote; @@ -373,7 +371,15 @@ pub(crate) mod tests { TensorType::new_float("tensor4", 4), TensorData::from([2f32]), None, - Conv2dConfig::new([3, 3], [3, 3]).with_padding(PaddingConfig2d::Valid), + Conv2dConfig::new( + [3, 3], + [3, 3], + [1, 1], + PaddingConfig2d::Valid, + [1, 1], + 1, + true, + ), )); graph.register_input_output( @@ -446,7 +452,15 @@ pub(crate) mod tests { TensorType::new_float("tensor4", 4), TensorData::from([2f32]), None, - Conv2dConfig::new([3, 3], [3, 3]).with_padding(PaddingConfig2d::Valid), + Conv2dConfig::new( + [3, 3], + [3, 3], + [1, 1], + PaddingConfig2d::Valid, + [1, 1], + 1, + true, + ), )); graph.register(MatmulNode::new( TensorType::new_float("tensor3", 4), diff --git a/crates/burn-import/src/burn/node/batch_norm.rs b/crates/burn-import/src/burn/node/batch_norm.rs index 41f7194f4f..ae0147f09c 100644 --- a/crates/burn-import/src/burn/node/batch_norm.rs +++ b/crates/burn-import/src/burn/node/batch_norm.rs @@ -2,10 +2,11 @@ use super::{Node, NodeCodegen, SerializationBackend}; use crate::burn::{BurnImports, OtherType, Scope, TensorType, ToTokens, Type}; use burn::{ module::{ConstantRecord, Param, ParamId}, - nn::{BatchNormConfig, BatchNormRecord}, + nn::BatchNormRecord, record::{PrecisionSettings, Record}, tensor::{Tensor, TensorData}, }; +use onnx_ir::node::batch_norm::BatchNormConfig; use proc_macro2::TokenStream; use quote::quote; use serde::Serialize; @@ -171,7 +172,7 @@ mod tests { TensorData::from([2f32]), TensorData::from([2f32]), TensorData::from([2f32]), - BatchNormConfig::new(128), + BatchNormConfig::new(128, 0.00001, 0.1), )); graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]); diff --git a/crates/burn-import/src/burn/node/conv1d.rs b/crates/burn-import/src/burn/node/conv1d.rs index 65a8343697..3eb0974769 100644 --- a/crates/burn-import/src/burn/node/conv1d.rs +++ b/crates/burn-import/src/burn/node/conv1d.rs @@ -2,10 +2,11 @@ use super::{Node, NodeCodegen, SerializationBackend}; use crate::burn::{BurnImports, OtherType, Scope, TensorType, ToTokens, Type}; use burn::{ module::{ConstantRecord, Param, ParamId}, - nn::conv::{Conv1dConfig, Conv1dRecord}, + nn::conv::Conv1dRecord, record::{PrecisionSettings, Record}, tensor::{Tensor, TensorData}, }; +use onnx_ir::node::conv1d::Conv1dConfig; use proc_macro2::TokenStream; use quote::quote; use serde::Serialize; @@ -135,10 +136,8 @@ mod tests { graph::BurnGraph, node::{conv1d::Conv1dNode, test::assert_tokens}, }; - use burn::{ - nn::{PaddingConfig1d, conv::Conv1dConfig}, - record::FullPrecisionSettings, - }; + use burn::record::FullPrecisionSettings; + use onnx_ir::node::padding::PaddingConfig1d; #[test] fn test_codegen() { @@ -150,7 +149,7 @@ mod tests { TensorType::new_float("output", 4), TensorData::from([2f32]), None, - Conv1dConfig::new(3, 3, 3).with_padding(PaddingConfig1d::Valid), + Conv1dConfig::new(3, 3, 3, 1, PaddingConfig1d::Valid, 1, 1, true), )); graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]); diff --git a/crates/burn-import/src/burn/node/conv2d.rs b/crates/burn-import/src/burn/node/conv2d.rs index 7de8c97cb6..1e2336b923 100644 --- a/crates/burn-import/src/burn/node/conv2d.rs +++ b/crates/burn-import/src/burn/node/conv2d.rs @@ -2,10 +2,11 @@ use super::{Node, NodeCodegen, SerializationBackend}; use crate::burn::{BurnImports, OtherType, Scope, TensorType, ToTokens, Type}; use burn::{ module::{ConstantRecord, Param, ParamId}, - nn::conv::{Conv2dConfig, Conv2dRecord}, + nn::conv::Conv2dRecord, record::{PrecisionSettings, Record}, tensor::{Tensor, TensorData}, }; +use onnx_ir::node::conv2d::Conv2dConfig; use proc_macro2::TokenStream; use quote::quote; use serde::Serialize; @@ -134,7 +135,8 @@ mod tests { graph::BurnGraph, node::{conv2d::Conv2dNode, test::assert_tokens}, }; - use burn::{nn::PaddingConfig2d, nn::conv::Conv2dConfig, record::FullPrecisionSettings}; + use burn::record::FullPrecisionSettings; + use onnx_ir::node::padding::PaddingConfig2d; #[test] fn test_codegen() { @@ -146,7 +148,15 @@ mod tests { TensorType::new_float("output", 4), TensorData::from([2f32]), None, - Conv2dConfig::new([3, 3], [3, 3]).with_padding(PaddingConfig2d::Valid), + Conv2dConfig::new( + [3, 3], + [3, 3], + [1, 1], + PaddingConfig2d::Valid, + [1, 1], + 1, + true, + ), )); graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]); diff --git a/crates/burn-import/src/burn/node/conv3d.rs b/crates/burn-import/src/burn/node/conv3d.rs index 9cad0f6e95..694ec1d2d7 100644 --- a/crates/burn-import/src/burn/node/conv3d.rs +++ b/crates/burn-import/src/burn/node/conv3d.rs @@ -2,10 +2,11 @@ use super::{Node, NodeCodegen, SerializationBackend}; use crate::burn::{BurnImports, OtherType, Scope, TensorType, ToTokens, Type}; use burn::{ module::{ConstantRecord, Param, ParamId}, - nn::conv::{Conv3dConfig, Conv3dRecord}, + nn::conv::Conv3dRecord, record::{PrecisionSettings, Record}, tensor::{Tensor, TensorData}, }; +use onnx_ir::node::conv3d::Conv3dConfig; use proc_macro2::TokenStream; use quote::quote; use serde::Serialize; @@ -134,7 +135,8 @@ mod tests { graph::BurnGraph, node::{conv3d::Conv3dNode, test::assert_tokens}, }; - use burn::{nn::PaddingConfig3d, nn::conv::Conv3dConfig, record::FullPrecisionSettings}; + use burn::record::FullPrecisionSettings; + use onnx_ir::node::padding::PaddingConfig3d; #[test] fn test_codegen() { @@ -146,7 +148,15 @@ mod tests { TensorType::new_float("output", 5), TensorData::from([2f32]), None, - Conv3dConfig::new([3, 3], [3, 3, 3]).with_padding(PaddingConfig3d::Valid), + Conv3dConfig::new( + [3, 3], + [3, 3, 3], + [1, 1, 1], + [1, 1, 1], + 1, + true, + PaddingConfig3d::Valid, + ), )); graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]); diff --git a/crates/burn-import/src/burn/node/conv_transpose_2d.rs b/crates/burn-import/src/burn/node/conv_transpose_2d.rs index 2fa77ca83c..ad6932c7c2 100644 --- a/crates/burn-import/src/burn/node/conv_transpose_2d.rs +++ b/crates/burn-import/src/burn/node/conv_transpose_2d.rs @@ -2,10 +2,11 @@ use super::{Node, NodeCodegen, SerializationBackend}; use crate::burn::{BurnImports, OtherType, Scope, TensorType, ToTokens, Type}; use burn::{ module::{ConstantRecord, Param, ParamId}, - nn::conv::{ConvTranspose2dConfig, ConvTranspose2dRecord}, + nn::conv::ConvTranspose2dRecord, record::{PrecisionSettings, Record}, tensor::{Tensor, TensorData}, }; +use onnx_ir::node::conv_transpose2d::ConvTranspose2dConfig; use proc_macro2::TokenStream; use quote::quote; use serde::Serialize; @@ -137,7 +138,7 @@ mod tests { graph::BurnGraph, node::{conv_transpose_2d::ConvTranspose2dNode, test::assert_tokens}, }; - use burn::{nn::conv::ConvTranspose2dConfig, record::FullPrecisionSettings}; + use burn::record::FullPrecisionSettings; #[test] fn test_codegen() { @@ -149,7 +150,7 @@ mod tests { TensorType::new_float("output", 4), TensorData::from([2f32]), None, - ConvTranspose2dConfig::new([3, 3], [3, 3]).with_padding([0, 0]), + ConvTranspose2dConfig::new([3, 3], [1, 1], [0, 0], [0, 0], [0, 0], [0, 0], 1, true), )); graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]); @@ -172,11 +173,11 @@ mod tests { impl Model { #[allow(unused_variables)] pub fn new(device: &B::Device) -> Self { - let conv_transpose_2d = ConvTranspose2dConfig::new([3, 3], [3, 3]) - .with_stride([1, 1]) + let conv_transpose_2d = ConvTranspose2dConfig::new([3, 3], [1, 1]) + .with_stride([0, 0]) .with_padding([0, 0]) .with_padding_out([0, 0]) - .with_dilation([1, 1]) + .with_dilation([0, 0]) .with_groups(1) .with_bias(true) .init(device); diff --git a/crates/burn-import/src/burn/node/conv_transpose_3d.rs b/crates/burn-import/src/burn/node/conv_transpose_3d.rs index 098dbcaffe..41dc2c83f6 100644 --- a/crates/burn-import/src/burn/node/conv_transpose_3d.rs +++ b/crates/burn-import/src/burn/node/conv_transpose_3d.rs @@ -2,10 +2,11 @@ use super::{Node, NodeCodegen, SerializationBackend}; use crate::burn::{BurnImports, OtherType, Scope, TensorType, ToTokens, Type}; use burn::{ module::{ConstantRecord, Param, ParamId}, - nn::conv::{ConvTranspose3dConfig, ConvTranspose3dRecord}, + nn::conv::ConvTranspose3dRecord, record::{PrecisionSettings, Record}, tensor::{Tensor, TensorData}, }; +use onnx_ir::node::conv_transpose3d::ConvTranspose3dConfig; use proc_macro2::TokenStream; use quote::quote; use serde::Serialize; @@ -137,7 +138,7 @@ mod tests { graph::BurnGraph, node::{conv_transpose_3d::ConvTranspose3dNode, test::assert_tokens}, }; - use burn::{nn::conv::ConvTranspose3dConfig, record::FullPrecisionSettings}; + use burn::record::FullPrecisionSettings; #[test] fn test_codegen() { @@ -149,7 +150,16 @@ mod tests { TensorType::new_float("output", 5), TensorData::from([2f32]), None, - ConvTranspose3dConfig::new([3, 3], [3, 3, 3]).with_padding([0, 0, 0]), + ConvTranspose3dConfig::new( + [3, 3], + [1, 1, 1], + [0, 0, 0], + [0, 0, 0], + [0, 0, 0], + [0, 0, 0], + 1, + true, + ), )); graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]); @@ -172,11 +182,11 @@ mod tests { impl Model { #[allow(unused_variables)] pub fn new(device: &B::Device) -> Self { - let conv_transpose_3d = ConvTranspose3dConfig::new([3, 3], [3, 3, 3]) - .with_stride([1, 1, 1]) + let conv_transpose_3d = ConvTranspose3dConfig::new([3, 3], [1, 1, 1]) + .with_stride([0, 0, 0]) .with_padding([0, 0, 0]) .with_padding_out([0, 0, 0]) - .with_dilation([1, 1, 1]) + .with_dilation([0, 0, 0]) .with_groups(1) .with_bias(true) .init(device); diff --git a/crates/burn-import/src/burn/node/dropout.rs b/crates/burn-import/src/burn/node/dropout.rs index f6a8d37eaa..ee9d1ed8dd 100644 --- a/crates/burn-import/src/burn/node/dropout.rs +++ b/crates/burn-import/src/burn/node/dropout.rs @@ -1,7 +1,8 @@ +use onnx_ir::node::dropout::DropoutConfig; use proc_macro2::TokenStream; use quote::quote; -use burn::{nn::DropoutConfig, record::PrecisionSettings}; +use burn::record::PrecisionSettings; use super::{Node, NodeCodegen}; use crate::burn::{BurnImports, OtherType, Scope, TensorType, ToTokens, Type}; @@ -83,7 +84,7 @@ impl NodeCodegen for DropoutNode { mod tests { use super::*; use crate::burn::{TensorType, graph::BurnGraph, node::test::assert_tokens}; - use burn::{nn::DropoutConfig, record::FullPrecisionSettings}; + use burn::record::FullPrecisionSettings; #[test] fn test_codegen() { diff --git a/crates/burn-import/src/burn/node/expand.rs b/crates/burn-import/src/burn/node/expand.rs index 0fd56124a7..3b3143c5dc 100644 --- a/crates/burn-import/src/burn/node/expand.rs +++ b/crates/burn-import/src/burn/node/expand.rs @@ -1,6 +1,7 @@ use super::{Node, NodeCodegen}; use crate::burn::{Scope, TensorType, ToTokens, Type}; use burn::record::PrecisionSettings; +use onnx_ir::node::expand::ExpandShape; use proc_macro2::TokenStream; use quote::quote; @@ -11,12 +12,6 @@ pub struct ExpandNode { pub shape: ExpandShape, } -#[derive(Debug, Clone)] -pub enum ExpandShape { - Static(Vec), - Runtime(Type), -} - impl NodeCodegen for ExpandNode { fn output_types(&self) -> Vec { vec![Type::Tensor(self.output.clone())] @@ -28,10 +23,9 @@ impl NodeCodegen for ExpandNode { // if it is dynamic, the shape will be our 2nd: match &self.shape { ExpandShape::Static(_) => vec![input], - ExpandShape::Runtime(rt_type) => vec![input, rt_type.clone()], + ExpandShape::Runtime(rt_type) => vec![input, Type::from(rt_type)], } } - fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { let input = scope.tensor_use_owned(&self.input, node_position); let output = &self.output.name; @@ -39,22 +33,20 @@ impl NodeCodegen for ExpandNode { let shape = match &self.shape { ExpandShape::Static(static_shape) => static_shape.to_tokens(), - ExpandShape::Runtime(Type::Tensor(shape_tensor)) => { - // Since we don't take ownership of the shape_tensor, `tensor_use_owned` is not needed here. - let tensor_name = &shape_tensor.name; - // The shape of the tensor is statically validated to be rank one during input parsing. - // The tensor must be downloaded from device to CPU for the expand operation. - // Additionally, it needs to be converted to an array for use in BroadcastArgs. - quote! { - TryInto::<[B::IntElem; #output_rank]>::try_into(#tensor_name.to_data().as_slice::().unwrap()).unwrap() + ExpandShape::Runtime(ty) => match Type::from(ty) { + Type::Tensor(shape_tensor) => { + let tensor_name = &shape_tensor.name; + quote! { + TryInto::<[B::IntElem; #output_rank]>::try_into(#tensor_name.to_data().as_slice::().unwrap()).unwrap() + } } - } - ExpandShape::Runtime(Type::Shape(shape)) => { - // Shape implements BroadcastArgs, allowing it to be passed directly to the expand method. - let shape_name = &shape.name; - quote! { #shape_name } - } - _ => panic!("Invalid shape source {:?}", self.shape), + Type::Shape(shape) => { + // Shape implements BroadcastArgs, allowing it to be passed directly to the expand method. + let shape_name = &shape.name; + quote! { #shape_name } + } + b => panic!("Invalid shape source {:?}", b), + }, }; quote! { @@ -70,10 +62,11 @@ impl NodeCodegen for ExpandNode { #[cfg(test)] mod tests { use burn::record::FullPrecisionSettings; + use onnx_ir::{ArgType, Argument, ElementType}; use super::*; use crate::burn::{ - ShapeType, TensorType, + TensorType, graph::BurnGraph, node::{expand::ExpandNode, test::assert_tokens}, }; @@ -121,7 +114,6 @@ mod tests { assert_tokens(graph.codegen(), expected); } - #[test] fn test_codegen_expand_shape() { let mut graph = BurnGraph::::default(); @@ -129,7 +121,12 @@ mod tests { graph.register(ExpandNode::new( TensorType::new_float("tensor1", 4), TensorType::new_float("tensor2", 4), - ExpandShape::Runtime(Type::Shape(ShapeType::new("shape1", 4))), + ExpandShape::Runtime(Argument { + name: "shape1".to_string(), + ty: ArgType::Shape(4), + value: None, + passed: false, + }), )); graph.register_input_output( @@ -177,12 +174,19 @@ mod tests { fn test_codegen_expand_tensor() { let mut graph = BurnGraph::::default(); - let shape_tensor_type = TensorType::new_int("tensor3", 4); - graph.register(ExpandNode::new( TensorType::new_float("tensor1", 4), TensorType::new_float("tensor2", 4), - ExpandShape::Runtime(Type::Tensor(shape_tensor_type)), + ExpandShape::Runtime(Argument { + name: "tensor3".to_string(), + ty: ArgType::Tensor(onnx_ir::TensorType { + elem_type: ElementType::Int32, + rank: 1, + static_shape: None, + }), + value: None, + passed: false, + }), )); graph.register_input_output( @@ -215,7 +219,7 @@ mod tests { pub fn forward( &self, tensor1: Tensor, - tensor3: Tensor, + tensor3: Tensor, ) -> Tensor { let tensor2 = tensor1.expand( TryInto::<[B::IntElem; 4usize]>::try_into(tensor3.to_data().as_slice::().unwrap()) diff --git a/crates/burn-import/src/burn/node/layer_norm.rs b/crates/burn-import/src/burn/node/layer_norm.rs index b9fd6d5d7c..867ce1252e 100644 --- a/crates/burn-import/src/burn/node/layer_norm.rs +++ b/crates/burn-import/src/burn/node/layer_norm.rs @@ -2,10 +2,11 @@ use super::{Node, NodeCodegen, SerializationBackend}; use crate::burn::{BurnImports, OtherType, Scope, TensorType, ToTokens, Type}; use burn::{ module::{ConstantRecord, Param, ParamId}, - nn::{LayerNormConfig, LayerNormRecord}, + nn::LayerNormRecord, record::{PrecisionSettings, Record}, tensor::{Tensor, TensorData}, }; +use onnx_ir::node::layer_norm::LayerNormConfig; use proc_macro2::TokenStream; use quote::quote; use serde::Serialize; diff --git a/crates/burn-import/src/burn/node/linear.rs b/crates/burn-import/src/burn/node/linear.rs index e8828295f6..8510d66131 100644 --- a/crates/burn-import/src/burn/node/linear.rs +++ b/crates/burn-import/src/burn/node/linear.rs @@ -2,10 +2,11 @@ use super::{Node, NodeCodegen, SerializationBackend}; use crate::burn::{BurnImports, OtherType, Scope, TensorType, ToTokens, Type}; use burn::{ module::{Param, ParamId}, - nn::{LinearConfig, LinearRecord}, + nn::LinearRecord, record::{PrecisionSettings, Record}, tensor::{Tensor, TensorData}, }; +use onnx_ir::node::linear::LinearConfig; use proc_macro2::TokenStream; use quote::quote; use serde::Serialize; diff --git a/crates/burn-import/src/burn/node/max_pool1d.rs b/crates/burn-import/src/burn/node/max_pool1d.rs index a77e068779..c092fef7f9 100644 --- a/crates/burn-import/src/burn/node/max_pool1d.rs +++ b/crates/burn-import/src/burn/node/max_pool1d.rs @@ -1,7 +1,8 @@ +use onnx_ir::node::max_pool1d::MaxPool1dConfig; use proc_macro2::TokenStream; use quote::quote; -use burn::{nn::pool::MaxPool1dConfig, record::PrecisionSettings}; +use burn::record::PrecisionSettings; use super::{Node, NodeCodegen}; use crate::burn::{BurnImports, OtherType, Scope, TensorType, ToTokens, Type}; @@ -92,10 +93,8 @@ impl NodeCodegen for MaxPool1dNode { mod tests { use super::*; use crate::burn::{TensorType, graph::BurnGraph, node::test::assert_tokens}; - use burn::{ - nn::{PaddingConfig1d, pool::MaxPool1dConfig}, - record::FullPrecisionSettings, - }; + use burn::record::FullPrecisionSettings; + use onnx_ir::node::padding::PaddingConfig1d; #[test] fn test_codegen() { diff --git a/crates/burn-import/src/burn/node/max_pool2d.rs b/crates/burn-import/src/burn/node/max_pool2d.rs index 3e3ae6ec09..15224c2af8 100644 --- a/crates/burn-import/src/burn/node/max_pool2d.rs +++ b/crates/burn-import/src/burn/node/max_pool2d.rs @@ -1,7 +1,8 @@ +use onnx_ir::node::max_pool2d::MaxPool2dConfig; use proc_macro2::TokenStream; use quote::quote; -use burn::{nn::pool::MaxPool2dConfig, record::PrecisionSettings}; +use burn::record::PrecisionSettings; use super::{Node, NodeCodegen}; use crate::burn::{BurnImports, OtherType, Scope, TensorType, ToTokens, Type}; @@ -96,7 +97,8 @@ mod tests { graph::BurnGraph, node::{max_pool2d::MaxPool2dNode, test::assert_tokens}, }; - use burn::{nn::PaddingConfig2d, nn::pool::MaxPool2dConfig, record::FullPrecisionSettings}; + use burn::record::FullPrecisionSettings; + use onnx_ir::node::padding::PaddingConfig2d; #[test] fn test_codegen() { diff --git a/crates/burn-import/src/burn/node/pad.rs b/crates/burn-import/src/burn/node/pad.rs index 0374a8832d..5b76421d05 100644 --- a/crates/burn-import/src/burn/node/pad.rs +++ b/crates/burn-import/src/burn/node/pad.rs @@ -2,17 +2,11 @@ use std::str::FromStr; use super::{Node, NodeCodegen}; use crate::burn::{Scope, TensorType, ToTokens, Type}; -use burn::config::Config; use burn::record::PrecisionSettings; +use onnx_ir::node::pad::PadConfig; use proc_macro2::TokenStream; use quote::quote; -#[derive(Config, Debug)] -pub struct PadConfig { - pub pads: Vec, - pub constant_value: f32, -} - #[derive(Debug, Clone, new)] pub struct PadNode { pub input: TensorType, diff --git a/crates/burn-import/src/burn/node/split.rs b/crates/burn-import/src/burn/node/split.rs index 2980959099..5943f94b6b 100644 --- a/crates/burn-import/src/burn/node/split.rs +++ b/crates/burn-import/src/burn/node/split.rs @@ -1,17 +1,10 @@ use super::{Node, NodeCodegen}; use crate::burn::{Scope, TensorType, ToTokens, Type}; -use burn::config::Config; use burn::record::PrecisionSettings; +use onnx_ir::node::split::SplitConfig; use proc_macro2::TokenStream; use quote::quote; -#[derive(Config, Debug)] -pub struct SplitConfig { - pub axis: usize, - pub split_size: Option, - pub split_sizes: Option>, -} - #[derive(Debug, Clone, new)] pub struct SplitNode { pub input: TensorType, diff --git a/crates/burn-import/src/burn/node/tile.rs b/crates/burn-import/src/burn/node/tile.rs index 68d31fae21..bda7dc9ecc 100644 --- a/crates/burn-import/src/burn/node/tile.rs +++ b/crates/burn-import/src/burn/node/tile.rs @@ -1,15 +1,10 @@ use super::{Node, NodeCodegen}; use crate::burn::{Scope, TensorType, ToTokens, Type}; -use burn::config::Config; use burn::record::PrecisionSettings; +use onnx_ir::node::tile::TileConfig; use proc_macro2::TokenStream; use quote::quote; -#[derive(Config, Debug)] -pub struct TileConfig { - pub repeats: Vec, -} - #[derive(Debug, Clone, new)] pub struct TileNode { pub input: TensorType, diff --git a/crates/burn-import/src/burn/node/top_k.rs b/crates/burn-import/src/burn/node/top_k.rs index e79a14c51d..bbd1c74d61 100644 --- a/crates/burn-import/src/burn/node/top_k.rs +++ b/crates/burn-import/src/burn/node/top_k.rs @@ -1,16 +1,10 @@ use super::{Node, NodeCodegen}; use crate::burn::{Scope, TensorType, Type}; -use burn::config::Config; use burn::record::PrecisionSettings; +use onnx_ir::node::topk::TopKConfig; use proc_macro2::TokenStream; use quote::{ToTokens, quote}; -#[derive(Config, Debug)] -pub struct TopKConfig { - pub axis: usize, - pub k: usize, -} - #[derive(Debug, Clone, new)] pub struct TopKNode { pub input: TensorType, diff --git a/crates/burn-import/src/burn/node/trilu.rs b/crates/burn-import/src/burn/node/trilu.rs index 18c0ac8196..2bb84e9447 100644 --- a/crates/burn-import/src/burn/node/trilu.rs +++ b/crates/burn-import/src/burn/node/trilu.rs @@ -1,16 +1,10 @@ use super::{Node, NodeCodegen}; use crate::burn::{Scope, TensorType, ToTokens, Type}; -use burn::config::Config; use burn::record::PrecisionSettings; +use onnx_ir::node::trilu::TriluConfig; use proc_macro2::TokenStream; use quote::quote; -#[derive(Config, Debug)] -pub struct TriluConfig { - pub upper: bool, - pub diagonal: i64, -} - #[derive(Debug, Clone, new)] pub struct TriluNode { pub input: TensorType, diff --git a/crates/burn-import/src/burn/node/unsqueeze.rs b/crates/burn-import/src/burn/node/unsqueeze.rs index afe22cf220..d8e40503c7 100644 --- a/crates/burn-import/src/burn/node/unsqueeze.rs +++ b/crates/burn-import/src/burn/node/unsqueeze.rs @@ -1,6 +1,7 @@ use super::{Node, NodeCodegen}; use crate::burn::{BurnImports, Scope, TensorType, ToTokens, Type}; use burn::record::PrecisionSettings; +use onnx_ir::node::unsqueeze::UnsqueezeConfig; use proc_macro2::TokenStream; use quote::quote; @@ -8,13 +9,7 @@ use quote::quote; pub struct UnsqueezeNode { pub input: Type, pub output: TensorType, - pub axes: UnsqueezeAxes, -} - -#[derive(Debug, Clone)] -pub enum UnsqueezeAxes { - Static(Vec), - Runtime(Type), + pub axes: UnsqueezeConfig, } impl NodeCodegen for UnsqueezeNode { @@ -25,8 +20,8 @@ impl NodeCodegen for UnsqueezeNode { fn input_types(&self) -> Vec { let input = self.input.clone(); match &self.axes { - UnsqueezeAxes::Static(_) => vec![input], - UnsqueezeAxes::Runtime(rt_type) => vec![input, rt_type.clone()], + UnsqueezeConfig::Static(_) => vec![input], + UnsqueezeConfig::Runtime(rt_type) => vec![input, Type::from(rt_type)], } } fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { @@ -34,17 +29,19 @@ impl NodeCodegen for UnsqueezeNode { let output_rank = self.output.rank.to_tokens(); let axes = match &self.axes { - UnsqueezeAxes::Static(static_axes) => static_axes.to_tokens(), - UnsqueezeAxes::Runtime(Type::Tensor(axes_tensor)) => { - let tensor_name = &axes_tensor.name; - quote! { - #tensor_name.to_data().as_slice::().unwrap().iter().map(|&x| x.to_isize()).collect::>() + UnsqueezeConfig::Static(static_axes) => static_axes.to_tokens(), + UnsqueezeConfig::Runtime(arg) => match Type::from(arg) { + Type::Tensor(axes_tensor) => { + let tensor_name = &axes_tensor.name; + quote! { + #tensor_name.to_data().as_slice::().unwrap().iter().map(|&x| x.to_isize()).collect::>() + } } - } - _ => panic!( - "UnsqueezeNode received invalid axes type: expected static axes or tensor but got {:?}", - self.axes - ), + _ => panic!( + "UnsqueezeNode received invalid axes type: expected tensor but got {:?}", + arg + ), + }, }; match &self.input { @@ -79,7 +76,7 @@ impl NodeCodegen for UnsqueezeNode { _ => {} } match &self.axes { - UnsqueezeAxes::Runtime(_) => { + UnsqueezeConfig::Runtime(_) => { imports.register("alloc::vec::Vec"); imports.register("burn::tensor::cast::ToElement"); } @@ -106,7 +103,7 @@ mod tests { graph.register(UnsqueezeNode::new( Type::Tensor(TensorType::new_float("tensor1", 3)), TensorType::new_float("tensor2", 5), - UnsqueezeAxes::Static([0, 4].into()), + UnsqueezeConfig::Static([0, 4].into()), )); graph.register_input_output(vec!["tensor1".to_string()], vec!["tensor2".to_string()]); diff --git a/crates/burn-import/src/onnx/mod.rs b/crates/burn-import/src/onnx/mod.rs index b0b67d79c3..e387b0bdb9 100644 --- a/crates/burn-import/src/onnx/mod.rs +++ b/crates/burn-import/src/onnx/mod.rs @@ -1,3 +1,2 @@ -mod op_configuration; mod to_burn; pub use to_burn::*; diff --git a/crates/burn-import/src/onnx/op_configuration.rs b/crates/burn-import/src/onnx/op_configuration.rs deleted file mode 100644 index e16c5f9772..0000000000 --- a/crates/burn-import/src/onnx/op_configuration.rs +++ /dev/null @@ -1,1943 +0,0 @@ -// TODO Move op_configuration.rs from burn-import to onnx-ir #3091 -// See https://github.com/tracel-ai/burn/issues/3091 - -use burn::nn::{ - BatchNormConfig, DropoutConfig, LayerNormConfig, LinearConfig, PaddingConfig1d, - PaddingConfig2d, PaddingConfig3d, - conv::{ - Conv1dConfig, Conv2dConfig, Conv3dConfig, ConvTranspose1dConfig, ConvTranspose2dConfig, - ConvTranspose3dConfig, - }, - pool::{AvgPool1dConfig, AvgPool2dConfig, MaxPool1dConfig, MaxPool2dConfig}, -}; - -use crate::burn::node::{ - expand::ExpandShape, pad::PadConfig, split::SplitConfig, tile::TileConfig, top_k::TopKConfig, - trilu::TriluConfig, unsqueeze::UnsqueezeAxes, -}; -use onnx_ir::ir::{ArgType, AttributeValue, Data, ElementType, Node, TensorData}; - -/// Create a Conv1dConfig from the attributes of the node -pub fn conv1d_config(curr: &Node) -> Conv1dConfig { - let mut kernel_shape = Vec::new(); // TODO default inferred from weight tensor per spec - let mut strides = vec![1]; - let mut pads = vec![0, 0]; - let mut dilations = vec![1]; - let mut group: usize = 1; - - let weight_shape = curr.inputs[1] - .value - .as_ref() - .expect("Conv1d: weight tensor must be present") - .shape - .clone(); - - // check if the bias is present - let bias = curr.inputs.len() == 3; - - for (key, value) in curr.attrs.iter() { - match key.as_str() { - "kernel_shape" => kernel_shape = value.clone().into_i64s(), - "strides" => strides = value.clone().into_i64s(), - "pads" => pads = value.clone().into_i64s(), - "dilations" => dilations = value.clone().into_i64s(), - "group" => group = value.clone().into_i64() as usize, - _ => {} - } - } - - // the channels are inverted in the weight tensor - let channels_in = weight_shape[1] * group; - let channels_out = weight_shape[0]; - - let padding = padding_config_1d(&pads); - - Conv1dConfig::new(channels_in, channels_out, kernel_shape[0] as usize) - .with_stride(strides[0] as usize) - .with_dilation(dilations[0] as usize) - .with_groups(group) - .with_bias(bias) - .with_padding(padding) -} - -/// Create a Conv2dConfig from the attributes of the node -pub fn conv2d_config(curr: &Node) -> Conv2dConfig { - let mut kernel_shape = Vec::new(); // TODO default inferred from weight tensor per spec - let mut strides = vec![1, 1]; - let mut pads = vec![0, 0, 0, 0]; - let mut dilations = vec![1, 1]; - let mut group: usize = 1; - - let weight_shape = curr.inputs[1] - .value - .as_ref() - .expect("Conv2d: weight tensor must be present") - .shape - .clone(); - - // check if the bias is present - let bias = curr.inputs.len() == 3; - - for (key, value) in curr.attrs.iter() { - match key.as_str() { - "kernel_shape" => kernel_shape = value.clone().into_i64s(), - "strides" => strides = value.clone().into_i64s(), - "pads" => pads = value.clone().into_i64s(), - "dilations" => dilations = value.clone().into_i64s(), - "group" => group = value.clone().into_i64() as usize, - _ => {} - } - } - - // the channels are inverted in the weight tensor - let channels_in = weight_shape[1] * group; - let channels_out = weight_shape[0]; - - let padding = padding_config_2d(&pads); - - Conv2dConfig::new( - [channels_in, channels_out], - [kernel_shape[0] as usize, kernel_shape[1] as usize], - ) - .with_stride([strides[0] as usize, strides[1] as usize]) - .with_dilation([dilations[0] as usize, dilations[1] as usize]) - .with_groups(group) - .with_bias(bias) - .with_padding(padding) -} - -/// Create a Conv3dConfig from the attributes of the node -pub fn conv3d_config(curr: &Node) -> Conv3dConfig { - let mut kernel_shape = Vec::new(); // TODO default inferred from weight tensor per spec - let mut strides = vec![1, 1, 1]; - let mut pads = vec![0, 0, 0, 0, 0, 0]; - let mut dilations = vec![1, 1, 1]; - let mut group: usize = 1; - - let weight_shape = curr.inputs[1] - .value - .as_ref() - .expect("Conv3d: weight tensor must be present") - .shape - .clone(); - - // check if the bias is present - let bias = curr.inputs.len() == 3; - - for (key, value) in curr.attrs.iter() { - match key.as_str() { - "kernel_shape" => kernel_shape = value.clone().into_i64s(), - "strides" => strides = value.clone().into_i64s(), - "pads" => pads = value.clone().into_i64s(), - "dilations" => dilations = value.clone().into_i64s(), - "group" => group = value.clone().into_i64() as usize, - _ => {} - } - } - - // the channels are inverted in the weight tensor - let channels_in = weight_shape[1] * group; - let channels_out = weight_shape[0]; - - let padding = padding_config_3d(&pads); - - Conv3dConfig::new( - [channels_in, channels_out], - [ - kernel_shape[0] as usize, - kernel_shape[1] as usize, - kernel_shape[2] as usize, - ], - ) - .with_stride([ - strides[0] as usize, - strides[1] as usize, - strides[2] as usize, - ]) - .with_dilation([ - dilations[0] as usize, - dilations[1] as usize, - dilations[2] as usize, - ]) - .with_groups(group) - .with_bias(bias) - .with_padding(padding) -} - -/// Create a MaxPool2dConfig from the attributes of the node -pub fn max_pool1d_config(curr: &Node) -> MaxPool1dConfig { - let mut kernel_shape = Vec::new(); - let mut stride = vec![1]; - let mut pads = vec![0, 0]; - let mut dilation = vec![1]; - - for (key, value) in curr.attrs.iter() { - match key.as_str() { - "kernel_shape" => kernel_shape = value.clone().into_i64s(), - "strides" => stride = value.clone().into_i64s(), - "pads" => pads = value.clone().into_i64s(), - "dilations" => dilation = value.clone().into_i64s(), - _ => {} - } - } - assert_eq!(kernel_shape.len(), 1); - assert_eq!(dilation.len(), 1); - assert_eq!(stride.len(), 1); - let padding = padding_config_1d(&pads); - - MaxPool1dConfig::new(kernel_shape[0] as usize) - .with_stride(stride[0] as usize) - .with_padding(padding) - .with_dilation(dilation[0] as usize) -} - -/// Create a MaxPool2dConfig from the attributes of the node -pub fn max_pool2d_config(curr: &Node) -> MaxPool2dConfig { - let mut kernel_shape = Vec::new(); - let mut strides = vec![1, 1]; - let mut pads = vec![0, 0, 0, 0]; - let mut dilations = vec![1, 1]; - - for (key, value) in curr.attrs.iter() { - match key.as_str() { - "kernel_shape" => kernel_shape = value.clone().into_i64s(), - "strides" => strides = value.clone().into_i64s(), - "pads" => pads = value.clone().into_i64s(), - "dilations" => dilations = value.clone().into_i64s(), - _ => {} - } - } - - let padding = padding_config_2d(&pads); - - MaxPool2dConfig::new([kernel_shape[0] as usize, kernel_shape[1] as usize]) - .with_strides([strides[0] as usize, strides[1] as usize]) - .with_padding(padding) - .with_dilation([dilations[0] as usize, dilations[1] as usize]) -} - -pub fn conv_transpose1d_config(curr: &Node) -> ConvTranspose1dConfig { - let mut attrs = curr.attrs.clone(); - - // Extract kernel_shape, default to an empty vector if not present - let kernel_shape = attrs - .remove("kernel_shape") - .map(AttributeValue::into_i64s) - .unwrap_or_default(); - - // Extract strides, default to 1 if not present - let stride = attrs - .remove("strides") - .map(AttributeValue::into_i64s) - .unwrap_or_else(|| vec![1]); - - // Extract padding, default to 0 if not present - let pads = attrs - .remove("pads") - .map(AttributeValue::into_i64s) - .unwrap_or_else(|| vec![0, 0]); - - // Extract dilations, default to 1 if not present - let dilations = attrs - .remove("dilations") - .map(AttributeValue::into_i64s) - .unwrap_or_else(|| vec![1]); - - // Extract group attribute, default to 1 - let group = attrs - .remove("group") - .map(AttributeValue::into_i64) - .unwrap_or(1) as usize; - - // Extract output_padding, default to 0 if not present - let output_padding = attrs - .remove("output_padding") - .map(AttributeValue::into_i64s) - .unwrap_or_else(|| vec![0]); - - // Ensure no unused attributes remain - if !attrs.is_empty() { - panic!("Not all attributes are used: {attrs:?}"); - } - // Check the pads are symmetric. - if pads.len() != 2 || pads[0] != pads[1] { - panic!( - "Asymmetric padding is not supported for ConvTranspose1d: {:?}", - pads - ); - } - - let weight_shape = curr.inputs[1] - .value - .as_ref() - .expect("ConvTranspose1d: weight tensor must be present") - .shape - .clone(); - - // Check if bias is present (third input) - let bias = curr.inputs.len() == 3; - - // Extract channels from the weight tensor shape [out_channels, in_channels] - let channels: [usize; 2] = [weight_shape[1] * group, weight_shape[0]]; - - // Create the ConvTranspose1d configuration - ConvTranspose1dConfig::new(channels, kernel_shape[0] as usize) - .with_stride(stride[0] as usize) - .with_padding(pads[0] as usize) - .with_dilation(dilations[0] as usize) - .with_padding_out(output_padding[0] as usize) - .with_groups(group) - .with_bias(bias) -} - -pub fn conv_transpose2d_config(curr: &Node) -> ConvTranspose2dConfig { - let mut attrs = curr.attrs.clone(); - let kernel_shape = attrs - .remove("kernel_shape") - .map(AttributeValue::into_i64s) - .unwrap_or_default(); - let stride = attrs - .remove("strides") - .map(AttributeValue::into_i64s) - .unwrap_or_else(|| vec![1, 1]); - let pads = attrs - .remove("pads") - .map(AttributeValue::into_i64s) - .unwrap_or_else(|| vec![0, 0, 0, 0]); - let dilations = attrs - .remove("dilations") - .map(AttributeValue::into_i64s) - .unwrap_or_else(|| vec![1, 1]); - let group = attrs - .remove("group") - .map(AttributeValue::into_i64) - .unwrap_or(1) as usize; - let output_padding = attrs - .remove("output_padding") - .map(AttributeValue::into_i64s) - .unwrap_or_else(|| vec![0, 0]); - - // Trick with remove + empty check is simplest way to not forget some attribute for runtime: - if !attrs.is_empty() { - panic!("Not all attributes are used: {attrs:?}"); - } - // Check the pads are symmetric. - let [left, top, right, bottom] = [pads[0], pads[1], pads[2], pads[3]]; - if left < 0 || top < 0 || right < 0 || bottom < 0 { - panic!("Negative pad values are not supported"); - } else if (left != right) || (top != bottom) { - panic!("Asymmetric padding is not supported"); - } - - let weight_shape = curr.inputs[1] - .value - .as_ref() - .expect("ConvTranspose2d: weight tensor must be present") - .shape - .clone(); - - // check if the bias is present - let bias = curr.inputs.len() == 3; - - // the channels are inverted in the weight tensor - let channels: [usize; 2] = [weight_shape[1] * group, weight_shape[0]]; - - ConvTranspose2dConfig::new( - channels, - [kernel_shape[0] as usize, kernel_shape[1] as usize], - ) - .with_stride([stride[0] as usize, stride[1] as usize]) - .with_padding([pads[0] as usize, pads[1] as usize]) - .with_dilation([dilations[0] as usize, dilations[1] as usize]) - .with_padding_out([output_padding[0] as usize, output_padding[1] as usize]) - .with_groups(group) - .with_bias(bias) -} - -pub fn conv_transpose3d_config(curr: &Node) -> ConvTranspose3dConfig { - let mut attrs = curr.attrs.clone(); - let kernel_shape = attrs - .remove("kernel_shape") - .map(AttributeValue::into_i64s) - .unwrap_or_default(); - let stride = attrs - .remove("strides") - .map(AttributeValue::into_i64s) - .unwrap_or_else(|| vec![1, 1, 1]); - let pads = attrs - .remove("pads") - .map(AttributeValue::into_i64s) - .unwrap_or_else(|| vec![0, 0, 0, 0, 0, 0]); - let dilations = attrs - .remove("dilations") - .map(AttributeValue::into_i64s) - .unwrap_or_else(|| vec![1, 1, 1]); - let group = attrs - .remove("group") - .map(AttributeValue::into_i64) - .unwrap_or(1) as usize; - let output_padding = attrs - .remove("output_padding") - .map(AttributeValue::into_i64s) - .unwrap_or_else(|| vec![0, 0, 0]); - - // Trick with remove + empty check is simplest way to not forget some attribute for runtime: - if !attrs.is_empty() { - panic!("Not all attributes are used: {attrs:?}"); - } - // Check the pads are symmetric. - let [left, top, front, right, bottom, back] = - [pads[0], pads[1], pads[2], pads[3], pads[4], pads[5]]; - - if left < 0 || top < 0 || front < 0 || right < 0 || bottom < 0 || back < 0 { - panic!("Negative pad values are not supported"); - } else if (left != right) || (top != bottom) || (front != back) { - panic!("Asymmetric padding is not supported"); - } - - let weight_shape = curr.inputs[1] - .value - .as_ref() - .expect("ConvTranspose3d: weight tensor must be present") - .shape - .clone(); - - // check if the bias is present - let bias = curr.inputs.len() == 3; - - // the channels are inverted in the weight tensor - let channels: [usize; 2] = [weight_shape[1] * group, weight_shape[0]]; - - ConvTranspose3dConfig::new( - channels, - [ - kernel_shape[0] as usize, - kernel_shape[1] as usize, - kernel_shape[2] as usize, - ], - ) - .with_stride([stride[0] as usize, stride[1] as usize, stride[2] as usize]) - .with_padding([pads[0] as usize, pads[1] as usize, pads[2] as usize]) - .with_dilation([ - dilations[0] as usize, - dilations[1] as usize, - dilations[2] as usize, - ]) - .with_padding_out([ - output_padding[0] as usize, - output_padding[1] as usize, - output_padding[2] as usize, - ]) - .with_groups(group) - .with_bias(bias) -} - -pub fn avg_pool1d_config(curr: &Node) -> AvgPool1dConfig { - let mut kernel_shape = Vec::new(); - let mut strides = vec![1]; - let mut pads = vec![0, 0]; - let mut count_include_pad: i64 = 0; - let mut ceil_mode: i64 = 0; - - for (key, value) in curr.attrs.iter() { - match key.as_str() { - "kernel_shape" => kernel_shape = value.clone().into_i64s(), - "strides" => strides = value.clone().into_i64s(), - "pads" => pads = value.clone().into_i64s(), - "count_include_pad" => count_include_pad = value.clone().into_i64(), - "ceil_mode" => ceil_mode = value.clone().into_i64(), - _ => {} - } - } - assert_eq!(kernel_shape.len(), 1); - assert_eq!(strides.len(), 1); - - if ceil_mode == 1 { - panic!("ceil_mode is not supported"); - } - - let padding = padding_config_1d(&pads); - - AvgPool1dConfig::new(kernel_shape[0] as usize) - .with_stride(strides[0] as usize) - .with_padding(padding) - .with_count_include_pad(count_include_pad == 1) -} -/// Create a AvgPool2dConfig from the attributes of the node -pub fn avg_pool2d_config(curr: &Node) -> AvgPool2dConfig { - let mut kernel_shape = Vec::new(); - let mut strides = vec![1, 1]; - let mut pads = vec![0, 0, 0, 0]; - let mut count_include_pad: i64 = 0; - let mut ceil_mode: i64 = 0; - - for (key, value) in curr.attrs.iter() { - match key.as_str() { - "kernel_shape" => kernel_shape = value.clone().into_i64s(), - "strides" => strides = value.clone().into_i64s(), - "pads" => pads = value.clone().into_i64s(), - "count_include_pad" => count_include_pad = value.clone().into_i64(), - "ceil_mode" => ceil_mode = value.clone().into_i64(), - _ => {} - } - } - - if ceil_mode == 1 { - panic!("ceil_mode is not supported"); - } - - let padding = padding_config_2d(&pads); - - AvgPool2dConfig::new([kernel_shape[0] as usize, kernel_shape[1] as usize]) - .with_strides([strides[0] as usize, strides[1] as usize]) - .with_padding(padding) - .with_count_include_pad(count_include_pad == 1) -} - -pub fn expand_config(node: &Node) -> ExpandShape { - match &node.inputs[1].ty { - ArgType::Tensor(tensor) => { - assert_eq!(tensor.rank, 1, "Expand: shape tensor must be 1D"); - assert!( - matches!(tensor.elem_type, ElementType::Int64), - "Expand: shape tensor must have element type int64" - ); - } - ArgType::Shape(_) => { - // Shapes are always 1-D int64 data, so nothing to assert here - } - _ => panic!("Only tensor input is valid for shape"), - } - - match &node.inputs[1].value { - Some(TensorData { - data: Data::Int64s(shape), - .. - }) => ExpandShape::Static(shape.clone()), - None => { - // we were unable to statically determine the input value, so we'll need to fetch it at runtime - ExpandShape::Runtime(crate::burn::Type::from(&node.inputs[1])) - } - _ => panic!( - "Shape data type must be int64, is {:?}", - &node.inputs[1].value - ), - } -} - -/// Create a FlattenConfig from the attributes of the node -pub fn flatten_config(curr: &Node) -> usize { - // the begin dimension is the first dimension (Default: 1 per ONNX spec) - let mut axis: i64 = 1; - - // check if the node has only one input - if curr.inputs.len() != 1 { - panic!( - "Flatten: multiple inputs are not supported (got {:?})", - curr.inputs.len() - ); - } - - // extract the shape of the input tensor - let tensor = match curr.inputs.first().unwrap().clone().ty { - ArgType::Tensor(tensor) => tensor, - _ => panic!("Only tensor input is valid"), - }; - - // check if the input tensor has at least 2 dimensions - if tensor.rank < 2 { - panic!( - "Flatten: input tensor must have at least 2 dimensions (got {:?})", - tensor.rank - ); - } - - // extract the attributes - for (key, value) in curr.attrs.iter() { - match key.as_str() { - "axis" => axis = value.clone().into_i64(), - _ => {} - } - } - - // if beg_dim is negative, it is counted from the end - if axis < 0 { - axis += tensor.rank as i64; - } - - axis as usize -} - -/// Create a GatherConfig from the attributes of the node -pub fn gather_config(curr: &Node) -> usize { - // Default: 0 per ONNX spec - let mut dim: i64 = 0; - - // check if the node has only one input - if curr.inputs.len() != 2 { - panic!("Gather: index tensor must be present"); - } - - // extract the shape of the input tensor - let input_dim = match curr.inputs.first().unwrap().clone().ty { - ArgType::Tensor(tensor) => tensor.rank as i64, - ArgType::Shape(_shape) => 1, //Shape is always 1-D - other => panic!("Only tensor or shape input is valid, got {:?}", other), - }; - - // extract the attributes - for (key, value) in curr.attrs.iter() { - match key.as_str() { - "axis" => dim = value.clone().into_i64(), - _ => {} - } - } - - // if dim is negative, it is counted from the end - if dim < 0 { - dim += input_dim; - } - - dim as usize -} - -/// Create a LinearConfig from the attributes of the node -pub fn linear_config(node: &Node) -> LinearConfig { - if node.inputs.len() < 2 { - panic!("Linear: missing weight tensor"); - } - - let weight_shape = node.inputs[1] - .value - .as_ref() - .expect("Linear: weight tensor must be present") - .shape - .clone(); - - // check if the weight tensor has at least 2 dimensions - if weight_shape.len() < 2 { - panic!( - "Linear: weight tensor must have at least 2 dimensions (got {:?})", - weight_shape.len() - ); - } - - let (in_size, out_size) = (weight_shape[0], weight_shape[1]); - - // check if the bias is present - let bias = node.inputs.len() == 3 && node.inputs[2].value.is_some(); - - LinearConfig::new(in_size, out_size).with_bias(bias) -} - -/// Create a DropoutConfig from an attribute and state of the node -pub fn dropout_config(node: &Node) -> DropoutConfig { - // Opset 7 and older store probability as an attribute - if node.attrs.contains_key("ratio") { - let prob = node.attrs.get("ratio").unwrap().clone().into_f32(); - return DropoutConfig::new(prob as f64); - } - - if node.inputs.len() < 2 { - panic!("Dropout configuration must have at least 2 inputs"); - } - - let ratio = node.inputs[1] - .value - .clone() - .expect("Dropout ratio must be passed in the second input") - .data - .into_scalar(); - - let prob = match ratio { - Data::Float16(ratio) => f64::from(f32::from(ratio)), - Data::Float32(ratio) => ratio as f64, - Data::Float64(ratio) => ratio, - _ => panic!("Dropout ratio must be a float"), - }; - - DropoutConfig::new(prob) -} - -/// Create log_softmax config from the attributes of the node -pub fn log_softmax_config(node: &Node) -> usize { - // the axis is the last dimension (Default: 1 per ONNX spec) - let mut axis: i64 = -1; - - // check if the node has only one input - if node.inputs.len() != 1 { - panic!( - "LogSoftmax: multiple inputs are not supported (got {:?})", - node.inputs.len() - ); - } - - // extract the shape of the input tensor - let tensor = match node.inputs.first().unwrap().clone().ty { - ArgType::Tensor(tensor) => tensor, - _ => panic!("Only tensor input is valid"), - }; - - // extract the attributes - for (key, value) in node.attrs.iter() { - match key.as_str() { - "axis" => axis = value.clone().into_i64(), - _ => {} - } - } - - // if axis is negative, it is counted from the end - if axis < 0 { - axis += tensor.rank as i64; - } - - axis as usize -} - -/// Create softmax config from the attributes of the node -pub fn softmax_config(node: &Node) -> usize { - // the axis is the last dimension (Default: 1 per ONNX spec) - let mut axis: i64 = -1; - - // check if the node has only one input - if node.inputs.len() != 1 { - panic!( - "Softmax: multiple inputs are not supported (got {:?})", - node.inputs.len() - ); - } - - // extract the shape of the input tensor - let tensor = match node.inputs.first().unwrap().clone().ty { - ArgType::Tensor(tensor) => tensor, - _ => panic!("Only tensor input is valid"), - }; - - // extract the attributes - for (key, value) in node.attrs.iter() { - match key.as_str() { - "axis" => axis = value.clone().into_i64(), - _ => {} - } - } - - // if axis is negative, it is counted from the end - if axis < 0 { - axis += tensor.rank as i64; - } - - axis as usize -} - -/// Create argmax config from the attributes of the node -pub fn argmax_config(node: &Node) -> usize { - let mut axis: i64 = 0; - - // check if the node has only one input - if node.inputs.len() != 1 { - panic!( - "Argmax: multiple inputs are not supported (got {:?})", - node.inputs.len() - ); - } - - // extract the shape of the input tensor - let tensor = match node.inputs.first().unwrap().clone().ty { - ArgType::Tensor(tensor) => tensor, - _ => panic!("Only tensor input is valid"), - }; - - // extract the attributes - for (key, value) in node.attrs.iter() { - match key.as_str() { - "axis" => axis = value.clone().into_i64(), - "select_last_index" => { - // not all params are supported in burn - if value.clone().into_i64() != 0 { - log::warn!( - "only select_last_index=0 is supported for argmax in burn. Ignoring supplied value (got {:?})", - value - ); - } - } - "keepdims" => { - // not all params are supported in burn - if value.clone().into_i64() != 1 { - panic!( - "Only keepdims=1 is supported for argmax in burn (got {:?})", - value - ); - } - } - _ => {} - } - } - - // if axis is negative, it is counted from the end - if axis < 0 { - axis += tensor.rank as i64; - } - - axis as usize -} - -/// Create concat config from the attributes of the node -pub fn concat_config(node: &Node) -> usize { - // the axis is the last dimension (Default: 1 per ONNX spec) - let mut axis: i64 = 1; - - // extract the shape of the input tensor - let tensor = match node.inputs.first().unwrap().clone().ty { - ArgType::Tensor(tensor) => tensor, - _ => panic!("Only tensor input is valid"), - }; - - // extract the attributes - for (key, value) in node.attrs.iter() { - match key.as_str() { - "axis" => axis = value.clone().into_i64(), - _ => {} - } - } - - // if axis is negative, it is counted from the end - if axis < 0 { - axis += tensor.rank as i64; - } - - axis as usize -} - -/// Create a BatchNormConfig from the attributes of the node -pub fn batch_norm_config(node: &Node) -> BatchNormConfig { - let weight_shape = node.inputs[1] - .value - .as_ref() - .expect("BatchNorm: weight tensor must be present") - .shape - .clone(); - - let num_features = weight_shape[0]; - - let mut epsilon = 0f32; - let mut momentum = 0f32; - - for (key, value) in node.attrs.iter() { - match key.as_str() { - "momentum" => momentum = value.clone().into_f32(), - "epsilon" => epsilon = value.clone().into_f32(), - _ => {} - } - } - - BatchNormConfig::new(num_features) - .with_epsilon(epsilon as f64) - .with_momentum(momentum as f64) -} - -/// Create a LayerNormConfig from the attributes of the node -pub fn layer_norm_config(node: &Node) -> (LayerNormConfig, bool) { - let weight_shape = node.inputs[1] - .value - .as_ref() - .expect("LayerNorm: weight tensor must be present") - .shape - .clone(); - - let num_features = weight_shape[0]; - - // When `stash_type` is `1` (default), perform operations in 32-bit float and - // cast the results back to original dtype - let mut stash_type = 1; - let mut axis = -1; - let mut epsilon = 1e-5; - - for (key, value) in node.attrs.iter() { - match key.as_str() { - "axis" => axis = value.clone().into_i64(), - "epsilon" => epsilon = value.clone().into_f32(), - "stash_type" => stash_type = value.clone().into_i64(), - _ => {} - } - } - - if axis != -1 && axis != weight_shape.len() as i64 - 1 { - panic!("LayerNorm: normalization is only supported on the last axis right now") - } - - ( - LayerNormConfig::new(num_features).with_epsilon(epsilon as f64), - stash_type == 1, - ) -} - -/// Create a TileConfig from the attributes of the node -pub fn tile_config(node: &Node) -> TileConfig { - let repeat = node - .inputs - .get(1) - .map(|input| { - if let Some(TensorData { data, .. }) = &input.value { - data.clone() - .into_i64s() - .iter() - .map(|&x| x as usize) - .collect() - } else { - vec![] - } - }) - .unwrap_or_default(); - TileConfig::new(repeat) -} - -/// Create a TopKConfig from the attributes of the node. -pub fn top_k_config(node: &Node) -> TopKConfig { - // extract the shape of the input data tensor - let data_tensor = match node.inputs.first().unwrap().clone().ty { - ArgType::Tensor(tensor) => tensor, - _ => panic!("Only tensor input is valid"), - }; - - let k = match node.inputs.get(1) { - Some(k_tensor) => k_tensor - .clone() - .value - .expect("TopK: only constant 'k' tensor is currently supported") - .data - .into_i64s()[0], - _ => node - .attrs - .get("k") - .expect("TopK: number of top elements 'k' is missing") - .clone() - .into_i64(), - }; - - let mut axis = match node.attrs.get("axis") { - Some(axis) => axis.clone().into_i64(), - None => -1, - }; - - // if axis is negative, it is counted from the end - if axis < 0 { - axis += data_tensor.rank as i64; - } - - if let Some(largest) = node.attrs.get("largest") { - if largest.clone().into_i64() != 1 { - unimplemented!("TopK: only largest elements is supported") - } - }; - - if let Some(sorted) = node.attrs.get("sorted") { - if sorted.clone().into_i64() != 1 { - unimplemented!("TopK: only sorted elements is supported") - } - }; - - TopKConfig::new(axis as usize, k as usize) -} - -/// Create a TriluConfig from the attributes of the node -pub fn trilu_config(node: &Node) -> TriluConfig { - let mut upper = true; - let mut diagonal = 0; - for (key, value) in node.attrs.iter() { - match key.as_str() { - "upper" => upper = value.clone().into_i64() != 0, - _ => {} - } - } - // The second input of the Trilu node is the diagonal value, coming from a constant node - if let Some(diagonal_arg) = node.inputs.get(1) { - if let Some(TensorData { - data: Data::Int64(diagonal_val), - .. - }) = &diagonal_arg.value - { - diagonal = *diagonal_val; - } - } - TriluConfig::new(upper, diagonal) -} - -/// Create a PadConfig from the attributes of the node -pub fn pad_config(node: &Node) -> PadConfig { - fn get_pads_input(node: &Node) -> Vec { - match &node.inputs[1].value { - Some(TensorData { data, .. }) => data.clone().into_i64s(), - _ => Vec::new(), - } - } - fn get_pads(node: &Node) -> Vec { - if node.inputs.is_empty() { - panic!("Pad: must provide data as input") - } - if node.inputs.len() >= 4 { - panic!("Pad: axes input is not supported") - } - - let input_dim = match &node.inputs.first().unwrap().ty { - ArgType::Tensor(tensor) => tensor.rank, - _ => panic!("Pad: Only tensor input is valid"), - }; - - //TODO : handle more possible attributes - let mut pads: Vec = get_pads_input(node) - .into_iter() - .map(|x| x as usize) - .collect(); - - for (key, value) in node.attrs.iter() { - match key.as_str() { - "pads" => { - pads = value - .clone() - .into_i64s() - .iter() - .map(|&x| { - if x < 0 { - panic!("Pad: Negative pad is not supported"); - } - x as usize - }) - .collect() - } - "mode" => { - let mode = value.clone().into_string(); - if mode != "constant" { - panic!("only constant mode is supported, given mode is {}", mode); - } - } - - _ => {} - } - } - - if pads.is_empty() { - panic!("Pad: pads should be given as attribute or as input"); - } - - if pads.len() != input_dim * 2 { - panic!("Pad: pads should be a 1D tensor of shape [2 * num_axes]"); - } - // TODO: Burn's pad should support 1D tensor - if input_dim < 2 { - panic!("Pad: input tensor should be rank 2 or higher"); - } - - let left_index = input_dim - 1; - let top_index = input_dim - 2; - let right_index = pads.len() - 1; - let bottom_index = pads.len() - 2; - let index_list = [left_index, top_index, right_index, bottom_index]; - - for (index, &item) in pads.iter().enumerate() { - if !index_list.contains(&index) && item != 0 { - panic!( - "Pad: padding will only be applied to the last two dimensions but found non zero padding for other dimensions" - ); - } - } - - let left = pads[left_index]; - let top = pads[top_index]; - let right = pads[right_index]; - let bottom = pads[bottom_index]; - vec![left, right, top, bottom] - } - fn get_constant_value(node: &Node) -> f32 { - // TODO: support int, boolean - let mut constant_value = node.inputs - .get(2) - .and_then(|input| match &input.value.as_ref().expect("Value input must be present").data { - Data::Float16s(constant_value) => { - constant_value.first().map(|&f| f32::from(f)) - } - Data::Float32s(constant_value) => { - constant_value.first().copied() - } - Data::Float64s(constant_value) => { - constant_value.first().map(|&f| f as f32) - } - Data::Float16(constant_value) => Some(f32::from(*constant_value)), - Data::Float32(constant_value) => Some(*constant_value), - Data::Float64(constant_value) => Some(*constant_value as f32), - _ => panic!("Pad: only float values are currently supported for constant value, submit an issue on github"), - }) - .unwrap_or(0.0); - - if node.attrs.contains_key("value") { - constant_value = node.attrs.get("value").map(|value| match value { - AttributeValue::Float32(value) => *value, - _ => panic!("Pad: only float32 values are currently supported for constant value as attribute, submit an issue on github"), - }).expect("constant_value should have had a value now"); - } - constant_value - } - - let pads = get_pads(node); - let constant_value = get_constant_value(node); - - PadConfig::new(pads, constant_value) -} - -/// Calculate the padding configuration for a 1D operations such as Convolution and Pooling. -/// -/// # Arguments -/// -/// * `pads` - The padding values -/// -/// # Panics -/// -/// * If the padding is negative -/// * If the padding is not symmetric -/// -/// # Returns -/// -/// * The padding configuration -/// -/// # Remarks -/// -/// This function is used when the padding is specified as a list of integers, -/// and not used when the padding is specified as a string, e.g. "SAME_UPPER". -fn padding_config_1d(pads: &[i64]) -> PaddingConfig1d { - let [left, right] = [pads[0], pads[1]]; - - if left < 0 || right < 0 { - panic!("Negative pad values are not supported"); - } else if left != right { - panic!("Asymmetric padding is not supported"); - } else if left == 0 && right == 0 { - // i.e. [0, 0] - PaddingConfig1d::Valid - } else if left == right { - // i.e. [2, 2] - PaddingConfig1d::Explicit(left as usize) - } else { - // Unaccounted for padding configuration - panic!("Padding configuration ({:?}) not supported", pads); - } -} - -/// Calculate the padding configuration for a 2D operations such as Convolution and Pooling. -/// -/// # Arguments -/// -/// * `pads` - The padding values -/// -/// # Panics -/// -/// * If the padding is negative -/// * If the padding is not symmetric -/// -/// # Returns -/// -/// * The padding configuration -/// -/// # Remarks -/// -/// This function is used when the padding is specified as a list of integers, -/// and not used when the padding is specified as a string, e.g. "SAME_UPPER". -fn padding_config_2d(pads: &[i64]) -> PaddingConfig2d { - let [left, top, right, bottom] = [pads[0], pads[1], pads[2], pads[3]]; - - if left < 0 || top < 0 || right < 0 || bottom < 0 { - panic!("Negative pad values are not supported"); - } else if (left != right) || (top != bottom) { - panic!("Asymmetric padding is not supported"); - } else if left == 0 && top == 0 && right == 0 && bottom == 0 { - // i.e [0, 0, 0, 0] - PaddingConfig2d::Valid - } else if left == right && top == bottom { - // i.e [2, 3, 2, 3] - PaddingConfig2d::Explicit(left as usize, top as usize) - } else { - // Unaccounted for padding configuration - panic!("Padding configuration ({:?}) not supported", pads); - } -} - -/// Calculate the padding configuration for a 3D operations such as Convolution and Pooling. -/// -/// # Arguments -/// -/// * `pads` - The padding values -/// -/// # Panics -/// -/// * If the padding is negative -/// * If the padding is not symmetric -/// -/// # Returns -/// -/// * The padding configuration -/// -/// # Remarks -/// -/// This function is used when the padding is specified as a list of integers, -/// and not used when the padding is specified as a string, e.g. "SAME_UPPER". -fn padding_config_3d(pads: &[i64]) -> PaddingConfig3d { - let [left, top, front, right, bottom, back] = - [pads[0], pads[1], pads[2], pads[3], pads[4], pads[5]]; - - if left < 0 || top < 0 || front < 0 || right < 0 || bottom < 0 || back < 0 { - panic!("Negative pad values are not supported"); - } else if (left != right) || (top != bottom) || (front != back) { - panic!("Asymmetric padding is not supported"); - } else if left == 0 && top == 0 && front == 0 && right == 0 && bottom == 0 && back == 0 { - // i.e [0, 0, 0, 0] - PaddingConfig3d::Valid - } else if left == right && top == bottom && front == back { - // i.e [2, 3, 2, 3] - PaddingConfig3d::Explicit(left as usize, top as usize, front as usize) - } else { - // Unaccounted for padding configuration - panic!("Padding configuration ({:?}) not supported", pads); - } -} - -// Create a LeakyReluConfig from the alpha attribute of the node -pub fn leaky_relu_config(node: &Node) -> f64 { - let mut alpha = 0.01; - - for (key, value) in node.attrs.iter() { - match key.as_str() { - "alpha" => alpha = value.clone().into_f32() as f64, - _ => {} - } - } - - alpha -} - -// Create a HardSigmoidConfig from the alpha and beta attributes of the node -pub fn hard_sigmoid_config(node: &Node) -> (f64, f64) { - let mut alpha = 0.2; - let mut beta = 0.5; - - for (key, value) in node.attrs.iter() { - match key.as_str() { - "alpha" => alpha = value.clone().into_f32() as f64, - "beta" => beta = value.clone().into_f32() as f64, - _ => {} - } - } - - (alpha, beta) -} - -pub fn reshape_config(node: &Node) -> Vec { - let mut allowzero = 0; - - for (key, value) in node.attrs.iter() { - match key.as_str() { - "allowzero" => allowzero = value.clone().into_i64(), - _ => {} - } - } - - // Burn does not support zero size shape (0 means false in ONNX) - // (see https://onnx.ai/onnx/operators/onnx__Reshape.html#attributes) - if allowzero != 0 { - panic!("Zero shape size is not supported"); - } - - // TODO: check "shape" attribute - if node.inputs.len() != 2 || node.inputs[1].value.is_none() { - panic!("Reshape: shape tensor must be present for {:?}", node); - } - - match &node.inputs[1].value { - Some(TensorData { data, shape, .. }) => { - assert_eq!(shape.len(), 1, "Reshape: shape tensor must be 1D"); - data.clone().into_i64s() - } - _ => panic!("Only tensor input is valid for shape"), - } -} - -pub fn resize_config(node: &Node) -> (String, Vec, Vec) { - let mut mode: String = "".to_string(); - - let mut scales: Vec; - let mut sizes: Vec; - - let input = if let ArgType::Tensor(tensor) = &node - .inputs - .first() - .expect("Resize: Input tensor must be present") - .ty - { - tensor - } else { - panic!("Resize: input must be a tensor") - }; - - // Note: we are ignoring some attributes because results are approximately the same - // and we are not supporting all the attributes of the Resize operator. - // However, some attributes are important to be checked and we are checking - // against the default values of the attributes. - // TODO revisit this when we have more Resize operators in the model - for (key, value) in node.attrs.iter() { - match key.as_str() { - "antialias" => assert_eq!( - value.clone().into_i32(), - 0, - "Resize: antialias other than 0 is not supported" - ), - "axes" => panic!("Resize: custom axes attribute is not supported"), - "coordinate_transformation_mode" => { - log::warn!("Resize: coordinate_transformation_mode is ignored") - } - - "cubic_coeff_a" => log::warn!("Resize: cubic_coeff_a is ignored"), - "exclude_outside" => assert_eq!( - value.clone().into_i32(), - 0, - "Resize: exclude_outside other than 0 is not supported" - ), - "extrapolation_value" => assert_eq!( - value.clone().into_f32(), - 0.0, - "Resize: extrapolation_value other than 0.0 is not supported" - ), - "keep_aspect_ratio_policy" => { - assert_eq!( - value.clone().into_string().to_lowercase(), - "stretch", - "Resize: keep_aspect_ratio_policy other than 'stretch' is not supported" - ) - } - "mode" => mode = value.clone().into_string().to_lowercase(), - "nearest_mode" => log::warn!("Resize: nearest_mode is ignored"), - - _ => {} - } - } - - let roi: Vec = node - .inputs - .get(1) - .map(|input| { - if let Some(TensorData { data, .. }) = &input.value { - data.clone().into_f32s() - } else { - vec![] - } - }) - .unwrap_or_default(); - - scales = node - .inputs - .get(2) - .map(|input| { - if let Some(TensorData { data, .. }) = &input.value { - data.clone().into_f32s() - } else { - vec![] - } - }) - .unwrap_or_default(); - - sizes = node - .inputs - .get(3) - .map(|input| { - if let Some(TensorData { data, .. }) = &input.value { - data.clone() - .into_i64s() - .iter() - .map(|&x| x as usize) - .collect() - } else { - vec![] - } - }) - .unwrap_or_default(); - - if mode.is_empty() { - panic!("Resize: mode attribute is required") - } - - if !roi.is_empty() { - panic!("Resize: roi input is not supported") - } - - if scales.is_empty() && sizes.is_empty() { - panic!("Resize: either scales or sizes input is required") - } - - if !scales.is_empty() { - assert!(scales.len() == input.rank); - // ignore the fist two items from scales - // because they are the batch and channel dimensions - scales = scales.iter().skip(2).cloned().collect(); - } - - if !sizes.is_empty() { - assert!(sizes.len() == input.rank); - // ignore the fist two items from sizes - // because they are the batch and channel dimensions - sizes = sizes.iter().skip(2).cloned().collect(); - } - - (mode, scales, sizes) -} - -//Note this function should only execute if the second input is a constant -//if it wasn't and the output shape was known, unsqueeze has been remapped to reshape -pub fn unsqueeze_config(node: &Node) -> UnsqueezeAxes { - // Check if axes attribute exists - for (key, value) in node.attrs.iter() { - match key.as_str() { - "axes" => return UnsqueezeAxes::Static(value.clone().into_i64s()), - _ => {} - } - } - - assert!( - !node.inputs.is_empty(), - "Unsqueeze: axes tensor must be present" - ); - - let input_value = &node.inputs[1]; - - match &node.inputs[1].ty { - ArgType::Tensor(tensor) => { - assert_eq!(tensor.rank, 1, "Unsqueeze: axes tensor must be 1D"); - if let Some(TensorData { - data: Data::Int64s(shape), - .. - }) = input_value.value.as_ref() - { - UnsqueezeAxes::Static(shape.clone()) - } else { - UnsqueezeAxes::Runtime(crate::burn::Type::from(&node.inputs[1])) - } - } - _ => panic!("Arg for unsqueeze must be tensor or scalar"), - } -} - -pub fn clip_config(node: &Node) -> (Option, Option) { - let mut min_result: Option = None; - let mut max_result: Option = None; - - // For Clip Opset 6+ , the min and max values are attributes - for (key, value) in node.attrs.iter() { - match key.as_str() { - "min" => { - let min = value.clone().into_f32() as f64; - min_result = Some(min); - } - "max" => { - let max = value.clone().into_f32(); - max_result = Some(max as f64); - } - _ => {} - } - } - - // For Clip Opset 11+ , the min and max values are inputs - // Get the min and max values from the input values - if min_result.is_none() && max_result.is_none() { - let min = &node.inputs[1].value; - let max = &node.inputs[2].value; - - if min_result.is_none() && min.is_some() { - let min = min.clone().unwrap().data.into_scalar(); - min_result = match min { - Data::Float16(min) => Some(f32::from(min) as f64), - Data::Float32(min) => Some(min as f64), - Data::Float64(min) => Some(min), - _ => panic!("Clip: only float min is supported"), - }; - } - - if max_result.is_none() && max.is_some() { - let max = max.clone().unwrap().data.into_scalar(); - max_result = match max { - Data::Float16(max) => Some(f32::from(max) as f64), - Data::Float32(max) => Some(max as f64), - Data::Float64(max) => Some(max), - _ => panic!("Clip: only float max is supported"), - }; - } - } - - if min_result.is_none() && max_result.is_none() { - panic!("Clip: min and max values must be either attributes or inputs"); - } - - (min_result, max_result) -} - -pub fn reduce_max_config(node: &Node) -> Option { - let mut axes = Vec::new(); - let mut keepdims = 1; - - let tensor = match node.inputs.first().unwrap().clone().ty { - ArgType::Tensor(tensor) => tensor, - _ => panic!("Only tensor input is valid"), - }; - - // Extract the attributes - for (key, value) in node.attrs.iter() { - match key.as_str() { - "axes" => axes = value.clone().into_i64s(), - "keepdims" => keepdims = value.clone().into_i64(), - _ => {} - } - } - - if axes.len() > 1 { - panic!("ReduceMax: reducing on multiple dimensions is not supported") - } - - if axes.is_empty() && keepdims == 1 { - panic!("ReduceMax: axes must be provided with keepdims") - } - - if !axes.is_empty() && keepdims == 0 { - // Not supported in Burn - panic!("ReduceMax: the reduce operation must preserve the reduced dimension") - } - - if axes.is_empty() { - None - } else { - let mut dim = axes[0]; - - if dim < 0 { - // Accepted range is [-r, r-1] where r = rank(data) but Burn only supports positive dim - dim += tensor.rank as i64; - } - Some(dim as usize) - } -} - -pub fn reduce_min_config(node: &Node) -> Option { - let mut axes = Vec::new(); - let mut keepdims = 1; - - let tensor = match node.inputs.first().unwrap().clone().ty { - ArgType::Tensor(tensor) => tensor, - _ => panic!("Only tensor input is valid"), - }; - - // Extract the attributes - for (key, value) in node.attrs.iter() { - match key.as_str() { - "axes" => axes = value.clone().into_i64s(), - "keepdims" => keepdims = value.clone().into_i64(), - _ => {} - } - } - - if axes.len() > 1 { - panic!("ReduceMin: reducing on multiple dimensions is not supported") - } - - if axes.is_empty() && keepdims == 1 { - panic!("ReduceMin: axes must be provided with keepdims") - } - - if !axes.is_empty() && keepdims == 0 { - panic!("ReduceMin: the reduce operation must preserve the reduced dimension") - } - - if axes.is_empty() { - None - } else { - let mut dim = axes[0]; - - if dim < 0 { - dim += tensor.rank as i64; - } - Some(dim as usize) - } -} - -pub fn reduce_mean_config(node: &Node) -> Option { - let mut axes = Vec::new(); - let mut keepdims = 1; - - let tensor = match node.inputs.first().unwrap().clone().ty { - ArgType::Tensor(tensor) => tensor, - _ => panic!("Only tensor input is valid"), - }; - - // Extract the attributes - for (key, value) in node.attrs.iter() { - match key.as_str() { - "axes" => axes = value.clone().into_i64s(), - "keepdims" => keepdims = value.clone().into_i64(), - _ => {} - } - } - - if axes.len() > 1 { - panic!("ReduceMean: reducing on multiple dimensions is not supported") - } - - if axes.is_empty() && keepdims == 1 { - panic!("ReduceMean: axes must be provided with keepdims") - } - - if !axes.is_empty() && keepdims == 0 { - // Not supported in Burn - panic!("ReduceMean: the reduce operation must preserve the reduced dimension") - } - - if axes.is_empty() { - None - } else { - let mut dim = axes[0]; - - if dim < 0 { - // Accepted range is [-r, r-1] where r = rank(data) but Burn only supports positive dim - dim += tensor.rank as i64; - } - Some(dim as usize) - } -} - -pub fn reduce_prod_config(node: &Node) -> Option { - let mut axes = Vec::new(); - let mut keepdims = 1; - - let tensor = match node.inputs.first().unwrap().clone().ty { - ArgType::Tensor(tensor) => tensor, - _ => panic!("Only tensor input is valid"), - }; - - // Extract the attributes - for (key, value) in node.attrs.iter() { - match key.as_str() { - "axes" => axes = value.clone().into_i64s(), - "keepdims" => keepdims = value.clone().into_i64(), - // TODO: handle noop_with_empty_axes (opset 18) - _ => {} - } - } - - if axes.len() > 1 { - panic!("ReduceProd: reducing on multiple dimensions is not supported") - } - - if axes.is_empty() && keepdims == 1 { - panic!("ReduceProd: axes must be provided with keepdims") - } - - if !axes.is_empty() && keepdims == 0 { - // Not supported in Burn - panic!("ReduceProd: the reduce operation must preserve the reduced dimension") - } - - if axes.is_empty() { - None - } else { - let mut dim = axes[0]; - - if dim < 0 { - // Accepted range is [-r, r-1] where r = rank(data) but Burn only supports positive dim - dim += tensor.rank as i64; - } - Some(dim as usize) - } -} - -pub fn reduce_sum_config(node: &Node) -> Option { - let mut axes = Vec::new(); - let mut keepdims = 1; - - let tensor = match node.inputs.first().unwrap().clone().ty { - ArgType::Tensor(tensor) => tensor, - _ => panic!("Only tensor input is valid"), - }; - - // Extract the attributes - for (key, value) in node.attrs.iter() { - match key.as_str() { - "keepdims" => keepdims = value.clone().into_i64(), - "axes" => axes = value.clone().into_i64s(), - // TODO: handle noop_with_empty_axes - _ => {} - } - } - - // TODO: Handle case where axes are passed in. Will require its own ReduceSumNode instead of a UnaryNode. - if let Some(value) = node - .inputs - .get(1) - .and_then(|argument| argument.value.as_ref()) - { - axes = value.clone().data.into_i64s(); - } - - if axes.len() > 1 { - panic!("ReduceMean: reducing on multiple dimensions is not supported") - } - - if axes.is_empty() && keepdims == 1 { - panic!("ReduceMean: axes must be provided with keepdims") - } - - if !axes.is_empty() && keepdims == 0 { - // Not supported in Burn - panic!("ReduceMean: the reduce operation must preserve the reduced dimension") - } - - if axes.is_empty() { - None - } else { - let mut dim = axes[0]; - - if dim < 0 { - // Accepted range is [-r, r-1] where r = rank(data) but Burn only supports positive dim - dim += tensor.rank as i64; - } - Some(dim as usize) - } -} - -pub fn shape_config(curr: &Node) -> (usize, usize) { - if curr.inputs.len() != 1 { - panic!( - "Shape: multiple inputs are not supported (got {:?})", - curr.inputs.len() - ); - } - - // Extract the shape of the input tensor - let tensor = match curr.inputs.first().unwrap().clone().ty { - ArgType::Tensor(tensor) => tensor, - _ => panic!("Only tensor input is valid"), - }; - - // Default: all axes up to the last one (included) - let mut start_dim: i64 = 0; - let mut end_dim: i64 = tensor.rank as i64; - - // Extract the attributes - for (key, value) in curr.attrs.iter() { - match key.as_str() { - "start" => start_dim = value.clone().into_i64(), - "end" => end_dim = value.clone().into_i64(), - _ => {} - } - } - - // If dim is negative, it is counted from the end - if start_dim < 0 { - start_dim += tensor.rank as i64; - } - if end_dim < 0 { - end_dim += tensor.rank as i64; - } - - (start_dim as usize, end_dim as usize) -} - -pub fn transpose_config(curr: &Node) -> Vec { - if curr.inputs.len() != 1 { - panic!( - "Transpose: multiple inputs are not supported (got {:?})", - curr.inputs.len() - ); - } - - // Extract the shape of the input tensor - let tensor = match curr.inputs.first().unwrap().clone().ty { - ArgType::Tensor(tensor) => tensor, - _ => panic!("Only tensor input is valid"), - }; - - // Default: reverse the dimensions - let mut perm = (0..tensor.rank as i64).rev().collect::>(); - - if let Some(axes) = curr.attrs.get("perm") { - perm = axes.clone().into_i64s(); - } - - perm -} - -pub fn squeeze_config(curr: &Node) -> Vec { - let axes = curr - .attrs - .iter() - .filter_map(|(key, value)| { - if key == "axes" { - Some(value.clone().into_i64s()) - } else { - None - } - }) - .next() - .unwrap_or_else(Vec::new); - - match curr.inputs.first().unwrap().clone().ty { - ArgType::Tensor(tensor) => tensor, - _ => panic!("Only tensor input is valid"), - }; - - axes -} -pub fn split_config(node: &Node) -> SplitConfig { - // Initialize the axis to split along (default is 0 as per ONNX specification) - let mut axis: i64 = 0; - // Holds the uniform split size if calculated or provided - let mut split_size: Option = None; - // Holds the custom split sizes if provided as input - let mut split_sizes: Option> = None; - - // Extract the input tensor type to determine rank and shape - let tensor = match node.inputs.first().unwrap().ty { - ArgType::Tensor(ref tensor) => tensor, - _ => panic!("Split: Input must be a valid tensor"), - }; - - // Optionally store the number of outputs if provided as an attribute - let mut num_outputs: Option = None; - - // Iterate through node attributes to extract relevant values - for (key, value) in node.attrs.iter() { - match key.as_str() { - "axis" => axis = value.clone().into_i64(), - "num_outputs" => num_outputs = Some(value.clone().into_i64() as usize), - _ => {} - } - } - - // Handle the case when num_outputs is provided to calculate uniform split size - if let Some(num_outputs) = num_outputs { - if num_outputs == 0 { - panic!("Split: 'num_outputs' must be a positive value greater than zero"); - } - - let dim_size = tensor - .static_shape - .as_ref() - .expect("Split: Static shape must be known to calculate split size")[axis as usize]; - - // Calculate the split size considering any remainder for non-evenly divisible dimensions - let calculated_split_size = - dim_size / (num_outputs - (dim_size % num_outputs != 0) as usize); - - if calculated_split_size == 0 { - panic!( - "Split: Calculated split size is zero. Please ensure 'num_outputs' is valid for the dimension size" - ); - } - - // Assign the calculated split size - split_size = Some(calculated_split_size); - } - - // Adjust axis if negative to count from the end as per ONNX spec - if axis < 0 { - axis += tensor.rank as i64; - } - - // Check for custom split sizes provided as a second input - if node.inputs.len() > 1 && node.inputs[1].value.is_some() { - let sizes = node.inputs[1] - .value - .as_ref() - .unwrap() - .data - .clone() - .into_usizes(); - - if !sizes.is_empty() { - split_sizes = Some(sizes); - } - } - - // Ensure that only one of 'split_sizes' or 'num_outputs' is specified - if split_sizes.is_some() && split_size.is_some() { - panic!( - "Split: Cannot specify both 'split' input and 'num_outputs' attribute simultaneously" - ); - } - - // Infer split_size if neither custom split_sizes nor split_size is provided - if split_sizes.is_none() && split_size.is_none() { - let num_outputs = node.outputs.len(); - let dim_size = tensor - .static_shape - .as_ref() - .expect("Split: Static shape must be known to infer split size")[axis as usize]; - - // Calculate inferred split size based on number of outputs - let calculated_split_size = - dim_size / (num_outputs - (dim_size % num_outputs != 0) as usize); - - if calculated_split_size == 0 { - panic!( - "Split: Inferred split size is zero. Please ensure the number of outputs is valid for the dimension size" - ); - } - - split_size = Some(calculated_split_size); - } - - // Return the configuration for splitting operation - SplitConfig { - axis: axis as usize, - split_size, - split_sizes, - } -} - -pub fn one_hot_config(curr: &Node) -> (usize, [f32; 2], i64) { - let depth = curr.inputs[1] - .value - .clone() - .expect("OneHot: Only constant depth is currently supported") - .data - .into_i64(); - - let values = curr.inputs[2] - .value - .clone() - .expect("OneHot: Only constant on/off values is currently supported") - .data - .into_f32s(); - - let axis = curr - .attrs - .get("axis") - .map(|val| val.clone().into_i64()) - .unwrap_or(-1); - - (depth as usize, values.try_into().unwrap(), axis) -} - -pub fn gemm_config(curr: &Node) -> (f32, f32, i64, i64) { - let alpha = curr - .attrs - .get("alpha") - .map(|val| val.clone().into_f32()) - .unwrap_or(1.0); - let beta = curr - .attrs - .get("beta") - .map(|val| val.clone().into_f32()) - .unwrap_or(1.0); - let trans_a = curr - .attrs - .get("transA") - .map(|val| val.clone().into_i64()) - .unwrap_or(0); - let trans_b = curr - .attrs - .get("transB") - .map(|val| val.clone().into_i64()) - .unwrap_or(0); - - (alpha, beta, trans_a, trans_b) -} diff --git a/crates/burn-import/src/onnx/to_burn.rs b/crates/burn-import/src/onnx/to_burn.rs index 503cf9f733..a6e0ddea5d 100644 --- a/crates/burn-import/src/onnx/to_burn.rs +++ b/crates/burn-import/src/onnx/to_burn.rs @@ -69,24 +69,31 @@ use crate::{ logger::init_log, }; -use super::op_configuration::{ - argmax_config, avg_pool1d_config, avg_pool2d_config, batch_norm_config, clip_config, - concat_config, conv_transpose1d_config, conv_transpose2d_config, conv_transpose3d_config, - conv1d_config, conv2d_config, conv3d_config, dropout_config, expand_config, flatten_config, - gather_config, gemm_config, hard_sigmoid_config, layer_norm_config, leaky_relu_config, - linear_config, log_softmax_config, max_pool1d_config, max_pool2d_config, one_hot_config, - pad_config, reduce_max_config, reduce_mean_config, reduce_min_config, reduce_prod_config, - reduce_sum_config, reshape_config, resize_config, shape_config, softmax_config, split_config, - squeeze_config, tile_config, top_k_config, transpose_config, trilu_config, unsqueeze_config, -}; use onnx_ir::{ convert_constant_value, ir::{ ArgType, Argument as OnnxArgument, Data, ElementType, Node, NodeType, OnnxGraph, TensorType as OnnxTensorType, }, - node::slice::slice_config, + node::{ + argmax::argmax_config, avg_pool1d::avg_pool1d_config, avg_pool2d::avg_pool2d_config, + batch_norm::batch_norm_config, clip::clip_config, concat::concat_config, + conv_transpose1d::conv_transpose1d_config, conv_transpose2d::conv_transpose2d_config, + conv_transpose3d::conv_transpose3d_config, conv1d::conv1d_config, conv2d::conv2d_config, + conv3d::conv3d_config, dropout::dropout_config, expand::expand_config, + flatten::flatten_config, gather::gather_config, gemm::gemm_config, + hard_sigmoid::hard_sigmoid_config, layer_norm::layer_norm_config, + leaky_relu::leaky_relu_config, linear::linear_config, log_softmax::log_softmax_config, + max_pool1d::max_pool1d_config, max_pool2d::max_pool2d_config, one_hot::one_hot_config, + pad::pad_config, reduce_max::reduce_max_config, reduce_mean::reduce_mean_config, + reduce_min::reduce_min_config, reduce_prod::reduce_prod_config, + reduce_sum::reduce_sum_config, reshape::reshape_config, resize::resize_config, + slice::slice_config, softmax::softmax_config, split::split_config, squeeze::squeeze_config, + tile::tile_config, topk::top_k_config, transpose::transpose_config, trilu::trilu_config, + unsqueeze::unsqueeze_config, + }, parse_onnx, + util::shape_config, }; pub use crate::burn::graph::RecordType; @@ -1022,6 +1029,8 @@ impl ParsedOnnxGraph { fn conv1d_conversion(node: Node) -> Conv1dNode { let input = TensorType::from(node.inputs.first().unwrap()); let output = TensorType::from(node.outputs.first().unwrap()); + + // Get configuration from onnx-ir let config = conv1d_config(&node); let bias = node.inputs.len() == 3; @@ -1070,6 +1079,8 @@ impl ParsedOnnxGraph { fn max_pool1d_conversion(node: Node) -> MaxPool1dNode { let input = TensorType::from(node.inputs.first().unwrap()); let output = TensorType::from(node.outputs.first().unwrap()); + + // Get configuration from onnx-ir let config = max_pool1d_config(&node); let name = &node.name; @@ -1114,7 +1125,21 @@ impl ParsedOnnxGraph { fn conv_transpose1d_conversion(node: Node) -> ConvTranspose1dNode { let input = TensorType::from(node.inputs.first().unwrap()); let output = TensorType::from(node.outputs.first().unwrap()); - let config = conv_transpose1d_config(&node); + + // Get configuration from onnx-ir + let onnx_config = conv_transpose1d_config(&node); + + // Convert to burn ConvTranspose1dConfig + let config = burn::nn::conv::ConvTranspose1dConfig::new( + [onnx_config.channels_in, onnx_config.channels_out], + onnx_config.kernel_size, + ) + .with_stride(onnx_config.stride) + .with_padding(onnx_config.padding) + .with_dilation(onnx_config.dilation) + .with_padding_out(onnx_config.padding_out) + .with_groups(onnx_config.groups) + .with_bias(onnx_config.bias); let bias = node.inputs.len() == 3; let weight = extract_data_serialize::(1, &node).unwrap(); @@ -1160,6 +1185,8 @@ impl ParsedOnnxGraph { fn avg_pool_1d_conversion(node: Node) -> AvgPool1dNode { let input = TensorType::from(node.inputs.first().unwrap()); let output = TensorType::from(node.outputs.first().unwrap()); + + // Get configuration from onnx-ir let config = avg_pool1d_config(&node); let name = &node.name; diff --git a/crates/onnx-ir/README.md b/crates/onnx-ir/README.md index 69e2cf8085..de80a3beb9 100644 --- a/crates/onnx-ir/README.md +++ b/crates/onnx-ir/README.md @@ -1,13 +1,52 @@ # ONNX-IR -ONNX-IR is a pure Rust library for parsing ONNX models into an intermediate representation that can -be used to generate code for various ML/DL frameworks. It's part of the Burn project, with key -features including ONNX model parsing, rank inference, and node remapping. The crate supports -converting ONNX models to Burn graphs and includes utilities for handling constants and graph -transformations. +ONNX-IR is a pure Rust library for parsing ONNX models into an intermediate representation (IR) that +can be used to generate code for various ML/DL frameworks. It's a core component of the Burn model +import system, providing a clean abstraction layer between ONNX protobuf structures and Burn's +tensor operations. -For a full list of currently supported operators, please check -[here](https://github.com/tracel-ai/burn/blob/main/crates/burn-import/SUPPORTED-ONNX-OPS.md) +## Architecture + +The ONNX-IR crate is designed with the following components: + +- **IR Core** (`ir.rs`): Defines the core data structures such as `Node`, `NodeType`, `Argument`, + etc. +- **Protocol Conversion** (`proto_conversion.rs`): Converts ONNX protobuf structures to IR +- **ONNX Parsing** (`from_onnx.rs`): Handles the parsing of ONNX models into the IR +- **Rank Inference** (`rank_inference.rs`): Computes output tensor ranks for each operation +- **Node Implementations** (`node/`): Contains operation-specific configurations and rank inference + functions +- **Node Remapping** (`node_remap.rs`): Maps generic ONNX operations to dimension-specific + alternatives + +## Usage + +ONNX-IR is typically used through the `burn-import` crate, but can also be used standalone: + +```rust +use onnx_ir::{parse_onnx, OnnxGraph}; +use std::path::Path; + +// Parse an ONNX model into the IR +let graph: OnnxGraph = parse_onnx(Path::new("path/to/model.onnx")); + +// Work with the IR +for node in &graph.nodes { + println!("Node: {}, Type: {:?}", node.name, node.node_type); + + // Access inputs and outputs + for input in &node.inputs { + println!(" Input: {}", input.name); + } + + for output in &node.outputs { + println!(" Output: {}", output.name); + } +} + +// Convert to another framework's representation +// (This is typically done by burn-import or another conversion layer) +``` ## ONNX Compatibility @@ -35,8 +74,16 @@ inferred_model = shape_inference.infer_shapes(upgraded_model) onnx.save(inferred_model, 'upgraded_model.onnx') ``` -For a full list of currently supported operators, please check -[here](https://github.com/tracel-ai/burn/blob/main/crates/burn-import/SUPPORTED-ONNX-OPS.md) +## Resources + +- **ONNX to Burn Conversion Guide**: For detailed implementation guidance on adding new operators, + see the + [ONNX to Burn conversion guide](https://github.com/tracel-ai/burn/blob/main/contributor-book/src/guides/onnx-to-burn-conversion-tool.md). + +- **Supported ONNX Operators**: For a full list of currently supported ONNX operators, please see + the + [Supported ONNX Operators table](https://github.com/tracel-ai/burn/blob/main/crates/burn-import/SUPPORTED-ONNX-OPS.md). -To see how to use this for generating burn graphs, see -[here](crates/burn-import/src/onnx/to_burn.rs). +- **Burn Integration**: ONNX-IR serves as the foundation for the ONNX import support in Burn. The + conversion from ONNX-IR to Burn graphs is implemented in + [`burn-import/src/onnx/to_burn.rs`](https://github.com/tracel-ai/burn/blob/main/crates/burn-import/src/onnx/to_burn.rs). diff --git a/crates/onnx-ir/src/lib.rs b/crates/onnx-ir/src/lib.rs index b195b88adb..7c7340482e 100644 --- a/crates/onnx-ir/src/lib.rs +++ b/crates/onnx-ir/src/lib.rs @@ -10,4 +10,4 @@ pub mod util; pub use from_onnx::convert_constant_value; pub use from_onnx::parse_onnx; -pub use ir::OnnxGraph; +pub use ir::*; diff --git a/crates/onnx-ir/src/node/argmax.rs b/crates/onnx-ir/src/node/argmax.rs new file mode 100644 index 0000000000..153eeac7b8 --- /dev/null +++ b/crates/onnx-ir/src/node/argmax.rs @@ -0,0 +1,132 @@ +use crate::ir::{ArgType, ElementType, Node, TensorType}; + +/// Create argmax config from the attributes of the node +pub fn argmax_config(node: &Node) -> usize { + let mut axis: i64 = 0; + + // check if the node has only one input + if node.inputs.len() != 1 { + panic!( + "Argmax: multiple inputs are not supported (got {:?})", + node.inputs.len() + ); + } + + // extract the shape of the input tensor + let tensor = match node.inputs.first().unwrap().clone().ty { + ArgType::Tensor(tensor) => tensor, + _ => panic!("Only tensor input is valid"), + }; + + // extract the attributes + for (key, value) in node.attrs.iter() { + match key.as_str() { + "axis" => axis = value.clone().into_i64(), + "select_last_index" => { + // not all params are supported in burn + if value.clone().into_i64() != 0 { + log::warn!( + "only select_last_index=0 is supported for argmax in burn. Ignoring supplied value (got {:?})", + value + ); + } + } + "keepdims" => { + // not all params are supported in burn + if value.clone().into_i64() != 1 { + panic!( + "Only keepdims=1 is supported for argmax in burn (got {:?})", + value + ); + } + } + _ => {} + } + } + + // if axis is negative, it is counted from the end + if axis < 0 { + axis += tensor.rank as i64; + } + + axis as usize +} + +/// Update output rank for ArgMax (same as input rank). +pub fn argmax_update_outputs(node: &mut Node) { + log::debug!("ArgMax rank inference for node {}", node.name); + + if node.inputs.len() != 1 { + panic!("ArgMax: multiple inputs are not supported"); + } + let tensor = match &node.inputs[0].ty { + ArgType::Tensor(tensor) => tensor, + _ => panic!("Only tensor input is valid"), + }; + + log::debug!("ArgMax input rank for {}: {}", node.name, tensor.rank); + + // Note: argmax in burn does not support keepdims=false + node.outputs[0].ty = ArgType::Tensor(TensorType { + elem_type: ElementType::Int64, + rank: tensor.rank, + static_shape: None, + }); + + log::debug!("ArgMax output rank for {}: {}", node.name, tensor.rank); +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::{Argument, ElementType, NodeType}; + use crate::node::test_utils::NodeBuilder; + + fn create_test_node(axis: i64, select_last_index: i64, keepdims: i64) -> Node { + NodeBuilder::new(NodeType::ArgMax, "test_argmax") + .input_tensor_f32("data", 3, None) + .output_tensor_i64("output", 3, None) + .attr_int("axis", axis) + .attr_int("select_last_index", select_last_index) + .attr_int("keepdims", keepdims) + .build() + } + + #[test] + fn test_argmax_config_basic() { + let node = create_test_node(0, 0, 1); + let config = argmax_config(&node); + assert_eq!(config, 0); + } + + #[test] + fn test_argmax_config_negative_axis() { + let node = create_test_node(-2, 0, 1); + let config = argmax_config(&node); + assert_eq!(config, 1); // -2 + 3 = 1 + } + + #[test] + #[should_panic(expected = "Argmax: multiple inputs are not supported")] + fn test_argmax_config_multiple_inputs() { + let mut node = create_test_node(0, 0, 1); + node.inputs.push(Argument { + name: "extra".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 1, + static_shape: None, + }), + value: None, + passed: true, + }); + let _ = argmax_config(&node); + } + + #[test] + #[should_panic(expected = "Only keepdims=1 is supported for argmax in burn")] + fn test_argmax_config_keepdims_not_supported() { + let node = create_test_node(0, 0, 0); + let _ = argmax_config(&node); + } +} diff --git a/crates/onnx-ir/src/node/avg_pool1d.rs b/crates/onnx-ir/src/node/avg_pool1d.rs new file mode 100644 index 0000000000..73d6d70291 --- /dev/null +++ b/crates/onnx-ir/src/node/avg_pool1d.rs @@ -0,0 +1,140 @@ +use crate::{ir::Node, node::padding::padding_config_1d}; + +use super::padding::PaddingConfig1d; + +/// Configuration for AvgPool1d operations extracted from ONNX nodes +#[derive(Debug, Clone)] +pub struct AvgPool1dConfig { + /// Kernel size + pub kernel_size: usize, + /// Stride + pub stride: usize, + /// Padding configuration + pub padding: PaddingConfig1d, + /// Whether to include padding in the average calculation + pub count_include_pad: bool, +} + +impl AvgPool1dConfig { + /// Create a new AvgPool1dConfig + pub fn new( + kernel_size: usize, + stride: usize, + padding: PaddingConfig1d, + count_include_pad: bool, + ) -> Self { + Self { + kernel_size, + stride, + padding, + count_include_pad, + } + } +} + +/// Create an AvgPool1dConfig from the attributes of the node +pub fn avg_pool1d_config(curr: &Node) -> AvgPool1dConfig { + let mut kernel_shape = Vec::new(); + let mut strides = vec![1]; + let mut pads = vec![0, 0]; + let mut count_include_pad: i64 = 0; + let mut ceil_mode: i64 = 0; + + for (key, value) in curr.attrs.iter() { + match key.as_str() { + "kernel_shape" => kernel_shape = value.clone().into_i64s(), + "strides" => strides = value.clone().into_i64s(), + "pads" => pads = value.clone().into_i64s(), + "count_include_pad" => count_include_pad = value.clone().into_i64(), + "ceil_mode" => ceil_mode = value.clone().into_i64(), + // These are attributes that are allowed but not used in this implementation + "auto_pad" | "storage_order" => {} + _ => panic!("Unexpected attribute for AvgPool1d: {key}"), + } + } + + assert_eq!( + kernel_shape.len(), + 1, + "AvgPool1d: kernel shape must have length 1" + ); + assert_eq!(strides.len(), 1, "AvgPool1d: stride must have length 1"); + + if ceil_mode == 1 { + panic!("ceil_mode is not supported"); + } + + let padding = padding_config_1d(&pads); + + AvgPool1dConfig { + kernel_size: kernel_shape[0] as usize, + stride: strides[0] as usize, + padding, + count_include_pad: count_include_pad == 1, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::NodeType; + use crate::node::test_utils::NodeBuilder; + + fn create_test_node( + kernel_shape: Vec, + strides: Vec, + pads: Vec, + count_include_pad: i64, + ceil_mode: i64, + ) -> Node { + NodeBuilder::new(NodeType::AveragePool1d, "test_avgpool1d") + .input_tensor_f32("data", 3, None) + .output_tensor_f32("output", 3, None) + .attr_ints("kernel_shape", kernel_shape) + .attr_ints("strides", strides) + .attr_ints("pads", pads) + .attr_int("count_include_pad", count_include_pad) + .attr_int("ceil_mode", ceil_mode) + .build() + } + + #[test] + fn test_avg_pool1d_config_basic() { + let node = create_test_node(vec![4], vec![1], vec![0, 0], 0, 0); + let config = avg_pool1d_config(&node); + + assert_eq!(config.kernel_size, 4); + assert_eq!(config.stride, 1); + assert!(!config.count_include_pad); + assert!(matches!(config.padding, PaddingConfig1d::Valid)); + } + + #[test] + fn test_avg_pool1d_config_with_padding() { + let node = create_test_node(vec![4], vec![2], vec![2, 2], 0, 0); + let config = avg_pool1d_config(&node); + + assert_eq!(config.kernel_size, 4); + assert_eq!(config.stride, 2); + assert!(!config.count_include_pad); + assert!(matches!(config.padding, PaddingConfig1d::Explicit(2))); + } + + #[test] + fn test_avg_pool1d_config_with_count_include_pad() { + let node = create_test_node(vec![4], vec![1], vec![2, 2], 1, 0); + let config = avg_pool1d_config(&node); + + assert_eq!(config.kernel_size, 4); + assert_eq!(config.stride, 1); + assert!(config.count_include_pad); + assert!(matches!(config.padding, PaddingConfig1d::Explicit(2))); + } + + #[test] + #[should_panic(expected = "ceil_mode is not supported")] + fn test_avg_pool1d_config_with_ceil_mode() { + let node = create_test_node(vec![4], vec![1], vec![0, 0], 0, 1); + let _ = avg_pool1d_config(&node); + } +} diff --git a/crates/onnx-ir/src/node/avg_pool2d.rs b/crates/onnx-ir/src/node/avg_pool2d.rs new file mode 100644 index 0000000000..bc6e5a0f73 --- /dev/null +++ b/crates/onnx-ir/src/node/avg_pool2d.rs @@ -0,0 +1,132 @@ +use crate::ir::Node; +use crate::node::padding::{PaddingConfig2d, padding_config_2d}; + +/// Configuration for AvgPool2d operations +#[derive(Debug, Clone)] +pub struct AvgPool2dConfig { + /// Kernel size [height, width] + pub kernel_size: [usize; 2], + /// Stride [height, width] + pub strides: [usize; 2], + /// Padding configuration + pub padding: PaddingConfig2d, + /// Whether to include padding in the average calculation + pub count_include_pad: bool, +} + +impl AvgPool2dConfig { + /// Create a new AvgPool2dConfig + pub fn new( + kernel_size: [usize; 2], + strides: [usize; 2], + padding: PaddingConfig2d, + count_include_pad: bool, + ) -> Self { + Self { + kernel_size, + strides, + padding, + count_include_pad, + } + } +} + +/// Create a AvgPool2dConfig from the attributes of the node +pub fn avg_pool2d_config(curr: &Node) -> AvgPool2dConfig { + let mut kernel_shape = Vec::new(); + let mut strides = vec![1, 1]; + let mut pads = vec![0, 0, 0, 0]; + let mut count_include_pad: i64 = 0; + let mut ceil_mode: i64 = 0; + + for (key, value) in curr.attrs.iter() { + match key.as_str() { + "kernel_shape" => kernel_shape = value.clone().into_i64s(), + "strides" => strides = value.clone().into_i64s(), + "pads" => pads = value.clone().into_i64s(), + "count_include_pad" => count_include_pad = value.clone().into_i64(), + "ceil_mode" => ceil_mode = value.clone().into_i64(), + // These are attributes that are allowed but not used in this implementation + "auto_pad" | "storage_order" => {} + _ => panic!("Unexpected attribute for AvgPool2d: {key}"), + } + } + + if ceil_mode == 1 { + panic!("ceil_mode is not supported"); + } + + let padding = padding_config_2d(&pads); + + AvgPool2dConfig::new( + [kernel_shape[0] as usize, kernel_shape[1] as usize], + [strides[0] as usize, strides[1] as usize], + padding, + count_include_pad == 1, + ) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::NodeType; + use crate::node::test_utils::NodeBuilder; + + fn create_test_node( + kernel_shape: Vec, + strides: Vec, + pads: Vec, + count_include_pad: i64, + ceil_mode: i64, + ) -> Node { + NodeBuilder::new(NodeType::AveragePool2d, "test_avgpool2d") + .input_tensor_f32("data", 4, None) + .output_tensor_f32("output", 4, None) + .attr_ints("kernel_shape", kernel_shape) + .attr_ints("strides", strides) + .attr_ints("pads", pads) + .attr_int("count_include_pad", count_include_pad) + .attr_int("ceil_mode", ceil_mode) + .build() + } + + #[test] + fn test_avg_pool2d_config_basic() { + let node = create_test_node(vec![3, 3], vec![1, 1], vec![0, 0, 0, 0], 0, 0); + let config = avg_pool2d_config(&node); + + assert_eq!(config.kernel_size, [3, 3]); + assert_eq!(config.strides, [1, 1]); + assert!(!config.count_include_pad); + assert!(matches!(config.padding, PaddingConfig2d::Valid)); + } + + #[test] + fn test_avg_pool2d_config_with_padding() { + let node = create_test_node(vec![2, 2], vec![2, 2], vec![1, 1, 1, 1], 0, 0); + let config = avg_pool2d_config(&node); + + assert_eq!(config.kernel_size, [2, 2]); + assert_eq!(config.strides, [2, 2]); + assert!(!config.count_include_pad); + assert!(matches!(config.padding, PaddingConfig2d::Explicit(1, 1))); + } + + #[test] + fn test_avg_pool2d_config_with_count_include_pad() { + let node = create_test_node(vec![3, 3], vec![1, 1], vec![1, 1, 1, 1], 1, 0); + let config = avg_pool2d_config(&node); + + assert_eq!(config.kernel_size, [3, 3]); + assert_eq!(config.strides, [1, 1]); + assert!(config.count_include_pad); + assert!(matches!(config.padding, PaddingConfig2d::Explicit(1, 1))); + } + + #[test] + #[should_panic(expected = "ceil_mode is not supported")] + fn test_avg_pool2d_config_with_ceil_mode() { + let node = create_test_node(vec![3, 3], vec![1, 1], vec![0, 0, 0, 0], 0, 1); + let _ = avg_pool2d_config(&node); + } +} diff --git a/crates/onnx-ir/src/node/batch_norm.rs b/crates/onnx-ir/src/node/batch_norm.rs new file mode 100644 index 0000000000..aaa0165c8f --- /dev/null +++ b/crates/onnx-ir/src/node/batch_norm.rs @@ -0,0 +1,91 @@ +use crate::ir::Node; + +/// Configuration for BatchNorm operations +#[derive(Debug, Clone)] +pub struct BatchNormConfig { + /// Number of features (channels) + pub num_features: usize, + /// Small constant added for numerical stability + pub epsilon: f64, + /// Momentum for running statistics + pub momentum: f64, +} + +impl BatchNormConfig { + /// Create a new BatchNormConfig + pub fn new(num_features: usize, epsilon: f64, momentum: f64) -> Self { + Self { + num_features, + epsilon, + momentum, + } + } +} + +/// Create a BatchNormConfig from the attributes of the node +pub fn batch_norm_config(node: &Node) -> BatchNormConfig { + let weight_shape = node.inputs[1] + .value + .as_ref() + .expect("BatchNorm: weight tensor must be present") + .shape + .clone(); + + let num_features = weight_shape[0]; + + let mut epsilon = 0f32; + let mut momentum = 0f32; + + for (key, value) in node.attrs.iter() { + match key.as_str() { + "momentum" => momentum = value.clone().into_f32(), + "epsilon" => epsilon = value.clone().into_f32(), + _ => {} + } + } + + BatchNormConfig::new(num_features, epsilon as f64, momentum as f64) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::NodeType; + use crate::node::test_utils::NodeBuilder; + + fn create_test_node(epsilon: f32, momentum: f32, num_features: usize) -> Node { + let ones = vec![1.0; num_features]; + let zeros = vec![0.0; num_features]; + + NodeBuilder::new(NodeType::BatchNormalization, "test_batchnorm") + .input_tensor_f32("X", 4, None) // NCHW format + .input_tensor_f32_data("scale", ones.clone(), vec![num_features]) + .input_tensor_f32_data("bias", zeros.clone(), vec![num_features]) + .input_tensor_f32_data("mean", zeros.clone(), vec![num_features]) + .input_tensor_f32_data("var", ones.clone(), vec![num_features]) + .output_tensor_f32("output", 4, None) + .attr_float("epsilon", epsilon) + .attr_float("momentum", momentum) + .build() + } + + #[test] + fn test_batch_norm_config_basic() { + let node = create_test_node(1e-5, 0.9, 64); + let config = batch_norm_config(&node); + + assert_eq!(config.num_features, 64); + assert!(f64::abs(config.epsilon - 1e-5) < 1e-6); + assert!(f64::abs(config.momentum - 0.9) < 1e-6); + } + + #[test] + fn test_batch_norm_config_default_values() { + let node = create_test_node(0.0, 0.0, 32); + let config = batch_norm_config(&node); + + assert_eq!(config.num_features, 32); + assert!(f64::abs(config.epsilon - 0.0) < 1e-6); + assert!(f64::abs(config.momentum - 0.0) < 1e-6); + } +} diff --git a/crates/onnx-ir/src/node/cast.rs b/crates/onnx-ir/src/node/cast.rs new file mode 100644 index 0000000000..b3dd289de1 --- /dev/null +++ b/crates/onnx-ir/src/node/cast.rs @@ -0,0 +1,134 @@ +use crate::ir::{ArgType, AttributeValue, ElementType, Node, TensorType}; +use crate::protos::tensor_proto::DataType; +use protobuf::Enum; + +/// Update output type for Cast operations, preserving rank. +pub fn cast_update_outputs(node: &mut Node) { + if node.inputs.len() != 1 { + panic!("Cast: multiple inputs are not supported"); + } + let input = &mut node.inputs[0]; + let output = &mut node.outputs[0]; + + let elem_type = match node.attrs.get("to") { + Some(value) => match &value { + AttributeValue::Int64(type_id) => match DataType::from_i32(*type_id as i32).unwrap() { + DataType::FLOAT => ElementType::Float32, + DataType::INT32 => ElementType::Int32, + DataType::INT64 => ElementType::Int64, + DataType::DOUBLE => ElementType::Float64, + DataType::BOOL => ElementType::Bool, + _ => panic!("Cast: unsupported type"), + }, + _ => panic!("'to' attribute must be an Int64"), + }, + None => panic!("Cast node must have a 'to' attribute"), + }; + + match input.ty.clone() { + ArgType::Tensor(tensor) => { + if tensor.rank == 0 { + // treat 0-dim tensor as scalar + output.ty = ArgType::Scalar(elem_type); + input.ty = ArgType::Scalar(tensor.elem_type); + } else { + // Cast input and output are the same shape, but possibly different types + output.ty = ArgType::Tensor(TensorType { + elem_type, + rank: tensor.rank, + static_shape: None, + }); + } + } + ArgType::Scalar(_) => output.ty = ArgType::Scalar(elem_type), + _ => panic!("Cast: only scalar and tensor inputs are valid"), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::{Argument, NodeType, TensorType}; + use crate::node::test_utils::NodeBuilder; + + fn create_test_node(input_rank: usize, to_type: i64) -> Node { + NodeBuilder::new(NodeType::Cast, "test_cast") + .input_tensor_f32("X", input_rank, None) + .output_tensor_f32("Y", input_rank, None) // Element type will be overwritten + .attr_int("to", to_type) + .build() + } + + // Additional test function to demonstrate scalar inputs + fn create_scalar_test_node(to_type: i64) -> Node { + NodeBuilder::new(NodeType::Cast, "test_cast") + .input_scalar_f32("X") + .output_scalar_f32("Y") // Element type will be overwritten + .attr_int("to", to_type) + .build() + } + + #[test] + fn test_cast_float_to_int64() { + let mut node = create_test_node(2, DataType::INT64.value() as i64); + cast_update_outputs(&mut node); + + match &node.outputs[0].ty { + ArgType::Tensor(tensor) => { + assert_eq!(tensor.elem_type, ElementType::Int64); + assert_eq!(tensor.rank, 2); + } + _ => panic!("Expected tensor output"), + } + } + + #[test] + fn test_cast_scalar_handling() { + let mut node = create_test_node(0, DataType::BOOL.value() as i64); + cast_update_outputs(&mut node); + + match &node.outputs[0].ty { + ArgType::Scalar(elem_type) => { + assert_eq!(*elem_type, ElementType::Bool); + } + _ => panic!("Expected scalar output for 0-rank tensor"), + } + + match &node.inputs[0].ty { + ArgType::Scalar(elem_type) => { + assert_eq!(*elem_type, ElementType::Float32); + } + _ => panic!("Input should have been converted to scalar"), + } + } + + #[test] + #[should_panic(expected = "Cast: multiple inputs are not supported")] + fn test_cast_multiple_inputs() { + let mut node = create_test_node(2, DataType::INT64.value() as i64); + node.inputs.push(Argument { + name: "extra".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 1, + static_shape: None, + }), + value: None, + passed: true, + }); + cast_update_outputs(&mut node); + } + + #[test] + fn test_cast_scalar_to_bool() { + let mut node = create_scalar_test_node(DataType::BOOL.value() as i64); + cast_update_outputs(&mut node); + + match &node.outputs[0].ty { + ArgType::Scalar(elem_type) => { + assert_eq!(*elem_type, ElementType::Bool); + } + _ => panic!("Expected scalar output"), + } + } +} diff --git a/crates/onnx-ir/src/node/clip.rs b/crates/onnx-ir/src/node/clip.rs new file mode 100644 index 0000000000..e33bf381f0 --- /dev/null +++ b/crates/onnx-ir/src/node/clip.rs @@ -0,0 +1,141 @@ +use crate::ir::{Data, Node}; + +pub fn clip_config(node: &Node) -> (Option, Option) { + let mut min_result: Option = None; + let mut max_result: Option = None; + + // For Clip Opset 6+ , the min and max values are attributes + for (key, value) in node.attrs.iter() { + match key.as_str() { + "min" => { + let min = value.clone().into_f32() as f64; + min_result = Some(min); + } + "max" => { + let max = value.clone().into_f32(); + max_result = Some(max as f64); + } + _ => {} + } + } + + // For Clip Opset 11+ , the min and max values are inputs + // Get the min and max values from the input values + if min_result.is_none() && max_result.is_none() { + let min = node.inputs.get(1).and_then(|arg| arg.value.clone()); + let max = node.inputs.get(2).and_then(|arg| arg.value.clone()); + + if min_result.is_none() && min.is_some() { + let min = min.unwrap().data.into_scalar(); + min_result = match min { + Data::Float16(min) => Some(f32::from(min) as f64), + Data::Float32(min) => Some(min as f64), + Data::Float64(min) => Some(min), + _ => panic!("Clip: only float min is supported"), + }; + } + + if max_result.is_none() && max.is_some() { + let max = max.unwrap().data.into_scalar(); + max_result = match max { + Data::Float16(max) => Some(f32::from(max) as f64), + Data::Float32(max) => Some(max as f64), + Data::Float64(max) => Some(max), + _ => panic!("Clip: only float max is supported"), + }; + } + } + + if min_result.is_none() && max_result.is_none() { + panic!("Clip: min and max values must be either attributes or inputs"); + } + + (min_result, max_result) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::NodeType; + use crate::node::test_utils::NodeBuilder; + + fn create_test_node_with_attributes(min: Option, max: Option) -> Node { + let mut builder = NodeBuilder::new(NodeType::Clip, "test_clip") + .input_tensor_f32("X", 4, None) + .output_tensor_f32("Y", 4, None); + + if let Some(min_val) = min { + builder = builder.attr_float("min", min_val); + } + + if let Some(max_val) = max { + builder = builder.attr_float("max", max_val); + } + + builder.build() + } + + fn create_test_node_with_inputs(min: Option, max: Option) -> Node { + NodeBuilder::new(NodeType::Clip, "test_clip") + .input_tensor_f32("X", 4, None) + .input_scalar_tensor_f32("min", min) + .input_scalar_tensor_f32("max", max) + .output_tensor_f32("Y", 4, None) + .build() + } + + #[test] + fn test_clip_config_with_attributes() { + let node = create_test_node_with_attributes(Some(-1.0), Some(1.0)); + let (min, max) = clip_config(&node); + assert_eq!(min, Some(-1.0)); + assert_eq!(max, Some(1.0)); + } + + #[test] + fn test_clip_config_with_attributes_min_only() { + let node = create_test_node_with_attributes(Some(-1.0), None); + let (min, max) = clip_config(&node); + assert_eq!(min, Some(-1.0)); + assert_eq!(max, None); + } + + #[test] + fn test_clip_config_with_attributes_max_only() { + let node = create_test_node_with_attributes(None, Some(1.0)); + let (min, max) = clip_config(&node); + assert_eq!(min, None); + assert_eq!(max, Some(1.0)); + } + + #[test] + fn test_clip_config_with_inputs() { + let node = create_test_node_with_inputs(Some(-1.0), Some(1.0)); + let (min, max) = clip_config(&node); + assert_eq!(min, Some(-1.0)); + assert_eq!(max, Some(1.0)); + } + + #[test] + fn test_clip_config_with_inputs_min_only() { + let node = create_test_node_with_inputs(Some(-1.0), None); + let (min, max) = clip_config(&node); + assert_eq!(min, Some(-1.0)); + assert_eq!(max, None); + } + + #[test] + fn test_clip_config_with_inputs_max_only() { + let node = create_test_node_with_inputs(None, Some(1.0)); + let (min, max) = clip_config(&node); + assert_eq!(min, None); + assert_eq!(max, Some(1.0)); + } + + #[test] + #[should_panic(expected = "Clip: min and max values must be either attributes or inputs")] + fn test_clip_config_no_min_max() { + let node = create_test_node_with_attributes(None, None); + let _ = clip_config(&node); + } +} diff --git a/crates/onnx-ir/src/node/comparison.rs b/crates/onnx-ir/src/node/comparison.rs new file mode 100644 index 0000000000..fca4cf0621 --- /dev/null +++ b/crates/onnx-ir/src/node/comparison.rs @@ -0,0 +1,85 @@ +use crate::ir::{ArgType, ElementType, Node, TensorType}; + +/// Update output type for comparison operations (e.g., Equal, Greater) to max input rank. +pub fn elementwise_comparison_outputs(node: &mut Node) { + log::debug!("Elementwise comparison for node {}", node.name); + + let max_rank = node.inputs.iter().fold(0, |acc, input| match &input.ty { + ArgType::Tensor(tensor) => acc.max(tensor.rank), + ArgType::Scalar(_) => acc, + _ => panic!("Invalid input type for comparison op"), + }); + + log::debug!("Max rank for comparison node {}: {}", node.name, max_rank); + + if max_rank == 0 { + node.outputs[0].ty = ArgType::Scalar(ElementType::Bool); + log::debug!("Scalar boolean result for node {}", node.name); + } else { + node.outputs[0].ty = ArgType::Tensor(TensorType { + elem_type: ElementType::Bool, + rank: max_rank, + static_shape: None, + }); + log::debug!( + "Tensor boolean result for node {} with rank {}", + node.name, + max_rank + ); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::NodeType; + use crate::node::test_utils::NodeBuilder; + + fn create_test_node(input1_rank: usize, input2_rank: usize) -> Node { + NodeBuilder::new(NodeType::Equal, "test_comparison") + .input_tensor_f32("A", input1_rank, None) + .input_tensor_f32("B", input2_rank, None) + .output_tensor_bool("result", 0, None) // rank will be updated + .build() + } + + #[test] + fn test_comparison_rank_broadcasting() { + let mut node = create_test_node(2, 3); + elementwise_comparison_outputs(&mut node); + + match &node.outputs[0].ty { + ArgType::Tensor(tensor) => { + assert_eq!(tensor.elem_type, ElementType::Bool); + assert_eq!(tensor.rank, 3); // max(2, 3) = 3 + } + _ => panic!("Expected tensor output"), + } + } + + #[test] + fn test_comparison_scalar_result() { + let mut node = create_test_node(0, 0); + + // Convert inputs to scalars + node.inputs[0].ty = ArgType::Scalar(ElementType::Float32); + node.inputs[1].ty = ArgType::Scalar(ElementType::Float32); + + elementwise_comparison_outputs(&mut node); + + match &node.outputs[0].ty { + ArgType::Scalar(elem_type) => { + assert_eq!(*elem_type, ElementType::Bool); + } + _ => panic!("Expected scalar output"), + } + } + + #[test] + #[should_panic(expected = "Invalid input type for comparison op")] + fn test_comparison_invalid_input() { + let mut node = create_test_node(2, 2); + node.inputs[0].ty = ArgType::Shape(2); + elementwise_comparison_outputs(&mut node); + } +} diff --git a/crates/onnx-ir/src/node/concat.rs b/crates/onnx-ir/src/node/concat.rs new file mode 100644 index 0000000000..71a4ee328c --- /dev/null +++ b/crates/onnx-ir/src/node/concat.rs @@ -0,0 +1,88 @@ +use crate::ir::{ArgType, Node, TensorType}; + +/// Update output rank for Concat (same as first tensor input). +pub fn concat_update_outputs(node: &mut Node) { + log::debug!("Concat rank inference for node {}", node.name); + + let tensor = node + .inputs + .iter() + .find_map(|input| match &input.ty { + ArgType::Tensor(tensor) => Some(tensor.clone()), + _ => None, + }) + .unwrap(); + + log::debug!("Concat using input rank for {}: {}", node.name, tensor.rank); + + node.outputs[0].ty = ArgType::Tensor(TensorType { + elem_type: tensor.elem_type, + rank: tensor.rank, + static_shape: None, + }); + + log::debug!("Concat output rank for {}: {}", node.name, tensor.rank); +} + +/// Create concat config from the attributes of the node +pub fn concat_config(node: &Node) -> usize { + // the axis is the last dimension (Default: 1 per ONNX spec) + let mut axis: i64 = 1; + + // extract the shape of the input tensor + let tensor = match node.inputs.first().unwrap().clone().ty { + ArgType::Tensor(tensor) => tensor, + _ => panic!("Only tensor input is valid"), + }; + + // extract the attributes + for (key, value) in node.attrs.iter() { + if key.as_str() == "axis" { + axis = value.clone().into_i64() + } + } + + // if axis is negative, it is counted from the end + if axis < 0 { + axis += tensor.rank as i64; + } + + axis as usize +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::NodeType; + use crate::node::test_utils::NodeBuilder; + + fn create_test_node(axis: i64, input_rank: usize, num_inputs: usize) -> Node { + NodeBuilder::new(NodeType::Concat, "test_concat") + .input_tensors_f32::>("data", num_inputs, input_rank, None) + .output_tensor_f32("output", input_rank, None) + .attr_int("axis", axis) + .build() + } + + #[test] + fn test_concat_config_basic() { + let node = create_test_node(1, 3, 2); + let config = concat_config(&node); + assert_eq!(config, 1); + } + + #[test] + fn test_concat_config_negative_axis() { + let node = create_test_node(-2, 3, 2); + let config = concat_config(&node); + assert_eq!(config, 1); // -2 + 3 = 1 + } + + #[test] + #[should_panic(expected = "Only tensor input is valid")] + fn test_concat_config_invalid_input() { + let mut node = create_test_node(1, 3, 1); + node.inputs[0].ty = ArgType::Shape(1); + let _ = concat_config(&node); + } +} diff --git a/crates/onnx-ir/src/node/constant.rs b/crates/onnx-ir/src/node/constant.rs new file mode 100644 index 0000000000..2af9aabb04 --- /dev/null +++ b/crates/onnx-ir/src/node/constant.rs @@ -0,0 +1,125 @@ +use crate::ir::{ArgType, AttributeValue, ElementType, Node, TensorType}; + +/// Update output type for constant nodes based on attribute values, focusing on rank only. +pub fn constant_update_outputs(node: &mut Node) { + log::debug!("Constant rank inference for node {}", node.name); + + let keys = [ + "value", + "value_float", + "value_floats", + "value_int", + "value_ints", + "value_string", + "value_strings", + "sparse_value", + ]; + + let matched_value = keys.iter().find_map(|&key| node.attrs.get(key).cloned()); + log::debug!("Constant found attribute: {}", matched_value.is_some()); + + node.outputs[0].ty = match matched_value { + Some(value) => match &value { + AttributeValue::Tensor(tensor) if tensor.shape.is_empty() => { + log::debug!("Constant as scalar for {}", node.name); + ArgType::Scalar(tensor.elem_type()) + } + AttributeValue::Tensor(tensor) => { + log::debug!( + "Constant tensor with rank {} for {}", + tensor.shape.len(), + node.name + ); + ArgType::Tensor(TensorType { + elem_type: tensor.elem_type(), + rank: tensor.shape.len(), + static_shape: None, + }) + } + AttributeValue::Float32(_) => { + log::debug!("Constant Float32 scalar for {}", node.name); + ArgType::Scalar(ElementType::Float32) + } + AttributeValue::Float32s(_) => { + log::debug!("Constant Float32s tensor with rank 1 for {}", node.name); + ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 1, + static_shape: None, + }) + } + AttributeValue::Int64(_) => { + log::debug!("Constant Int64 scalar for {}", node.name); + ArgType::Scalar(ElementType::Int64) + } + AttributeValue::Int64s(_) => { + log::debug!("Constant Int64s tensor with rank 1 for {}", node.name); + ArgType::Tensor(TensorType { + elem_type: ElementType::Int64, + rank: 1, + static_shape: None, + }) + } + ty => panic!("Constant value of {:?} is not supported", ty), + }, + None => panic!("Constant node must have a value attribute"), + }; +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::{NodeType, TensorData}; + use crate::node::test_utils::NodeBuilder; + + fn create_test_node() -> Node { + NodeBuilder::new(NodeType::Constant, "test_constant") + .output_tensor_f32("output", 0, None) // This will be overwritten + .build() + } + + #[test] + fn test_constant_scalar_float() { + let mut node = create_test_node(); + node.attrs + .insert("value_float".to_string(), AttributeValue::Float32(6.14)); + + constant_update_outputs(&mut node); + + match &node.outputs[0].ty { + ArgType::Scalar(elem_type) => { + assert_eq!(*elem_type, ElementType::Float32); + } + _ => panic!("Expected scalar output"), + } + } + + #[test] + fn test_constant_tensor() { + let mut node = create_test_node(); + node.attrs.insert( + "value".to_string(), + AttributeValue::Tensor(TensorData { + shape: vec![2, 3], + data: crate::ir::Data::Float32s(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]), + }), + ); + + constant_update_outputs(&mut node); + + match &node.outputs[0].ty { + ArgType::Tensor(tensor) => { + assert_eq!(tensor.elem_type, ElementType::Float32); + assert_eq!(tensor.rank, 2); + } + _ => panic!("Expected tensor output"), + } + } + + #[test] + #[should_panic(expected = "Constant node must have a value attribute")] + fn test_constant_missing_value() { + let mut node = create_test_node(); + constant_update_outputs(&mut node); + } +} diff --git a/crates/onnx-ir/src/node/constant_of_shape.rs b/crates/onnx-ir/src/node/constant_of_shape.rs new file mode 100644 index 0000000000..723a788562 --- /dev/null +++ b/crates/onnx-ir/src/node/constant_of_shape.rs @@ -0,0 +1,138 @@ +use crate::ir::{ArgType, ElementType, Node, TensorType}; + +/// Updates the output rank for a ConstantOfShape node based on the rank of its input. +pub fn constant_of_shape_update_output(node: &mut Node) { + log::debug!("ConstantOfShape rank inference for node {}", node.name); + + let value_type = node + .attrs + .get("value") + .map(|v| v.clone().into_tensor().elem_type()) + .unwrap_or(ElementType::Float32); // If not given, defaults to 0 as float32 + log::debug!( + "ConstantOfShape value type for {}: {:?}", + node.name, + value_type + ); + + let rank = match &node.inputs[0].ty { + ArgType::Shape(rank) => { + log::debug!( + "ConstantOfShape input is Shape with rank {} for {}", + rank, + node.name + ); + *rank + } + ArgType::Tensor(tensor_type) => { + log::debug!("ConstantOfShape input is Tensor for {}", node.name); + let r = tensor_type + .static_shape + .as_ref() + .and_then(|shape| shape.first()) + .copied() + .expect( + "ConstantOfShape node must have a Tensor with a non-empty static shape value", + ); + log::debug!( + "ConstantOfShape derived rank from tensor: {} for {}", + r, + node.name + ); + r + } + _ => panic!("ConstantOfShape node requires a Tensor or Shape type as input"), + }; + + // Update the input type to be a shape + node.inputs[0].ty = ArgType::Shape(rank); + log::debug!( + "ConstantOfShape updated input to Shape({}) for {}", + rank, + node.name + ); + + node.outputs[0].ty = ArgType::Tensor(TensorType { + elem_type: value_type, + rank, + static_shape: None, + }); + log::debug!("ConstantOfShape output rank for {}: {}", node.name, rank); +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::{AttributeValue, Data, NodeType, TensorData}; + use crate::node::test_utils::NodeBuilder; + + fn create_test_node(input_ty: ArgType) -> Node { + NodeBuilder::new(NodeType::ConstantOfShape, "test_constantofshape") + .add_input("shape", input_ty) + .output_tensor_f32("output", 0, None) // Will be updated + .build() + } + + #[test] + fn test_shape_input() { + let mut node = create_test_node(ArgType::Shape(3)); + + constant_of_shape_update_output(&mut node); + + match &node.outputs[0].ty { + ArgType::Tensor(tensor) => { + assert_eq!(tensor.elem_type, ElementType::Float32); + assert_eq!(tensor.rank, 3); + } + _ => panic!("Expected tensor output"), + } + } + + #[test] + fn test_tensor_input_with_static_shape() { + let mut node = create_test_node(ArgType::Tensor(TensorType { + elem_type: ElementType::Int64, + rank: 1, + static_shape: Some(vec![4]), + })); + + constant_of_shape_update_output(&mut node); + + match &node.outputs[0].ty { + ArgType::Tensor(tensor) => { + assert_eq!(tensor.elem_type, ElementType::Float32); + assert_eq!(tensor.rank, 4); + } + _ => panic!("Expected tensor output"), + } + } + + #[test] + fn test_custom_value_type() { + let mut node = create_test_node(ArgType::Shape(2)); + node.attrs.insert( + "value".to_string(), + AttributeValue::Tensor(TensorData { + shape: vec![], + data: Data::Int64s(vec![7]), // Int64 value + }), + ); + + constant_of_shape_update_output(&mut node); + + match &node.outputs[0].ty { + ArgType::Tensor(tensor) => { + assert_eq!(tensor.elem_type, ElementType::Int64); + assert_eq!(tensor.rank, 2); + } + _ => panic!("Expected tensor output"), + } + } + + #[test] + #[should_panic(expected = "ConstantOfShape node requires a Tensor or Shape type as input")] + fn test_invalid_input_type() { + let mut node = create_test_node(ArgType::Scalar(ElementType::Float32)); + constant_of_shape_update_output(&mut node); + } +} diff --git a/crates/onnx-ir/src/node/conv1d.rs b/crates/onnx-ir/src/node/conv1d.rs new file mode 100644 index 0000000000..3ab2dcf1ef --- /dev/null +++ b/crates/onnx-ir/src/node/conv1d.rs @@ -0,0 +1,215 @@ +use crate::ir::Node; + +use super::padding::{PaddingConfig1d, padding_config_1d}; + +/// Configuration for Conv1d operations extracted from ONNX nodes +#[derive(Debug, Clone)] +pub struct Conv1dConfig { + /// Input channels + pub channels_in: usize, + /// Output channels + pub channels_out: usize, + /// Kernel size + pub kernel_size: usize, + /// Stride + pub stride: usize, + /// Dilation + pub dilation: usize, + /// Number of groups + pub groups: usize, + /// Whether bias is used + pub bias: bool, + /// Padding configuration + pub padding: PaddingConfig1d, +} + +impl Conv1dConfig { + /// Create a new Conv1dConfig + #[allow(clippy::too_many_arguments)] + pub fn new( + channels_in: usize, + channels_out: usize, + kernel_size: usize, + stride: usize, + padding: PaddingConfig1d, + dilation: usize, + groups: usize, + bias: bool, + ) -> Self { + Self { + channels_in, + channels_out, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + } + } +} + +/// Create a Conv1dConfig from the attributes of the node +pub fn conv1d_config(curr: &Node) -> Conv1dConfig { + let mut kernel_shape = Vec::new(); // TODO default inferred from weight tensor per spec + let mut strides = vec![1]; + let mut pads = vec![0, 0]; + let mut dilations = vec![1]; + let mut group: usize = 1; + + let weight_shape = curr.inputs[1] + .value + .as_ref() + .expect("Conv1d: weight tensor must be present") + .shape + .clone(); + + // check if the bias is present + let bias = curr.inputs.len() == 3; + + for (key, value) in curr.attrs.iter() { + match key.as_str() { + "kernel_shape" => kernel_shape = value.clone().into_i64s(), + "strides" => strides = value.clone().into_i64s(), + "pads" => pads = value.clone().into_i64s(), + "dilations" => dilations = value.clone().into_i64s(), + "group" => group = value.clone().into_i64() as usize, + _ => {} + } + } + + // the channels are inverted in the weight tensor + let channels_in = weight_shape[1] * group; + let channels_out = weight_shape[0]; + + let padding = padding_config_1d(&pads); + + Conv1dConfig { + channels_in, + channels_out, + kernel_size: kernel_shape[0] as usize, + stride: strides[0] as usize, + dilation: dilations[0] as usize, + groups: group, + bias, + padding, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::NodeType; + use crate::node::test_utils::NodeBuilder; + + fn create_test_node( + kernel_shape: Vec, + strides: Vec, + pads: Vec, + dilations: Vec, + group: i64, + has_bias: bool, + ) -> Node { + // Create weight tensor data + let weight_data = vec![0.1; 16]; + + // Start building the node with input and weight + let mut builder = NodeBuilder::new(NodeType::Conv1d, "test_conv1d") + .input_tensor_f32("data", 3, None) + .input_tensor_f32_data( + "weight", + weight_data, + vec![2, 2, 4], // [out_channels, in_channels, kernel_size] + ) + .output_tensor_f32("output", 3, None); + + // Add bias if needed + if has_bias { + builder = builder.input_tensor_f32_data("bias", vec![0.1, 0.2], vec![2]); + } + + // Add attributes + builder = builder + .attr_ints("kernel_shape", kernel_shape) + .attr_ints("strides", strides) + .attr_ints("pads", pads) + .attr_ints("dilations", dilations) + .attr_int("group", group); + + builder.build() + } + + #[test] + fn test_conv1d_config_basic() { + let node = create_test_node(vec![4], vec![1], vec![0, 0], vec![1], 1, false); + let config = conv1d_config(&node); + + assert_eq!(config.channels_in, 2); + assert_eq!(config.channels_out, 2); + assert_eq!(config.kernel_size, 4); + assert_eq!(config.stride, 1); + assert_eq!(config.dilation, 1); + assert_eq!(config.groups, 1); + assert!(!config.bias); + assert!(matches!(config.padding, PaddingConfig1d::Valid)); + } + + #[test] + fn test_conv1d_config_with_padding() { + let node = create_test_node(vec![4], vec![2], vec![2, 2], vec![1], 1, true); + let config = conv1d_config(&node); + + assert_eq!(config.channels_in, 2); + assert_eq!(config.channels_out, 2); + assert_eq!(config.kernel_size, 4); + assert_eq!(config.stride, 2); + assert_eq!(config.dilation, 1); + assert_eq!(config.groups, 1); + assert!(config.bias); + assert!(matches!(config.padding, PaddingConfig1d::Explicit(2))); + } + + #[test] + fn test_conv1d_config_with_dilation() { + let node = create_test_node(vec![4], vec![1], vec![0, 0], vec![2], 1, false); + let config = conv1d_config(&node); + + assert_eq!(config.channels_in, 2); + assert_eq!(config.channels_out, 2); + assert_eq!(config.kernel_size, 4); + assert_eq!(config.stride, 1); + assert_eq!(config.dilation, 2); + assert_eq!(config.groups, 1); + assert!(!config.bias); + assert!(matches!(config.padding, PaddingConfig1d::Valid)); + } + + #[test] + fn test_conv1d_config_with_groups() { + let node = create_test_node(vec![4], vec![1], vec![0, 0], vec![1], 2, false); + let config = conv1d_config(&node); + + assert_eq!(config.channels_in, 4); + assert_eq!(config.channels_out, 2); + assert_eq!(config.kernel_size, 4); + assert_eq!(config.stride, 1); + assert_eq!(config.dilation, 1); + assert_eq!(config.groups, 2); + assert!(!config.bias); + assert!(matches!(config.padding, PaddingConfig1d::Valid)); + } + + #[test] + #[should_panic(expected = "Asymmetric padding is not supported")] + fn test_conv1d_config_asymmetric_padding() { + let node = create_test_node(vec![4], vec![1], vec![1, 2], vec![1], 1, false); + let _ = conv1d_config(&node); + } + + #[test] + #[should_panic(expected = "Negative pad values are not supported")] + fn test_conv1d_config_negative_padding() { + let node = create_test_node(vec![4], vec![1], vec![-1, -1], vec![1], 1, false); + let _ = conv1d_config(&node); + } +} diff --git a/crates/onnx-ir/src/node/conv2d.rs b/crates/onnx-ir/src/node/conv2d.rs new file mode 100644 index 0000000000..6715069e26 --- /dev/null +++ b/crates/onnx-ir/src/node/conv2d.rs @@ -0,0 +1,195 @@ +use crate::ir::Node; +use crate::node::padding::{PaddingConfig2d, padding_config_2d}; + +/// Configuration for Conv2d operations +#[derive(Debug, Clone)] +pub struct Conv2dConfig { + /// Channels [in, out] + pub channels: [usize; 2], + /// Kernel size [height, width] + pub kernel_size: [usize; 2], + /// Stride [height, width] + pub stride: [usize; 2], + /// Padding configuration + pub padding: PaddingConfig2d, + /// Dilation [height, width] + pub dilation: [usize; 2], + /// Number of groups + pub groups: usize, + /// Whether bias is used + pub bias: bool, +} + +impl Conv2dConfig { + /// Create a new Conv2dConfig + pub fn new( + channels: [usize; 2], + kernel_size: [usize; 2], + stride: [usize; 2], + padding: PaddingConfig2d, + dilation: [usize; 2], + groups: usize, + bias: bool, + ) -> Self { + Self { + channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + } + } +} + +/// Create a Conv2dConfig from the attributes of the node +pub fn conv2d_config(curr: &Node) -> Conv2dConfig { + let mut kernel_shape = Vec::new(); // TODO default inferred from weight tensor per spec + let mut strides = vec![1, 1]; + let mut pads = vec![0, 0, 0, 0]; + let mut dilations = vec![1, 1]; + let mut group: usize = 1; + + let weight_shape = curr.inputs[1] + .value + .as_ref() + .expect("Conv2d: weight tensor must be present") + .shape + .clone(); + + // check if the bias is present + let bias = curr.inputs.len() == 3; + + for (key, value) in curr.attrs.iter() { + match key.as_str() { + "kernel_shape" => kernel_shape = value.clone().into_i64s(), + "strides" => strides = value.clone().into_i64s(), + "pads" => pads = value.clone().into_i64s(), + "dilations" => dilations = value.clone().into_i64s(), + "group" => group = value.clone().into_i64() as usize, + _ => panic!("Unexpected attribute for Conv2d: {key}"), + } + } + + // the channels are inverted in the weight tensor + let channels_in = weight_shape[1] * group; + let channels_out = weight_shape[0]; + + let padding = padding_config_2d(&pads); + + Conv2dConfig::new( + [channels_in, channels_out], + [kernel_shape[0] as usize, kernel_shape[1] as usize], + [strides[0] as usize, strides[1] as usize], + padding, + [dilations[0] as usize, dilations[1] as usize], + group, + bias, + ) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::NodeType; + use crate::node::test_utils::NodeBuilder; + + fn create_test_node( + kernel_shape: Vec, + strides: Vec, + pads: Vec, + dilations: Vec, + group: i64, + has_bias: bool, + ) -> Node { + // Weight tensor data - not important for the test + let weight_data = vec![0.0; 16]; + // [output_channels, input_channels/groups, k_h, k_w] + let weight_shape = vec![4, 2, 2, 2]; + + let mut builder = NodeBuilder::new(NodeType::Conv2d, "test_conv2d") + .input_tensor_f32("data", 4, None) + .input_tensor_f32_data("weight", weight_data.clone(), weight_shape) + .output_tensor_f32("output", 4, None) + .attr_ints("kernel_shape", kernel_shape) + .attr_ints("strides", strides) + .attr_ints("pads", pads) + .attr_ints("dilations", dilations) + .attr_int("group", group); + + if has_bias { + builder = builder.input_tensor_f32("bias", 1, None); + } + + builder.build() + } + + #[test] + fn test_conv2d_config_basic() { + let node = create_test_node( + vec![2, 2], + vec![1, 1], + vec![0, 0, 0, 0], + vec![1, 1], + 1, + false, + ); + let config = conv2d_config(&node); + + assert_eq!(config.channels, [2, 4]); + assert_eq!(config.kernel_size, [2, 2]); + assert_eq!(config.stride, [1, 1]); + assert_eq!(config.dilation, [1, 1]); + assert_eq!(config.groups, 1); + assert!(!config.bias); + assert!(matches!(config.padding, PaddingConfig2d::Valid)); + } + + #[test] + fn test_conv2d_config_with_padding() { + let node = create_test_node( + vec![3, 3], + vec![1, 1], + vec![1, 1, 1, 1], + vec![1, 1], + 1, + false, + ); + let config = conv2d_config(&node); + + assert_eq!(config.kernel_size, [3, 3]); + assert!(matches!(config.padding, PaddingConfig2d::Explicit(1, 1))); + } + + #[test] + fn test_conv2d_config_with_groups() { + let node = create_test_node( + vec![2, 2], + vec![1, 1], + vec![0, 0, 0, 0], + vec![1, 1], + 2, + false, + ); + let config = conv2d_config(&node); + + assert_eq!(config.groups, 2); + assert_eq!(config.channels, [4, 4]); // channels_in is adjusted by groups + } + + #[test] + fn test_conv2d_config_with_bias() { + let node = create_test_node( + vec![2, 2], + vec![1, 1], + vec![0, 0, 0, 0], + vec![1, 1], + 1, + true, + ); + let config = conv2d_config(&node); + + assert!(config.bias); + } +} diff --git a/crates/onnx-ir/src/node/conv3d.rs b/crates/onnx-ir/src/node/conv3d.rs new file mode 100644 index 0000000000..043ec6a5d6 --- /dev/null +++ b/crates/onnx-ir/src/node/conv3d.rs @@ -0,0 +1,211 @@ +use crate::ir::Node; +use crate::node::padding::{PaddingConfig3d, padding_config_3d}; + +/// Configuration for Conv3d operations. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Conv3dConfig { + /// Input and output channels [in, out]. + pub channels: [usize; 2], + /// Size of the kernel. + pub kernel_size: [usize; 3], + /// Stride of the convolutional kernel. + pub stride: [usize; 3], + /// Dilation of the convolutional kernel. + pub dilation: [usize; 3], + /// Groups. + pub groups: usize, + /// Use bias. + pub bias: bool, + /// Padding. + pub padding: PaddingConfig3d, +} + +impl Conv3dConfig { + /// Create a new configuration for a Conv3d. + pub fn new( + channels: [usize; 2], + kernel_size: [usize; 3], + stride: [usize; 3], + dilation: [usize; 3], + groups: usize, + bias: bool, + padding: PaddingConfig3d, + ) -> Self { + Self { + channels, + kernel_size, + stride, + dilation, + groups, + bias, + padding, + } + } +} + +/// Create a Conv3dConfig from the attributes of the node +pub fn conv3d_config(curr: &Node) -> Conv3dConfig { + let mut kernel_shape = Vec::new(); // TODO default inferred from weight tensor per spec + let mut strides = vec![1, 1, 1]; + let mut pads = vec![0, 0, 0, 0, 0, 0]; + let mut dilations = vec![1, 1, 1]; + let mut group: usize = 1; + + let weight_shape = curr.inputs[1] + .value + .as_ref() + .expect("Conv3d: weight tensor must be present") + .shape + .clone(); + + // check if the bias is present + let bias = curr.inputs.len() == 3; + + for (key, value) in curr.attrs.iter() { + match key.as_str() { + "kernel_shape" => kernel_shape = value.clone().into_i64s(), + "strides" => strides = value.clone().into_i64s(), + "pads" => pads = value.clone().into_i64s(), + "dilations" => dilations = value.clone().into_i64s(), + "group" => group = value.clone().into_i64() as usize, + _ => panic!("Unexpected attribute for Conv3d: {key}"), + } + } + + // the channels are inverted in the weight tensor + let channels_in = weight_shape[1] * group; + let channels_out = weight_shape[0]; + + let padding = padding_config_3d(&pads); + + Conv3dConfig::new( + [channels_in, channels_out], + [ + kernel_shape[0] as usize, + kernel_shape[1] as usize, + kernel_shape[2] as usize, + ], + [ + strides[0] as usize, + strides[1] as usize, + strides[2] as usize, + ], + [ + dilations[0] as usize, + dilations[1] as usize, + dilations[2] as usize, + ], + group, + bias, + padding, + ) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::NodeType; + use crate::node::test_utils::NodeBuilder; + + fn create_test_node( + kernel_shape: Vec, + strides: Vec, + pads: Vec, + dilations: Vec, + group: i64, + has_bias: bool, + ) -> Node { + // Create weight tensor data (not important for the test) + let weight_data = vec![0.0; 32]; + let weight_shape = vec![4, 2, 2, 2, 2]; // [output_channels, input_channels/groups, k_d, k_h, k_w] + + // Start building the node with input and weight + let mut builder = NodeBuilder::new(NodeType::Conv3d, "test_conv3d") + .input_tensor_f32("data", 5, None) + .input_tensor_f32_data("weight", weight_data, weight_shape) + .output_tensor_f32("output", 5, None); + + // Add bias if needed + if has_bias { + builder = builder.input_tensor_f32("bias", 1, None); + } + + // Add attributes + builder = builder + .attr_ints("kernel_shape", kernel_shape) + .attr_ints("strides", strides) + .attr_ints("pads", pads) + .attr_ints("dilations", dilations) + .attr_int("group", group); + + builder.build() + } + + #[test] + fn test_conv3d_config_basic() { + let node = create_test_node( + vec![2, 2, 2], + vec![1, 1, 1], + vec![0, 0, 0, 0, 0, 0], + vec![1, 1, 1], + 1, + false, + ); + let config = conv3d_config(&node); + + assert_eq!(config.channels, [2, 4]); + assert_eq!(config.kernel_size, [2, 2, 2]); + assert_eq!(config.stride, [1, 1, 1]); + assert_eq!(config.dilation, [1, 1, 1]); + assert_eq!(config.groups, 1); + assert!(!config.bias); + assert!(matches!(config.padding, PaddingConfig3d::Valid)); + } + + #[test] + fn test_conv3d_config_with_padding() { + let node = create_test_node( + vec![3, 3, 3], + vec![1, 1, 1], + vec![1, 1, 1, 1, 1, 1], + vec![1, 1, 1], + 1, + false, + ); + let config = conv3d_config(&node); + + assert_eq!(config.kernel_size, [3, 3, 3]); + assert!(matches!(config.padding, PaddingConfig3d::Explicit(1, 1, 1))); + } + + #[test] + fn test_conv3d_config_with_groups() { + let node = create_test_node( + vec![2, 2, 2], + vec![1, 1, 1], + vec![0, 0, 0, 0, 0, 0], + vec![1, 1, 1], + 2, + false, + ); + let config = conv3d_config(&node); + + assert_eq!(config.groups, 2); + assert_eq!(config.channels, [4, 4]); // channels_in is adjusted by groups + } + + #[test] + fn test_conv3d_config_with_bias() { + let node = create_test_node( + vec![2, 2, 2], + vec![1, 1, 1], + vec![0, 0, 0, 0, 0, 0], + vec![1, 1, 1], + 1, + true, + ); + let config = conv3d_config(&node); + + assert!(config.bias); + } +} diff --git a/crates/onnx-ir/src/node/conv_transpose1d.rs b/crates/onnx-ir/src/node/conv_transpose1d.rs new file mode 100644 index 0000000000..961de340cc --- /dev/null +++ b/crates/onnx-ir/src/node/conv_transpose1d.rs @@ -0,0 +1,194 @@ +use crate::ir::Node; + +/// Configuration for ConvTranspose1d operations extracted from ONNX nodes +#[derive(Debug, Clone)] +pub struct ConvTranspose1dConfig { + /// Input channels + pub channels_in: usize, + /// Output channels + pub channels_out: usize, + /// Kernel size + pub kernel_size: usize, + /// Stride + pub stride: usize, + /// Dilation + pub dilation: usize, + /// Number of groups + pub groups: usize, + /// Whether bias is used + pub bias: bool, + /// Padding size + pub padding: usize, + /// Output padding size + pub padding_out: usize, +} + +impl ConvTranspose1dConfig { + /// Create a new ConvTranspose1dConfig + #[allow(clippy::too_many_arguments)] + pub fn new( + channels_in: usize, + channels_out: usize, + kernel_size: usize, + stride: usize, + padding: usize, + dilation: usize, + groups: usize, + bias: bool, + padding_out: usize, + ) -> Self { + Self { + channels_in, + channels_out, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + padding_out, + } + } +} + +/// Create a ConvTranspose1dConfig from the attributes of the node +pub fn conv_transpose1d_config(curr: &Node) -> ConvTranspose1dConfig { + let mut kernel_shape = Vec::new(); // Default to empty vector + let mut stride = vec![1]; // Default stride to 1 + let mut pads = vec![0, 0]; // Default padding to 0 + let mut dilations = vec![1]; // Default dilation to 1 + let mut group: usize = 1; // Default group to 1 + let mut output_padding = vec![0]; // Default output padding to 0 + + // Extract attributes + for (key, value) in curr.attrs.iter() { + match key.as_str() { + "kernel_shape" => kernel_shape = value.clone().into_i64s(), + "strides" => stride = value.clone().into_i64s(), + "pads" => pads = value.clone().into_i64s(), + "dilations" => dilations = value.clone().into_i64s(), + "group" => group = value.clone().into_i64() as usize, + "output_padding" => output_padding = value.clone().into_i64s(), + _ => panic!("Unexpected attribute for ConvTranspose1d: {key}"), + } + } + + // Check the pads are symmetric + if pads.len() != 2 || pads[0] != pads[1] { + panic!( + "Asymmetric padding is not supported for ConvTranspose1d: {:?}", + pads + ); + } + + let weight_shape = curr.inputs[1] + .value + .as_ref() + .expect("ConvTranspose1d: weight tensor must be present") + .shape + .clone(); + + // Check if bias is present (third input) + let bias = curr.inputs.len() == 3; + + // Extract channels from the weight tensor shape [out_channels, in_channels] + let channels_in = weight_shape[1] * group; + let channels_out = weight_shape[0]; + + ConvTranspose1dConfig { + channels_in, + channels_out, + kernel_size: kernel_shape[0] as usize, + stride: stride[0] as usize, + padding: pads[0] as usize, + dilation: dilations[0] as usize, + padding_out: output_padding[0] as usize, + groups: group, + bias, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::NodeType; + use crate::node::test_utils::NodeBuilder; + + fn create_test_node( + kernel_shape: Vec, + stride: Vec, + pads: Vec, + dilations: Vec, + group: i64, + output_padding: Vec, + has_bias: bool, + ) -> Node { + // Create weight tensor data + let weight_data = vec![0.1; 16]; + + // Start building the node with input and weight + let mut builder = NodeBuilder::new(NodeType::ConvTranspose1d, "test_conv_transpose1d") + .input_tensor_f32("data", 3, None) + .input_tensor_f32_data( + "weight", + weight_data, + vec![2, 2, 4], // [out_channels, in_channels, kernel_size] + ) + .output_tensor_f32("output", 3, None); + + // Add bias if needed + if has_bias { + builder = builder.input_tensor_f32_data("bias", vec![0.1, 0.2], vec![2]); + } + + // Add attributes + builder = builder + .attr_ints("kernel_shape", kernel_shape) + .attr_ints("strides", stride) + .attr_ints("pads", pads) + .attr_ints("dilations", dilations) + .attr_int("group", group) + .attr_ints("output_padding", output_padding); + + builder.build() + } + + #[test] + fn test_conv_transpose1d_config_basic() { + let node = create_test_node(vec![4], vec![1], vec![0, 0], vec![1], 1, vec![0], false); + let config = conv_transpose1d_config(&node); + + assert_eq!(config.channels_in, 2); + assert_eq!(config.channels_out, 2); + assert_eq!(config.kernel_size, 4); + assert_eq!(config.stride, 1); + assert_eq!(config.padding, 0); + assert_eq!(config.dilation, 1); + assert_eq!(config.padding_out, 0); + assert_eq!(config.groups, 1); + assert!(!config.bias); + } + + #[test] + fn test_conv_transpose1d_config_with_params() { + let node = create_test_node(vec![4], vec![2], vec![1, 1], vec![2], 2, vec![1], true); + let config = conv_transpose1d_config(&node); + + assert_eq!(config.channels_in, 4); // weight_shape[1] * group = 2 * 2 + assert_eq!(config.channels_out, 2); + assert_eq!(config.kernel_size, 4); + assert_eq!(config.stride, 2); + assert_eq!(config.padding, 1); + assert_eq!(config.dilation, 2); + assert_eq!(config.padding_out, 1); + assert_eq!(config.groups, 2); + assert!(config.bias); + } + + #[test] + #[should_panic(expected = "Asymmetric padding is not supported")] + fn test_conv_transpose1d_config_asymmetric_padding() { + let node = create_test_node(vec![4], vec![1], vec![1, 2], vec![1], 1, vec![0], false); + let _ = conv_transpose1d_config(&node); + } +} diff --git a/crates/onnx-ir/src/node/conv_transpose2d.rs b/crates/onnx-ir/src/node/conv_transpose2d.rs new file mode 100644 index 0000000000..f813ee9136 --- /dev/null +++ b/crates/onnx-ir/src/node/conv_transpose2d.rs @@ -0,0 +1,253 @@ +use crate::ir::Node; + +/// Configuration for ConvTranspose2d operations. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ConvTranspose2dConfig { + /// Input and output channels [in, out]. + pub channels: [usize; 2], + /// Size of the kernel. + pub kernel_size: [usize; 2], + /// Stride of the convolutional kernel. + pub stride: [usize; 2], + /// Dilation of the convolutional kernel. + pub dilation: [usize; 2], + /// Padding. + pub padding: [usize; 2], + /// Output padding. + pub padding_out: [usize; 2], + /// Groups. + pub groups: usize, + /// Use bias. + pub bias: bool, +} + +impl ConvTranspose2dConfig { + /// Create a new configuration for a ConvTranspose2d. + #[allow(clippy::too_many_arguments)] + pub fn new( + channels: [usize; 2], + kernel_size: [usize; 2], + stride: [usize; 2], + dilation: [usize; 2], + padding: [usize; 2], + padding_out: [usize; 2], + groups: usize, + bias: bool, + ) -> Self { + Self { + channels, + kernel_size, + stride, + dilation, + padding, + padding_out, + groups, + bias, + } + } +} + +/// Create a ConvTranspose2dConfig from the attributes of the node +pub fn conv_transpose2d_config(curr: &Node) -> ConvTranspose2dConfig { + let mut kernel_shape = Vec::new(); // Default to empty vector + let mut stride = vec![1, 1]; // Default stride to 1 + let mut pads = vec![0, 0, 0, 0]; // Default padding to 0 + let mut dilations = vec![1, 1]; // Default dilation to 1 + let mut group: usize = 1; // Default group to 1 + let mut output_padding = vec![0, 0]; // Default output padding to 0 + + // Extract attributes + for (key, value) in curr.attrs.iter() { + match key.as_str() { + "kernel_shape" => kernel_shape = value.clone().into_i64s(), + "strides" => stride = value.clone().into_i64s(), + "pads" => pads = value.clone().into_i64s(), + "dilations" => dilations = value.clone().into_i64s(), + "group" => group = value.clone().into_i64() as usize, + "output_padding" => output_padding = value.clone().into_i64s(), + _ => panic!("Unexpected attribute for ConvTranspose2d: {key}"), + } + } + + // Check the pads are symmetric. + let [left, top, right, bottom] = [pads[0], pads[1], pads[2], pads[3]]; + if left < 0 || top < 0 || right < 0 || bottom < 0 { + panic!("Negative pad values are not supported"); + } else if (left != right) || (top != bottom) { + panic!("Asymmetric padding is not supported"); + } + + let weight_shape = curr.inputs[1] + .value + .as_ref() + .expect("ConvTranspose2d: weight tensor must be present") + .shape + .clone(); + + // check if the bias is present + let bias = curr.inputs.len() == 3; + + // the channels are inverted in the weight tensor + let channels: [usize; 2] = [weight_shape[1] * group, weight_shape[0]]; + + ConvTranspose2dConfig::new( + channels, + [kernel_shape[0] as usize, kernel_shape[1] as usize], + [stride[0] as usize, stride[1] as usize], + [dilations[0] as usize, dilations[1] as usize], + [pads[0] as usize, pads[1] as usize], + [output_padding[0] as usize, output_padding[1] as usize], + group, + bias, + ) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::NodeType; + use crate::node::test_utils::NodeBuilder; + + fn create_test_node( + kernel_shape: Vec, + strides: Vec, + pads: Vec, + dilations: Vec, + output_padding: Vec, + group: i64, + has_bias: bool, + ) -> Node { + // Create weight tensor data + let weight_data = vec![0.0; 16]; // Not important for the test + + // Start building the node with input and weight + let mut builder = NodeBuilder::new(NodeType::ConvTranspose2d, "test_convtranspose2d") + .input_tensor_f32("data", 4, None) + .input_tensor_f32_data( + "weight", + weight_data, + vec![2, 4, 2, 2], // [out_channels, in_channels, k_h, k_w] + ) + .output_tensor_f32("output", 4, None); + + // Add bias if needed + if has_bias { + builder = builder.input_tensor_f32("bias", 1, None); + } + + // Add attributes + builder = builder + .attr_ints("kernel_shape", kernel_shape) + .attr_ints("strides", strides) + .attr_ints("pads", pads) + .attr_ints("dilations", dilations) + .attr_ints("output_padding", output_padding) + .attr_int("group", group); + + builder.build() + } + + #[test] + fn test_conv_transpose2d_config_basic() { + let node = create_test_node( + vec![2, 2], + vec![1, 1], + vec![0, 0, 0, 0], + vec![1, 1], + vec![0, 0], + 1, + false, + ); + let config = conv_transpose2d_config(&node); + + assert_eq!(config.channels, [4, 2]); + assert_eq!(config.kernel_size, [2, 2]); + assert_eq!(config.stride, [1, 1]); + assert_eq!(config.dilation, [1, 1]); + assert_eq!(config.padding, [0, 0]); + assert_eq!(config.padding_out, [0, 0]); + assert_eq!(config.groups, 1); + assert!(!config.bias); + } + + #[test] + fn test_conv_transpose2d_config_with_padding() { + let node = create_test_node( + vec![3, 3], + vec![2, 2], + vec![1, 1, 1, 1], + vec![1, 1], + vec![0, 0], + 1, + false, + ); + let config = conv_transpose2d_config(&node); + + assert_eq!(config.padding, [1, 1]); + assert_eq!(config.stride, [2, 2]); + } + + #[test] + fn test_conv_transpose2d_config_with_output_padding() { + let node = create_test_node( + vec![2, 2], + vec![2, 2], + vec![0, 0, 0, 0], + vec![1, 1], + vec![1, 1], + 1, + false, + ); + let config = conv_transpose2d_config(&node); + + assert_eq!(config.padding_out, [1, 1]); + } + + #[test] + fn test_conv_transpose2d_config_with_groups() { + let node = create_test_node( + vec![2, 2], + vec![1, 1], + vec![0, 0, 0, 0], + vec![1, 1], + vec![0, 0], + 2, + false, + ); + let config = conv_transpose2d_config(&node); + + assert_eq!(config.groups, 2); + assert_eq!(config.channels, [8, 2]); // channels_in is adjusted by groups + } + + #[test] + fn test_conv_transpose2d_config_with_bias() { + let node = create_test_node( + vec![2, 2], + vec![1, 1], + vec![0, 0, 0, 0], + vec![1, 1], + vec![0, 0], + 1, + true, + ); + let config = conv_transpose2d_config(&node); + + assert!(config.bias); + } + + #[test] + #[should_panic(expected = "Asymmetric padding is not supported")] + fn test_conv_transpose2d_config_with_asymmetric_padding() { + let node = create_test_node( + vec![2, 2], + vec![1, 1], + vec![1, 1, 2, 2], + vec![1, 1], + vec![0, 0], + 1, + false, + ); + let _ = conv_transpose2d_config(&node); + } +} diff --git a/crates/onnx-ir/src/node/conv_transpose3d.rs b/crates/onnx-ir/src/node/conv_transpose3d.rs new file mode 100644 index 0000000000..288800776e --- /dev/null +++ b/crates/onnx-ir/src/node/conv_transpose3d.rs @@ -0,0 +1,267 @@ +use crate::ir::Node; + +/// Configuration for ConvTranspose3d operations. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ConvTranspose3dConfig { + /// Input and output channels [in, out]. + pub channels: [usize; 2], + /// Size of the kernel. + pub kernel_size: [usize; 3], + /// Stride of the convolutional kernel. + pub stride: [usize; 3], + /// Dilation of the convolutional kernel. + pub dilation: [usize; 3], + /// Padding. + pub padding: [usize; 3], + /// Output padding. + pub padding_out: [usize; 3], + /// Groups. + pub groups: usize, + /// Use bias. + pub bias: bool, +} + +impl ConvTranspose3dConfig { + /// Create a new configuration for a ConvTranspose3d. + #[allow(clippy::too_many_arguments)] + pub fn new( + channels: [usize; 2], + kernel_size: [usize; 3], + stride: [usize; 3], + dilation: [usize; 3], + padding: [usize; 3], + padding_out: [usize; 3], + groups: usize, + bias: bool, + ) -> Self { + Self { + channels, + kernel_size, + stride, + dilation, + padding, + padding_out, + groups, + bias, + } + } +} + +/// Create a ConvTranspose3dConfig from the attributes of the node +pub fn conv_transpose3d_config(curr: &Node) -> ConvTranspose3dConfig { + let mut kernel_shape = Vec::new(); // Default to empty vector + let mut stride = vec![1, 1, 1]; // Default stride to 1 + let mut pads = vec![0, 0, 0, 0, 0, 0]; // Default padding to 0 + let mut dilations = vec![1, 1, 1]; // Default dilation to 1 + let mut group: usize = 1; // Default group to 1 + let mut output_padding = vec![0, 0, 0]; // Default output padding to 0 + + // Extract attributes + for (key, value) in curr.attrs.iter() { + match key.as_str() { + "kernel_shape" => kernel_shape = value.clone().into_i64s(), + "strides" => stride = value.clone().into_i64s(), + "pads" => pads = value.clone().into_i64s(), + "dilations" => dilations = value.clone().into_i64s(), + "group" => group = value.clone().into_i64() as usize, + "output_padding" => output_padding = value.clone().into_i64s(), + _ => panic!("Unexpected attribute for ConvTranspose3d: {key}"), + } + } + + // Check the pads are symmetric. + let [left, top, front, right, bottom, back] = + [pads[0], pads[1], pads[2], pads[3], pads[4], pads[5]]; + + if left < 0 || top < 0 || front < 0 || right < 0 || bottom < 0 || back < 0 { + panic!("Negative pad values are not supported"); + } else if (left != right) || (top != bottom) || (front != back) { + panic!("Asymmetric padding is not supported"); + } + + let weight_shape = curr.inputs[1] + .value + .as_ref() + .expect("ConvTranspose3d: weight tensor must be present") + .shape + .clone(); + + // check if the bias is present + let bias = curr.inputs.len() == 3; + + // the channels are inverted in the weight tensor + let channels: [usize; 2] = [weight_shape[1] * group, weight_shape[0]]; + + ConvTranspose3dConfig::new( + channels, + [ + kernel_shape[0] as usize, + kernel_shape[1] as usize, + kernel_shape[2] as usize, + ], + [stride[0] as usize, stride[1] as usize, stride[2] as usize], + [ + dilations[0] as usize, + dilations[1] as usize, + dilations[2] as usize, + ], + [pads[0] as usize, pads[1] as usize, pads[2] as usize], + [ + output_padding[0] as usize, + output_padding[1] as usize, + output_padding[2] as usize, + ], + group, + bias, + ) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::NodeType; + use crate::node::test_utils::NodeBuilder; + + fn create_test_node( + kernel_shape: Vec, + strides: Vec, + pads: Vec, + dilations: Vec, + output_padding: Vec, + group: i64, + has_bias: bool, + ) -> Node { + // Create weight tensor data + let weight_data = vec![0.0; 32]; // Not important for the test + + // Start building the node with input and weight + let mut builder = NodeBuilder::new(NodeType::ConvTranspose3d, "test_convtranspose3d") + .input_tensor_f32("data", 5, None) + .input_tensor_f32_data( + "weight", + weight_data, + vec![2, 4, 2, 2, 2], // [out_channels, in_channels, k_d, k_h, k_w] + ) + .output_tensor_f32("output", 5, None); + + // Add bias if needed + if has_bias { + builder = builder.input_tensor_f32("bias", 1, None); + } + + // Add attributes + builder = builder + .attr_ints("kernel_shape", kernel_shape) + .attr_ints("strides", strides) + .attr_ints("pads", pads) + .attr_ints("dilations", dilations) + .attr_ints("output_padding", output_padding) + .attr_int("group", group); + + builder.build() + } + + #[test] + fn test_conv_transpose3d_config_basic() { + let node = create_test_node( + vec![2, 2, 2], + vec![1, 1, 1], + vec![0, 0, 0, 0, 0, 0], + vec![1, 1, 1], + vec![0, 0, 0], + 1, + false, + ); + let config = conv_transpose3d_config(&node); + + assert_eq!(config.channels, [4, 2]); + assert_eq!(config.kernel_size, [2, 2, 2]); + assert_eq!(config.stride, [1, 1, 1]); + assert_eq!(config.dilation, [1, 1, 1]); + assert_eq!(config.padding, [0, 0, 0]); + assert_eq!(config.padding_out, [0, 0, 0]); + assert_eq!(config.groups, 1); + assert!(!config.bias); + } + + #[test] + fn test_conv_transpose3d_config_with_padding() { + let node = create_test_node( + vec![3, 3, 3], + vec![2, 2, 2], + vec![1, 1, 1, 1, 1, 1], + vec![1, 1, 1], + vec![0, 0, 0], + 1, + false, + ); + let config = conv_transpose3d_config(&node); + + assert_eq!(config.padding, [1, 1, 1]); + assert_eq!(config.stride, [2, 2, 2]); + } + + #[test] + fn test_conv_transpose3d_config_with_output_padding() { + let node = create_test_node( + vec![2, 2, 2], + vec![2, 2, 2], + vec![0, 0, 0, 0, 0, 0], + vec![1, 1, 1], + vec![1, 1, 1], + 1, + false, + ); + let config = conv_transpose3d_config(&node); + + assert_eq!(config.padding_out, [1, 1, 1]); + } + + #[test] + fn test_conv_transpose3d_config_with_groups() { + let node = create_test_node( + vec![2, 2, 2], + vec![1, 1, 1], + vec![0, 0, 0, 0, 0, 0], + vec![1, 1, 1], + vec![0, 0, 0], + 2, + false, + ); + let config = conv_transpose3d_config(&node); + + assert_eq!(config.groups, 2); + assert_eq!(config.channels, [8, 2]); // channels_in is adjusted by groups + } + + #[test] + fn test_conv_transpose3d_config_with_bias() { + let node = create_test_node( + vec![2, 2, 2], + vec![1, 1, 1], + vec![0, 0, 0, 0, 0, 0], + vec![1, 1, 1], + vec![0, 0, 0], + 1, + true, + ); + let config = conv_transpose3d_config(&node); + + assert!(config.bias); + } + + #[test] + #[should_panic(expected = "Asymmetric padding is not supported")] + fn test_conv_transpose3d_config_with_asymmetric_padding() { + let node = create_test_node( + vec![2, 2, 2], + vec![1, 1, 1], + vec![1, 1, 1, 2, 2, 2], + vec![1, 1, 1], + vec![0, 0, 0], + 1, + false, + ); + let _ = conv_transpose3d_config(&node); + } +} diff --git a/crates/onnx-ir/src/node/dropout.rs b/crates/onnx-ir/src/node/dropout.rs new file mode 100644 index 0000000000..05b3edb2d2 --- /dev/null +++ b/crates/onnx-ir/src/node/dropout.rs @@ -0,0 +1,90 @@ +use crate::ir::{Data, Node}; + +/// Configuration for Dropout operations +#[derive(Debug, Clone)] +pub struct DropoutConfig { + /// Probability of dropping out a unit + pub prob: f64, +} + +impl DropoutConfig { + /// Create a new DropoutConfig + pub fn new(prob: f64) -> Self { + Self { prob } + } +} + +/// Create a DropoutConfig from an attribute and state of the node +pub fn dropout_config(node: &Node) -> DropoutConfig { + // Opset 7 and older store probability as an attribute + if node.attrs.contains_key("ratio") { + let prob = node.attrs.get("ratio").unwrap().clone().into_f32(); + return DropoutConfig::new(prob as f64); + } + + if node.inputs.len() < 2 { + panic!("Dropout configuration must have at least 2 inputs"); + } + + let ratio = node.inputs[1] + .value + .clone() + .expect("Dropout ratio must be passed in the second input") + .data + .into_scalar(); + + let prob = match ratio { + Data::Float16(ratio) => f64::from(f32::from(ratio)), + Data::Float32(ratio) => ratio as f64, + Data::Float64(ratio) => ratio, + _ => panic!("Dropout ratio must be a float"), + }; + + DropoutConfig::new(prob) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::NodeType; + use crate::node::test_utils::NodeBuilder; + + fn create_test_node_with_attr(ratio: f32) -> Node { + NodeBuilder::new(NodeType::Dropout, "test_dropout") + .input_tensor_f32("data", 3, None) + .output_tensor_f32("output", 3, None) + .attr_float("ratio", ratio) + .build() + } + + fn create_test_node_with_input(ratio: f32) -> Node { + NodeBuilder::new(NodeType::Dropout, "test_dropout") + .input_tensor_f32("data", 3, None) + .input_scalar_tensor_f32("ratio", Some(ratio)) + .output_tensor_f32("output", 3, None) + .build() + } + + #[test] + fn test_dropout_config_with_attr() { + let node = create_test_node_with_attr(0.3); + let config = dropout_config(&node); + assert!(f64::abs(config.prob - 0.3) < 1e-6); + } + + #[test] + fn test_dropout_config_with_input() { + let node = create_test_node_with_input(0.5); + let config = dropout_config(&node); + assert!(f64::abs(config.prob - 0.5) < 1e-6); + } + + #[test] + #[should_panic(expected = "Dropout configuration must have at least 2 inputs")] + fn test_dropout_config_missing_input() { + let mut node = create_test_node_with_input(0.5); + node.attrs.clear(); // Remove attributes + node.inputs.remove(1); // Remove ratio input + let _ = dropout_config(&node); + } +} diff --git a/crates/onnx-ir/src/node/expand.rs b/crates/onnx-ir/src/node/expand.rs new file mode 100644 index 0000000000..a78377de4d --- /dev/null +++ b/crates/onnx-ir/src/node/expand.rs @@ -0,0 +1,269 @@ +use crate::{ + Argument, ElementType, TensorData, + ir::{ArgType, Data, Node, TensorType}, +}; + +/// Updates the output rank and shape for the Expand operation based on the provided shape input. +/// If the shape is a constant, the rank and static shape of the output are set accordingly. +/// If the shape is dynamic, the rank is inferred from the static shape of the shape input. +pub fn expand_update_outputs(node: &mut Node) { + let shape = if node.inputs.len() == 2 { + match &node.inputs[1].value { + Some(value) => match &value.data { + Data::Int64s(shape) => Some(shape.clone()), + _ => panic!("Expand operation encountered invalid input types"), + }, + None => None, + } + } else { + panic!("Expand operation requires exactly two inputs"); + }; + + let output = match &node.outputs[0].ty { + ArgType::Tensor(tensor) => tensor.clone(), + _ => panic!("Expand operation encountered invalid output types"), + }; + + if let Some(shape) = shape { + node.outputs[0].ty = ArgType::Tensor(TensorType { + rank: shape.len(), + static_shape: Some(shape.into_iter().map(|dim| dim as usize).collect()), + ..output + }); + } else { + // When the shape cannot be determined statically (i.e., the second argument 'shape' is passed dynamically), + // infer the rank from the static shape of the input tensor. + let output_rank = match &node.inputs[1].ty { + ArgType::Tensor(tensor) => tensor + .static_shape + .as_ref() + .expect("Shape input must have a static shape defined") + .first() + .copied() + .expect("Static shape must contain at least one element"), + ArgType::Shape(rank) => *rank, + _ => panic!("Shape input must be of tensor or shape type",), + }; + + node.outputs[0].ty = ArgType::Tensor(TensorType { + rank: output_rank, + static_shape: None, // The exact shape cannot be determined statically + ..output + }); + } +} + +/// Shape information for the Expand operation. +#[derive(Debug, Clone)] +pub enum ExpandShape { + /// Static shape information known at compile time. + Static(Vec), + /// Runtime shape that will be determined during execution. + Runtime(Argument), +} + +/// Creates an ExpandShape configuration from the given Node. +/// +/// Extracts shape information from the node's second input to determine +/// whether to use static or runtime shape expansion. +pub fn expand_config(node: &Node) -> ExpandShape { + match &node.inputs[1].ty { + ArgType::Tensor(tensor) => { + assert_eq!(tensor.rank, 1, "Expand: shape tensor must be 1D"); + assert!( + matches!(tensor.elem_type, ElementType::Int64), + "Expand: shape tensor must have element type int64" + ); + } + ArgType::Shape(_) => { + // Shapes are always 1-D int64 data, so nothing to assert here + } + _ => panic!("Only tensor input is valid for shape"), + } + + match &node.inputs[1].value { + Some(TensorData { + data: Data::Int64s(shape), + .. + }) => ExpandShape::Static(shape.clone()), + None => { + // we were unable to statically determine the input value, so we'll need to fetch it at runtime + ExpandShape::Runtime(node.inputs[1].clone()) + } + _ => panic!( + "Shape data type must be int64, is {:?}", + &node.inputs[1].value + ), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::{ElementType, NodeType, TensorData}; + use crate::node::test_utils::NodeBuilder; + + fn create_test_node( + input_rank: usize, + shape_value: Option>, + shape_type: Option, + ) -> Node { + let mut builder = NodeBuilder::new(NodeType::Expand, "test_expand") + .input_tensor_f32("input", input_rank, None) + .output_tensor_f32("output", 0, None); // Rank 0 will be updated + + if let Some(shape) = shape_value { + builder = builder.input_tensor_i64_data("shape", shape.clone(), vec![shape.len()]); + } else if let Some(st) = shape_type { + // Use the provided custom shape type + builder = builder.add_input("shape", st); + } else { + // Default case with dynamic shape + builder = builder.input_tensor_i64("shape", 1, Some(vec![3])); + } + + builder.build() + } + + #[test] + fn test_expand_with_constant_shape() { + let mut node = create_test_node(2, Some(vec![2, 3, 4]), None); + + expand_update_outputs(&mut node); + + match &node.outputs[0].ty { + ArgType::Tensor(tensor) => { + assert_eq!(tensor.elem_type, ElementType::Float32); + assert_eq!(tensor.rank, 3); + assert_eq!(tensor.static_shape, Some(vec![2, 3, 4])); + } + _ => panic!("Expected tensor output"), + } + } + + #[test] + fn test_expand_with_dynamic_shape() { + let mut node = create_test_node(2, None, None); + + expand_update_outputs(&mut node); + + match &node.outputs[0].ty { + ArgType::Tensor(tensor) => { + assert_eq!(tensor.elem_type, ElementType::Float32); + assert_eq!(tensor.rank, 3); + assert_eq!(tensor.static_shape, None); + } + _ => panic!("Expected tensor output"), + } + } + + #[test] + #[should_panic(expected = "Expand operation requires exactly two inputs")] + fn test_expand_with_incorrect_inputs() { + let mut node = create_test_node(2, Some(vec![2, 3, 4]), None); + node.inputs.pop(); // Remove one input + + expand_update_outputs(&mut node); + } + + // Tests for expand_config function + + #[test] + fn test_expand_config_with_static_shape() { + let node = create_test_node(2, Some(vec![2, 3, 4]), None); + let config = expand_config(&node); + + match config { + ExpandShape::Static(shape) => { + assert_eq!(shape, vec![2, 3, 4]); + } + ExpandShape::Runtime(_) => panic!("Expected Static config, got Runtime"), + } + } + + #[test] + fn test_expand_config_with_runtime_shape() { + let node = create_test_node(2, None, None); + let config = expand_config(&node); + + match config { + ExpandShape::Static(_) => panic!("Expected Runtime config, got Static"), + ExpandShape::Runtime(arg) => { + assert_eq!(arg.name, "shape"); + match arg.ty { + ArgType::Tensor(tensor) => { + assert_eq!(tensor.elem_type, ElementType::Int64); + assert_eq!(tensor.rank, 1); + } + _ => panic!("Expected tensor type for runtime shape"), + } + } + } + } + + #[test] + fn test_expand_config_with_shape_type() { + let shape_type = ArgType::Shape(3); + let node = create_test_node(2, None, Some(shape_type)); + let config = expand_config(&node); + + match config { + ExpandShape::Static(_) => panic!("Expected Runtime config, got Static"), + ExpandShape::Runtime(arg) => { + assert_eq!(arg.name, "shape"); + match arg.ty { + ArgType::Shape(rank) => { + assert_eq!(rank, 3); + } + _ => panic!("Expected shape type for runtime shape"), + } + } + } + } + + #[test] + #[should_panic(expected = "Expand: shape tensor must be 1D")] + fn test_expand_config_with_invalid_shape_rank() { + let invalid_shape_type = ArgType::Tensor(TensorType { + elem_type: ElementType::Int64, + rank: 2, // Invalid rank, should be 1 + static_shape: None, + }); + let node = create_test_node(2, None, Some(invalid_shape_type)); + let _ = expand_config(&node); + } + + #[test] + #[should_panic(expected = "Expand: shape tensor must have element type int64")] + fn test_expand_config_with_invalid_shape_type() { + let invalid_shape_type = ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, // Invalid element type, should be Int64 + rank: 1, + static_shape: None, + }); + let node = create_test_node(2, None, Some(invalid_shape_type)); + let _ = expand_config(&node); + } + + #[test] + #[should_panic(expected = "Only tensor input is valid for shape")] + fn test_expand_config_with_invalid_input_type() { + let invalid_shape_type = ArgType::Scalar(ElementType::Int64); + let node = create_test_node(2, None, Some(invalid_shape_type)); + let _ = expand_config(&node); + } + + #[test] + #[should_panic(expected = "Shape data type must be int64")] + fn test_expand_config_with_invalid_value_type() { + let mut node = create_test_node(2, None, None); + + // Replace the value with a non-Int64s value + node.inputs[1].value = Some(TensorData { + shape: vec![1], + data: Data::Float32s(vec![1.0]), // Invalid data type + }); + + let _ = expand_config(&node); + } +} diff --git a/crates/onnx-ir/src/node/flatten.rs b/crates/onnx-ir/src/node/flatten.rs new file mode 100644 index 0000000000..394f01dc17 --- /dev/null +++ b/crates/onnx-ir/src/node/flatten.rs @@ -0,0 +1,123 @@ +use crate::ir::{ArgType, Node, TensorType}; + +/// Update output type for Flatten operation (rank 2). +pub fn flatten_update_outputs(node: &mut Node) { + if node.inputs.len() != 1 { + panic!("Flatten: multiple inputs are not supported"); + } + let tensor = node + .inputs + .iter() + .find_map(|input| match &input.ty { + ArgType::Tensor(tensor) => Some(tensor), + _ => None, + }) + .unwrap(); + + // Flatten to a 2D tensor + node.outputs[0].ty = ArgType::Tensor(TensorType { + rank: 2, + ..tensor.clone() + }); +} + +/// Create a FlattenConfig from the attributes of the node +pub fn flatten_config(curr: &Node) -> usize { + // the begin dimension is the first dimension (Default: 1 per ONNX spec) + let mut axis: i64 = 1; + + // check if the node has only one input + if curr.inputs.len() != 1 { + panic!( + "Flatten: multiple inputs are not supported (got {:?})", + curr.inputs.len() + ); + } + + // extract the shape of the input tensor + let tensor = match curr.inputs.first().unwrap().clone().ty { + ArgType::Tensor(tensor) => tensor, + _ => panic!("Only tensor input is valid"), + }; + + // check if the input tensor has at least 2 dimensions + if tensor.rank < 2 { + panic!( + "Flatten: input tensor must have at least 2 dimensions (got {:?})", + tensor.rank + ); + } + + // extract the attributes + for (key, value) in curr.attrs.iter() { + if key.as_str() == "axis" { + axis = value.clone().into_i64() + } + } + + // if beg_dim is negative, it is counted from the end + if axis < 0 { + axis += tensor.rank as i64; + } + + axis as usize +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::NodeType; + use crate::node::test_utils::NodeBuilder; + + fn create_test_node(axis: i64) -> Node { + NodeBuilder::new(NodeType::Flatten, "test_flatten") + .input_tensor_f32("data", 4, None) + .output_tensor_f32("output", 2, None) + .attr_int("axis", axis) + .build() + } + + #[test] + fn test_flatten_config_basic() { + let node = create_test_node(1); + let config = flatten_config(&node); + assert_eq!(config, 1); + } + + #[test] + fn test_flatten_config_with_negative_axis() { + let node = create_test_node(-2); + let config = flatten_config(&node); + assert_eq!(config, 2); // -2 + 4 = 2 + } + + #[test] + #[should_panic(expected = "Flatten: input tensor must have at least 2 dimensions")] + fn test_flatten_config_with_low_rank() { + let mut node = create_test_node(1); + // Replace the input with one that has lower rank + let input = NodeBuilder::new(NodeType::Identity, "temp") + .input_tensor_f32("x", 1, None) + .build() + .inputs + .pop() + .unwrap(); + node.inputs[0] = input; + let _ = flatten_config(&node); + } + + #[test] + #[should_panic(expected = "Flatten: multiple inputs are not supported")] + fn test_flatten_config_with_multiple_inputs() { + let mut node = create_test_node(1); + // Add an extra input + let extra_input = NodeBuilder::new(NodeType::Identity, "temp") + .input_tensor_f32("extra", 1, None) + .build() + .inputs + .pop() + .unwrap(); + node.inputs.push(extra_input); + let _ = flatten_config(&node); + } +} diff --git a/crates/onnx-ir/src/node/gather.rs b/crates/onnx-ir/src/node/gather.rs new file mode 100644 index 0000000000..d785f35652 --- /dev/null +++ b/crates/onnx-ir/src/node/gather.rs @@ -0,0 +1,160 @@ +use crate::ir::{ArgType, ElementType, Node, TensorType}; + +/// Update output rank for Gather based on input and indices ranks. +pub fn gather_update_outputs(node: &mut Node) { + log::debug!("Gather rank inference for node {}", node.name); + + if node.inputs.len() != 2 { + panic!("Gather requires two inputs: data and indices"); + } + + let indices_rank = match &node.inputs[1].ty { + ArgType::Tensor(tensor) => tensor.rank, + ArgType::Scalar(_) => 0, + _ => panic!("Only tensor indices is valid, got {:?}", node.inputs[1].ty), + }; + log::debug!("Gather indices rank for {}: {}", node.name, indices_rank); + + match &node.inputs[0].ty { + ArgType::Tensor(input_tensor) => { + log::debug!( + "Gather input tensor rank for {}: {}", + node.name, + input_tensor.rank + ); + // Output of rank q+(r-1), where q is rank of indices tensor and r is rank of input + let output_rank = indices_rank + input_tensor.rank - 1; + log::debug!("Gather output rank for {}: {}", node.name, output_rank); + + if output_rank == 0 { + node.outputs[0].ty = ArgType::Scalar(input_tensor.elem_type.clone()); + log::debug!("Gather result for {} is scalar", node.name); + } else { + node.outputs[0].ty = ArgType::Tensor(TensorType { + elem_type: input_tensor.elem_type.clone(), + rank: output_rank, + static_shape: None, + }); + log::debug!( + "Gather result for {} is tensor with rank {}", + node.name, + output_rank + ); + } + } + ArgType::Shape(_) => { + log::debug!("Gather input is shape for {}", node.name); + let shape_rank = 1; + // Output of rank q+(r-1), where q is rank of indices tensor and r is rank of input + let output_rank = indices_rank + shape_rank - 1; + log::debug!( + "Gather output rank for {} with shape input: {}", + node.name, + output_rank + ); + + if output_rank == 0 { + node.outputs[0].ty = ArgType::Scalar(ElementType::Int64); + log::debug!("Gather result for {} is scalar (from shape)", node.name); + } else { + node.outputs[0].ty = ArgType::Tensor(TensorType { + elem_type: ElementType::Int64, + rank: output_rank, + static_shape: None, + }); + log::debug!( + "Gather result for {} is tensor with rank {} (from shape)", + node.name, + output_rank + ); + } + } + ty => panic!("Only tensor/shape input is valid but received: {:?}", ty), + } +} + +/// Create a GatherConfig from the attributes of the node +pub fn gather_config(curr: &Node) -> usize { + // Default: 0 per ONNX spec + let mut dim: i64 = 0; + + // check if the node has only one input + if curr.inputs.len() != 2 { + panic!("Gather: index tensor must be present"); + } + + // extract the shape of the input tensor + let input_dim = match curr.inputs.first().unwrap().clone().ty { + ArgType::Tensor(tensor) => tensor.rank as i64, + ArgType::Shape(_shape) => 1, // Shape is always 1-D + other => panic!("Only tensor or shape input is valid, got {:?}", other), + }; + + // extract the attributes + for (key, value) in curr.attrs.iter() { + if key.as_str() == "axis" { + dim = value.clone().into_i64() + } + } + + // if dim is negative, it is counted from the end + if dim < 0 { + dim += input_dim; + } + + dim as usize +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::NodeType; + use crate::node::test_utils::NodeBuilder; + + fn create_test_node(axis: i64, input_rank: usize, is_shape: bool) -> Node { + // Start building the node with the appropriate input type + let mut builder = NodeBuilder::new(NodeType::Gather, "test_gather").attr_int("axis", axis); + + if is_shape { + builder = builder.add_input("data", ArgType::Shape(1)); + } else { + builder = builder.input_tensor_f32("data", input_rank, None); + } + + // Add indices and output + builder = builder + .input_tensor_i64("indices", 1, None) + .output_tensor_f32("output", input_rank, None); + + builder.build() + } + + #[test] + fn test_gather_config_basic() { + let node = create_test_node(0, 3, false); + let config = gather_config(&node); + assert_eq!(config, 0); + } + + #[test] + fn test_gather_config_negative_axis() { + let node = create_test_node(-2, 3, false); + let config = gather_config(&node); + assert_eq!(config, 1); // -2 + 3 = 1 + } + + #[test] + fn test_gather_config_shape_input() { + let node = create_test_node(0, 0, true); + let config = gather_config(&node); + assert_eq!(config, 0); + } + + #[test] + #[should_panic(expected = "Gather: index tensor must be present")] + fn test_gather_config_missing_index() { + let mut node = create_test_node(0, 3, false); + node.inputs.pop(); // Remove the indices input + let _ = gather_config(&node); + } +} diff --git a/crates/onnx-ir/src/node/gemm.rs b/crates/onnx-ir/src/node/gemm.rs new file mode 100644 index 0000000000..b81d5ce436 --- /dev/null +++ b/crates/onnx-ir/src/node/gemm.rs @@ -0,0 +1,119 @@ +use crate::ir::{ArgType, Node, TensorType}; +use core::cmp::max; + +/// Update output shape for Gemm operation based on input ranks. +pub fn gemm_output_shape(node: &mut Node) { + log::debug!("Gemm rank inference for node {}", node.name); + + let a_rank = match &node.inputs[0].ty { + ArgType::Tensor(tensor) => tensor.rank, + _ => panic!("Input A should be a tensor!"), + }; + let b_rank = match &node.inputs[1].ty { + ArgType::Tensor(tensor) => tensor.rank, + _ => panic!("Input B should be a tensor!"), + }; + + log::debug!( + "Gemm input ranks for {}: a_rank={}, b_rank={}", + node.name, + a_rank, + b_rank + ); + + let output_rank = max(a_rank, b_rank); + log::debug!("Gemm output rank for {}: {}", node.name, output_rank); + + node.outputs[0].ty = ArgType::Tensor(TensorType { + rank: output_rank, + static_shape: None, + elem_type: match &node.inputs[0].ty { + ArgType::Tensor(t) => t.elem_type.clone(), + _ => panic!("Unexpected type for input A"), + }, + }); +} + +pub fn gemm_config(curr: &Node) -> (f32, f32, i64, i64) { + let mut alpha: f32 = 1.0; + let mut beta: f32 = 1.0; + let mut trans_a: i64 = 0; + let mut trans_b: i64 = 0; + + for (key, value) in curr.attrs.iter() { + match key.as_str() { + "alpha" => alpha = value.clone().into_f32(), + "beta" => beta = value.clone().into_f32(), + "transA" => trans_a = value.clone().into_i64(), + "transB" => trans_b = value.clone().into_i64(), + _ => panic!("Unexpected attribute for Gemm: {key}"), + } + } + + (alpha, beta, trans_a, trans_b) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::NodeType; + use crate::node::test_utils::NodeBuilder; + + fn create_test_node( + alpha: Option, + beta: Option, + trans_a: Option, + trans_b: Option, + ) -> Node { + let mut builder = NodeBuilder::new(NodeType::Gemm, "test_gemm") + .input_tensor_f32("A", 2, None) + .input_tensor_f32("B", 2, None) + .input_tensor_f32("C", 2, None) + .output_tensor_f32("Y", 2, None); + + if let Some(alpha_val) = alpha { + builder = builder.attr_float("alpha", alpha_val); + } + if let Some(beta_val) = beta { + builder = builder.attr_float("beta", beta_val); + } + if let Some(trans_a_val) = trans_a { + builder = builder.attr_int("transA", trans_a_val); + } + if let Some(trans_b_val) = trans_b { + builder = builder.attr_int("transB", trans_b_val); + } + + builder.build() + } + + #[test] + fn test_gemm_config_defaults() { + let node = create_test_node(None, None, None, None); + let (alpha, beta, trans_a, trans_b) = gemm_config(&node); + assert_eq!(alpha, 1.0); + assert_eq!(beta, 1.0); + assert_eq!(trans_a, 0); + assert_eq!(trans_b, 0); + } + + #[test] + fn test_gemm_config_with_attrs() { + let node = create_test_node(Some(2.0), Some(3.0), Some(1), Some(1)); + let (alpha, beta, trans_a, trans_b) = gemm_config(&node); + assert_eq!(alpha, 2.0); + assert_eq!(beta, 3.0); + assert_eq!(trans_a, 1); + assert_eq!(trans_b, 1); + } + + #[test] + fn test_gemm_config_partial_attrs() { + let node = create_test_node(Some(0.5), None, Some(1), None); + let (alpha, beta, trans_a, trans_b) = gemm_config(&node); + assert_eq!(alpha, 0.5); + assert_eq!(beta, 1.0); // default + assert_eq!(trans_a, 1); + assert_eq!(trans_b, 0); // default + } +} diff --git a/crates/onnx-ir/src/node/hard_sigmoid.rs b/crates/onnx-ir/src/node/hard_sigmoid.rs new file mode 100644 index 0000000000..7b1a517f00 --- /dev/null +++ b/crates/onnx-ir/src/node/hard_sigmoid.rs @@ -0,0 +1,50 @@ +use crate::ir::Node; + +/// Create a HardSigmoidConfig from the alpha and beta attributes of the node +pub fn hard_sigmoid_config(node: &Node) -> (f64, f64) { + let mut alpha = 0.2; + let mut beta = 0.5; + + for (key, value) in node.attrs.iter() { + match key.as_str() { + "alpha" => alpha = value.clone().into_f32() as f64, + "beta" => beta = value.clone().into_f32() as f64, + _ => {} + } + } + + (alpha, beta) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::NodeType; + use crate::node::test_utils::NodeBuilder; + + fn create_test_node(alpha: f32, beta: f32) -> Node { + NodeBuilder::new(NodeType::HardSigmoid, "test_hard_sigmoid") + .input_tensor_f32("X", 4, None) + .output_tensor_f32("Y", 4, None) + .attr_float("alpha", alpha) + .attr_float("beta", beta) + .build() + } + + #[test] + fn test_hard_sigmoid_config_with_attrs() { + let node = create_test_node(0.3, 0.6); + let (alpha, beta) = hard_sigmoid_config(&node); + assert!((alpha - 0.3).abs() < 1e-6); + assert!((beta - 0.6).abs() < 1e-6); + } + + #[test] + fn test_hard_sigmoid_config_default() { + let mut node = create_test_node(0.3, 0.6); + node.attrs.clear(); // Remove all attributes + let (alpha, beta) = hard_sigmoid_config(&node); + assert_eq!(alpha, 0.2); // Check default values + assert_eq!(beta, 0.5); + } +} diff --git a/crates/onnx-ir/src/node/layer_norm.rs b/crates/onnx-ir/src/node/layer_norm.rs new file mode 100644 index 0000000000..e97e0d085a --- /dev/null +++ b/crates/onnx-ir/src/node/layer_norm.rs @@ -0,0 +1,122 @@ +use crate::ir::Node; + +/// Configuration for LayerNorm operations +#[derive(Debug, Clone)] +pub struct LayerNormConfig { + /// Number of features/model dimension + pub d_model: usize, + /// Small constant added for numerical stability + pub epsilon: f64, +} + +impl LayerNormConfig { + /// Create a new LayerNormConfig + pub fn new(d_model: usize) -> Self { + Self { + d_model, + epsilon: 1e-5, + } + } + + /// Set the epsilon value + pub fn with_epsilon(mut self, epsilon: f64) -> Self { + self.epsilon = epsilon; + self + } +} + +/// Create a LayerNormConfig from the attributes of the node +pub fn layer_norm_config(node: &Node) -> (LayerNormConfig, bool) { + let weight_shape = node.inputs[1] + .value + .as_ref() + .expect("LayerNorm: weight tensor must be present") + .shape + .clone(); + + let num_features = weight_shape[0]; + + // When `stash_type` is `1` (default), perform operations in 32-bit float and + // cast the results back to original dtype + let mut stash_type = 1; + let mut axis = -1; + let mut epsilon = 1e-5; + + for (key, value) in node.attrs.iter() { + match key.as_str() { + "axis" => axis = value.clone().into_i64(), + "epsilon" => epsilon = value.clone().into_f32(), + "stash_type" => stash_type = value.clone().into_i64(), + _ => panic!("Unexpected attribute for LayerNorm: {key}"), + } + } + + if axis != -1 && axis != weight_shape.len() as i64 - 1 { + panic!("LayerNorm: normalization is only supported on the last axis right now") + } + + ( + LayerNormConfig::new(num_features).with_epsilon(epsilon as f64), + stash_type == 1, + ) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::NodeType; + use crate::node::test_utils::NodeBuilder; + + fn create_test_node(epsilon: f32, axis: i64, stash_type: i64, num_features: usize) -> Node { + let weight_data = vec![1.0; num_features]; // Not important for the test + let bias_data = vec![0.0; num_features]; // Not important for the test + + NodeBuilder::new(NodeType::LayerNormalization, "test_layernorm") + .input_tensor_f32("X", 3, None) + .input_tensor_f32_data("scale", weight_data, vec![num_features]) + .input_tensor_f32_data("bias", bias_data, vec![num_features]) + .output_tensor_f32("output", 3, None) + .attr_float("epsilon", epsilon) + .attr_int("axis", axis) + .attr_int("stash_type", stash_type) + .build() + } + + #[test] + fn test_layer_norm_config_basic() { + let node = create_test_node(1e-5, -1, 1, 64); + let (config, stash_type_flag) = layer_norm_config(&node); + + assert_eq!(config.d_model, 64); + assert!(f64::abs(config.epsilon - 1e-5) < 1e-6); + assert!(stash_type_flag); + } + + #[test] + fn test_layer_norm_config_no_stash_type() { + let node = create_test_node(1e-5, -1, 0, 32); + let (config, stash_type_flag) = layer_norm_config(&node); + + assert_eq!(config.d_model, 32); + assert!(!stash_type_flag); + } + + #[test] + #[should_panic] + fn test_layer_norm_config_invalid_axis() { + // For a 1D weight tensor with shape [num_features], + // both axis=0 (the first and only dim) and axis=-1 (the last dim) are valid + // So we need to use a 2D weight tensor to test the invalid axis case + + // Create a custom node with a 2D weight tensor + let mut node = create_test_node(1e-5, 0, 1, 64); + + // Modify the weight tensor to be 2D + if let Some(ref mut tensor) = node.inputs[1].value { + tensor.shape = vec![64, 64]; // Make it 2D + } + + // Now axis=0 should trigger a panic since it's not the last dimension + let _ = layer_norm_config(&node); + } +} diff --git a/crates/onnx-ir/src/node/leaky_relu.rs b/crates/onnx-ir/src/node/leaky_relu.rs new file mode 100644 index 0000000000..ebeff2464a --- /dev/null +++ b/crates/onnx-ir/src/node/leaky_relu.rs @@ -0,0 +1,44 @@ +use crate::ir::Node; + +/// Create a LeakyReluConfig from the alpha attribute of the node +pub fn leaky_relu_config(node: &Node) -> f64 { + let mut alpha = 0.01; + + for (key, value) in node.attrs.iter() { + if key.as_str() == "alpha" { + alpha = value.clone().into_f32() as f64 + } + } + + alpha +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::NodeType; + use crate::node::test_utils::NodeBuilder; + + fn create_test_node(alpha: f32) -> Node { + NodeBuilder::new(NodeType::LeakyRelu, "test_leaky_relu") + .input_tensor_f32("X", 4, None) + .output_tensor_f32("Y", 4, None) + .attr_float("alpha", alpha) + .build() + } + + #[test] + fn test_leaky_relu_config_with_alpha() { + let node = create_test_node(0.2); + let alpha = leaky_relu_config(&node); + assert!((alpha - 0.2).abs() < 1e-6); + } + + #[test] + fn test_leaky_relu_config_default() { + let mut node = create_test_node(0.2); + node.attrs.clear(); // Remove all attributes + let alpha = leaky_relu_config(&node); + assert_eq!(alpha, 0.01); // Check default value + } +} diff --git a/crates/onnx-ir/src/node/linear.rs b/crates/onnx-ir/src/node/linear.rs new file mode 100644 index 0000000000..a9afa761ac --- /dev/null +++ b/crates/onnx-ir/src/node/linear.rs @@ -0,0 +1,138 @@ +use crate::ir::{ArgType, Node, TensorType}; + +/// Configuration for Linear operations +#[derive(Debug, Clone)] +pub struct LinearConfig { + /// Input dimension (features) + pub d_input: usize, + /// Output dimension (features) + pub d_output: usize, + /// Whether bias is used + pub bias: bool, +} + +impl LinearConfig { + /// Create a new LinearConfig + pub fn new(d_input: usize, d_output: usize) -> Self { + Self { + d_input, + d_output, + bias: true, + } + } + + /// Set whether bias is used + pub fn with_bias(mut self, bias: bool) -> Self { + self.bias = bias; + self + } +} + +/// Update output rank for Linear operations (same as input rank). +pub fn linear_update_outputs(node: &mut Node) { + log::debug!("Linear rank inference for node {}", node.name); + + if let ArgType::Tensor(tensor) = &node.inputs[0].ty { + log::debug!("Linear input rank for {}: {}", node.name, tensor.rank); + + node.outputs[0].ty = ArgType::Tensor(TensorType { + elem_type: tensor.elem_type.clone(), + rank: tensor.rank, + static_shape: None, + }); + + log::debug!("Linear output rank for {}: {}", node.name, tensor.rank); + } else { + panic!("Only tensor input is valid"); + } +} + +/// Create a LinearConfig from the attributes of the node +pub fn linear_config(node: &Node) -> LinearConfig { + if node.inputs.len() < 2 { + panic!("Linear: missing weight tensor"); + } + + let weight_shape = node.inputs[1] + .value + .as_ref() + .expect("Linear: weight tensor must be present") + .shape + .clone(); + + // check if the weight tensor has at least 2 dimensions + if weight_shape.len() < 2 { + panic!( + "Linear: weight tensor must have at least 2 dimensions (got {:?})", + weight_shape.len() + ); + } + + let (in_size, out_size) = (weight_shape[0], weight_shape[1]); + + // check if the bias is present + let bias = node.inputs.len() == 3 && node.inputs[2].value.is_some(); + + LinearConfig::new(in_size, out_size).with_bias(bias) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::NodeType; + use crate::node::test_utils::NodeBuilder; + + fn create_test_node(has_bias: bool, weight_dims: Vec) -> Node { + // Create weight tensor data + let weight_data = vec![0.0; weight_dims.iter().product()]; // Not important for the test + + // Start building the node with input and weight + let mut builder = NodeBuilder::new(NodeType::Gemm, "test_linear") + .input_tensor_f32("input", 2, None) + .input_tensor_f32_data("weight", weight_data, weight_dims.clone()) + .output_tensor_f32("output", 2, None); + + // Add bias if needed + if has_bias { + let bias_data = vec![0.0; weight_dims[1]]; // bias size equals output size + builder = builder.input_tensor_f32_data("bias", bias_data, vec![weight_dims[1]]); + } + + builder.build() + } + + #[test] + fn test_linear_config_basic() { + let node = create_test_node(false, vec![10, 5]); + let config = linear_config(&node); + + assert_eq!(config.d_input, 10); + assert_eq!(config.d_output, 5); + assert!(!config.bias); + } + + #[test] + fn test_linear_config_with_bias() { + let node = create_test_node(true, vec![10, 5]); + let config = linear_config(&node); + + assert_eq!(config.d_input, 10); + assert_eq!(config.d_output, 5); + assert!(config.bias); + } + + #[test] + #[should_panic(expected = "Linear: weight tensor must have at least 2 dimensions")] + fn test_linear_config_invalid_weight_dims() { + let node = create_test_node(false, vec![10]); + let _ = linear_config(&node); + } + + #[test] + #[should_panic(expected = "Linear: missing weight tensor")] + fn test_linear_config_missing_weight() { + let mut node = create_test_node(false, vec![10, 5]); + node.inputs.remove(1); + let _ = linear_config(&node); + } +} diff --git a/crates/onnx-ir/src/node/log_softmax.rs b/crates/onnx-ir/src/node/log_softmax.rs new file mode 100644 index 0000000000..c06e4b74e1 --- /dev/null +++ b/crates/onnx-ir/src/node/log_softmax.rs @@ -0,0 +1,79 @@ +use crate::ir::{ArgType, Node}; + +/// Create log_softmax config from the attributes of the node +pub fn log_softmax_config(node: &Node) -> usize { + // the axis is the last dimension (Default: 1 per ONNX spec) + let mut axis: i64 = -1; + + // check if the node has only one input + if node.inputs.len() != 1 { + panic!( + "LogSoftmax: multiple inputs are not supported (got {:?})", + node.inputs.len() + ); + } + + // extract the shape of the input tensor + let tensor = match node.inputs.first().unwrap().clone().ty { + ArgType::Tensor(tensor) => tensor, + _ => panic!("Only tensor input is valid"), + }; + + // extract the attributes + for (key, value) in node.attrs.iter() { + if key.as_str() == "axis" { + axis = value.clone().into_i64() + } + } + + // if axis is negative, it is counted from the end + if axis < 0 { + axis += tensor.rank as i64; + } + + axis as usize +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::NodeType; + use crate::node::test_utils::NodeBuilder; + + fn create_test_node(axis: i64, input_rank: usize) -> Node { + NodeBuilder::new(NodeType::LogSoftmax, "test_log_softmax") + .input_tensor_f32("data", input_rank, None) + .output_tensor_f32("output", input_rank, None) + .attr_int("axis", axis) + .build() + } + + #[test] + fn test_log_softmax_config_basic() { + let node = create_test_node(-1, 3); + let config = log_softmax_config(&node); + assert_eq!(config, 2); // -1 + 3 = 2 (last dimension) + } + + #[test] + fn test_log_softmax_config_explicit_axis() { + let node = create_test_node(1, 3); + let config = log_softmax_config(&node); + assert_eq!(config, 1); + } + + #[test] + #[should_panic(expected = "LogSoftmax: multiple inputs are not supported")] + fn test_log_softmax_config_multiple_inputs() { + let mut node = create_test_node(1, 3); + // Add an extra input + let extra_input = NodeBuilder::new(NodeType::Identity, "temp") + .input_tensor_f32("extra", 1, None) + .build() + .inputs + .pop() + .unwrap(); + node.inputs.push(extra_input); + let _ = log_softmax_config(&node); + } +} diff --git a/crates/onnx-ir/src/node/matmul.rs b/crates/onnx-ir/src/node/matmul.rs new file mode 100644 index 0000000000..403f5d4942 --- /dev/null +++ b/crates/onnx-ir/src/node/matmul.rs @@ -0,0 +1,103 @@ +use crate::ir::{ArgType, Node, TensorType}; +use core::cmp::max; + +/// Update output rank for MatMul based on input ranks. +pub fn matmul_update_outputs(node: &mut Node) { + log::debug!("MatMul rank inference for node {}", node.name); + + match (&node.inputs[0].ty, &node.inputs[1].ty) { + (ArgType::Tensor(a), ArgType::Tensor(b)) => { + log::debug!( + "MatMul input ranks for {}: a.rank={}, b.rank={}", + node.name, + a.rank, + b.rank + ); + + let mut out_rank = max(a.rank, b.rank); + if (a.rank >= 2 && b.rank == 1) || (a.rank == 1 && b.rank >= 2) { + out_rank -= 1; + log::debug!( + "MatMul special case for node {}: reducing output rank", + node.name + ); + } + + node.outputs[0].ty = ArgType::Tensor(TensorType { + elem_type: a.elem_type.clone(), + rank: out_rank, + static_shape: None, + }); + + log::debug!("MatMul output rank for {}: {}", node.name, out_rank); + } + _ => panic!("Only tensor inputs are valid"), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::{ElementType, NodeType}; + use crate::node::test_utils::NodeBuilder; + + fn create_test_node(a_rank: usize, b_rank: usize) -> Node { + NodeBuilder::new(NodeType::MatMul, "test_matmul") + .input_tensor_f32("A", a_rank, None) + .input_tensor_f32("B", b_rank, None) + .output_tensor_f32("C", 0, None) // Rank will be updated + .build() + } + + #[test] + fn test_matmul_standard_case() { + let mut node = create_test_node(2, 2); + matmul_update_outputs(&mut node); + + match &node.outputs[0].ty { + ArgType::Tensor(tensor) => { + assert_eq!(tensor.elem_type, ElementType::Float32); + assert_eq!(tensor.rank, 2); + } + _ => panic!("Expected tensor output"), + } + } + + #[test] + fn test_matmul_broadcasting() { + let mut node = create_test_node(3, 2); + matmul_update_outputs(&mut node); + + match &node.outputs[0].ty { + ArgType::Tensor(tensor) => { + assert_eq!(tensor.elem_type, ElementType::Float32); + assert_eq!(tensor.rank, 3); + } + _ => panic!("Expected tensor output"), + } + } + + #[test] + fn test_matmul_vector_matrix() { + // When multiplying a vector (rank 1) by a matrix (rank 2) + // the result should have rank 1 (vector) + let mut node = create_test_node(1, 2); + matmul_update_outputs(&mut node); + + match &node.outputs[0].ty { + ArgType::Tensor(tensor) => { + assert_eq!(tensor.elem_type, ElementType::Float32); + assert_eq!(tensor.rank, 1); + } + _ => panic!("Expected tensor output"), + } + } + + #[test] + #[should_panic(expected = "Only tensor inputs are valid")] + fn test_matmul_invalid_input() { + let mut node = create_test_node(2, 2); + node.inputs[0].ty = ArgType::Scalar(ElementType::Float32); + matmul_update_outputs(&mut node); + } +} diff --git a/crates/onnx-ir/src/node/max_pool1d.rs b/crates/onnx-ir/src/node/max_pool1d.rs new file mode 100644 index 0000000000..6112e7fba1 --- /dev/null +++ b/crates/onnx-ir/src/node/max_pool1d.rs @@ -0,0 +1,145 @@ +use crate::{ir::Node, node::padding::padding_config_1d}; + +use super::padding::PaddingConfig1d; + +/// Configuration for MaxPool1d operations extracted from ONNX nodes +#[derive(Debug, Clone)] +pub struct MaxPool1dConfig { + /// Kernel size + pub kernel_size: usize, + /// Stride + pub stride: usize, + /// Dilation + pub dilation: usize, + /// Padding configuration + pub padding: PaddingConfig1d, +} + +impl MaxPool1dConfig { + /// Create a new MaxPool1dConfig + pub fn new(kernel_size: usize) -> Self { + Self { + kernel_size, + stride: 1, + padding: PaddingConfig1d::Valid, + dilation: 1, + } + } + + /// Set the stride + pub fn with_stride(mut self, stride: usize) -> Self { + self.stride = stride; + self + } + + /// Set the padding configuration + pub fn with_padding(mut self, padding: PaddingConfig1d) -> Self { + self.padding = padding; + self + } + + /// Set the dilation + pub fn with_dilation(mut self, dilation: usize) -> Self { + self.dilation = dilation; + self + } +} + +/// Create a MaxPool1dConfig from the attributes of the node +pub fn max_pool1d_config(curr: &Node) -> MaxPool1dConfig { + let mut kernel_shape = Vec::new(); + let mut stride = vec![1]; + let mut pads = vec![0, 0]; + let mut dilation = vec![1]; + + for (key, value) in curr.attrs.iter() { + match key.as_str() { + "kernel_shape" => kernel_shape = value.clone().into_i64s(), + "strides" => stride = value.clone().into_i64s(), + "pads" => pads = value.clone().into_i64s(), + "dilations" => dilation = value.clone().into_i64s(), + // These are attributes that are allowed but not used in this implementation + "auto_pad" | "ceil_mode" | "storage_order" => {} + _ => panic!("Unexpected attribute for MaxPool1d: {key}"), + } + } + + assert_eq!( + kernel_shape.len(), + 1, + "MaxPool1d: kernel shape must have length 1" + ); + assert_eq!(dilation.len(), 1, "MaxPool1d: dilation must have length 1"); + assert_eq!(stride.len(), 1, "MaxPool1d: stride must have length 1"); + + let padding = padding_config_1d(&pads); + + MaxPool1dConfig { + kernel_size: kernel_shape[0] as usize, + stride: stride[0] as usize, + dilation: dilation[0] as usize, + padding, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ir::NodeType, node::padding::PaddingConfig1d, node::test_utils::NodeBuilder}; + + fn create_test_node( + kernel_shape: Vec, + stride: Vec, + pads: Vec, + dilation: Vec, + ) -> Node { + NodeBuilder::new(NodeType::MaxPool1d, "test_maxpool1d") + .input_tensor_f32("data", 3, None) + .output_tensor_f32("output", 3, None) + .attr_ints("kernel_shape", kernel_shape) + .attr_ints("strides", stride) + .attr_ints("pads", pads) + .attr_ints("dilations", dilation) + .build() + } + + #[test] + fn test_max_pool1d_config_basic() { + let node = create_test_node(vec![4], vec![1], vec![0, 0], vec![1]); + let config = max_pool1d_config(&node); + + assert_eq!(config.kernel_size, 4); + assert_eq!(config.stride, 1); + assert_eq!(config.dilation, 1); + assert!(matches!(config.padding, PaddingConfig1d::Valid)); + } + + #[test] + fn test_max_pool1d_config_with_padding() { + let node = create_test_node(vec![4], vec![2], vec![2, 2], vec![1]); + let config = max_pool1d_config(&node); + + assert_eq!(config.kernel_size, 4); + assert_eq!(config.stride, 2); + assert_eq!(config.dilation, 1); + assert!(matches!(config.padding, PaddingConfig1d::Explicit(2))); + } + + #[test] + fn test_max_pool1d_config_with_dilation() { + let node = create_test_node(vec![4], vec![1], vec![0, 0], vec![2]); + let config = max_pool1d_config(&node); + + assert_eq!(config.kernel_size, 4); + assert_eq!(config.stride, 1); + assert_eq!(config.dilation, 2); + assert!(matches!(config.padding, PaddingConfig1d::Valid)); + } + + #[test] + #[should_panic(expected = "Asymmetric padding is not supported")] + fn test_max_pool1d_config_asymmetric_padding() { + let node = create_test_node(vec![4], vec![1], vec![1, 2], vec![1]); + let _ = max_pool1d_config(&node); + } +} diff --git a/crates/onnx-ir/src/node/max_pool2d.rs b/crates/onnx-ir/src/node/max_pool2d.rs new file mode 100644 index 0000000000..9883f86b6a --- /dev/null +++ b/crates/onnx-ir/src/node/max_pool2d.rs @@ -0,0 +1,128 @@ +use crate::ir::Node; +use crate::node::padding::{PaddingConfig2d, padding_config_2d}; + +/// Configuration for MaxPool2d operations +#[derive(Debug, Clone)] +pub struct MaxPool2dConfig { + /// Kernel size [height, width] + pub kernel_size: [usize; 2], + /// Stride [height, width] + pub strides: [usize; 2], + /// Padding configuration + pub padding: PaddingConfig2d, + /// Dilation [height, width] + pub dilation: [usize; 2], +} + +impl MaxPool2dConfig { + /// Create a new MaxPool2dConfig + pub fn new(kernel_size: [usize; 2]) -> Self { + Self { + kernel_size, + strides: [1, 1], + padding: PaddingConfig2d::Valid, + dilation: [1, 1], + } + } + + /// Set the strides + pub fn with_strides(mut self, strides: [usize; 2]) -> Self { + self.strides = strides; + self + } + + /// Set the padding configuration + pub fn with_padding(mut self, padding: PaddingConfig2d) -> Self { + self.padding = padding; + self + } + + /// Set the dilation + pub fn with_dilation(mut self, dilation: [usize; 2]) -> Self { + self.dilation = dilation; + self + } +} + +/// Create a MaxPool2dConfig from the attributes of the node +pub fn max_pool2d_config(curr: &Node) -> MaxPool2dConfig { + let mut kernel_shape = Vec::new(); + let mut strides = vec![1, 1]; + let mut pads = vec![0, 0, 0, 0]; + let mut dilations = vec![1, 1]; + + for (key, value) in curr.attrs.iter() { + match key.as_str() { + "kernel_shape" => kernel_shape = value.clone().into_i64s(), + "strides" => strides = value.clone().into_i64s(), + "pads" => pads = value.clone().into_i64s(), + "dilations" => dilations = value.clone().into_i64s(), + // These are attributes that are allowed but not used in this implementation + "auto_pad" | "ceil_mode" | "storage_order" => {} + _ => panic!("Unexpected attribute for MaxPool2d: {key}"), + } + } + + let padding = padding_config_2d(&pads); + + MaxPool2dConfig::new([kernel_shape[0] as usize, kernel_shape[1] as usize]) + .with_strides([strides[0] as usize, strides[1] as usize]) + .with_padding(padding) + .with_dilation([dilations[0] as usize, dilations[1] as usize]) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::NodeType; + use crate::node::test_utils::NodeBuilder; + + fn create_test_node( + kernel_shape: Vec, + strides: Vec, + pads: Vec, + dilations: Vec, + ) -> Node { + NodeBuilder::new(NodeType::MaxPool2d, "test_maxpool2d") + .input_tensor_f32("data", 4, None) + .output_tensor_f32("output", 4, None) + .attr_ints("kernel_shape", kernel_shape) + .attr_ints("strides", strides) + .attr_ints("pads", pads) + .attr_ints("dilations", dilations) + .build() + } + + #[test] + fn test_max_pool2d_config_basic() { + let node = create_test_node(vec![3, 3], vec![1, 1], vec![0, 0, 0, 0], vec![1, 1]); + let config = max_pool2d_config(&node); + + assert_eq!(config.kernel_size, [3, 3]); + assert_eq!(config.strides, [1, 1]); + assert_eq!(config.dilation, [1, 1]); + assert!(matches!(config.padding, PaddingConfig2d::Valid)); + } + + #[test] + fn test_max_pool2d_config_with_padding() { + let node = create_test_node(vec![2, 2], vec![2, 2], vec![1, 1, 1, 1], vec![1, 1]); + let config = max_pool2d_config(&node); + + assert_eq!(config.kernel_size, [2, 2]); + assert_eq!(config.strides, [2, 2]); + assert_eq!(config.dilation, [1, 1]); + assert!(matches!(config.padding, PaddingConfig2d::Explicit(1, 1))); + } + + #[test] + fn test_max_pool2d_config_with_dilation() { + let node = create_test_node(vec![3, 3], vec![1, 1], vec![0, 0, 0, 0], vec![2, 2]); + let config = max_pool2d_config(&node); + + assert_eq!(config.kernel_size, [3, 3]); + assert_eq!(config.strides, [1, 1]); + assert_eq!(config.dilation, [2, 2]); + assert!(matches!(config.padding, PaddingConfig2d::Valid)); + } +} diff --git a/crates/onnx-ir/src/node/mod.rs b/crates/onnx-ir/src/node/mod.rs index 913812d439..f47a06d9a6 100644 --- a/crates/onnx-ir/src/node/mod.rs +++ b/crates/onnx-ir/src/node/mod.rs @@ -1 +1,65 @@ +//! Node module contains implementations of ONNX operations. +//! +//! Each submodule implements a specific ONNX operation, providing: +//! - Operation configuration and parameters +//! - Rank inference functionality +//! +//! This modular structure allows for clean separation of operation implementations +//! and facilitates easier maintenance and extension of the ONNX operation set. + +#[cfg(test)] +pub mod test_utils; + +pub mod argmax; +pub mod avg_pool1d; +pub mod avg_pool2d; +pub mod batch_norm; +pub mod cast; +pub mod clip; +pub mod comparison; +pub mod concat; +pub mod constant; +pub mod constant_of_shape; +pub mod conv1d; +pub mod conv2d; +pub mod conv3d; +pub mod conv_transpose1d; +pub mod conv_transpose2d; +pub mod conv_transpose3d; +pub mod dropout; +pub mod expand; +pub mod flatten; +pub mod gather; +pub mod gemm; +pub mod hard_sigmoid; +pub mod layer_norm; +pub mod leaky_relu; +pub mod linear; +pub mod log_softmax; +pub mod matmul; +pub mod max_pool1d; +pub mod max_pool2d; +pub mod one_hot; +pub mod pad; +pub mod padding; +pub mod random; +pub mod random_like; +pub mod range; +pub mod reduce_max; +pub mod reduce_mean; +pub mod reduce_min; +pub mod reduce_prod; +pub mod reduce_sum; +pub mod reshape; +pub mod resize; +pub mod shape; pub mod slice; +pub mod softmax; +pub mod split; +pub mod squeeze; +pub mod tile; +pub mod topk; +pub mod transpose; +pub mod trilu; +pub mod unsqueeze; +pub mod where_op; diff --git a/crates/onnx-ir/src/node/one_hot.rs b/crates/onnx-ir/src/node/one_hot.rs new file mode 100644 index 0000000000..1ccff14d45 --- /dev/null +++ b/crates/onnx-ir/src/node/one_hot.rs @@ -0,0 +1,109 @@ +use crate::ir::{ArgType, Node, TensorType}; + +pub fn one_hot_config(curr: &Node) -> (usize, [f32; 2], i64) { + let depth = curr.inputs[1] + .value + .clone() + .expect("OneHot: Only constant depth is currently supported") + .data + .into_i64(); + + let values = curr.inputs[2] + .value + .clone() + .expect("OneHot: Only constant on/off values is currently supported") + .data + .into_f32s(); + + let axis = curr + .attrs + .get("axis") + .map(|val| val.clone().into_i64()) + .unwrap_or(-1); + + (depth as usize, values.try_into().unwrap(), axis) +} + +/// Update output rank for OneHot (input rank + 1). +pub fn one_hot_output_shape(node: &mut Node) { + log::debug!("OneHot rank inference for node {}", node.name); + + let input_rank = match &node.inputs[0].ty { + ArgType::Tensor(tensor) => tensor.rank, + _ => panic!("OneHot: invalid input type"), + }; + log::debug!("OneHot input rank for {}: {}", node.name, input_rank); + + let output_rank = input_rank + 1; + log::debug!("OneHot output rank for {}: {}", node.name, output_rank); + + node.outputs[0].ty = ArgType::Tensor(TensorType { + elem_type: node.outputs[0].ty.elem_type().clone(), + rank: output_rank, + static_shape: None, + }); +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::NodeType; + use crate::node::test_utils::NodeBuilder; + + fn create_test_node(depth: i64, values: Vec, axis: Option) -> Node { + let mut builder = NodeBuilder::new(NodeType::OneHot, "test_one_hot") + .input_tensor_i64("indices", 2, None) + .input_scalar_tensor_i64("depth", depth) + .input_tensor_f32_data("values", values.clone(), vec![2]) // always [off_value, on_value] + .output_tensor_f32("output", 3, None); // rank increases by 1 + + if let Some(axis_val) = axis { + builder = builder.attr_int("axis", axis_val); + } + + builder.build() + } + + #[test] + fn test_one_hot_config_basic() { + let node = create_test_node(5, vec![0.0, 1.0], None); + let (depth, values, axis) = one_hot_config(&node); + assert_eq!(depth, 5); + assert_eq!(values, [0.0, 1.0]); + assert_eq!(axis, -1); // default axis + } + + #[test] + fn test_one_hot_config_with_axis() { + let node = create_test_node(5, vec![0.0, 1.0], Some(1)); + let (depth, values, axis) = one_hot_config(&node); + assert_eq!(depth, 5); + assert_eq!(values, [0.0, 1.0]); + assert_eq!(axis, 1); + } + + #[test] + fn test_one_hot_config_custom_values() { + let node = create_test_node(10, vec![-1.0, 2.0], None); + let (depth, values, axis) = one_hot_config(&node); + assert_eq!(depth, 10); + assert_eq!(values, [-1.0, 2.0]); // custom off/on values + assert_eq!(axis, -1); + } + + #[test] + #[should_panic(expected = "Only constant depth is currently supported")] + fn test_one_hot_config_no_depth_value() { + let mut node = create_test_node(5, vec![0.0, 1.0], None); + node.inputs[1].value = None; // Remove depth value + let _ = one_hot_config(&node); + } + + #[test] + #[should_panic(expected = "Only constant on/off values is currently supported")] + fn test_one_hot_config_no_values() { + let mut node = create_test_node(5, vec![0.0, 1.0], None); + node.inputs[2].value = None; // Remove values + let _ = one_hot_config(&node); + } +} diff --git a/crates/onnx-ir/src/node/pad.rs b/crates/onnx-ir/src/node/pad.rs new file mode 100644 index 0000000000..c1eceda76f --- /dev/null +++ b/crates/onnx-ir/src/node/pad.rs @@ -0,0 +1,351 @@ +use crate::ir::{ArgType, AttributeValue, Data, Node, TensorData}; + +/// Configuration for the Pad operation. +#[derive(Debug, Clone, PartialEq)] +pub struct PadConfig { + /// The paddings to be applied to each dimension. + pub pads: Vec, + /// The constant value to fill the padded areas with. + pub constant_value: f32, +} + +impl PadConfig { + pub fn new(pads: Vec, constant_value: f32) -> Self { + PadConfig { + pads, + constant_value, + } + } +} + +/// Creates a PadConfig from the node attributes and inputs. +pub fn pad_config(node: &Node) -> PadConfig { + fn get_pads_input(node: &Node) -> Vec { + if node.inputs.len() <= 1 { + return Vec::new(); + } + + match &node.inputs[1].value { + Some(TensorData { data, .. }) => data.clone().into_i64s(), + _ => Vec::new(), + } + } + fn get_pads(node: &Node) -> Vec { + if node.inputs.is_empty() { + panic!("Pad: must provide data as input") + } + if node.inputs.len() >= 4 { + panic!("Pad: axes input is not supported") + } + + let input_dim = match &node.inputs.first().unwrap().ty { + ArgType::Tensor(tensor) => tensor.rank, + _ => panic!("Pad: Only tensor input is valid"), + }; + + // TODO: Handle more possible attributes + let mut pads: Vec = get_pads_input(node) + .into_iter() + .map(|x| x as usize) + .collect(); + + for (key, value) in node.attrs.iter() { + match key.as_str() { + "pads" => { + pads = value + .clone() + .into_i64s() + .iter() + .map(|&x| { + if x < 0 { + panic!("Pad: Negative pad is not supported"); + } + x as usize + }) + .collect() + } + "mode" => { + let mode = value.clone().into_string(); + if mode != "constant" { + panic!("only constant mode is supported, given mode is {}", mode); + } + } + + _ => {} + } + } + + if pads.is_empty() { + panic!("Pad: pads should be given as attribute or as input"); + } + + if pads.len() != input_dim * 2 { + panic!("Pad: pads should be a 1D tensor of shape [2 * num_axes]"); + } + // TODO: Burn's pad should support 1D tensor + if input_dim < 2 { + panic!("Pad: input tensor should be rank 2 or higher"); + } + + let left_index = input_dim - 1; + let top_index = input_dim - 2; + let right_index = pads.len() - 1; + let bottom_index = pads.len() - 2; + let index_list = [left_index, top_index, right_index, bottom_index]; + + for (index, &item) in pads.iter().enumerate() { + if !index_list.contains(&index) && item != 0 { + panic!( + "Pad: padding will only be applied to the last two dimensions but found non zero padding for other dimensions" + ); + } + } + + let left = pads[left_index]; + let top = pads[top_index]; + let right = pads[right_index]; + let bottom = pads[bottom_index]; + vec![left, right, top, bottom] + } + fn get_constant_value(node: &Node) -> f32 { + // TODO: Support int, boolean + let mut constant_value = node.inputs + .get(2) + .and_then(|input| match &input.value.as_ref().expect("Value input must be present").data { + Data::Float16s(constant_value) => { + constant_value.first().map(|&f| f32::from(f)) + } + Data::Float32s(constant_value) => { + constant_value.first().copied() + } + Data::Float64s(constant_value) => { + constant_value.first().map(|&f| f as f32) + } + Data::Float16(constant_value) => Some(f32::from(*constant_value)), + Data::Float32(constant_value) => Some(*constant_value), + Data::Float64(constant_value) => Some(*constant_value as f32), + _ => panic!("Pad: only float values are currently supported for constant value, submit an issue on github"), + }) + .unwrap_or(0.0); + + if node.attrs.contains_key("value") { + constant_value = node.attrs.get("value").map(|value| match value { + AttributeValue::Float32(value) => *value, + _ => panic!("Pad: only float32 values are currently supported for constant value as attribute, submit an issue on github"), + }).expect("constant_value should have had a value now"); + } + constant_value + } + + let pads = get_pads(node); + let constant_value = get_constant_value(node); + + PadConfig::new(pads, constant_value) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::{ArgType, Argument, Data, ElementType, NodeType, TensorData, TensorType}; + use crate::node::test_utils::NodeBuilder; + + fn create_test_node( + pad_attrs: Option>, + pad_inputs: Option>, + constant_value_attr: Option, + constant_value_input: Option, + mode: Option<&str>, + rank: usize, + ) -> Node { + let mut builder = NodeBuilder::new(NodeType::Pad, "test_pad") + .input_tensor_f32("data", rank, None) + .output_tensor_f32("output", rank, None); + + // Add pad inputs if provided + if let Some(pads) = pad_inputs.clone() { + builder = builder.input_tensor_i64_data("pads", pads, vec![]); + } + + // Add constant value input if provided + if let Some(value) = constant_value_input { + builder = builder.input_scalar_tensor_f32("constant_value", Some(value)); + } + + // Add attributes if provided + if let Some(pads) = pad_attrs { + builder = builder.attr_ints("pads", pads); + } + + if let Some(value) = constant_value_attr { + builder = builder.attr_float("value", value); + } + + if let Some(mode_val) = mode { + builder = builder.attr_string("mode", mode_val); + } + + builder.build() + } + + #[test] + fn test_pad_config_with_attrs() { + // Test for 2D tensor (rank 2) + let pads = vec![0, 0, 1, 1]; + let node = create_test_node( + Some(pads.clone()), + None, + Some(0.0), + None, + Some("constant"), + 2, + ); + let config = pad_config(&node); + assert_eq!( + config, + PadConfig { + pads: vec![0, 1, 0, 1], + constant_value: 0.0 + } + ); + } + + #[test] + fn test_pad_config_with_inputs() { + // For a 2D tensor, pads should have 4 values (2*rank) + let pads = vec![0, 0, 1, 1]; + let node = create_test_node(None, Some(pads.clone()), None, Some(1.0), None, 2); + let config = pad_config(&node); + assert_eq!( + config, + PadConfig { + pads: vec![0, 1, 0, 1], + constant_value: 1.0 + } + ); + } + + #[test] + fn test_pad_config_with_3d_tensor() { + // For a 3D tensor, pads should have 6 values (2*rank) + let pads = vec![0, 0, 0, 0, 1, 1]; + let node = create_test_node( + Some(pads.clone()), + None, + Some(0.5), + None, + Some("constant"), + 3, + ); + let config = pad_config(&node); + assert_eq!( + config, + PadConfig { + pads: vec![0, 1, 0, 1], + constant_value: 0.5 + } + ); + } + + #[test] + fn test_pad_config_attrs_override_inputs() { + // Attributes should override inputs + let attr_pads = vec![0, 0, 2, 2]; + let input_pads = vec![0, 0, 1, 1]; + let node = create_test_node( + Some(attr_pads.clone()), + Some(input_pads), + Some(0.0), + Some(1.0), + Some("constant"), + 2, + ); + let config = pad_config(&node); + assert_eq!( + config, + PadConfig { + pads: vec![0, 2, 0, 2], + constant_value: 0.0 + } + ); + } + + #[test] + #[should_panic(expected = "Pad: must provide data as input")] + fn test_pad_config_no_inputs() { + let mut node = create_test_node(None, None, None, None, None, 2); + node.inputs = vec![]; + let _ = pad_config(&node); + } + + #[test] + #[should_panic(expected = "Pad: Only tensor input is valid")] + fn test_pad_config_invalid_input_type() { + let mut node = create_test_node(Some(vec![0, 0, 1, 1]), None, None, None, None, 2); + node.inputs[0].ty = ArgType::Scalar(ElementType::Float32); + let _ = pad_config(&node); + } + + #[test] + #[should_panic(expected = "Pad: axes input is not supported")] + fn test_pad_config_with_axes_input() { + // Create node with 4 inputs (including axes) + let mut node = create_test_node(None, Some(vec![0, 0, 1, 1]), None, Some(0.0), None, 2); + node.inputs.push(Argument { + name: "axes".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Int64, + rank: 1, + static_shape: None, + }), + value: Some(TensorData { + data: Data::Int64s(vec![0, 1]), + shape: vec![], + }), + passed: true, + }); + let _ = pad_config(&node); + } + + #[test] + #[should_panic(expected = "Pad: Negative pad is not supported")] + fn test_pad_config_negative_pad() { + let node = create_test_node(Some(vec![0, 0, -1, 1]), None, None, None, None, 2); + let _ = pad_config(&node); + } + + #[test] + #[should_panic(expected = "only constant mode is supported")] + fn test_pad_config_unsupported_mode() { + let node = create_test_node(Some(vec![0, 0, 1, 1]), None, None, None, Some("reflect"), 2); + let _ = pad_config(&node); + } + + #[test] + #[should_panic(expected = "Pad: pads should be given as attribute or as input")] + fn test_pad_config_no_pads() { + let node = create_test_node(None, None, None, None, None, 2); + let _ = pad_config(&node); + } + + #[test] + #[should_panic(expected = "Pad: pads should be a 1D tensor of shape [2 * num_axes]")] + fn test_pad_config_invalid_pads_length() { + let node = create_test_node(Some(vec![0, 0, 1]), None, None, None, None, 2); + let _ = pad_config(&node); + } + + #[test] + #[should_panic(expected = "Pad: input tensor should be rank 2 or higher")] + fn test_pad_config_invalid_tensor_rank() { + let node = create_test_node(Some(vec![0, 1]), None, None, None, None, 1); + let _ = pad_config(&node); + } + + #[test] + #[should_panic(expected = "Pad: padding will only be applied to the last two dimensions")] + fn test_pad_config_non_zero_padding_on_other_dimensions() { + // For a 3D tensor, we try to set non-zero padding on first dimension + let node = create_test_node(Some(vec![1, 0, 0, 0, 1, 1]), None, None, None, None, 3); + let _ = pad_config(&node); + } +} diff --git a/crates/onnx-ir/src/node/padding.rs b/crates/onnx-ir/src/node/padding.rs new file mode 100644 index 0000000000..3fba910027 --- /dev/null +++ b/crates/onnx-ir/src/node/padding.rs @@ -0,0 +1,231 @@ +use std::fmt; + +/// Padding configuration for 1D operations such as convolution +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum PaddingConfig1d { + /// No padding (valid padding) + Valid, + /// Explicit padding with a specific size + Explicit(usize), +} + +impl fmt::Display for PaddingConfig1d { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + PaddingConfig1d::Valid => write!(f, "Valid"), + PaddingConfig1d::Explicit(size) => write!(f, "Explicit({})", size), + } + } +} + +/// Calculate the padding configuration for a 1D operations such as Convolution and Pooling. +/// +/// # Arguments +/// +/// * `pads` - The padding values +/// +/// # Panics +/// +/// * If the padding is negative +/// * If the padding is not symmetric +/// +/// # Returns +/// +/// * The padding configuration +/// +/// # Remarks +/// +/// This function is used when the padding is specified as a list of integers, +/// and not used when the padding is specified as a string, e.g. "SAME_UPPER". +pub fn padding_config_1d(pads: &[i64]) -> PaddingConfig1d { + let [left, right] = [pads[0], pads[1]]; + + if left < 0 || right < 0 { + panic!("Negative pad values are not supported"); + } else if left != right { + panic!("Asymmetric padding is not supported"); + } else if left == 0 && right == 0 { + // i.e. [0, 0] + PaddingConfig1d::Valid + } else if left == right { + // i.e. [2, 2] + PaddingConfig1d::Explicit(left as usize) + } else { + // Unaccounted for padding configuration + panic!("Padding configuration ({:?}) not supported", pads); + } +} + +/// Padding configuration for 2D operations such as convolution +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum PaddingConfig2d { + /// No padding (valid padding) + Valid, + /// Explicit padding with specific width and height + Explicit(usize, usize), +} + +impl fmt::Display for PaddingConfig2d { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + PaddingConfig2d::Valid => write!(f, "Valid"), + PaddingConfig2d::Explicit(width, height) => { + write!(f, "Explicit({}, {})", width, height) + } + } + } +} + +/// Padding configuration for 3D operations such as convolution +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum PaddingConfig3d { + /// No padding (valid padding) + Valid, + /// Explicit padding with specific width, height, and depth + Explicit(usize, usize, usize), +} + +impl fmt::Display for PaddingConfig3d { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + PaddingConfig3d::Valid => write!(f, "Valid"), + PaddingConfig3d::Explicit(width, height, depth) => { + write!(f, "Explicit({}, {}, {})", width, height, depth) + } + } + } +} + +/// Calculate the padding configuration for a 2D operations such as Convolution and Pooling. +/// +/// # Arguments +/// +/// * `pads` - The padding values [left, right, top, bottom] +/// +/// # Panics +/// +/// * If the padding is negative +/// * If the padding is not symmetric +/// +/// # Returns +/// +/// * The padding configuration +/// +/// # Remarks +/// +/// This function is used when the padding is specified as a list of integers, +/// and not used when the padding is specified as a string, e.g. "SAME_UPPER". +pub fn padding_config_2d(pads: &[i64]) -> PaddingConfig2d { + let [top, left, bottom, right] = [pads[0], pads[1], pads[2], pads[3]]; + + if left < 0 || right < 0 || top < 0 || bottom < 0 { + panic!("Negative pad values are not supported"); + } else if left != right || top != bottom { + panic!("Asymmetric padding is not supported"); + } else if left == 0 && right == 0 && top == 0 && bottom == 0 { + PaddingConfig2d::Valid + } else if left == right && top == bottom { + PaddingConfig2d::Explicit(top as usize, left as usize) + } else { + // Unaccounted for padding configuration + panic!("Padding configuration ({:?}) not supported", pads); + } +} + +/// Calculate the padding configuration for a 3D operations such as Convolution and Pooling. +/// +/// # Arguments +/// +/// * `pads` - The padding values [left, right, top, bottom, front, back] +/// +/// # Panics +/// +/// * If the padding is negative +/// * If the padding is not symmetric +/// +/// # Returns +/// +/// * The padding configuration +/// +/// # Remarks +/// +/// This function is used when the padding is specified as a list of integers, +/// and not used when the padding is specified as a string, e.g. "SAME_UPPER". +pub fn padding_config_3d(pads: &[i64]) -> PaddingConfig3d { + let [front, top, left, back, bottom, right] = + [pads[0], pads[1], pads[2], pads[3], pads[4], pads[5]]; + + if left < 0 || right < 0 || top < 0 || bottom < 0 || front < 0 || back < 0 { + panic!("Negative pad values are not supported"); + } else if left != right || top != bottom || front != back { + panic!("Asymmetric padding is not supported"); + } else if left == 0 && right == 0 && top == 0 && bottom == 0 && front == 0 && back == 0 { + PaddingConfig3d::Valid + } else if left == right && top == bottom && front == back { + PaddingConfig3d::Explicit(front as usize, top as usize, left as usize) + } else { + // Unaccounted for padding configuration + panic!("Padding configuration ({:?}) not supported", pads); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_padding_config_2d_valid() { + let pads = vec![0, 0, 0, 0]; + let config = padding_config_2d(&pads); + assert!(matches!(config, PaddingConfig2d::Valid)); + } + + #[test] + fn test_padding_config_2d_explicit() { + let pads = vec![2, 2, 2, 2]; + let config = padding_config_2d(&pads); + assert!(matches!(config, PaddingConfig2d::Explicit(2, 2))); + } + + #[test] + #[should_panic(expected = "Asymmetric padding is not supported")] + fn test_padding_config_2d_asymmetric() { + let pads = vec![2, 3, 2, 2]; + let _ = padding_config_2d(&pads); + } + + #[test] + #[should_panic(expected = "Negative pad values are not supported")] + fn test_padding_config_2d_negative() { + let pads = vec![-1, -1, -1, -1]; + let _ = padding_config_2d(&pads); + } + + #[test] + fn test_padding_config_3d_valid() { + let pads = vec![0, 0, 0, 0, 0, 0]; + let config = padding_config_3d(&pads); + assert!(matches!(config, PaddingConfig3d::Valid)); + } + + #[test] + fn test_padding_config_3d_explicit() { + let pads = vec![2, 3, 1, 2, 3, 1]; + let config = padding_config_3d(&pads); + assert!(matches!(config, PaddingConfig3d::Explicit(2, 3, 1))); + } + + #[test] + #[should_panic(expected = "Asymmetric padding is not supported")] + fn test_padding_config_3d_asymmetric() { + let pads = vec![2, 3, 1, 3, 3, 1]; + let _ = padding_config_3d(&pads); + } + + #[test] + #[should_panic(expected = "Negative pad values are not supported")] + fn test_padding_config_3d_negative() { + let pads = vec![-1, -1, -1, -1, -1, -1]; + let _ = padding_config_3d(&pads); + } +} diff --git a/crates/onnx-ir/src/node/random.rs b/crates/onnx-ir/src/node/random.rs new file mode 100644 index 0000000000..c441cea478 --- /dev/null +++ b/crates/onnx-ir/src/node/random.rs @@ -0,0 +1,98 @@ +use crate::ir::{ArgType, ElementType, Node, TensorType}; +use crate::protos::tensor_proto::DataType; +use protobuf::Enum; + +/// Update output rank for Random operations with explicit shape attribute. +pub fn random_update_output(node: &mut Node) { + log::debug!("Random rank inference for node {}", node.name); + + let dtype = node + .attrs + .get("dtype") + .map(|val| DataType::from_i32(val.clone().into_i32()).unwrap()) + .unwrap_or(DataType::FLOAT); + log::debug!("Random dtype for {}: {:?}", node.name, dtype); + + let shape = node + .attrs + .get("shape") + .expect("required shape attribute missing") + .clone() + .into_i64s(); + log::debug!("Random shape for {}: {:?}", node.name, shape); + + let elem_type = match dtype { + DataType::FLOAT => ElementType::Float32, + DataType::DOUBLE => ElementType::Float64, + _ => panic!("tensor with type {dtype:?} not supported for random output"), + }; + + let rank = shape.len(); + log::debug!("Random output rank for {}: {}", node.name, rank); + + node.outputs[0].ty = ArgType::Tensor(TensorType { + elem_type, + rank, + static_shape: None, + }); +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::NodeType; + use crate::node::test_utils::NodeBuilder; + use crate::protos::tensor_proto::DataType; + + fn create_test_node(dtype: i32, shape: Vec) -> Node { + NodeBuilder::new(NodeType::RandomNormal, "test_random") + .output_tensor_f32("output", 0, None) // Rank 0 will be updated + .attr_int("dtype", dtype as i64) + .attr_ints("shape", shape) + .build() + } + + #[test] + fn test_random_normal_float() { + let mut node = create_test_node(DataType::FLOAT.value(), vec![2, 3, 4]); + random_update_output(&mut node); + + match &node.outputs[0].ty { + ArgType::Tensor(tensor) => { + assert_eq!(tensor.elem_type, ElementType::Float32); + assert_eq!(tensor.rank, 3); + } + _ => panic!("Expected tensor output"), + } + } + + #[test] + fn test_random_normal_double() { + let mut node = create_test_node(DataType::DOUBLE.value(), vec![5]); + random_update_output(&mut node); + + match &node.outputs[0].ty { + ArgType::Tensor(tensor) => { + assert_eq!(tensor.elem_type, ElementType::Float64); + assert_eq!(tensor.rank, 1); + } + _ => panic!("Expected tensor output"), + } + } + + #[test] + #[should_panic(expected = "required shape attribute missing")] + fn test_random_normal_missing_shape() { + // Create node and then manually remove the shape attribute + let mut node = create_test_node(DataType::FLOAT.value(), vec![2, 3]); + node.attrs.remove("shape"); + random_update_output(&mut node); + } + + #[test] + #[should_panic(expected = "tensor with type INT32 not supported for random output")] + fn test_random_normal_unsupported_type() { + let mut node = create_test_node(DataType::INT32.value(), vec![2, 3]); + random_update_output(&mut node); + } +} diff --git a/crates/onnx-ir/src/node/random_like.rs b/crates/onnx-ir/src/node/random_like.rs new file mode 100644 index 0000000000..25a09b8d70 --- /dev/null +++ b/crates/onnx-ir/src/node/random_like.rs @@ -0,0 +1,96 @@ +use crate::ir::{ArgType, ElementType, Node, TensorType}; +use crate::protos::tensor_proto::DataType; +use protobuf::Enum; + +/// Update output rank for RandomLike operations based on input rank. +pub fn random_like_update_output(node: &mut Node) { + log::debug!("RandomLike rank inference for node {}", node.name); + + let dtype = node + .attrs + .get("dtype") + .map(|val| DataType::from_i32(val.clone().into_i32()).unwrap()) + .unwrap_or(DataType::FLOAT); + log::debug!("RandomLike dtype for {}: {:?}", node.name, dtype); + + let elem_type = match dtype { + DataType::FLOAT => ElementType::Float32, + DataType::FLOAT16 => ElementType::Float16, + DataType::DOUBLE => ElementType::Float64, + _ => panic!("Tensor with type {dtype:?} not supported for random output"), + }; + + if let ArgType::Tensor(tensor) = &node.inputs[0].ty { + log::debug!("RandomLike input rank for {}: {}", node.name, tensor.rank); + + node.outputs[0].ty = ArgType::Tensor(TensorType { + elem_type, + rank: tensor.rank, + static_shape: tensor.static_shape.clone(), + }); + + log::debug!("RandomLike output rank for {}: {}", node.name, tensor.rank); + } else { + panic!("Only tensor input is valid"); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::NodeType; + use crate::node::test_utils::NodeBuilder; + use crate::protos::tensor_proto::DataType; + + fn create_test_node(dtype: i32, input_rank: usize, static_shape: Option>) -> Node { + NodeBuilder::new(NodeType::RandomNormalLike, "test_random_like") + .input_tensor_f32("input", input_rank, static_shape) + .output_tensor_f32("output", 0, None) // Rank 0 will be updated + .attr_int("dtype", dtype as i64) + .build() + } + + #[test] + fn test_random_like_float() { + let mut node = create_test_node(DataType::FLOAT.value(), 3, None); + random_like_update_output(&mut node); + + match &node.outputs[0].ty { + ArgType::Tensor(tensor) => { + assert_eq!(tensor.elem_type, ElementType::Float32); + assert_eq!(tensor.rank, 3); + } + _ => panic!("Expected tensor output"), + } + } + + #[test] + fn test_random_like_double() { + let mut node = create_test_node(DataType::DOUBLE.value(), 2, Some(vec![5, 10])); + random_like_update_output(&mut node); + + match &node.outputs[0].ty { + ArgType::Tensor(tensor) => { + assert_eq!(tensor.elem_type, ElementType::Float64); + assert_eq!(tensor.rank, 2); + assert_eq!(tensor.static_shape, Some(vec![5, 10])); + } + _ => panic!("Expected tensor output"), + } + } + + #[test] + #[should_panic(expected = "Only tensor input is valid")] + fn test_random_like_invalid_input() { + let mut node = create_test_node(DataType::FLOAT.value(), 2, None); + node.inputs[0].ty = ArgType::Scalar(ElementType::Float32); + random_like_update_output(&mut node); + } + + #[test] + #[should_panic(expected = "Tensor with type INT32 not supported for random output")] + fn test_random_like_unsupported_type() { + let mut node = create_test_node(DataType::INT32.value(), 2, None); + random_like_update_output(&mut node); + } +} diff --git a/crates/onnx-ir/src/node/range.rs b/crates/onnx-ir/src/node/range.rs new file mode 100644 index 0000000000..8a21ca5086 --- /dev/null +++ b/crates/onnx-ir/src/node/range.rs @@ -0,0 +1,60 @@ +use crate::ir::{ArgType, ElementType, Node, TensorType}; + +/// Update output rank for Range (always rank 1). +pub fn range_update_outputs(node: &mut Node) { + log::debug!("Range rank inference for node {}", node.name); + + if node.inputs.len() != 3 { + panic!("Range: expected 3 inputs, found {}", node.inputs.len()); + } + log::debug!( + "Range operation always produces rank 1 tensor for {}", + node.name + ); + + node.outputs[0].ty = ArgType::Tensor(TensorType { + elem_type: ElementType::Int64, + rank: 1, + static_shape: None, + }); + + log::debug!("Range output rank for {}: 1", node.name); +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::NodeType; + use crate::node::test_utils::NodeBuilder; + + fn create_test_node() -> Node { + NodeBuilder::new(NodeType::Range, "test_range") + .input_scalar_i64("start") + .input_scalar_i64("limit") + .input_scalar_i64("delta") + .output_tensor_i64("output", 0, None) // Rank 0 will be updated + .build() + } + + #[test] + fn test_range_output() { + let mut node = create_test_node(); + range_update_outputs(&mut node); + + match &node.outputs[0].ty { + ArgType::Tensor(tensor) => { + assert_eq!(tensor.elem_type, ElementType::Int64); + assert_eq!(tensor.rank, 1); + } + _ => panic!("Expected tensor output"), + } + } + + #[test] + #[should_panic(expected = "Range: expected 3 inputs, found 2")] + fn test_range_missing_inputs() { + let mut node = create_test_node(); + node.inputs.pop(); + range_update_outputs(&mut node); + } +} diff --git a/crates/onnx-ir/src/node/reduce_max.rs b/crates/onnx-ir/src/node/reduce_max.rs new file mode 100644 index 0000000000..2cae7dfa99 --- /dev/null +++ b/crates/onnx-ir/src/node/reduce_max.rs @@ -0,0 +1,137 @@ +use crate::ir::{ArgType, AttributeValue, Node, TensorType}; + +/// Create a ReduceMaxConfig from the attributes of the node +pub fn reduce_max_config(node: &Node) -> Option { + let mut axes = Vec::new(); + let mut keepdims = 1; + + let tensor = match node.inputs.first().unwrap().clone().ty { + ArgType::Tensor(tensor) => tensor, + _ => panic!("Only tensor input is valid"), + }; + + // Extract the attributes + for (key, value) in node.attrs.iter() { + match key.as_str() { + "axes" => axes = value.clone().into_i64s(), + "keepdims" => keepdims = value.clone().into_i64(), + _ => {} + } + } + + if axes.len() > 1 { + panic!("ReduceMax: reducing on multiple dimensions is not supported") + } + + if axes.is_empty() && keepdims == 1 { + panic!("ReduceMax: axes must be provided with keepdims") + } + + if !axes.is_empty() && keepdims == 0 { + // Not supported in Burn + panic!("ReduceMax: the reduce operation must preserve the reduced dimension") + } + + if axes.is_empty() { + None + } else { + let mut dim = axes[0]; + + if dim < 0 { + // Accepted range is [-r, r-1] where r = rank(data) but Burn only supports positive dim + dim += tensor.rank as i64; + } + Some(dim as usize) + } +} + +/// Update output rank for ReduceMax based on axes. +pub fn reduce_max_update_outputs(node: &mut Node) { + log::debug!("ReduceMax rank inference for node {}", node.name); + + if node.inputs.len() != 1 { + panic!("ReduceMax: multiple inputs are not supported"); + } + let tensor = match &node.inputs[0].ty { + ArgType::Tensor(tensor) => tensor, + _ => panic!("Only tensor input is valid"), + }; + log::debug!("ReduceMax input rank for {}: {}", node.name, tensor.rank); + + let dim_only = match node.attrs.get("axes") { + Some(value) => match &value { + AttributeValue::Int64(_) => true, + AttributeValue::Int64s(ints) => ints.len() == 1, + _ => false, + }, + None => false, + }; + + let output_rank = if dim_only { tensor.rank } else { 1 }; + log::debug!("ReduceMax output rank for {}: {}", node.name, output_rank); + + node.outputs[0].ty = ArgType::Tensor(TensorType { + elem_type: tensor.elem_type.clone(), + rank: output_rank, + static_shape: None, + }); +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::NodeType; + use crate::node::test_utils::NodeBuilder; + + fn create_test_node(axes: Option>, keepdims: Option) -> Node { + let mut builder = NodeBuilder::new(NodeType::ReduceMax, "test_reduce_max") + .input_tensor_f32("data", 3, None) + .output_tensor_f32("reduced", 3, None); + + if let Some(axes_val) = axes { + builder = builder.attr_ints("axes", axes_val); + } + if let Some(kd) = keepdims { + builder = builder.attr_int("keepdims", kd); + } + + builder.build() + } + + #[test] + fn test_reduce_max_config_basic() { + let node = create_test_node(Some(vec![1]), Some(1)); + let dim = reduce_max_config(&node); + assert_eq!(dim, Some(1)); + } + + #[test] + fn test_reduce_max_config_negative_axis() { + let node = create_test_node(Some(vec![-2]), Some(1)); + let dim = reduce_max_config(&node); + assert_eq!(dim, Some(1)); // -2 + 3 = 1 + } + + #[test] + #[should_panic(expected = "ReduceMax: axes must be provided with keepdims")] + fn test_reduce_max_config_no_axes() { + let node = create_test_node(None, Some(1)); + let _ = reduce_max_config(&node); + } + + #[test] + #[should_panic(expected = "ReduceMax: reducing on multiple dimensions is not supported")] + fn test_reduce_max_config_multiple_axes() { + let node = create_test_node(Some(vec![0, 1]), Some(1)); + let _ = reduce_max_config(&node); + } + + #[test] + #[should_panic( + expected = "ReduceMax: the reduce operation must preserve the reduced dimension" + )] + fn test_reduce_max_config_no_keepdims() { + let node = create_test_node(Some(vec![1]), Some(0)); + let _ = reduce_max_config(&node); + } +} diff --git a/crates/onnx-ir/src/node/reduce_mean.rs b/crates/onnx-ir/src/node/reduce_mean.rs new file mode 100644 index 0000000000..81849d2504 --- /dev/null +++ b/crates/onnx-ir/src/node/reduce_mean.rs @@ -0,0 +1,136 @@ +use crate::ir::{ArgType, AttributeValue, Node, TensorType}; + +/// Create a ReduceMeanConfig from the attributes of the node +pub fn reduce_mean_config(node: &Node) -> Option { + let mut axes = Vec::new(); + let mut keepdims = 1; + + let tensor = match node.inputs.first().unwrap().clone().ty { + ArgType::Tensor(tensor) => tensor, + _ => panic!("Only tensor input is valid"), + }; + + // Extract the attributes + for (key, value) in node.attrs.iter() { + match key.as_str() { + "axes" => axes = value.clone().into_i64s(), + "keepdims" => keepdims = value.clone().into_i64(), + _ => {} + } + } + + if axes.len() > 1 { + panic!("ReduceMean: reducing on multiple dimensions is not supported") + } + + if axes.is_empty() && keepdims == 1 { + panic!("ReduceMean: axes must be provided with keepdims") + } + + if !axes.is_empty() && keepdims == 0 { + // Not supported in Burn + panic!("ReduceMean: the reduce operation must preserve the reduced dimension") + } + + if axes.is_empty() { + None + } else { + let mut dim = axes[0]; + + if dim < 0 { + // Accepted range is [-r, r-1] where r = rank(data) but Burn only supports positive dim + dim += tensor.rank as i64; + } + Some(dim as usize) + } +} + +/// Update output rank for ReduceMean based on axes. +pub fn reduce_mean_update_outputs(node: &mut Node) { + log::debug!("ReduceMean rank inference for node {}", node.name); + + if node.inputs.len() != 1 { + panic!("ReduceMean: multiple inputs are not supported"); + } + let tensor = match &node.inputs[0].ty { + ArgType::Tensor(tensor) => tensor, + _ => panic!("Only tensor input is valid"), + }; + + let dim_only = match node.attrs.get("axes") { + Some(value) => match &value { + AttributeValue::Int64(_) => true, + AttributeValue::Int64s(ints) => ints.len() == 1, + _ => false, + }, + None => false, + }; + + let output_rank = if dim_only { tensor.rank } else { 1 }; + log::debug!("ReduceMean output rank for {}: {}", node.name, output_rank); + + node.outputs[0].ty = ArgType::Tensor(TensorType { + elem_type: tensor.elem_type.clone(), + rank: output_rank, + static_shape: None, + }); +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::NodeType; + use crate::node::test_utils::NodeBuilder; + + fn create_test_node(axes: Option>, keepdims: Option) -> Node { + let mut builder = NodeBuilder::new(NodeType::ReduceMean, "test_reduce_mean") + .input_tensor_f32("data", 3, None) + .output_tensor_f32("reduced", 3, None); + + if let Some(axes_val) = axes { + builder = builder.attr_ints("axes", axes_val); + } + if let Some(kd) = keepdims { + builder = builder.attr_int("keepdims", kd); + } + + builder.build() + } + + #[test] + fn test_reduce_mean_config_basic() { + let node = create_test_node(Some(vec![1]), Some(1)); + let dim = reduce_mean_config(&node); + assert_eq!(dim, Some(1)); + } + + #[test] + fn test_reduce_mean_config_negative_axis() { + let node = create_test_node(Some(vec![-2]), Some(1)); + let dim = reduce_mean_config(&node); + assert_eq!(dim, Some(1)); // -2 + 3 = 1 + } + + #[test] + #[should_panic(expected = "ReduceMean: axes must be provided with keepdims")] + fn test_reduce_mean_config_no_axes() { + let node = create_test_node(None, Some(1)); + let _ = reduce_mean_config(&node); + } + + #[test] + #[should_panic(expected = "ReduceMean: reducing on multiple dimensions is not supported")] + fn test_reduce_mean_config_multiple_axes() { + let node = create_test_node(Some(vec![0, 1]), Some(1)); + let _ = reduce_mean_config(&node); + } + + #[test] + #[should_panic( + expected = "ReduceMean: the reduce operation must preserve the reduced dimension" + )] + fn test_reduce_mean_config_no_keepdims() { + let node = create_test_node(Some(vec![1]), Some(0)); + let _ = reduce_mean_config(&node); + } +} diff --git a/crates/onnx-ir/src/node/reduce_min.rs b/crates/onnx-ir/src/node/reduce_min.rs new file mode 100644 index 0000000000..494079454f --- /dev/null +++ b/crates/onnx-ir/src/node/reduce_min.rs @@ -0,0 +1,135 @@ +use crate::ir::{ArgType, AttributeValue, Node, TensorType}; + +/// Create a ReduceMinConfig from the attributes of the node +pub fn reduce_min_config(node: &Node) -> Option { + let mut axes = Vec::new(); + let mut keepdims = 1; + + let tensor = match node.inputs.first().unwrap().clone().ty { + ArgType::Tensor(tensor) => tensor, + _ => panic!("Only tensor input is valid"), + }; + + // Extract the attributes + for (key, value) in node.attrs.iter() { + match key.as_str() { + "axes" => axes = value.clone().into_i64s(), + "keepdims" => keepdims = value.clone().into_i64(), + _ => {} + } + } + + if axes.len() > 1 { + panic!("ReduceMin: reducing on multiple dimensions is not supported") + } + + if axes.is_empty() && keepdims == 1 { + panic!("ReduceMin: axes must be provided with keepdims") + } + + if !axes.is_empty() && keepdims == 0 { + panic!("ReduceMin: the reduce operation must preserve the reduced dimension") + } + + if axes.is_empty() { + None + } else { + let mut dim = axes[0]; + + if dim < 0 { + dim += tensor.rank as i64; + } + Some(dim as usize) + } +} + +/// Update output rank for ReduceMin based on axes. +pub fn reduce_min_update_outputs(node: &mut Node) { + log::debug!("ReduceMin rank inference for node {}", node.name); + + if node.inputs.len() != 1 { + panic!("ReduceMin: multiple inputs are not supported"); + } + let tensor = match &node.inputs[0].ty { + ArgType::Tensor(tensor) => tensor, + _ => panic!("Only tensor input is valid"), + }; + log::debug!("ReduceMin input rank for {}: {}", node.name, tensor.rank); + + let dim_only = match node.attrs.get("axes") { + Some(value) => match &value { + AttributeValue::Int64(_) => true, + AttributeValue::Int64s(ints) => ints.len() == 1, + _ => false, + }, + None => false, + }; + + let output_rank = if dim_only { tensor.rank } else { 1 }; + log::debug!("ReduceMin output rank for {}: {}", node.name, output_rank); + + node.outputs[0].ty = ArgType::Tensor(TensorType { + elem_type: tensor.elem_type.clone(), + rank: output_rank, + static_shape: None, + }); +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::NodeType; + use crate::node::test_utils::NodeBuilder; + + fn create_test_node(axes: Option>, keepdims: Option) -> Node { + let mut builder = NodeBuilder::new(NodeType::ReduceMin, "test_reduce_min") + .input_tensor_f32("data", 3, None) + .output_tensor_f32("reduced", 3, None); + + if let Some(axes_val) = axes { + builder = builder.attr_ints("axes", axes_val); + } + if let Some(kd) = keepdims { + builder = builder.attr_int("keepdims", kd); + } + + builder.build() + } + + #[test] + fn test_reduce_min_config_basic() { + let node = create_test_node(Some(vec![1]), Some(1)); + let dim = reduce_min_config(&node); + assert_eq!(dim, Some(1)); + } + + #[test] + fn test_reduce_min_config_negative_axis() { + let node = create_test_node(Some(vec![-2]), Some(1)); + let dim = reduce_min_config(&node); + assert_eq!(dim, Some(1)); // -2 + 3 = 1 + } + + #[test] + #[should_panic(expected = "ReduceMin: axes must be provided with keepdims")] + fn test_reduce_min_config_no_axes() { + let node = create_test_node(None, Some(1)); + let _ = reduce_min_config(&node); + } + + #[test] + #[should_panic(expected = "ReduceMin: reducing on multiple dimensions is not supported")] + fn test_reduce_min_config_multiple_axes() { + let node = create_test_node(Some(vec![0, 1]), Some(1)); + let _ = reduce_min_config(&node); + } + + #[test] + #[should_panic( + expected = "ReduceMin: the reduce operation must preserve the reduced dimension" + )] + fn test_reduce_min_config_no_keepdims() { + let node = create_test_node(Some(vec![1]), Some(0)); + let _ = reduce_min_config(&node); + } +} diff --git a/crates/onnx-ir/src/node/reduce_prod.rs b/crates/onnx-ir/src/node/reduce_prod.rs new file mode 100644 index 0000000000..62196a7b4b --- /dev/null +++ b/crates/onnx-ir/src/node/reduce_prod.rs @@ -0,0 +1,138 @@ +use crate::ir::{ArgType, AttributeValue, Node, TensorType}; + +/// Create a ReduceProdConfig from the attributes of the node +pub fn reduce_prod_config(node: &Node) -> Option { + let mut axes = Vec::new(); + let mut keepdims = 1; + + let tensor = match node.inputs.first().unwrap().clone().ty { + ArgType::Tensor(tensor) => tensor, + _ => panic!("Only tensor input is valid"), + }; + + // Extract the attributes + for (key, value) in node.attrs.iter() { + match key.as_str() { + "axes" => axes = value.clone().into_i64s(), + "keepdims" => keepdims = value.clone().into_i64(), + // TODO: handle noop_with_empty_axes (opset 18) + _ => {} + } + } + + if axes.len() > 1 { + panic!("ReduceProd: reducing on multiple dimensions is not supported") + } + + if axes.is_empty() && keepdims == 1 { + panic!("ReduceProd: axes must be provided with keepdims") + } + + if !axes.is_empty() && keepdims == 0 { + // Not supported in Burn + panic!("ReduceProd: the reduce operation must preserve the reduced dimension") + } + + if axes.is_empty() { + None + } else { + let mut dim = axes[0]; + + if dim < 0 { + // Accepted range is [-r, r-1] where r = rank(data) but Burn only supports positive dim + dim += tensor.rank as i64; + } + Some(dim as usize) + } +} + +/// Update output rank for ReduceProd based on axes. +pub fn reduce_prod_update_outputs(node: &mut Node) { + log::debug!("ReduceProd rank inference for node {}", node.name); + + if node.inputs.len() != 1 { + panic!("ReduceProd: multiple inputs are not supported"); + } + let tensor = match &node.inputs[0].ty { + ArgType::Tensor(tensor) => tensor, + _ => panic!("Only tensor input is valid"), + }; + log::debug!("ReduceProd input rank for {}: {}", node.name, tensor.rank); + + let dim_only = match node.attrs.get("axes") { + Some(value) => match &value { + AttributeValue::Int64(_) => true, + AttributeValue::Int64s(ints) => ints.len() == 1, + _ => false, + }, + None => false, + }; + + let output_rank = if dim_only { tensor.rank } else { 1 }; + log::debug!("ReduceProd output rank for {}: {}", node.name, output_rank); + + node.outputs[0].ty = ArgType::Tensor(TensorType { + elem_type: tensor.elem_type.clone(), + rank: output_rank, + static_shape: None, + }); +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::NodeType; + use crate::node::test_utils::NodeBuilder; + + fn create_test_node(axes: Option>, keepdims: Option) -> Node { + let mut builder = NodeBuilder::new(NodeType::ReduceProd, "test_reduce_prod") + .input_tensor_f32("data", 3, None) + .output_tensor_f32("reduced", 3, None); + + if let Some(axes_val) = axes { + builder = builder.attr_ints("axes", axes_val); + } + if let Some(kd) = keepdims { + builder = builder.attr_int("keepdims", kd); + } + + builder.build() + } + + #[test] + fn test_reduce_prod_config_basic() { + let node = create_test_node(Some(vec![1]), Some(1)); + let dim = reduce_prod_config(&node); + assert_eq!(dim, Some(1)); + } + + #[test] + fn test_reduce_prod_config_negative_axis() { + let node = create_test_node(Some(vec![-2]), Some(1)); + let dim = reduce_prod_config(&node); + assert_eq!(dim, Some(1)); // -2 + 3 = 1 + } + + #[test] + #[should_panic(expected = "ReduceProd: axes must be provided with keepdims")] + fn test_reduce_prod_config_no_axes() { + let node = create_test_node(None, Some(1)); + let _ = reduce_prod_config(&node); + } + + #[test] + #[should_panic(expected = "ReduceProd: reducing on multiple dimensions is not supported")] + fn test_reduce_prod_config_multiple_axes() { + let node = create_test_node(Some(vec![0, 1]), Some(1)); + let _ = reduce_prod_config(&node); + } + + #[test] + #[should_panic( + expected = "ReduceProd: the reduce operation must preserve the reduced dimension" + )] + fn test_reduce_prod_config_no_keepdims() { + let node = create_test_node(Some(vec![1]), Some(0)); + let _ = reduce_prod_config(&node); + } +} diff --git a/crates/onnx-ir/src/node/reduce_sum.rs b/crates/onnx-ir/src/node/reduce_sum.rs new file mode 100644 index 0000000000..6b2fae1283 --- /dev/null +++ b/crates/onnx-ir/src/node/reduce_sum.rs @@ -0,0 +1,170 @@ +use crate::ir::{ArgType, AttributeValue, Data, Node, TensorType}; + +/// Create a ReduceSumConfig from the attributes of the node +pub fn reduce_sum_config(node: &Node) -> Option { + let mut axes = Vec::new(); + let mut keepdims = 1; + + let tensor = match node.inputs.first().unwrap().clone().ty { + ArgType::Tensor(tensor) => tensor, + _ => panic!("Only tensor input is valid"), + }; + + // Extract the attributes + for (key, value) in node.attrs.iter() { + match key.as_str() { + "keepdims" => keepdims = value.clone().into_i64(), + "axes" => axes = value.clone().into_i64s(), + // TODO: handle noop_with_empty_axes + _ => {} + } + } + + // Process axes from additional input (if available) + if let Some(value) = node + .inputs + .get(1) + .and_then(|argument| argument.value.as_ref()) + { + axes = value.clone().data.into_i64s(); + } + + if axes.len() > 1 { + panic!("ReduceSum: reducing on multiple dimensions is not supported") + } + + if axes.is_empty() && keepdims == 1 { + panic!("ReduceSum: axes must be provided with keepdims") + } + + if !axes.is_empty() && keepdims == 0 { + // Not supported in Burn + panic!("ReduceSum: the reduce operation must preserve the reduced dimension") + } + + if axes.is_empty() { + None + } else { + let mut dim = axes[0]; + + if dim < 0 { + // Accepted range is [-r, r-1] where r = rank(data) but Burn only supports positive dim + dim += tensor.rank as i64; + } + Some(dim as usize) + } +} + +/// Update output rank for ReduceSum based on axes. +pub fn reduce_sum_update_outputs(node: &mut Node) { + log::debug!("ReduceSum rank inference for node {}", node.name); + + let tensor = match &node.inputs[0].ty { + ArgType::Tensor(tensor) => tensor, + _ => panic!("Only tensor input is valid"), + }; + log::debug!("ReduceSum input rank for {}: {}", node.name, tensor.rank); + + let dim_only = match node.attrs.get("axes") { + Some(value) => match &value { + AttributeValue::Int64(_) => true, + AttributeValue::Int64s(ints) => ints.len() == 1, + _ => false, + }, + None => false, + } || match node.inputs.get(1).and_then(|arg| arg.value.as_ref()) { + Some(value) => match &value.data { + Data::Int64(_) => true, + Data::Int64s(ints) => ints.len() == 1, + _ => false, + }, + None => false, + }; + + let output_rank = if dim_only { tensor.rank } else { 1 }; + log::debug!("ReduceSum output rank for {}: {}", node.name, output_rank); + + node.outputs[0].ty = ArgType::Tensor(TensorType { + elem_type: tensor.elem_type.clone(), + rank: output_rank, + static_shape: None, + }); +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::NodeType; + use crate::node::test_utils::NodeBuilder; + + fn create_test_node( + axes: Option>, + keepdims: Option, + with_axes_input: bool, + ) -> Node { + let mut builder = NodeBuilder::new(NodeType::ReduceSum, "test_reduce_sum") + .input_tensor_f32("data", 3, None) + .output_tensor_f32("reduced", 3, None); + + // Add axes input if requested + if with_axes_input && axes.is_some() { + let axes_vec = axes.clone().unwrap(); + builder = builder.input_tensor_i64_data("axes", axes_vec.clone(), vec![axes_vec.len()]); + } + + // Add attributes + if !with_axes_input && axes.is_some() { + builder = builder.attr_ints("axes", axes.clone().unwrap()); + } + + if let Some(kd) = keepdims { + builder = builder.attr_int("keepdims", kd); + } + + builder.build() + } + + #[test] + fn test_reduce_sum_config_basic() { + let node = create_test_node(Some(vec![1]), Some(1), false); + let dim = reduce_sum_config(&node); + assert_eq!(dim, Some(1)); + } + + #[test] + fn test_reduce_sum_config_with_input_axes() { + let node = create_test_node(Some(vec![1]), Some(1), true); + let dim = reduce_sum_config(&node); + assert_eq!(dim, Some(1)); + } + + #[test] + fn test_reduce_sum_config_negative_axis() { + let node = create_test_node(Some(vec![-2]), Some(1), false); + let dim = reduce_sum_config(&node); + assert_eq!(dim, Some(1)); // -2 + 3 = 1 + } + + #[test] + #[should_panic(expected = "ReduceSum: axes must be provided with keepdims")] + fn test_reduce_sum_config_no_axes() { + let node = create_test_node(None, Some(1), false); + let _ = reduce_sum_config(&node); + } + + #[test] + #[should_panic(expected = "ReduceSum: reducing on multiple dimensions is not supported")] + fn test_reduce_sum_config_multiple_axes() { + let node = create_test_node(Some(vec![0, 1]), Some(1), false); + let _ = reduce_sum_config(&node); + } + + #[test] + #[should_panic( + expected = "ReduceSum: the reduce operation must preserve the reduced dimension" + )] + fn test_reduce_sum_config_no_keepdims() { + let node = create_test_node(Some(vec![1]), Some(0), false); + let _ = reduce_sum_config(&node); + } +} diff --git a/crates/onnx-ir/src/node/reshape.rs b/crates/onnx-ir/src/node/reshape.rs new file mode 100644 index 0000000000..b48b4176b4 --- /dev/null +++ b/crates/onnx-ir/src/node/reshape.rs @@ -0,0 +1,135 @@ +use crate::ir::{ArgType, Data, Node, TensorData, TensorType}; + +/// Update output rank for Reshape based on shape input if constant, otherwise use input rank. +pub fn reshape_update_outputs(node: &mut Node) { + log::debug!("Reshape rank inference for node {}", node.name); + + let shape = if node.inputs.len() == 2 { + log::debug!("Reshape node {} has shape as second input", node.name); + match &node.inputs[1].value { + Some(value) => match &value.data { + Data::Int64s(shape) => { + log::debug!("Reshape node {} has constant shape: {:?}", node.name, shape); + Some(shape.clone()) + } + _ => panic!("Reshape: invalid input types"), + }, + None => { + log::debug!( + "Reshape node {} has dynamic shape as second input", + node.name + ); + None + } + } + } else { + log::debug!("Reshape node {} using shape from attributes", node.name); + node.attrs.get("shape").cloned().map(|v| { + let shape = v.into_i64s(); + log::debug!("Reshape node {} shape attribute: {:?}", node.name, shape); + shape + }) + }; + + let output = match &node.outputs[0].ty { + ArgType::Tensor(tensor) => tensor.clone(), + _ => panic!("Reshape: invalid output types"), + }; + + let rank = match &shape { + Some(s) => s.len(), + None => output.rank, + }; + + log::debug!("Reshape output rank for node {}: {}", node.name, rank); + + node.outputs[0].ty = ArgType::Tensor(TensorType { + rank, + static_shape: None, + ..output + }); +} + +pub fn reshape_config(node: &Node) -> Vec { + let mut allowzero = 0; + + for (key, value) in node.attrs.iter() { + match key.as_str() { + "allowzero" => allowzero = value.clone().into_i64(), + "shape" => {} // This can be used when shape is not provided as input - handled elsewhere + _ => panic!("Unexpected attribute for Reshape: {key}"), + } + } + + // Burn does not support zero size shape (0 means false in ONNX) + // (see https://onnx.ai/onnx/operators/onnx__Reshape.html#attributes) + if allowzero != 0 { + panic!("Zero shape size is not supported"); + } + + // TODO: check "shape" attribute + if node.inputs.len() != 2 || node.inputs[1].value.is_none() { + panic!("Reshape: shape tensor must be present for {:?}", node); + } + + match &node.inputs[1].value { + Some(TensorData { data, shape, .. }) => { + assert_eq!(shape.len(), 1, "Reshape: shape tensor must be 1D"); + data.clone().into_i64s() + } + _ => panic!("Only tensor input is valid for shape"), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::NodeType; + use crate::node::test_utils::NodeBuilder; + + fn create_test_node(allowzero: i64, shape_vec: Vec) -> Node { + let mut builder = NodeBuilder::new(NodeType::Reshape, "test_reshape") + .input_tensor_f32("data", 4, None) + .input_tensor_i64_data("shape", shape_vec.clone(), vec![shape_vec.len()]) + .output_tensor_f32("reshaped", 2, None); + + if allowzero != 0 { + builder = builder.attr_int("allowzero", allowzero); + } + + builder.build() + } + + #[test] + fn test_reshape_config_basic() { + let node = create_test_node(0, vec![2, 3]); + let shape = reshape_config(&node); + assert_eq!(shape, vec![2, 3]); + } + + #[test] + #[should_panic(expected = "Zero shape size is not supported")] + fn test_reshape_config_allowzero_not_supported() { + let node = create_test_node(1, vec![2, 3]); + let _ = reshape_config(&node); + } + + #[test] + #[should_panic(expected = "shape tensor must be present")] + fn test_reshape_config_no_shape_input() { + let mut node = create_test_node(0, vec![2, 3]); + node.inputs.pop(); // Remove the shape input + let _ = reshape_config(&node); + } + + #[test] + #[should_panic(expected = "shape tensor must be 1D")] + fn test_reshape_config_invalid_shape_dim() { + let mut node = create_test_node(0, vec![2, 3]); + // Modify the shape tensor's shape to be 2D + if let Some(tensor_data) = &mut node.inputs[1].value { + tensor_data.shape = vec![2, 1]; + } + let _ = reshape_config(&node); + } +} diff --git a/crates/onnx-ir/src/node/resize.rs b/crates/onnx-ir/src/node/resize.rs new file mode 100644 index 0000000000..4d43e58824 --- /dev/null +++ b/crates/onnx-ir/src/node/resize.rs @@ -0,0 +1,232 @@ +use crate::ir::{ArgType, Node, TensorData}; + +pub fn resize_config(node: &Node) -> (String, Vec, Vec) { + let mut mode: String = "".to_string(); + + let mut scales: Vec; + let mut sizes: Vec; + + let input = if let ArgType::Tensor(tensor) = &node + .inputs + .first() + .expect("Resize: Input tensor must be present") + .ty + { + tensor + } else { + panic!("Resize: input must be a tensor") + }; + + // Note: we are ignoring some attributes because results are approximately the same + // and we are not supporting all the attributes of the Resize operator. + // However, some attributes are important to be checked and we are checking + // against the default values of the attributes. + // TODO revisit this when we have more Resize operators in the model + for (key, value) in node.attrs.iter() { + match key.as_str() { + "antialias" => assert_eq!( + value.clone().into_i32(), + 0, + "Resize: antialias other than 0 is not supported" + ), + "axes" => panic!("Resize: custom axes attribute is not supported"), + "coordinate_transformation_mode" => { + log::warn!("Resize: coordinate_transformation_mode is ignored") + } + + "cubic_coeff_a" => log::warn!("Resize: cubic_coeff_a is ignored"), + "exclude_outside" => assert_eq!( + value.clone().into_i32(), + 0, + "Resize: exclude_outside other than 0 is not supported" + ), + "extrapolation_value" => assert_eq!( + value.clone().into_f32(), + 0.0, + "Resize: extrapolation_value other than 0.0 is not supported" + ), + "keep_aspect_ratio_policy" => { + assert_eq!( + value.clone().into_string().to_lowercase(), + "stretch", + "Resize: keep_aspect_ratio_policy other than 'stretch' is not supported" + ) + } + "mode" => mode = value.clone().into_string().to_lowercase(), + "nearest_mode" => log::warn!("Resize: nearest_mode is ignored"), + + _ => {} + } + } + + let roi: Vec = node + .inputs + .get(1) + .map(|input| { + if let Some(TensorData { data, .. }) = &input.value { + data.clone().into_f32s() + } else { + vec![] + } + }) + .unwrap_or_default(); + + scales = node + .inputs + .get(2) + .map(|input| { + if let Some(TensorData { data, .. }) = &input.value { + data.clone().into_f32s() + } else { + vec![] + } + }) + .unwrap_or_default(); + + sizes = node + .inputs + .get(3) + .map(|input| { + if let Some(TensorData { data, .. }) = &input.value { + data.clone() + .into_i64s() + .iter() + .map(|&x| x as usize) + .collect() + } else { + vec![] + } + }) + .unwrap_or_default(); + + if mode.is_empty() { + panic!("Resize: mode attribute is required") + } + + if !roi.is_empty() { + panic!("Resize: roi input is not supported") + } + + if scales.is_empty() && sizes.is_empty() { + panic!("Resize: either scales or sizes input is required") + } + + if !scales.is_empty() { + assert!(scales.len() == input.rank); + // ignore the fist two items from scales + // because they are the batch and channel dimensions + scales = scales.iter().skip(2).cloned().collect(); + } + + if !sizes.is_empty() { + assert!(sizes.len() == input.rank); + // ignore the fist two items from sizes + // because they are the batch and channel dimensions + sizes = sizes.iter().skip(2).cloned().collect(); + } + + (mode, scales, sizes) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::NodeType; + use crate::node::test_utils::NodeBuilder; + + fn create_test_node( + mode: &str, + scales: Option>, + sizes: Option>, + roi: Option>, + ) -> Node { + let mut builder = NodeBuilder::new(NodeType::Resize, "test_resize") + .input_tensor_f32("X", 4, None) // N,C,H,W format + .output_tensor_f32("Y", 4, None) + .attr_string("mode", mode); + + // Add ROI input if provided + if let Some(roi_data) = roi { + builder = builder.input_tensor_f32_data("roi", roi_data.clone(), vec![8]); + // For 4D input (start x, start y, end x, end y) + } else { + // Empty ROI still needs to be added as a placeholder + builder = builder.input_tensor_f32("roi", 1, None); + } + + // Add scales input if provided + if let Some(scales_data) = scales { + builder = builder.input_tensor_f32_data("scales", scales_data.clone(), vec![4]); + // N,C,H,W scales + } else { + // Empty scales still needs to be added as a placeholder + builder = builder.input_tensor_f32("scales", 1, None); + } + + // Add sizes input if provided + if let Some(sizes_data) = sizes { + builder = builder.input_tensor_i64_data("sizes", sizes_data.clone(), vec![4]); + // N,C,H,W sizes + } else { + // Empty sizes still needs to be added as a placeholder + builder = builder.input_tensor_i64("sizes", 1, None); + } + + builder.build() + } + + #[test] + fn test_resize_config_with_scales() { + let node = create_test_node( + "nearest", + Some(vec![1.0, 1.0, 2.0, 2.0]), // Keep N,C same, double H,W + None, + None, + ); + let (mode, scales, sizes) = resize_config(&node); + assert_eq!(mode, "nearest"); + assert_eq!(scales, vec![2.0, 2.0]); // Only the spatial scales (H,W) + assert!(sizes.is_empty()); + } + + #[test] + fn test_resize_config_with_sizes() { + let node = create_test_node( + "linear", + None, + Some(vec![1, 3, 224, 224]), // Fixed output size + None, + ); + let (mode, scales, sizes) = resize_config(&node); + assert_eq!(mode, "linear"); + assert!(scales.is_empty()); + assert_eq!(sizes, vec![224, 224]); // Only the spatial sizes (H,W) + } + + #[test] + #[should_panic(expected = "Resize: roi input is not supported")] + fn test_resize_config_with_roi() { + let node = create_test_node( + "nearest", + Some(vec![1.0, 1.0, 2.0, 2.0]), + None, + Some(vec![0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0]), // ROI values + ); + let _ = resize_config(&node); + } + + #[test] + #[should_panic(expected = "Resize: either scales or sizes input is required")] + fn test_resize_config_no_scales_or_sizes() { + let node = create_test_node("nearest", None, None, None); + let _ = resize_config(&node); + } + + #[test] + #[should_panic(expected = "Resize: mode attribute is required")] + fn test_resize_config_no_mode() { + let mut node = create_test_node("nearest", Some(vec![1.0, 1.0, 2.0, 2.0]), None, None); + node.attrs.clear(); // Remove all attributes including mode + let _ = resize_config(&node); + } +} diff --git a/crates/onnx-ir/src/node/shape.rs b/crates/onnx-ir/src/node/shape.rs new file mode 100644 index 0000000000..1d4a9770e1 --- /dev/null +++ b/crates/onnx-ir/src/node/shape.rs @@ -0,0 +1,137 @@ +use crate::ir::{ArgType, Node}; + +pub fn shape_config(curr: &Node) -> (usize, usize) { + if curr.inputs.len() != 1 { + panic!( + "Shape: multiple inputs are not supported (got {:?})", + curr.inputs.len() + ); + } + + // Extract the shape of the input tensor + let tensor = match curr.inputs.first().unwrap().clone().ty { + ArgType::Tensor(tensor) => tensor, + _ => panic!("Only tensor input is valid"), + }; + + // Default: all axes up to the last one (included) + let mut start_dim: i64 = 0; + let mut end_dim: i64 = tensor.rank as i64; + + // Extract the attributes + for (key, value) in curr.attrs.iter() { + match key.as_str() { + "start" => start_dim = value.clone().into_i64(), + "end" => end_dim = value.clone().into_i64(), + _ => {} + } + } + + // If dim is negative, it is counted from the end + if start_dim < 0 { + start_dim += tensor.rank as i64; + } + if end_dim < 0 { + end_dim += tensor.rank as i64; + } + + (start_dim as usize, end_dim as usize) +} + +/// Update output type for Shape operation (rank 1). +pub fn shape_update_outputs(node: &mut Node) { + if node.inputs.len() != 1 { + panic!("Shape: multiple inputs are not supported: {:?}", node); + } + let (start, end) = shape_config(node); + let dim = end - start; + log::debug!( + "Shape operation for node {}: start={}, end={}, dim={}", + node.name, + start, + end, + dim + ); + node.outputs[0].ty = ArgType::Shape(dim); +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::NodeType; + use crate::node::test_utils::NodeBuilder; + + fn create_test_node(start: Option, end: Option, rank: usize) -> Node { + let mut builder = NodeBuilder::new(NodeType::Shape, "test_shape") + .input_tensor_f32("data", rank, None) + .output_tensor_i64("shape", 1, None); + + if let Some(start_val) = start { + builder = builder.attr_int("start", start_val); + } + + if let Some(end_val) = end { + builder = builder.attr_int("end", end_val); + } + + builder.build() + } + + #[test] + fn test_shape_config_defaults() { + let node = create_test_node(None, None, 4); + let (start, end) = shape_config(&node); + assert_eq!(start, 0); + assert_eq!(end, 4); + } + + #[test] + fn test_shape_config_with_start() { + let node = create_test_node(Some(1), None, 4); + let (start, end) = shape_config(&node); + assert_eq!(start, 1); + assert_eq!(end, 4); + } + + #[test] + fn test_shape_config_with_end() { + let node = create_test_node(None, Some(3), 4); + let (start, end) = shape_config(&node); + assert_eq!(start, 0); + assert_eq!(end, 3); + } + + #[test] + fn test_shape_config_with_start_and_end() { + let node = create_test_node(Some(1), Some(3), 4); + let (start, end) = shape_config(&node); + assert_eq!(start, 1); + assert_eq!(end, 3); + } + + #[test] + fn test_shape_config_negative_dims() { + let node = create_test_node(Some(-2), Some(-1), 4); + let (start, end) = shape_config(&node); + assert_eq!(start, 2); // -2 + 4 = 2 + assert_eq!(end, 3); // -1 + 4 = 3 + } + + #[test] + #[should_panic(expected = "Shape: multiple inputs are not supported")] + fn test_shape_config_multiple_inputs() { + let mut node = create_test_node(None, None, 4); + // Add an extra input to cause the expected panic + node.inputs.push(crate::ir::Argument { + name: "extra".to_string(), + ty: crate::ir::ArgType::Tensor(crate::ir::TensorType { + elem_type: crate::ir::ElementType::Float32, + rank: 4, + static_shape: None, + }), + value: None, + passed: true, + }); + let _ = shape_config(&node); + } +} diff --git a/crates/onnx-ir/src/node/slice.rs b/crates/onnx-ir/src/node/slice.rs index 1424095be7..f1441db485 100644 --- a/crates/onnx-ir/src/node/slice.rs +++ b/crates/onnx-ir/src/node/slice.rs @@ -131,9 +131,8 @@ pub fn slice_update_output_rank(node: &mut Node) { #[cfg(test)] mod tests { - use std::collections::HashMap; - - use crate::ir::{Argument, AttributeValue, ElementType, NodeType, TensorType}; + use crate::ir::{ElementType, NodeType}; + use crate::node::test_utils::NodeBuilder; use super::*; @@ -143,151 +142,40 @@ mod tests { axes: Option>, use_attrs: bool, ) -> Node { - let mut inputs = vec![Argument { - name: "data".to_string(), - ty: crate::ir::ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: 3, - static_shape: None, - }), - value: None, - passed: true, - }]; + let mut builder = NodeBuilder::new(NodeType::Slice, "test_slice") + .input_tensor_f32("data", 3, None) + .output_default("output"); if !use_attrs { // Add inputs as tensors - inputs.push(Argument { - name: "starts".to_string(), - ty: crate::ir::ArgType::Tensor(TensorType { - elem_type: ElementType::Int64, - rank: 1, - static_shape: Some(vec![starts.len()]), - }), - value: Some(TensorData { - data: Data::Int64s(starts.clone()), - shape: vec![starts.len()], - }), - passed: true, - }); - - inputs.push(Argument { - name: "ends".to_string(), - ty: crate::ir::ArgType::Tensor(TensorType { - elem_type: ElementType::Int64, - rank: 1, - static_shape: Some(vec![ends.len()]), - }), - value: Some(TensorData { - data: Data::Int64s(ends.clone()), - shape: vec![ends.len()], - }), - passed: true, - }); + builder = builder.input_tensor_i64_data("starts", starts.clone(), vec![starts.len()]); + builder = builder.input_tensor_i64_data("ends", ends.clone(), vec![ends.len()]); - if let Some(axes_vec) = &axes { - inputs.push(Argument { - name: "axes".to_string(), - ty: crate::ir::ArgType::Tensor(TensorType { - elem_type: ElementType::Int64, - rank: 1, - static_shape: Some(vec![axes_vec.len()]), - }), - value: Some(TensorData { - data: Data::Int64s(axes_vec.clone()), - shape: vec![axes_vec.len()], - }), - passed: true, - }); + if let Some(axes_vec) = axes.clone() { + builder = + builder.input_tensor_i64_data("axes", axes_vec.clone(), vec![axes_vec.len()]); } - } + } else { + // Add attributes + builder = builder.attr_ints("starts", starts); + builder = builder.attr_ints("ends", ends); - let mut attrs = HashMap::new(); - if use_attrs { - attrs.insert("starts".to_string(), AttributeValue::Int64s(starts)); - attrs.insert("ends".to_string(), AttributeValue::Int64s(ends)); if let Some(axes_vec) = axes { - attrs.insert("axes".to_string(), AttributeValue::Int64s(axes_vec)); + builder = builder.attr_ints("axes", axes_vec); } } - Node { - node_type: NodeType::Slice, - name: "test_slice".to_string(), - inputs, - outputs: vec![Argument { - name: "output".to_string(), - ty: ArgType::default(), - value: None, - passed: true, - }], - attrs, - } + builder.build() } fn create_shape_input_node(start: i64, end: i64) -> Node { - let mut node = Node { - node_type: NodeType::Slice, - name: "test_slice_shape".to_string(), - inputs: vec![Argument { - name: "data".to_string(), - ty: ArgType::Shape(5), // 1-dimensional shape (important: matches what the tests expect) - value: None, - passed: true, - }], - outputs: vec![Argument { - name: "output".to_string(), - ty: ArgType::default(), - value: None, - passed: true, - }], - attrs: HashMap::new(), - }; - - // Add starts and ends as tensors - node.inputs.push(Argument { - name: "starts".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Int64, - rank: 1, - static_shape: Some(vec![1]), - }), - value: Some(TensorData { - data: Data::Int64s(vec![start]), - shape: vec![1], - }), - passed: true, - }); - - node.inputs.push(Argument { - name: "ends".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Int64, - rank: 1, - static_shape: Some(vec![1]), - }), - value: Some(TensorData { - data: Data::Int64s(vec![end]), - shape: vec![1], - }), - passed: true, - }); - - // Add axes tensor to specify dimension 0 - node.inputs.push(Argument { - name: "axes".to_string(), - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Int64, - rank: 1, - static_shape: Some(vec![1]), - }), - value: Some(TensorData { - data: Data::Int64s(vec![0]), - shape: vec![1], - }), - passed: true, - }); - - node + NodeBuilder::new(NodeType::Slice, "test_slice_shape") + .input_shape("data", 5) + .input_tensor_i64_data("starts", vec![start], vec![1]) + .input_tensor_i64_data("ends", vec![end], vec![1]) + .input_tensor_i64_data("axes", vec![0], vec![1]) + .output_default("output") + .build() } #[test] diff --git a/crates/onnx-ir/src/node/softmax.rs b/crates/onnx-ir/src/node/softmax.rs new file mode 100644 index 0000000000..60f018a2ae --- /dev/null +++ b/crates/onnx-ir/src/node/softmax.rs @@ -0,0 +1,79 @@ +use crate::ir::{ArgType, Node}; + +/// Create softmax config from the attributes of the node +pub fn softmax_config(node: &Node) -> usize { + // the axis is the last dimension (Default: 1 per ONNX spec) + let mut axis: i64 = -1; + + // check if the node has only one input + if node.inputs.len() != 1 { + panic!( + "Softmax: multiple inputs are not supported (got {:?})", + node.inputs.len() + ); + } + + // extract the shape of the input tensor + let tensor = match node.inputs.first().unwrap().clone().ty { + ArgType::Tensor(tensor) => tensor, + _ => panic!("Only tensor input is valid"), + }; + + // extract the attributes + for (key, value) in node.attrs.iter() { + if key.as_str() == "axis" { + axis = value.clone().into_i64() + } + } + + // if axis is negative, it is counted from the end + if axis < 0 { + axis += tensor.rank as i64; + } + + axis as usize +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::NodeType; + use crate::node::test_utils::NodeBuilder; + + fn create_test_node(axis: i64, input_rank: usize) -> Node { + NodeBuilder::new(NodeType::Softmax, "test_softmax") + .input_tensor_f32("data", input_rank, None) + .output_tensor_f32("output", input_rank, None) + .attr_int("axis", axis) + .build() + } + + #[test] + fn test_softmax_config_basic() { + let node = create_test_node(-1, 3); + let config = softmax_config(&node); + assert_eq!(config, 2); // -1 + 3 = 2 (last dimension) + } + + #[test] + fn test_softmax_config_explicit_axis() { + let node = create_test_node(1, 3); + let config = softmax_config(&node); + assert_eq!(config, 1); + } + + #[test] + #[should_panic(expected = "Softmax: multiple inputs are not supported")] + fn test_softmax_config_multiple_inputs() { + let mut node = create_test_node(1, 3); + // Add an extra input + let extra_input = NodeBuilder::new(NodeType::Identity, "temp") + .input_tensor_f32("extra", 1, None) + .build() + .inputs + .pop() + .unwrap(); + node.inputs.push(extra_input); + let _ = softmax_config(&node); + } +} diff --git a/crates/onnx-ir/src/node/split.rs b/crates/onnx-ir/src/node/split.rs new file mode 100644 index 0000000000..8f5f70af9a --- /dev/null +++ b/crates/onnx-ir/src/node/split.rs @@ -0,0 +1,399 @@ +use crate::ir::{ArgType, Node, TensorType}; + +/// Update output rank for Split (same as input). +pub fn split_update_outputs(node: &mut Node) { + log::debug!("Split rank inference for node {}", node.name); + + let tensor = match &node.inputs[0].ty { + ArgType::Tensor(tensor) => tensor, + _ => panic!("Split: Input must be a tensor"), + }; + log::debug!("Split input rank for {}: {}", node.name, tensor.rank); + log::debug!( + "Split will generate {} outputs for {}", + node.outputs.len(), + node.name + ); + + for (i, output_arg) in node.outputs.iter_mut().enumerate() { + output_arg.ty = ArgType::Tensor(TensorType { + elem_type: tensor.elem_type.clone(), + rank: tensor.rank, + static_shape: None, + }); + log::debug!("Split output {} rank for {}: {}", i, node.name, tensor.rank); + } +} + +/// Configuration for the Split operation. +#[derive(Clone, Debug)] +pub struct SplitConfig { + /// The axis along which to split the input tensor. + pub axis: usize, + /// The uniform size of each split when splitting evenly. + pub split_size: Option, + /// Custom sizes for each split when splitting unevenly. + pub split_sizes: Option>, +} + +impl SplitConfig { + pub fn new(axis: usize, split_size: Option, split_sizes: Option>) -> Self { + SplitConfig { + axis, + split_size, + split_sizes, + } + } +} + +/// Creates a SplitConfig from the node attributes and inputs. +pub fn split_config(node: &Node) -> SplitConfig { + // Initialize the axis to split along (default is 0 as per ONNX specification) + let mut axis: i64 = 0; + // Holds the uniform split size if calculated or provided + let mut split_size: Option = None; + // Holds the custom split sizes if provided as input + let mut split_sizes: Option> = None; + + // Extract the input tensor type to determine rank and shape + let tensor = match node.inputs.first().unwrap().ty { + ArgType::Tensor(ref tensor) => tensor, + _ => panic!("Split: Input must be a valid tensor"), + }; + + // Optionally store the number of outputs if provided as an attribute + let mut num_outputs: Option = None; + + // Iterate through node attributes to extract relevant values + for (key, value) in node.attrs.iter() { + match key.as_str() { + "axis" => axis = value.clone().into_i64(), + "num_outputs" => num_outputs = Some(value.clone().into_i64() as usize), + _ => {} + } + } + + // Handle the case when num_outputs is provided to calculate uniform split size + if let Some(num_outputs) = num_outputs { + if num_outputs == 0 { + panic!("Split: 'num_outputs' must be a positive value greater than zero"); + } + + let dim_size = tensor + .static_shape + .as_ref() + .expect("Split: Static shape must be known to calculate split size")[axis as usize]; + + // Calculate the split size considering any remainder for non-evenly divisible dimensions + let calculated_split_size = + dim_size / (num_outputs - (dim_size % num_outputs != 0) as usize); + + if calculated_split_size == 0 { + panic!( + "Split: Calculated split size is zero. Please ensure 'num_outputs' is valid for the dimension size" + ); + } + + // Assign the calculated split size + split_size = Some(calculated_split_size); + } + + // Adjust axis if negative to count from the end as per ONNX spec + if axis < 0 { + axis += tensor.rank as i64; + } + + // Check for custom split sizes provided as a second input + if node.inputs.len() > 1 && node.inputs[1].value.is_some() { + let sizes = node.inputs[1] + .value + .as_ref() + .unwrap() + .data + .clone() + .into_usizes(); + + if !sizes.is_empty() { + split_sizes = Some(sizes); + } + } + + // Ensure that only one of 'split_sizes' or 'num_outputs' is specified + if split_sizes.is_some() && split_size.is_some() { + panic!( + "Split: Cannot specify both 'split' input and 'num_outputs' attribute simultaneously" + ); + } + + // Infer split_size if neither custom split_sizes nor split_size is provided + if split_sizes.is_none() && split_size.is_none() { + let num_outputs = node.outputs.len(); + let dim_size = tensor + .static_shape + .as_ref() + .expect("Split: Static shape must be known to infer split size")[axis as usize]; + + // Calculate inferred split size based on number of outputs + let calculated_split_size = + dim_size / (num_outputs - (dim_size % num_outputs != 0) as usize); + + if calculated_split_size == 0 { + panic!( + "Split: Inferred split size is zero. Please ensure the number of outputs is valid for the dimension size" + ); + } + + split_size = Some(calculated_split_size); + } + + // Return the configuration for splitting operation + SplitConfig { + axis: axis as usize, + split_size, + split_sizes, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::{ArgType, AttributeValue, ElementType, NodeType}; + use crate::node::test_utils::NodeBuilder; + use std::collections::HashMap; + + fn create_test_node( + input_rank: usize, + num_outputs: usize, + static_shape: Option>, + attrs: Option>, + split_sizes_input: Option>, + ) -> Node { + // Start with input tensor + let mut builder = NodeBuilder::new(NodeType::Split, "test_split").input_tensor_f32( + "input", + input_rank, + static_shape, + ); + + // Add split sizes input if provided + if let Some(sizes) = split_sizes_input { + builder = builder.input_tensor_i64_data("split", sizes.clone(), vec![sizes.len()]); + } + + // Add output tensors + for i in 0..num_outputs { + builder = builder.output_tensor_f32( + &format!("output_{}", i), + 0, // Will be updated + None, + ); + } + + // Add attributes if provided + let mut node = builder.build(); + + if let Some(attributes) = attrs { + node.attrs = attributes; + } + + node + } + + #[test] + fn test_split_single_output() { + let mut node = create_test_node(3, 1, None, None, None); + split_update_outputs(&mut node); + + assert_eq!(node.outputs.len(), 1); + match &node.outputs[0].ty { + ArgType::Tensor(tensor) => { + assert_eq!(tensor.elem_type, ElementType::Float32); + assert_eq!(tensor.rank, 3); + } + _ => panic!("Expected tensor output"), + } + } + + #[test] + fn test_split_multiple_outputs() { + let mut node = create_test_node(4, 3, None, None, None); + split_update_outputs(&mut node); + + assert_eq!(node.outputs.len(), 3); + for output in &node.outputs { + match &output.ty { + ArgType::Tensor(tensor) => { + assert_eq!(tensor.elem_type, ElementType::Float32); + assert_eq!(tensor.rank, 4); + } + _ => panic!("Expected tensor output"), + } + } + } + + #[test] + #[should_panic(expected = "Split: Input must be a tensor")] + fn test_split_invalid_input() { + let mut node = create_test_node(3, 2, None, None, None); + node.inputs[0].ty = ArgType::Scalar(ElementType::Float32); + split_update_outputs(&mut node); + } + + // Tests for split_config function + + #[test] + fn test_split_config_default_axis() { + // Create a node with static shape and 2 outputs + let static_shape = Some(vec![10, 20, 30]); + let node = create_test_node(3, 2, static_shape, None, None); + + let config = split_config(&node); + + // Default axis should be 0, and split_size should be calculated + assert_eq!(config.axis, 0); + assert_eq!(config.split_size, Some(5)); // 10 / 2 = 5 + assert_eq!(config.split_sizes, None); + } + + #[test] + fn test_split_config_specified_axis() { + // Create a node with static shape, 2 outputs, and a specified axis + let static_shape = Some(vec![10, 20, 30]); + let mut attrs = HashMap::new(); + attrs.insert("axis".to_string(), AttributeValue::Int64(1)); // Split along axis 1 + + let node = create_test_node(3, 2, static_shape, Some(attrs), None); + + let config = split_config(&node); + + assert_eq!(config.axis, 1); + assert_eq!(config.split_size, Some(10)); // 20 / 2 = 10 + assert_eq!(config.split_sizes, None); + } + + #[test] + fn test_split_config_negative_axis() { + // Test with negative axis (should count from the end) + let static_shape = Some(vec![10, 20, 30]); + let mut attrs = HashMap::new(); + attrs.insert("axis".to_string(), AttributeValue::Int64(-1)); // Last axis (index 2) + + let node = create_test_node(3, 3, static_shape, Some(attrs), None); + + let config = split_config(&node); + + assert_eq!(config.axis, 2); // -1 should be converted to 2 + assert_eq!(config.split_size, Some(10)); // 30 / 3 = 10 + assert_eq!(config.split_sizes, None); + } + + #[test] + fn test_split_config_num_outputs_attr() { + // Test with explicitly specified num_outputs attribute + let static_shape = Some(vec![12, 24, 36]); + let mut attrs = HashMap::new(); + attrs.insert("num_outputs".to_string(), AttributeValue::Int64(4)); + + let node = create_test_node(3, 4, static_shape, Some(attrs), None); + + let config = split_config(&node); + + assert_eq!(config.axis, 0); + assert_eq!(config.split_size, Some(3)); // 12 / 4 = 3 + assert_eq!(config.split_sizes, None); + } + + #[test] + fn test_split_config_with_split_sizes_input() { + // Test with explicit split sizes provided as second input + let static_shape = Some(vec![10, 20, 30]); + let split_sizes = vec![5, 15]; // Custom split sizes along default axis + + let node = create_test_node(3, 2, static_shape, None, Some(split_sizes.clone())); + + let config = split_config(&node); + + assert_eq!(config.axis, 0); + assert_eq!(config.split_size, None); + assert_eq!(config.split_sizes, Some(vec![5, 15])); + } + + #[test] + #[should_panic( + expected = "Split: Cannot specify both 'split' input and 'num_outputs' attribute simultaneously" + )] + fn test_split_config_both_splits_and_num_outputs() { + // Test with both split sizes input and num_outputs attribute (should panic) + let static_shape = Some(vec![10, 20, 30]); + let mut attrs = HashMap::new(); + attrs.insert("num_outputs".to_string(), AttributeValue::Int64(2)); + let split_sizes = vec![3, 7]; + + let node = create_test_node(3, 2, static_shape, Some(attrs), Some(split_sizes)); + + let _ = split_config(&node); + } + + #[test] + #[should_panic(expected = "Split: 'num_outputs' must be a positive value greater than zero")] + fn test_split_config_zero_num_outputs() { + // Test with num_outputs attribute set to 0 (should panic) + let static_shape = Some(vec![10, 20, 30]); + let mut attrs = HashMap::new(); + attrs.insert("num_outputs".to_string(), AttributeValue::Int64(0)); + + let node = create_test_node(3, 0, static_shape, Some(attrs), None); + + let _ = split_config(&node); + } + + #[test] + #[should_panic(expected = "Split: Calculated split size is zero")] + fn test_split_config_invalid_num_outputs() { + // Test with num_outputs larger than the dimension size (should result in split_size = 0) + let static_shape = Some(vec![5, 10, 15]); + let mut attrs = HashMap::new(); + attrs.insert("num_outputs".to_string(), AttributeValue::Int64(10)); // Larger than dim 0 size + + let node = create_test_node(3, 10, static_shape, Some(attrs), None); + + let _ = split_config(&node); + } + + #[test] + #[should_panic(expected = "Split: Static shape must be known to calculate split size")] + fn test_split_config_no_static_shape() { + // Test with no static shape available + let mut attrs = HashMap::new(); + attrs.insert("num_outputs".to_string(), AttributeValue::Int64(2)); + + let node = create_test_node(3, 2, None, Some(attrs), None); + + let _ = split_config(&node); + } + + #[test] + #[should_panic(expected = "Split: Input must be a valid tensor")] + fn test_split_config_invalid_input_type() { + // Test with invalid input type + let mut node = create_test_node(3, 2, Some(vec![10, 20, 30]), None, None); + node.inputs[0].ty = ArgType::Scalar(ElementType::Float32); + + let _ = split_config(&node); + } + + #[test] + fn test_split_config_non_even_split() { + // Test with non-evenly divisible dimension size + let static_shape = Some(vec![11, 22, 33]); // 11 is not evenly divisible by 3 + let mut attrs = HashMap::new(); + attrs.insert("axis".to_string(), AttributeValue::Int64(0)); + + let node = create_test_node(3, 3, static_shape, Some(attrs), None); + + let config = split_config(&node); + + // 11 / (3-1) = 5, since the dimension is not evenly divisible + assert_eq!(config.split_size, Some(5)); + } +} diff --git a/crates/onnx-ir/src/node/squeeze.rs b/crates/onnx-ir/src/node/squeeze.rs new file mode 100644 index 0000000000..cdb4c9c8c4 --- /dev/null +++ b/crates/onnx-ir/src/node/squeeze.rs @@ -0,0 +1,94 @@ +use crate::ir::{ArgType, Data, Node, TensorType}; + +pub fn squeeze_config(curr: &Node) -> Vec { + let axes = curr + .attrs + .iter() + .filter_map(|(key, value)| { + if key == "axes" { + Some(value.clone().into_i64s()) + } else { + None + } + }) + .next() + .unwrap_or_else(Vec::new); + + match curr.inputs.first().unwrap().clone().ty { + ArgType::Tensor(tensor) => tensor, + _ => panic!("Only tensor input is valid"), + }; + + axes +} + +/// Update output rank for Squeeze based on axes. +pub fn squeeze_update_output(node: &mut Node) { + log::debug!("Squeeze rank inference for node {}", node.name); + + let axes = if node.inputs.len() == 2 { + match &node.inputs[1].value { + Some(value) => match &value.data { + Data::Int64s(axes) => Some(axes.clone()), + _ => panic!("Squeeze: invalid input types"), + }, + None => None, + } + } else { + node.attrs.get("axes").cloned().map(|v| v.into_i64s()) + }; + + let axes = axes.unwrap_or_else(|| panic!("Squeeze must specify an axis")); + log::debug!("Squeeze axes for {}: {:?}", node.name, axes); + + let input_rank = match &node.inputs[0].ty { + ArgType::Tensor(tensor) => tensor.rank, + ty => panic!("Squeeze: invalid input type: {:?}", ty), + }; + + log::debug!("Squeeze input rank for {}: {}", node.name, input_rank); + + let output_rank = input_rank - axes.len(); + log::debug!("Squeeze output rank for {}: {}", node.name, output_rank); + + node.outputs[0].ty = ArgType::Tensor(TensorType { + elem_type: node.inputs[0].ty.elem_type().clone(), + rank: output_rank, + static_shape: None, + }); +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::NodeType; + use crate::node::test_utils::NodeBuilder; + + fn create_test_node(axes: Option>, rank: usize) -> Node { + let output_rank = rank - (axes.as_ref().map_or(0, |a| a.len())); + + let mut builder = NodeBuilder::new(NodeType::Squeeze, "test_squeeze") + .input_tensor_f32("data", rank, None) + .output_tensor_f32("squeezed", output_rank, None); + + if let Some(axes_val) = axes { + builder = builder.attr_ints("axes", axes_val); + } + + builder.build() + } + + #[test] + fn test_squeeze_config_with_axes() { + let node = create_test_node(Some(vec![0, 2]), 4); + let axes = squeeze_config(&node); + assert_eq!(axes, vec![0, 2]); + } + + #[test] + fn test_squeeze_config_no_axes() { + let node = create_test_node(None, 4); + let axes = squeeze_config(&node); + assert!(axes.is_empty()); + } +} diff --git a/crates/onnx-ir/src/node/test_utils.rs b/crates/onnx-ir/src/node/test_utils.rs new file mode 100644 index 0000000000..d6df87f50f --- /dev/null +++ b/crates/onnx-ir/src/node/test_utils.rs @@ -0,0 +1,506 @@ +use crate::ir::{ + ArgType, Argument, AttributeValue, Data, ElementType, Node, NodeType, TensorData, TensorType, +}; +use std::collections::HashMap; + +/// Builder for creating test node instances with convenient defaults and simple API. +pub struct NodeBuilder { + node_type: NodeType, + name: String, + inputs: Vec, + outputs: Vec, + attrs: HashMap, +} + +impl NodeBuilder { + /// Create a new builder with the specified node type and name + pub fn new(node_type: NodeType, name: &str) -> Self { + Self { + node_type, + name: name.to_string(), + inputs: Vec::new(), + outputs: Vec::new(), + attrs: HashMap::new(), + } + } + + /// Add a generic input with the given name and type + /// + /// Note: Prefer using the specialized methods like `input_tensor_f32`, + /// `input_scalar_f32`, etc. for better readability and type safety. + #[doc(hidden)] + pub fn add_input(mut self, name: &str, ty: ArgType) -> Self { + self.inputs.push(Argument { + name: name.to_string(), + ty, + value: None, + passed: true, + }); + self + } + + /// Add a float32 tensor input with the given name and rank + pub fn input_tensor_f32( + self, + name: &str, + rank: usize, + static_shape: Option>, + ) -> Self { + self.add_input( + name, + ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank, + static_shape, + }), + ) + } + + /// Add a float64 tensor input with the given name and rank + pub fn input_tensor_f64( + self, + name: &str, + rank: usize, + static_shape: Option>, + ) -> Self { + self.add_input( + name, + ArgType::Tensor(TensorType { + elem_type: ElementType::Float64, + rank, + static_shape, + }), + ) + } + + /// Add an int32 tensor input with the given name and rank + pub fn input_tensor_i32( + self, + name: &str, + rank: usize, + static_shape: Option>, + ) -> Self { + self.add_input( + name, + ArgType::Tensor(TensorType { + elem_type: ElementType::Int32, + rank, + static_shape, + }), + ) + } + + /// Add an int64 tensor input with the given name and rank + pub fn input_tensor_i64( + self, + name: &str, + rank: usize, + static_shape: Option>, + ) -> Self { + self.add_input( + name, + ArgType::Tensor(TensorType { + elem_type: ElementType::Int64, + rank, + static_shape, + }), + ) + } + + /// Add a bool tensor input with the given name and rank + pub fn input_tensor_bool( + self, + name: &str, + rank: usize, + static_shape: Option>, + ) -> Self { + self.add_input( + name, + ArgType::Tensor(TensorType { + elem_type: ElementType::Bool, + rank, + static_shape, + }), + ) + } + + /// Add a float16 tensor input with the given name and rank + pub fn input_tensor_f16( + self, + name: &str, + rank: usize, + static_shape: Option>, + ) -> Self { + self.add_input( + name, + ArgType::Tensor(TensorType { + elem_type: ElementType::Float16, + rank, + static_shape, + }), + ) + } + + /// Add a string tensor input with the given name and rank + pub fn input_tensor_string( + self, + name: &str, + rank: usize, + static_shape: Option>, + ) -> Self { + self.add_input( + name, + ArgType::Tensor(TensorType { + elem_type: ElementType::String, + rank, + static_shape, + }), + ) + } + + /// Add a scalar input with the given name and element type + pub fn input_scalar(self, name: &str, elem_type: ElementType) -> Self { + self.add_input(name, ArgType::Scalar(elem_type)) + } + + /// Add a float32 scalar input with the given name + pub fn input_scalar_f32(self, name: &str) -> Self { + self.input_scalar(name, ElementType::Float32) + } + + /// Add an int64 scalar input with the given name + pub fn input_scalar_i64(self, name: &str) -> Self { + self.input_scalar(name, ElementType::Int64) + } + + /// Add a shape input with the given name and rank + pub fn input_shape(self, name: &str, rank: usize) -> Self { + self.add_input(name, ArgType::Shape(rank)) + } + + /// Add a tensor input with data value + pub fn input_tensor_with_data( + mut self, + name: &str, + elem_type: ElementType, + rank: usize, + data: Data, + shape: Vec, + ) -> Self { + let arg = Argument { + name: name.to_string(), + ty: ArgType::Tensor(TensorType { + elem_type, + rank, + static_shape: None, + }), + value: Some(TensorData { data, shape }), + passed: true, + }; + self.inputs.push(arg); + self + } + + /// Add a float32 tensor input with data values + pub fn input_tensor_f32_data(self, name: &str, data: Vec, shape: Vec) -> Self { + self.input_tensor_with_data( + name, + ElementType::Float32, + shape.len(), + Data::Float32s(data), + shape, + ) + } + + /// Add an int64 tensor input with data values + pub fn input_tensor_i64_data(self, name: &str, data: Vec, shape: Vec) -> Self { + self.input_tensor_with_data( + name, + ElementType::Int64, + shape.len(), + Data::Int64s(data), + shape, + ) + } + + /// Add a float32 scalar tensor input (rank 0) + pub fn input_scalar_tensor_f32(mut self, name: &str, value: Option) -> Self { + let arg = Argument { + name: name.to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 0, + static_shape: None, + }), + value: value.map(|val| TensorData { + data: Data::Float32(val), + shape: vec![], + }), + passed: true, + }; + self.inputs.push(arg); + self + } + + /// Add an int64 scalar tensor input (rank 0) + pub fn input_scalar_tensor_i64(mut self, name: &str, value: i64) -> Self { + let arg = Argument { + name: name.to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Int64, + rank: 0, + static_shape: None, + }), + value: Some(TensorData { + data: Data::Int64(value), + shape: vec![], + }), + passed: true, + }; + self.inputs.push(arg); + self + } + + /// Add multiple tensor inputs with the same type but different names + pub fn input_tensors_f32( + mut self, + name_prefix: &str, + count: usize, + rank: usize, + static_shape: Option>, + ) -> Self { + for i in 0..count { + self = self.input_tensor_f32( + &format!("{}_{}", name_prefix, i), + rank, + static_shape.clone(), + ); + } + self + } + + /// Add a generic output with the given name and type + /// + /// Note: Prefer using the specialized methods like `output_tensor_f32`, + /// `output_scalar_f32`, etc. for better readability and type safety. + #[doc(hidden)] + pub fn add_output(mut self, name: &str, ty: ArgType) -> Self { + self.outputs.push(Argument { + name: name.to_string(), + ty, + value: None, + passed: true, + }); + self + } + + /// Add a float32 tensor output with the given name and rank + pub fn output_tensor_f32( + self, + name: &str, + rank: usize, + static_shape: Option>, + ) -> Self { + self.add_output( + name, + ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank, + static_shape, + }), + ) + } + + /// Add a float64 tensor output with the given name and rank + pub fn output_tensor_f64( + self, + name: &str, + rank: usize, + static_shape: Option>, + ) -> Self { + self.add_output( + name, + ArgType::Tensor(TensorType { + elem_type: ElementType::Float64, + rank, + static_shape, + }), + ) + } + + /// Add an int32 tensor output with the given name and rank + pub fn output_tensor_i32( + self, + name: &str, + rank: usize, + static_shape: Option>, + ) -> Self { + self.add_output( + name, + ArgType::Tensor(TensorType { + elem_type: ElementType::Int32, + rank, + static_shape, + }), + ) + } + + /// Add an int64 tensor output with the given name and rank + pub fn output_tensor_i64( + self, + name: &str, + rank: usize, + static_shape: Option>, + ) -> Self { + self.add_output( + name, + ArgType::Tensor(TensorType { + elem_type: ElementType::Int64, + rank, + static_shape, + }), + ) + } + + /// Add a bool tensor output with the given name and rank + pub fn output_tensor_bool( + self, + name: &str, + rank: usize, + static_shape: Option>, + ) -> Self { + self.add_output( + name, + ArgType::Tensor(TensorType { + elem_type: ElementType::Bool, + rank, + static_shape, + }), + ) + } + + /// Add a float16 tensor output with the given name and rank + pub fn output_tensor_f16( + self, + name: &str, + rank: usize, + static_shape: Option>, + ) -> Self { + self.add_output( + name, + ArgType::Tensor(TensorType { + elem_type: ElementType::Float16, + rank, + static_shape, + }), + ) + } + + /// Add a string tensor output with the given name and rank + pub fn output_tensor_string( + self, + name: &str, + rank: usize, + static_shape: Option>, + ) -> Self { + self.add_output( + name, + ArgType::Tensor(TensorType { + elem_type: ElementType::String, + rank, + static_shape, + }), + ) + } + + /// Add a scalar output with the given name and element type + pub fn output_scalar(self, name: &str, elem_type: ElementType) -> Self { + self.add_output(name, ArgType::Scalar(elem_type)) + } + + /// Add a float32 scalar output with the given name + pub fn output_scalar_f32(self, name: &str) -> Self { + self.output_scalar(name, ElementType::Float32) + } + + /// Add an int64 scalar output with the given name + pub fn output_scalar_i64(self, name: &str) -> Self { + self.output_scalar(name, ElementType::Int64) + } + + /// Add a shape output with the given name and rank + pub fn output_shape(self, name: &str, rank: usize) -> Self { + self.add_output(name, ArgType::Shape(rank)) + } + + /// Add an integer attribute + pub fn attr_int(mut self, name: &str, value: i64) -> Self { + self.attrs + .insert(name.to_string(), AttributeValue::Int64(value)); + self + } + + /// Add a float attribute + pub fn attr_float(mut self, name: &str, value: f32) -> Self { + self.attrs + .insert(name.to_string(), AttributeValue::Float32(value)); + self + } + + /// Add a string attribute + pub fn attr_string(mut self, name: &str, value: &str) -> Self { + self.attrs + .insert(name.to_string(), AttributeValue::String(value.to_string())); + self + } + + /// Add an integer array attribute + pub fn attr_ints(mut self, name: &str, values: Vec) -> Self { + self.attrs + .insert(name.to_string(), AttributeValue::Int64s(values)); + self + } + + /// Add a float array attribute + pub fn attr_floats(mut self, name: &str, values: Vec) -> Self { + self.attrs + .insert(name.to_string(), AttributeValue::Float32s(values)); + self + } + + /// Add a string array attribute + pub fn attr_strings(mut self, name: &str, values: Vec) -> Self { + self.attrs + .insert(name.to_string(), AttributeValue::Strings(values)); + self + } + + /// Add a tensor attribute + pub fn attr_tensor(mut self, name: &str, tensor: TensorData) -> Self { + self.attrs + .insert(name.to_string(), AttributeValue::Tensor(tensor)); + self + } + + /// Add a default output with the given name + pub fn output_default(mut self, name: &str) -> Self { + self.outputs.push(Argument { + name: name.to_string(), + ty: ArgType::default(), + value: None, + passed: true, + }); + self + } + + /// Build the node + pub fn build(self) -> Node { + Node { + node_type: self.node_type, + name: self.name, + inputs: self.inputs, + outputs: self.outputs, + attrs: self.attrs, + } + } +} diff --git a/crates/onnx-ir/src/node/tile.rs b/crates/onnx-ir/src/node/tile.rs new file mode 100644 index 0000000000..9538fd635c --- /dev/null +++ b/crates/onnx-ir/src/node/tile.rs @@ -0,0 +1,158 @@ +use crate::{Node, TensorData}; + +/// Configuration for the Tile operation. +#[derive(Debug, Clone, PartialEq)] +pub struct TileConfig { + /// The number of times to repeat each dimension. + pub repeats: Vec, +} + +impl TileConfig { + pub fn new(repeats: Vec) -> Self { + TileConfig { repeats } + } +} + +/// Creates a TileConfig from the node attributes and inputs. +pub fn tile_config(node: &Node) -> TileConfig { + let repeat = node + .inputs + .get(1) + .map(|input| { + if let Some(TensorData { data, .. }) = &input.value { + data.clone() + .into_i64s() + .iter() + .map(|&x| x as usize) + .collect() + } else { + vec![] + } + }) + .unwrap_or_default(); + TileConfig::new(repeat) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::NodeType; + use crate::node::test_utils::NodeBuilder; + + /// Helper function to create test nodes with different repeat values + fn create_test_node(repeats: Option>, input_rank: usize) -> Node { + let mut builder = NodeBuilder::new(NodeType::Tile, "test_tile") + .input_tensor_f32("input", input_rank, None) + .output_tensor_f32("output", input_rank, None); // Same rank as input initially + + // Add repeats input if provided + if let Some(reps) = repeats { + builder = builder.input_tensor_i64_data("repeats", reps.clone(), vec![reps.len()]); + } + + builder.build() + } + + #[test] + fn test_tile_config_with_repeats() { + // Test with normal repeats values + let repeats = vec![2, 3, 4]; + let node = create_test_node(Some(repeats.clone()), 3); + + let config = tile_config(&node); + + // Should extract repeats correctly + assert_eq!(config.repeats, vec![2, 3, 4]); + } + + #[test] + fn test_tile_config_with_single_repeat() { + // Test with single repeat value + let repeats = vec![5]; + let node = create_test_node(Some(repeats.clone()), 1); + + let config = tile_config(&node); + + assert_eq!(config.repeats, vec![5]); + } + + #[test] + fn test_tile_config_with_zero_repeats() { + // Test with repeats including zeros + let repeats = vec![0, 1, 0]; + let node = create_test_node(Some(repeats.clone()), 3); + + let config = tile_config(&node); + + assert_eq!(config.repeats, vec![0, 1, 0]); + } + + #[test] + fn test_tile_config_with_large_repeats() { + // Test with large repeats values + let repeats = vec![100, 200]; + let node = create_test_node(Some(repeats.clone()), 2); + + let config = tile_config(&node); + + assert_eq!(config.repeats, vec![100, 200]); + } + + #[test] + fn test_tile_config_without_repeats_input() { + // Test when repeats input is missing + let node = create_test_node(None, 3); + + let config = tile_config(&node); + + // Should return empty repeats + assert_eq!(config.repeats, vec![]); + } + + #[test] + fn test_tile_config_with_negative_repeats() { + // Test with negative repeats values (will be converted to usize) + let repeats = vec![-1, 2, -3]; + let node = create_test_node(Some(repeats), 3); + + let config = tile_config(&node); + + // Negative values get converted to very large positive values due to usize conversion + // This is expected behavior for this function (though may cause issues elsewhere) + assert!(config.repeats[0] > 0); + assert_eq!(config.repeats[1], 2); + assert!(config.repeats[2] > 0); + } + + #[test] + fn test_tile_config_with_empty_repeats() { + // Test with empty repeats array + let repeats = vec![]; + let node = create_test_node(Some(repeats), 3); + + let config = tile_config(&node); + + assert_eq!(config.repeats, vec![]); + } + + #[test] + fn test_tile_config_with_missing_value() { + // Test with repeats input that has no value + let mut node = create_test_node(None, 3); + + // Add repeats input with no value + node.inputs.push( + NodeBuilder::new(NodeType::Identity, "temp") + .input_tensor_i64("repeats", 1, Some(vec![3])) + .build() + .inputs + .pop() + .unwrap(), + ); + + let config = tile_config(&node); + + // Should return empty repeats + assert_eq!(config.repeats, vec![]); + } +} diff --git a/crates/onnx-ir/src/node/topk.rs b/crates/onnx-ir/src/node/topk.rs new file mode 100644 index 0000000000..c4a7219156 --- /dev/null +++ b/crates/onnx-ir/src/node/topk.rs @@ -0,0 +1,309 @@ +use crate::ir::{ArgType, ElementType, Node, TensorType}; + +/// Update output rank for TopK (same as input rank). +pub fn top_k_update_output(node: &mut Node) { + log::debug!("TopK rank inference for node {}", node.name); + + let rank = match &node.inputs[0].ty { + ArgType::Tensor(tensor) => tensor.rank, + _ => panic!("TopK: invalid input type"), + }; + log::debug!("TopK input rank for {}: {}", node.name, rank); + + node.outputs[0].ty = ArgType::Tensor(TensorType { + elem_type: node.inputs[0].ty.elem_type().clone(), + rank, + static_shape: None, + }); + node.outputs[1].ty = ArgType::Tensor(TensorType { + elem_type: ElementType::Int64, + rank, + static_shape: None, + }); + + log::debug!( + "TopK output rank for {}: {} (both outputs)", + node.name, + rank + ); +} + +/// Configuration for the TopK operation. +#[derive(Debug, Clone, PartialEq)] +pub struct TopKConfig { + /// The axis along which to perform the top-k selection. + pub axis: usize, + /// The number of top elements to select. + pub k: usize, +} + +impl TopKConfig { + /// Creates a new TopKConfig. + pub fn new(axis: usize, k: usize) -> Self { + Self { axis, k } + } +} + +/// Creates a TopKConfig from the node attributes and inputs. +pub fn top_k_config(node: &Node) -> TopKConfig { + // Extract the shape of the input data tensor + let data_tensor = match node.inputs.first().unwrap().clone().ty { + ArgType::Tensor(tensor) => tensor, + _ => panic!("Only tensor input is valid"), + }; + + let k = match node.inputs.get(1) { + Some(k_tensor) => k_tensor + .clone() + .value + .expect("TopK: only constant 'k' tensor is currently supported") + .data + .into_i64s()[0], + _ => node + .attrs + .get("k") + .expect("TopK: number of top elements 'k' is missing") + .clone() + .into_i64(), + }; + + let mut axis = match node.attrs.get("axis") { + Some(axis) => axis.clone().into_i64(), + None => -1, + }; + + // If axis is negative, it is counted from the end + if axis < 0 { + axis += data_tensor.rank as i64; + } + + if let Some(largest) = node.attrs.get("largest") { + if largest.clone().into_i64() != 1 { + unimplemented!("TopK: only largest elements is supported") + } + }; + + if let Some(sorted) = node.attrs.get("sorted") { + if sorted.clone().into_i64() != 1 { + unimplemented!("TopK: only sorted elements is supported") + } + }; + + TopKConfig::new(axis as usize, k as usize) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::{AttributeValue, NodeType}; + use crate::node::test_utils::NodeBuilder; + use std::collections::HashMap; + + fn create_test_node( + input_rank: usize, + attrs: Option>, + k_input_value: Option, + ) -> Node { + let mut builder = NodeBuilder::new(NodeType::TopK, "test_topk") + .input_tensor_f32("X", input_rank, None) + .output_tensor_f32("Values", 0, None) // Rank will be updated + .output_tensor_i64("Indices", 0, None); // Rank will be updated + + // Add K input if provided + if let Some(k) = k_input_value { + builder = builder.input_tensor_i64_data("K", vec![k], vec![]); + } + + // Add attributes if provided + if let Some(attr_map) = attrs { + for (key, value) in attr_map { + match value { + AttributeValue::Int64(val) => builder = builder.attr_int(&key, val), + AttributeValue::Int64s(vals) => builder = builder.attr_ints(&key, vals), + AttributeValue::Float32(val) => builder = builder.attr_float(&key, val), + AttributeValue::Float32s(vals) => builder = builder.attr_floats(&key, vals), + AttributeValue::String(val) => builder = builder.attr_string(&key, &val), + AttributeValue::Strings(vals) => builder = builder.attr_strings(&key, vals), + _ => panic!("Unsupported attribute type"), + } + } + } + + builder.build() + } + + #[test] + fn test_topk_basic() { + let mut node = create_test_node(3, None, None); + // Add K attribute since we didn't provide K input + node.attrs.insert("k".to_string(), AttributeValue::Int64(5)); + + top_k_update_output(&mut node); + + assert_eq!(node.outputs.len(), 2); + + // Check first output (values) + match &node.outputs[0].ty { + ArgType::Tensor(tensor) => { + assert_eq!(tensor.elem_type, ElementType::Float32); + assert_eq!(tensor.rank, 3); + } + _ => panic!("Expected tensor output for values"), + } + + // Check second output (indices) + match &node.outputs[1].ty { + ArgType::Tensor(tensor) => { + assert_eq!(tensor.elem_type, ElementType::Int64); + assert_eq!(tensor.rank, 3); + } + _ => panic!("Expected tensor output for indices"), + } + } + + #[test] + #[should_panic(expected = "TopK: invalid input type")] + fn test_topk_invalid_input() { + let mut node = create_test_node(3, None, None); + node.attrs.insert("k".to_string(), AttributeValue::Int64(5)); + node.inputs[0].ty = ArgType::Scalar(ElementType::Float32); + top_k_update_output(&mut node); + } + + // Tests for top_k_config function + + #[test] + fn test_top_k_config_with_k_attribute() { + // Test when k is provided as an attribute + let mut attrs = HashMap::new(); + attrs.insert("k".to_string(), AttributeValue::Int64(10)); + let node = create_test_node(3, Some(attrs), None); + + let config = top_k_config(&node); + + // Default axis should be -1 which gets converted to rank-1 + assert_eq!(config, TopKConfig { axis: 2, k: 10 }); + } + + #[test] + fn test_top_k_config_with_k_input() { + // Test when k is provided as an input + let node = create_test_node(4, None, Some(5)); + + let config = top_k_config(&node); + + // Default axis should be -1 which gets converted to rank-1 + assert_eq!(config, TopKConfig { axis: 3, k: 5 }); + } + + #[test] + fn test_top_k_config_with_explicit_axis() { + // Test with explicitly specified axis + let mut attrs = HashMap::new(); + attrs.insert("k".to_string(), AttributeValue::Int64(3)); + attrs.insert("axis".to_string(), AttributeValue::Int64(1)); + let node = create_test_node(3, Some(attrs), None); + + let config = top_k_config(&node); + + assert_eq!(config, TopKConfig { axis: 1, k: 3 }); + } + + #[test] + fn test_top_k_config_with_negative_axis() { + // Test with negative axis (counts from the end) + let mut attrs = HashMap::new(); + attrs.insert("k".to_string(), AttributeValue::Int64(5)); + attrs.insert("axis".to_string(), AttributeValue::Int64(-2)); // Second-to-last axis + let node = create_test_node(4, Some(attrs), None); + + let config = top_k_config(&node); + + // For rank 4, axis -2 should be 2 + assert_eq!(config, TopKConfig { axis: 2, k: 5 }); + } + + #[test] + fn test_top_k_config_with_largest_attribute() { + // Test with largest attribute set to 1 (default supported behavior) + let mut attrs = HashMap::new(); + attrs.insert("k".to_string(), AttributeValue::Int64(7)); + attrs.insert("largest".to_string(), AttributeValue::Int64(1)); + let node = create_test_node(2, Some(attrs), None); + + let config = top_k_config(&node); + + assert_eq!(config, TopKConfig { axis: 1, k: 7 }); + } + + #[test] + fn test_top_k_config_with_sorted_attribute() { + // Test with sorted attribute set to 1 (default supported behavior) + let mut attrs = HashMap::new(); + attrs.insert("k".to_string(), AttributeValue::Int64(2)); + attrs.insert("sorted".to_string(), AttributeValue::Int64(1)); + let node = create_test_node(3, Some(attrs), None); + + let config = top_k_config(&node); + + assert_eq!(config, TopKConfig { axis: 2, k: 2 }); + } + + #[test] + #[should_panic(expected = "only largest elements is supported")] + fn test_top_k_config_with_largest_false() { + // Test with largest attribute set to 0 (unsupported) + let mut attrs = HashMap::new(); + attrs.insert("k".to_string(), AttributeValue::Int64(3)); + attrs.insert("largest".to_string(), AttributeValue::Int64(0)); + let node = create_test_node(2, Some(attrs), None); + + let _ = top_k_config(&node); + } + + #[test] + #[should_panic(expected = "only sorted elements is supported")] + fn test_top_k_config_with_sorted_false() { + // Test with sorted attribute set to 0 (unsupported) + let mut attrs = HashMap::new(); + attrs.insert("k".to_string(), AttributeValue::Int64(3)); + attrs.insert("sorted".to_string(), AttributeValue::Int64(0)); + let node = create_test_node(2, Some(attrs), None); + + let _ = top_k_config(&node); + } + + #[test] + #[should_panic(expected = "Only tensor input is valid")] + fn test_top_k_config_with_invalid_input_type() { + // Test with invalid input type + let mut node = create_test_node(2, None, None); + node.attrs.insert("k".to_string(), AttributeValue::Int64(3)); + node.inputs[0].ty = ArgType::Scalar(ElementType::Float32); + + let _ = top_k_config(&node); + } + + #[test] + #[should_panic(expected = "TopK: number of top elements 'k' is missing")] + fn test_top_k_config_without_k() { + // Test when k is neither provided as input nor attribute + let node = create_test_node(3, None, None); + + let _ = top_k_config(&node); + } + + #[test] + fn test_top_k_config_with_both_k_input_and_attribute() { + // Test when k is provided both as input and attribute + // Input should take precedence + let mut attrs = HashMap::new(); + attrs.insert("k".to_string(), AttributeValue::Int64(10)); + let node = create_test_node(3, Some(attrs), Some(5)); + + let config = top_k_config(&node); + + // K from input should be used (5), not from attribute (10) + assert_eq!(config, TopKConfig { axis: 2, k: 5 }); + } +} diff --git a/crates/onnx-ir/src/node/transpose.rs b/crates/onnx-ir/src/node/transpose.rs new file mode 100644 index 0000000000..fbff00c2c9 --- /dev/null +++ b/crates/onnx-ir/src/node/transpose.rs @@ -0,0 +1,76 @@ +use crate::ir::{ArgType, Node}; + +pub fn transpose_config(curr: &Node) -> Vec { + if curr.inputs.len() != 1 { + panic!( + "Transpose: multiple inputs are not supported (got {:?})", + curr.inputs.len() + ); + } + + // Extract the shape of the input tensor + let tensor = match curr.inputs.first().unwrap().clone().ty { + ArgType::Tensor(tensor) => tensor, + _ => panic!("Only tensor input is valid"), + }; + + // Default: reverse the dimensions + let mut perm = (0..tensor.rank as i64).rev().collect::>(); + + if let Some(axes) = curr.attrs.get("perm") { + perm = axes.clone().into_i64s(); + } + + perm +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::NodeType; + use crate::node::test_utils::NodeBuilder; + + fn create_test_node(perm: Option>, rank: usize) -> Node { + let mut builder = NodeBuilder::new(NodeType::Transpose, "test_transpose") + .input_tensor_f32("data", rank, None) + .output_tensor_f32("transposed", rank, None); + + if let Some(perm_val) = perm { + builder = builder.attr_ints("perm", perm_val); + } + + builder.build() + } + + #[test] + fn test_transpose_config_default() { + let node = create_test_node(None, 3); + let perm = transpose_config(&node); + assert_eq!(perm, vec![2, 1, 0]); // Default is to reverse the dimensions + } + + #[test] + fn test_transpose_config_with_perm() { + let node = create_test_node(Some(vec![0, 2, 1]), 3); + let perm = transpose_config(&node); + assert_eq!(perm, vec![0, 2, 1]); + } + + #[test] + #[should_panic(expected = "Transpose: multiple inputs are not supported")] + fn test_transpose_config_multiple_inputs() { + let mut node = create_test_node(None, 3); + // Add an extra input to cause the expected panic + node.inputs.push(crate::ir::Argument { + name: "extra".to_string(), + ty: crate::ir::ArgType::Tensor(crate::ir::TensorType { + elem_type: crate::ir::ElementType::Float32, + rank: 3, + static_shape: None, + }), + value: None, + passed: true, + }); + let _ = transpose_config(&node); + } +} diff --git a/crates/onnx-ir/src/node/trilu.rs b/crates/onnx-ir/src/node/trilu.rs new file mode 100644 index 0000000000..41a022f777 --- /dev/null +++ b/crates/onnx-ir/src/node/trilu.rs @@ -0,0 +1,195 @@ +use crate::{Data, Node, TensorData}; + +/// Configuration for the Trilu operation. +#[derive(Debug, Clone, PartialEq)] +pub struct TriluConfig { + /// Whether to return the upper triangular matrix. + pub upper: bool, + /// The diagonal offset. + pub diagonal: i64, +} + +impl TriluConfig { + /// Creates a TriluConfig from the node attributes and inputs. + pub fn new(upper: bool, diagonal: i64) -> Self { + Self { upper, diagonal } + } +} + +/// Creates a TriluConfig from the node attributes and inputs. +pub fn trilu_config(node: &Node) -> TriluConfig { + let mut upper = true; + let mut diagonal = 0; + for (key, value) in node.attrs.iter() { + if key.as_str() == "upper" { + upper = value.clone().into_i64() != 0 + } + } + // The second input of the Trilu node is the diagonal value, coming from a constant node + if let Some(diagonal_arg) = node.inputs.get(1) { + if let Some(TensorData { + data: Data::Int64(diagonal_val), + .. + }) = &diagonal_arg.value + { + diagonal = *diagonal_val; + } + } + TriluConfig::new(upper, diagonal) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::NodeType; + use crate::node::test_utils::NodeBuilder; + + /// Helper function to create test nodes for Trilu tests + fn create_test_node(upper_attr: Option, diagonal_input: Option) -> Node { + let mut builder = NodeBuilder::new(NodeType::Trilu, "test_trilu") + .input_tensor_f32("X", 2, None) // Typically a matrix + .output_tensor_f32("Y", 2, None); + + // Add diagonal input if provided + if let Some(diag) = diagonal_input { + builder = builder.input_scalar_tensor_i64("k", diag); + } + + // Add upper attribute if provided + if let Some(upper) = upper_attr { + builder = builder.attr_int("upper", upper); + } + + builder.build() + } + + #[test] + fn test_trilu_config_default() { + // Test with no attributes or inputs - should use defaults (upper=true, diagonal=0) + let node = create_test_node(None, None); + + let config = trilu_config(&node); + + assert_eq!( + config, + TriluConfig { + upper: true, + diagonal: 0 + } + ); + } + + #[test] + fn test_trilu_config_upper_true() { + // Test with upper=1 attribute + let node = create_test_node(Some(1), None); + + let config = trilu_config(&node); + + assert_eq!( + config, + TriluConfig { + upper: true, + diagonal: 0 + } + ); + } + + #[test] + fn test_trilu_config_upper_false() { + // Test with upper=0 attribute (lower triangular) + let node = create_test_node(Some(0), None); + + let config = trilu_config(&node); + + assert_eq!( + config, + TriluConfig { + upper: false, + diagonal: 0 + } + ); + } + + #[test] + fn test_trilu_config_with_diagonal() { + // Test with diagonal=2 input (offset 2 above main diagonal) + let node = create_test_node(None, Some(2)); + + let config = trilu_config(&node); + + assert_eq!( + config, + TriluConfig { + upper: true, + diagonal: 2 + } + ); + } + + #[test] + fn test_trilu_config_with_negative_diagonal() { + // Test with diagonal=-3 input (offset 3 below main diagonal) + let node = create_test_node(None, Some(-3)); + + let config = trilu_config(&node); + + assert_eq!( + config, + TriluConfig { + upper: true, + diagonal: -3 + } + ); + } + + #[test] + fn test_trilu_config_both_params() { + // Test with both upper attribute and diagonal input + let node = create_test_node(Some(0), Some(1)); + + let config = trilu_config(&node); + + assert_eq!( + config, + TriluConfig { + upper: false, + diagonal: 1 + } + ); + } + + #[test] + fn test_trilu_config_non_binary_upper() { + // Test with non-binary values for the upper attribute + // Any non-zero value should be treated as true + let node = create_test_node(Some(42), None); + + let config = trilu_config(&node); + + assert_eq!( + config, + TriluConfig { + upper: true, + diagonal: 0 + } + ); + } + + #[test] + fn test_trilu_config_negative_non_binary_upper() { + // Test with negative values for the upper attribute + // Any non-zero value should be treated as true + let node = create_test_node(Some(-5), None); + + let config = trilu_config(&node); + + assert_eq!( + config, + TriluConfig { + upper: true, + diagonal: 0 + } + ); + } +} diff --git a/crates/onnx-ir/src/node/unsqueeze.rs b/crates/onnx-ir/src/node/unsqueeze.rs new file mode 100644 index 0000000000..38f58a5700 --- /dev/null +++ b/crates/onnx-ir/src/node/unsqueeze.rs @@ -0,0 +1,303 @@ +use crate::{ + Argument, TensorData, + ir::{ArgType, Data, Node, TensorType}, +}; + +/// Update output rank for Unsqueeze based on axes. +/// Update the output tensor dimension based on the "axes" attribute or the second input +pub fn unsqueeze_update_output(node: &mut Node) { + log::debug!("Unsqueeze rank inference for node {}", node.name); + + let axes = if node.inputs.len() == 2 { + match &node.inputs[1].value { + Some(value) => match &value.data { + Data::Int64s(a) => Some(a.clone()), + _ => panic!("Unsqueeze: invalid input types"), + }, + None => None, + } + } else { + let axes = node.attrs.get("axes").cloned().map(|v| { + let axes = v.into_i64s(); + log::debug!( + "Unsqueeze axes from attribute for {}: {:?}", + node.name, + axes + ); + axes + }); + axes + }; + + let input_rank = match &node.inputs[0].ty { + ArgType::Tensor(tensor) => tensor.rank, + ArgType::Scalar(_) => { + 0 // treat scalar as 0-dim tensor + } + _ => panic!("Unsqueeze: invalid input type"), + }; + + let output_elem = match &node.outputs[0].ty { + ArgType::Tensor(_) => node.inputs[0].ty.elem_type().clone(), + ArgType::Scalar(elem_type) => elem_type.clone(), + _ => panic!("Unsqueeze: invalid output type"), + }; + + let output_rank = if let Some(axes) = axes { + input_rank + axes.len() + } else if let ArgType::Tensor(tensor) = &node.inputs[1].ty { + if let Some(static_shape) = &tensor.static_shape { + input_rank + *static_shape.first().expect("Empty shape") + } else { + panic!("Unsqueeze: should have static shape") + } + } else { + panic!("Unsqueeze: missing axes information") + }; + + node.outputs[0].ty = ArgType::Tensor(TensorType { + rank: output_rank, + static_shape: None, // shape is tracked and calculated at runtime + elem_type: output_elem, + }); + + log::debug!("Unsqueeze output rank for {}: {}", node.name, output_rank); +} + +/// Axes specification for the Unsqueeze operation. +#[derive(Debug, Clone)] +pub enum UnsqueezeConfig { + /// Static axes known at compile time. + Static(Vec), + /// Runtime axes that will be determined during execution. + Runtime(Argument), +} + +/// Creates UnsqueezeAxes configuration from the node attributes. +/// +/// Note: This function should only execute if the second input is a constant. +/// If it wasn't and the output shape was known, unsqueeze has been remapped to reshape. +pub fn unsqueeze_config(node: &Node) -> UnsqueezeConfig { + // Check if axes attribute exists + for (key, value) in node.attrs.iter() { + if key.as_str() == "axes" { + return UnsqueezeConfig::Static(value.clone().into_i64s()); + } + } + + assert!( + !node.inputs.is_empty(), + "Unsqueeze: axes tensor must be present" + ); + + let input_value = &node.inputs[1]; + + match &node.inputs[1].ty { + ArgType::Tensor(tensor) => { + assert_eq!(tensor.rank, 1, "Unsqueeze: axes tensor must be 1D"); + if let Some(TensorData { + data: Data::Int64s(shape), + .. + }) = input_value.value.as_ref() + { + UnsqueezeConfig::Static(shape.clone()) + } else { + UnsqueezeConfig::Runtime(node.inputs[1].clone()) + } + } + _ => panic!("Arg for unsqueeze must be tensor or scalar"), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::{ElementType, NodeType}; + use crate::node::test_utils::NodeBuilder; + + // Implement custom equality for UnsqueezeConfig to make testing easier + impl PartialEq for UnsqueezeConfig { + fn eq(&self, other: &UnsqueezeConfig) -> bool { + match (self, other) { + (UnsqueezeConfig::Static(a), UnsqueezeConfig::Static(b)) => a == b, + (UnsqueezeConfig::Runtime(a), UnsqueezeConfig::Runtime(b)) => a.name == b.name, + _ => false, + } + } + } + + fn create_test_node_with_attr(input_rank: usize, axes: Vec) -> Node { + let builder = NodeBuilder::new(NodeType::Unsqueeze, "test_unsqueeze") + .input_tensor_f32("X", input_rank, None) + .output_tensor_f32("Y", 0, None) // Will be updated + .attr_ints("axes", axes); + + builder.build() + } + + fn create_test_node_with_input(input_rank: usize, axes: Vec, with_value: bool) -> Node { + let axes_len = axes.len(); + let mut builder = NodeBuilder::new(NodeType::Unsqueeze, "test_unsqueeze") + .input_tensor_f32("X", input_rank, None) + .output_tensor_f32("Y", 0, None); // Will be updated + + // Add axes input with or without value + if with_value { + builder = builder.input_tensor_i64_data("axes", axes.clone(), vec![axes_len]); + } else { + // Input without value + builder = builder.input_tensor_i64("axes", 1, Some(vec![axes_len])); + } + + builder.build() + } + + // Tests for unsqueeze_update_output function + + #[test] + fn test_unsqueeze_with_attr() { + let mut node = create_test_node_with_attr(2, vec![0, 3]); + unsqueeze_update_output(&mut node); + + match &node.outputs[0].ty { + ArgType::Tensor(tensor) => { + assert_eq!(tensor.elem_type, ElementType::Float32); + assert_eq!(tensor.rank, 4); // 2 + 2 = 4 + } + _ => panic!("Expected tensor output"), + } + } + + #[test] + fn test_unsqueeze_with_input() { + let mut node = create_test_node_with_input(3, vec![1, 2, 4], true); + unsqueeze_update_output(&mut node); + + match &node.outputs[0].ty { + ArgType::Tensor(tensor) => { + assert_eq!(tensor.elem_type, ElementType::Float32); + assert_eq!(tensor.rank, 6); // 3 + 3 = 6 + } + _ => panic!("Expected tensor output"), + } + } + + #[test] + fn test_unsqueeze_scalar() { + let mut node = create_test_node_with_attr(0, vec![0]); + node.inputs[0].ty = ArgType::Scalar(ElementType::Float32); + unsqueeze_update_output(&mut node); + + match &node.outputs[0].ty { + ArgType::Tensor(tensor) => { + assert_eq!(tensor.elem_type, ElementType::Float32); + assert_eq!(tensor.rank, 1); // 0 + 1 = 1 + } + _ => panic!("Expected tensor output"), + } + } + + #[test] + #[should_panic(expected = "Unsqueeze: invalid input type")] + fn test_unsqueeze_invalid_input() { + let mut node = create_test_node_with_attr(2, vec![0]); + node.inputs[0].ty = ArgType::Shape(1); + unsqueeze_update_output(&mut node); + } + + // Tests for unsqueeze_config function + + #[test] + fn test_unsqueeze_config_with_attr() { + // Test with axes provided as attribute + let axes = vec![0, 2, 4]; + let node = create_test_node_with_attr(3, axes.clone()); + + let config = unsqueeze_config(&node); + + assert_eq!(config, UnsqueezeConfig::Static(axes)); + } + + #[test] + fn test_unsqueeze_config_with_static_input() { + // Test with axes provided as input tensor with static value + let axes = vec![1, 3]; + let node = create_test_node_with_input(2, axes.clone(), true); + + let config = unsqueeze_config(&node); + + assert_eq!(config, UnsqueezeConfig::Static(axes)); + } + + #[test] + fn test_unsqueeze_config_with_runtime_input() { + // Test with axes provided as input tensor but without static value + let axes = vec![0, 2]; + let node = create_test_node_with_input(2, axes.clone(), false); + + let config = unsqueeze_config(&node); + + // Should return a Runtime config since the axes are only known at runtime + match config { + UnsqueezeConfig::Static(_) => panic!("Expected Runtime config"), + UnsqueezeConfig::Runtime(arg) => { + assert_eq!(arg.name, "axes"); + } + } + } + + #[test] + fn test_unsqueeze_config_negative_axes() { + // Test with negative axes (should be handled by the caller) + let axes = vec![-1, -3]; + let node = create_test_node_with_attr(3, axes.clone()); + + let config = unsqueeze_config(&node); + + assert_eq!(config, UnsqueezeConfig::Static(axes)); + } + + #[test] + fn test_unsqueeze_config_empty_axes() { + // Test with empty axes array (edge case) + let axes = vec![]; + let node = create_test_node_with_attr(2, axes.clone()); + + let config = unsqueeze_config(&node); + + assert_eq!(config, UnsqueezeConfig::Static(axes)); + } + + #[test] + #[should_panic(expected = "index out of bounds")] + fn test_unsqueeze_config_missing_axes() { + // Test with neither axes attribute nor input + let mut node = create_test_node_with_attr(2, vec![0]); + node.attrs.clear(); // Remove the axes attribute + node.inputs = vec![node.inputs[0].clone()]; // Remove the axes input + + let _ = unsqueeze_config(&node); + } + + #[test] + #[should_panic(expected = "Unsqueeze: axes tensor must be 1D")] + fn test_unsqueeze_config_invalid_axes_rank() { + // Test with axes tensor that is not 1D + let mut node = create_test_node_with_input(2, vec![0, 1], true); + if let ArgType::Tensor(ref mut tensor) = node.inputs[1].ty { + tensor.rank = 2; // Invalid rank for axes + } + + let _ = unsqueeze_config(&node); + } + + #[test] + #[should_panic(expected = "Arg for unsqueeze must be tensor or scalar")] + fn test_unsqueeze_config_invalid_axes_type() { + // Test with axes input that is not a tensor + let mut node = create_test_node_with_input(2, vec![0], false); + node.inputs[1].ty = ArgType::Shape(1); // Invalid type for axes + + let _ = unsqueeze_config(&node); + } +} diff --git a/crates/onnx-ir/src/node/where_op.rs b/crates/onnx-ir/src/node/where_op.rs new file mode 100644 index 0000000000..18a0bde986 --- /dev/null +++ b/crates/onnx-ir/src/node/where_op.rs @@ -0,0 +1,126 @@ +use crate::ir::{ArgType, ElementType, Node, TensorType}; +use core::cmp::max; + +/// Update output rank for Where to max input rank. +pub fn where_update_outputs(node: &mut Node) { + log::debug!("Where rank inference for node {}", node.name); + + let condition = &node.inputs[0].ty; + let x = &node.inputs[1].ty; + let y = &node.inputs[2].ty; + let elem_type = x.elem_type().clone(); + assert_eq!( + *condition.elem_type(), + ElementType::Bool, + "Where condition must be boolean!" + ); + assert_eq!( + elem_type, + *y.elem_type(), + "Where x and y have different element types!" + ); + + log::debug!( + "Where input ranks for {}: condition={}, x={}, y={}", + node.name, + condition.rank(), + x.rank(), + y.rank() + ); + + let output_rank = max(condition.rank(), max(x.rank(), y.rank())); + log::debug!("Where output rank for {}: {}", node.name, output_rank); + + if output_rank == 0 { + node.outputs[0].ty = ArgType::Scalar(elem_type); + log::debug!("Where result for {} is scalar", node.name); + } else { + node.outputs[0].ty = ArgType::Tensor(TensorType { + elem_type, + rank: output_rank, + static_shape: None, + }); + log::debug!( + "Where result for {} is tensor with rank {}", + node.name, + output_rank + ); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::NodeType; + use crate::node::test_utils::NodeBuilder; + + fn create_test_node(condition_rank: usize, x_rank: usize, y_rank: usize) -> Node { + NodeBuilder::new(NodeType::Where, "test_where") + .input_tensor_bool("condition", condition_rank, None) + .input_tensor_f32("X", x_rank, None) + .input_tensor_f32("Y", y_rank, None) + .output_tensor_f32("output", 0, None) // Rank will be updated + .build() + } + + #[test] + fn test_where_basic() { + let mut node = create_test_node(2, 3, 2); + where_update_outputs(&mut node); + + match &node.outputs[0].ty { + ArgType::Tensor(tensor) => { + assert_eq!(tensor.elem_type, ElementType::Float32); + assert_eq!(tensor.rank, 3); // max(2, max(3, 2)) = 3 + } + _ => panic!("Expected tensor output"), + } + } + + #[test] + fn test_where_scalar_result() { + let mut node = create_test_node(0, 0, 0); + where_update_outputs(&mut node); + + match &node.outputs[0].ty { + ArgType::Scalar(elem_type) => { + assert_eq!(*elem_type, ElementType::Float32); + } + _ => panic!("Expected scalar output"), + } + } + + #[test] + #[should_panic(expected = "Where condition must be boolean!")] + fn test_where_invalid_condition() { + let mut node = create_test_node(2, 2, 2); + + // Replace condition with non-boolean tensor + let non_bool_input = NodeBuilder::new(NodeType::Identity, "temp") + .input_tensor_f32("x", 2, None) + .build() + .inputs + .pop() + .unwrap(); + + node.inputs[0] = non_bool_input; + where_update_outputs(&mut node); + } + + #[test] + #[should_panic(expected = "Where x and y have different element types!")] + fn test_where_mismatched_types() { + let mut node = create_test_node(2, 2, 2); + + // Replace Y with int64 tensor (different from X's float32) + let int64_input = NodeBuilder::new(NodeType::Identity, "temp") + .input_tensor_i64("y", 2, None) + .build() + .inputs + .pop() + .unwrap(); + + node.inputs[2] = int64_input; + where_update_outputs(&mut node); + } +} diff --git a/crates/onnx-ir/src/rank_inference.rs b/crates/onnx-ir/src/rank_inference.rs index b3a0bd2bbf..1e8adeedd8 100644 --- a/crates/onnx-ir/src/rank_inference.rs +++ b/crates/onnx-ir/src/rank_inference.rs @@ -1,13 +1,21 @@ -use core::cmp::max; -use core::panic; - -use protobuf::Enum; - use crate::{ - ir::{ArgType, AttributeValue, Data, ElementType, Node, NodeType, TensorType}, - node::slice::slice_update_output_rank, - protos::tensor_proto::DataType, - util::shape_config, + ir::{Node, NodeType}, + node::{ + argmax::argmax_update_outputs, cast::cast_update_outputs, + comparison::elementwise_comparison_outputs, concat::concat_update_outputs, + constant::constant_update_outputs, constant_of_shape::constant_of_shape_update_output, + expand::expand_update_outputs, flatten::flatten_update_outputs, + gather::gather_update_outputs, gemm::gemm_output_shape, linear::linear_update_outputs, + matmul::matmul_update_outputs, one_hot::one_hot_output_shape, random::random_update_output, + random_like::random_like_update_output, range::range_update_outputs, + reduce_max::reduce_max_update_outputs, reduce_mean::reduce_mean_update_outputs, + reduce_min::reduce_min_update_outputs, reduce_prod::reduce_prod_update_outputs, + reduce_sum::reduce_sum_update_outputs, reshape::reshape_update_outputs, + shape::shape_update_outputs, slice::slice_update_output_rank, split::split_update_outputs, + squeeze::squeeze_update_output, topk::top_k_update_output, + unsqueeze::unsqueeze_update_output, where_op::where_update_outputs, + }, + util::{same_as_input, same_as_input_broadcast, temporary_pass_through_stub}, }; /// Infer the rank of each output tensor and update them based solely on rank inference. @@ -25,8 +33,8 @@ pub fn rank_inference(node: &mut Node) { NodeType::Concat => concat_update_outputs(node), NodeType::Constant => constant_update_outputs(node), NodeType::ConstantOfShape => constant_of_shape_update_output(node), - NodeType::Conv1d => conv1d_update_outputs(node), - NodeType::Conv2d => conv2d_update_outputs(node), + NodeType::Conv1d => same_as_input(node), + NodeType::Conv2d => same_as_input(node), NodeType::Cos => same_as_input(node), NodeType::Cosh => same_as_input(node), NodeType::Div => same_as_input_broadcast(node), @@ -45,8 +53,8 @@ pub fn rank_inference(node: &mut Node) { NodeType::GreaterOrEqual => elementwise_comparison_outputs(node), NodeType::HardSigmoid => same_as_input(node), NodeType::GlobalAveragePool => same_as_input(node), - NodeType::ConvTranspose1d => conv_transpose1d_update_outputs(node), - NodeType::ConvTranspose2d => conv_transpose2d_update_outputs(node), + NodeType::ConvTranspose1d => same_as_input(node), + NodeType::ConvTranspose2d => same_as_input(node), NodeType::LayerNormalization => same_as_input(node), NodeType::LeakyRelu => same_as_input(node), NodeType::Less => elementwise_comparison_outputs(node), @@ -108,1187 +116,3 @@ pub fn rank_inference(node: &mut Node) { node.outputs ); } - -/// Update output type for constant nodes based on attribute values, focusing on rank only. -fn constant_update_outputs(node: &mut Node) { - log::debug!("Constant rank inference for node {}", node.name); - - let keys = [ - "value", - "value_float", - "value_floats", - "value_int", - "value_ints", - "value_string", - "value_strings", - "sparse_value", - ]; - - let matched_value = keys.iter().find_map(|&key| node.attrs.get(key).cloned()); - log::debug!("Constant found attribute: {}", matched_value.is_some()); - - node.outputs[0].ty = match matched_value { - Some(value) => match &value { - AttributeValue::Tensor(tensor) if tensor.shape.is_empty() => { - log::debug!("Constant as scalar for {}", node.name); - ArgType::Scalar(tensor.elem_type()) - } - AttributeValue::Tensor(tensor) => { - log::debug!( - "Constant tensor with rank {} for {}", - tensor.shape.len(), - node.name - ); - ArgType::Tensor(TensorType { - elem_type: tensor.elem_type(), - rank: tensor.shape.len(), - static_shape: None, - }) - } - AttributeValue::Float32(_) => { - log::debug!("Constant Float32 scalar for {}", node.name); - ArgType::Scalar(ElementType::Float32) - } - AttributeValue::Float32s(_) => { - log::debug!("Constant Float32s tensor with rank 1 for {}", node.name); - ArgType::Tensor(TensorType { - elem_type: ElementType::Float32, - rank: 1, - static_shape: None, - }) - } - AttributeValue::Int64(_) => { - log::debug!("Constant Int64 scalar for {}", node.name); - ArgType::Scalar(ElementType::Int64) - } - AttributeValue::Int64s(_) => { - log::debug!("Constant Int64s tensor with rank 1 for {}", node.name); - ArgType::Tensor(TensorType { - elem_type: ElementType::Int64, - rank: 1, - static_shape: None, - }) - } - ty => panic!("Constant value of {:?} is not supported", ty), - }, - None => panic!("Constant node must have a value attribute"), - }; -} - -/// Updates the output rank for a ConstantOfShape node based on the rank of its input. -fn constant_of_shape_update_output(node: &mut Node) { - log::debug!("ConstantOfShape rank inference for node {}", node.name); - - let value_type = node - .attrs - .get("value") - .map(|v| v.clone().into_tensor().elem_type()) - .unwrap_or(ElementType::Float32); // If not given, defaults to 0 as float32 - log::debug!( - "ConstantOfShape value type for {}: {:?}", - node.name, - value_type - ); - - let rank = match &node.inputs[0].ty { - ArgType::Shape(rank) => { - log::debug!( - "ConstantOfShape input is Shape with rank {} for {}", - rank, - node.name - ); - *rank - } - ArgType::Tensor(tensor_type) => { - log::debug!("ConstantOfShape input is Tensor for {}", node.name); - let r = tensor_type - .static_shape - .as_ref() - .and_then(|shape| shape.first()) - .copied() - .expect( - "ConstantOfShape node must have a Tensor with a non-empty static shape value", - ); - log::debug!( - "ConstantOfShape derived rank from tensor: {} for {}", - r, - node.name - ); - r - } - _ => panic!("ConstantOfShape node requires a Tensor or Shape type as input"), - }; - - // Update the input type to be a shape - node.inputs[0].ty = ArgType::Shape(rank); - log::debug!( - "ConstantOfShape updated input to Shape({}) for {}", - rank, - node.name - ); - - node.outputs[0].ty = ArgType::Tensor(TensorType { - elem_type: value_type, - rank, - static_shape: None, - }); - log::debug!("ConstantOfShape output rank for {}: {}", node.name, rank); -} - -/// Update output rank for Random operations with explicit shape attribute. -fn random_update_output(node: &mut Node) { - log::debug!("Random rank inference for node {}", node.name); - - let dtype = node - .attrs - .get("dtype") - .map(|val| DataType::from_i32(val.clone().into_i32()).unwrap()) - .unwrap_or(DataType::FLOAT); - log::debug!("Random dtype for {}: {:?}", node.name, dtype); - - let shape = node - .attrs - .get("shape") - .expect("required shape attribute missing") - .clone() - .into_i64s(); - log::debug!("Random shape for {}: {:?}", node.name, shape); - - let elem_type = match dtype { - DataType::FLOAT => ElementType::Float32, - DataType::DOUBLE => ElementType::Float64, - _ => panic!("tensor with type {dtype:?} not supported for random output"), - }; - - let rank = shape.len(); - log::debug!("Random output rank for {}: {}", node.name, rank); - - node.outputs[0].ty = ArgType::Tensor(TensorType { - elem_type, - rank, - static_shape: None, - }); -} - -/// Update output rank for RandomLike operations based on input rank. -fn random_like_update_output(node: &mut Node) { - log::debug!("RandomLike rank inference for node {}", node.name); - - let dtype = node - .attrs - .get("dtype") - .map(|val| DataType::from_i32(val.clone().into_i32()).unwrap()) - .unwrap_or(DataType::FLOAT); - log::debug!("RandomLike dtype for {}: {:?}", node.name, dtype); - - let elem_type = match dtype { - DataType::FLOAT => ElementType::Float32, - DataType::FLOAT16 => ElementType::Float16, - DataType::DOUBLE => ElementType::Float64, - _ => panic!("Tensor with type {dtype:?} not supported for random output"), - }; - - if let ArgType::Tensor(tensor) = &node.inputs[0].ty { - log::debug!("RandomLike input rank for {}: {}", node.name, tensor.rank); - - node.outputs[0].ty = ArgType::Tensor(TensorType { - elem_type, - rank: tensor.rank, - static_shape: tensor.static_shape.clone(), - }); - - log::debug!("RandomLike output rank for {}: {}", node.name, tensor.rank); - } else { - panic!("Only tensor input is valid"); - } -} - -/// Update output rank for Linear operations (same as input rank). -fn linear_update_outputs(node: &mut Node) { - log::debug!("Linear rank inference for node {}", node.name); - - if let ArgType::Tensor(tensor) = &node.inputs[0].ty { - log::debug!("Linear input rank for {}: {}", node.name, tensor.rank); - - node.outputs[0].ty = ArgType::Tensor(TensorType { - elem_type: tensor.elem_type.clone(), - rank: tensor.rank, - static_shape: None, - }); - - log::debug!("Linear output rank for {}: {}", node.name, tensor.rank); - } else { - panic!("Only tensor input is valid"); - } -} - -/// Update output type for Cast operations, preserving rank. -fn cast_update_outputs(node: &mut Node) { - if node.inputs.len() != 1 { - panic!("Cast: multiple inputs are not supported"); - } - let input = &mut node.inputs[0]; - let output = &mut node.outputs[0]; - - let elem_type = match node.attrs.get("to") { - Some(value) => match &value { - AttributeValue::Int64(type_id) => match DataType::from_i32(*type_id as i32).unwrap() { - DataType::FLOAT => ElementType::Float32, - DataType::INT32 => ElementType::Int32, - DataType::INT64 => ElementType::Int64, - DataType::DOUBLE => ElementType::Float64, - DataType::BOOL => ElementType::Bool, - _ => panic!("Cast: unsupported type"), - }, - _ => panic!("'to' attribute must be an Int64"), - }, - None => panic!("Cast node must have a 'to' attribute"), - }; - - match input.ty.clone() { - ArgType::Tensor(tensor) => { - if tensor.rank == 0 { - // treat 0-dim tensor as scalar - output.ty = ArgType::Scalar(elem_type); - input.ty = ArgType::Scalar(tensor.elem_type); - } else { - // Cast input and output are the same shape, but possibly different types - output.ty = ArgType::Tensor(TensorType { - elem_type, - rank: tensor.rank, - static_shape: None, - }); - } - } - ArgType::Scalar(_) => output.ty = ArgType::Scalar(elem_type), - _ => panic!("Cast: only scalar and tensor inputs are valid"), - } -} - -/// Update output rank for Concat (same as first tensor input). -fn concat_update_outputs(node: &mut Node) { - log::debug!("Concat rank inference for node {}", node.name); - - let tensor = node - .inputs - .iter() - .find_map(|input| match &input.ty { - ArgType::Tensor(tensor) => Some(tensor.clone()), - _ => None, - }) - .unwrap(); - - log::debug!("Concat using input rank for {}: {}", node.name, tensor.rank); - - node.outputs[0].ty = ArgType::Tensor(TensorType { - elem_type: tensor.elem_type, - rank: tensor.rank, - static_shape: None, - }); - - log::debug!("Concat output rank for {}: {}", node.name, tensor.rank); -} - -/// Update output rank for Reshape based on shape input if constant, otherwise use input rank. -fn reshape_update_outputs(node: &mut Node) { - log::debug!("Reshape rank inference for node {}", node.name); - - let shape = if node.inputs.len() == 2 { - log::debug!("Reshape node {} has shape as second input", node.name); - match &node.inputs[1].value { - Some(value) => match &value.data { - Data::Int64s(shape) => { - log::debug!("Reshape node {} has constant shape: {:?}", node.name, shape); - Some(shape.clone()) - } - _ => panic!("Reshape: invalid input types"), - }, - None => { - log::debug!( - "Reshape node {} has dynamic shape as second input", - node.name - ); - None - } - } - } else { - log::debug!("Reshape node {} using shape from attributes", node.name); - node.attrs.get("shape").cloned().map(|v| { - let shape = v.into_i64s(); - log::debug!("Reshape node {} shape attribute: {:?}", node.name, shape); - shape - }) - }; - - let output = match &node.outputs[0].ty { - ArgType::Tensor(tensor) => tensor.clone(), - _ => panic!("Reshape: invalid output types"), - }; - - let rank = match &shape { - Some(s) => s.len(), - None => output.rank, - }; - - log::debug!("Reshape output rank for node {}: {}", node.name, rank); - - node.outputs[0].ty = ArgType::Tensor(TensorType { - rank, - static_shape: None, - ..output - }); -} - -/// Update output rank for ReduceMean based on axes. -fn reduce_mean_update_outputs(node: &mut Node) { - log::debug!("ReduceMean rank inference for node {}", node.name); - - if node.inputs.len() != 1 { - panic!("ReduceMean: multiple inputs are not supported"); - } - let tensor = match &node.inputs[0].ty { - ArgType::Tensor(tensor) => tensor, - _ => panic!("Only tensor input is valid"), - }; - - let dim_only = match node.attrs.get("axes") { - Some(value) => match &value { - AttributeValue::Int64(_) => true, - AttributeValue::Int64s(ints) => ints.len() == 1, - _ => false, - }, - None => false, - }; - - let output_rank = if dim_only { tensor.rank } else { 1 }; - log::debug!("ReduceMean output rank for {}: {}", node.name, output_rank); - - node.outputs[0].ty = ArgType::Tensor(TensorType { - elem_type: tensor.elem_type.clone(), - rank: output_rank, - static_shape: None, - }); -} - -/// Update output rank for ArgMax (same as input rank). -fn argmax_update_outputs(node: &mut Node) { - log::debug!("ArgMax rank inference for node {}", node.name); - - if node.inputs.len() != 1 { - panic!("ArgMax: multiple inputs are not supported"); - } - let tensor = match &node.inputs[0].ty { - ArgType::Tensor(tensor) => tensor, - _ => panic!("Only tensor input is valid"), - }; - - log::debug!("ArgMax input rank for {}: {}", node.name, tensor.rank); - - // Note: argmax in burn does not support keepdims=false - node.outputs[0].ty = ArgType::Tensor(TensorType { - elem_type: ElementType::Int64, - rank: tensor.rank, - static_shape: None, - }); - - log::debug!("ArgMax output rank for {}: {}", node.name, tensor.rank); -} - -/// Update output rank for Squeeze based on axes. -fn squeeze_update_output(node: &mut Node) { - log::debug!("Squeeze rank inference for node {}", node.name); - - let axes = if node.inputs.len() == 2 { - match &node.inputs[1].value { - Some(value) => match &value.data { - Data::Int64s(axes) => Some(axes.clone()), - _ => panic!("Squeeze: invalid input types"), - }, - None => None, - } - } else { - node.attrs.get("axes").cloned().map(|v| v.into_i64s()) - }; - - let axes = axes.unwrap_or_else(|| panic!("Squeeze must specify an axis")); - log::debug!("Squeeze axes for {}: {:?}", node.name, axes); - - let input_rank = match &node.inputs[0].ty { - ArgType::Tensor(tensor) => tensor.rank, - ty => panic!("Squeeze: invalid input type: {:?}", ty), - }; - - log::debug!("Squeeze input rank for {}: {}", node.name, input_rank); - - let output_rank = input_rank - axes.len(); - log::debug!("Squeeze output rank for {}: {}", node.name, output_rank); - - node.outputs[0].ty = ArgType::Tensor(TensorType { - elem_type: node.inputs[0].ty.elem_type().clone(), - rank: output_rank, - static_shape: None, - }); -} - -/// Update output rank for broadcasting operations (e.g., Add, Sub) to max input rank. -fn same_as_input_broadcast(node: &mut Node) { - log::debug!("Broadcasting operation for node {}", node.name); - - let max_rank = node.inputs.iter().fold(0, |acc, input| match &input.ty { - ArgType::Tensor(tensor) => acc.max(tensor.rank), - ArgType::Scalar(_) => acc, - _ => panic!("Unsupported input type for broadcasting operation"), - }); - - log::debug!("Max rank for broadcasting node {}: {}", node.name, max_rank); - - if max_rank == 0 { - node.outputs[0].ty = ArgType::Scalar(node.inputs[0].ty.elem_type().clone()); - log::debug!("Scalar result for node {}", node.name); - } else { - let elem_type = node - .inputs - .iter() - .find_map(|input| match &input.ty { - ArgType::Tensor(tensor) => Some(tensor.elem_type.clone()), - _ => None, - }) - .unwrap_or_else(|| node.inputs[0].ty.elem_type().clone()); - - node.outputs[0].ty = ArgType::Tensor(TensorType { - elem_type, - rank: max_rank, - static_shape: None, - // Removed call to set_broadcasting_output_shape - }); - log::debug!( - "Tensor result for node {} with rank {}", - node.name, - max_rank - ); - } -} - -/// Update output rank for Unsqueeze based on axes. -/// Update the output tensor dimension based on the "axes" attribute or the second input -fn unsqueeze_update_output(node: &mut Node) { - log::debug!("Unsqueeze rank inference for node {}", node.name); - - let axes = if node.inputs.len() == 2 { - match &node.inputs[1].value { - Some(value) => match &value.data { - Data::Int64s(a) => Some(a.clone()), - _ => panic!("Unsqueeze: invalid input types"), - }, - None => None, - } - } else { - let axes = node.attrs.get("axes").cloned().map(|v| { - let axes = v.into_i64s(); - log::debug!( - "Unsqueeze axes from attribute for {}: {:?}", - node.name, - axes - ); - axes - }); - axes - }; - - let input_rank = match &node.inputs[0].ty { - ArgType::Tensor(tensor) => tensor.rank, - ArgType::Scalar(_) => { - 0 // treat scalar as 0-dim tensor - } - _ => panic!("Unsqueeze: invalid input type"), - }; - - let output_elem = match &node.outputs[0].ty { - ArgType::Tensor(_) => node.inputs[0].ty.elem_type().clone(), - ArgType::Scalar(elem_type) => elem_type.clone(), - _ => panic!("Unsqueeze: invalid output type"), - }; - - let output_rank = if let Some(axes) = axes { - input_rank + axes.len() - } else if let ArgType::Tensor(tensor) = &node.inputs[1].ty { - if let Some(static_shape) = &tensor.static_shape { - input_rank + *static_shape.first().expect("Empty shape") - } else { - panic!("Unsqueeze: should have static shape") - } - } else { - panic!("Unsqueeze: missing axes information") - }; - - node.outputs[0].ty = ArgType::Tensor(TensorType { - rank: output_rank, - static_shape: None, // shape is tracked and calculated at runtime - elem_type: output_elem, - }); - - log::debug!("Unsqueeze output rank for {}: {}", node.name, output_rank); -} - -/// Preserve input rank for operations like Relu, Sigmoid, etc. -fn same_as_input(node: &mut Node) { - log::debug!("Copying input type to output for node {}", node.name); - - if let ArgType::Tensor(tensor) = &node.inputs[0].ty { - log::debug!("Input rank for {}: {}", node.name, tensor.rank); - } else if let ArgType::Scalar(_) = &node.inputs[0].ty { - log::debug!("Input is scalar for {}", node.name); - } - - node.outputs[0].ty = node.inputs[0].ty.clone(); - log::debug!("Output type is same as input for {}", node.name); -} - -/// Update output rank for TopK (same as input rank). -fn top_k_update_output(node: &mut Node) { - log::debug!("TopK rank inference for node {}", node.name); - - let rank = match &node.inputs[0].ty { - ArgType::Tensor(tensor) => tensor.rank, - _ => panic!("TopK: invalid input type"), - }; - log::debug!("TopK input rank for {}: {}", node.name, rank); - - node.outputs[0].ty = ArgType::Tensor(TensorType { - elem_type: node.inputs[0].ty.elem_type().clone(), - rank, - static_shape: None, - }); - node.outputs[1].ty = ArgType::Tensor(TensorType { - elem_type: ElementType::Int64, - rank, - static_shape: None, - }); - - log::debug!( - "TopK output rank for {}: {} (both outputs)", - node.name, - rank - ); -} - -/// Temporary stub preserves input type for unhandled operations. -fn temporary_pass_through_stub(node: &mut Node) { - log::warn!( - "Must implement rank inference for node type {:?} (name: {})", - node.node_type, - node.name - ); - - if let Some(input_rank) = node.inputs.first().map(|input| match &input.ty { - ArgType::Tensor(tensor) => tensor.rank, - ArgType::Scalar(_) => 0, - _ => 0, - }) { - log::debug!( - "Passing through input rank {} for unhandled node {}", - input_rank, - node.name - ); - } - - node.outputs[0].ty = node.inputs[0].ty.clone(); - log::debug!( - "Using pass-through inference for unhandled node type {:?} ({})", - node.node_type, - node.name - ); -} - -/// Update output type for comparison operations (e.g., Equal, Greater) to max input rank. -fn elementwise_comparison_outputs(node: &mut Node) { - log::debug!("Elementwise comparison for node {}", node.name); - - let max_rank = node.inputs.iter().fold(0, |acc, input| match &input.ty { - ArgType::Tensor(tensor) => acc.max(tensor.rank), - ArgType::Scalar(_) => acc, - _ => panic!("Invalid input type for comparison op"), - }); - - log::debug!("Max rank for comparison node {}: {}", node.name, max_rank); - - if max_rank == 0 { - node.outputs[0].ty = ArgType::Scalar(ElementType::Bool); - log::debug!("Scalar boolean result for node {}", node.name); - } else { - node.outputs[0].ty = ArgType::Tensor(TensorType { - elem_type: ElementType::Bool, - rank: max_rank, - static_shape: None, - }); - log::debug!( - "Tensor boolean result for node {} with rank {}", - node.name, - max_rank - ); - } -} - -/// Updates the output rank and shape for the Expand operation based on the provided shape input. -/// If the shape is a constant, the rank and static shape of the output are set accordingly. -/// If the shape is dynamic, the rank is inferred from the static shape of the shape input. -fn expand_update_outputs(node: &mut Node) { - let shape = if node.inputs.len() == 2 { - match &node.inputs[1].value { - Some(value) => match &value.data { - Data::Int64s(shape) => Some(shape.clone()), - _ => panic!("Expand operation encountered invalid input types"), - }, - None => None, - } - } else { - panic!("Expand operation requires exactly two inputs"); - }; - - let output = match &node.outputs[0].ty { - ArgType::Tensor(tensor) => tensor.clone(), - _ => panic!("Expand operation encountered invalid output types"), - }; - - if let Some(shape) = shape { - node.outputs[0].ty = ArgType::Tensor(TensorType { - rank: shape.len(), - static_shape: Some(shape.into_iter().map(|dim| dim as usize).collect()), - ..output - }); - } else { - // When the shape cannot be determined statically (i.e., the second argument 'shape' is passed dynamically), - // infer the rank from the static shape of the input tensor. - let output_rank = match &node.inputs[1].ty { - ArgType::Tensor(tensor) => tensor - .static_shape - .as_ref() - .expect("Shape input must have a static shape defined") - .first() - .copied() - .expect("Static shape must contain at least one element"), - ArgType::Shape(rank) => *rank, - _ => panic!("Shape input must be of tensor or shape type",), - }; - - node.outputs[0].ty = ArgType::Tensor(TensorType { - rank: output_rank, - static_shape: None, // The exact shape cannot be determined statically - ..output - }); - } -} - -/// Update output type for Shape operation (rank 1). -fn shape_update_outputs(node: &mut Node) { - if node.inputs.len() != 1 { - panic!("Shape: multiple inputs are not supported: {:?}", node); - } - let (start, end) = shape_config(node); - let dim = end - start; - log::debug!( - "Shape operation for node {}: start={}, end={}, dim={}", - node.name, - start, - end, - dim - ); - node.outputs[0].ty = ArgType::Shape(dim); -} - -/// Update output type for Flatten operation (rank 2). -fn flatten_update_outputs(node: &mut Node) { - if node.inputs.len() != 1 { - panic!("Flatten: multiple inputs are not supported"); - } - let tensor = node - .inputs - .iter() - .find_map(|input| match &input.ty { - ArgType::Tensor(tensor) => Some(tensor), - _ => None, - }) - .unwrap(); - - // Flatten to a 2D tensor - node.outputs[0].ty = ArgType::Tensor(TensorType { - rank: 2, - ..tensor.clone() - }); -} - -/// Update output rank for Conv1d (same as input). -fn conv1d_update_outputs(node: &mut Node) { - log::debug!("Conv1d rank inference for node {}", node.name); - - if let ArgType::Tensor(tensor) = &node.inputs[0].ty { - log::debug!("Conv1d input rank for {}: {}", node.name, tensor.rank); - - node.outputs[0].ty = ArgType::Tensor(TensorType { - elem_type: tensor.elem_type.clone(), - rank: tensor.rank, - static_shape: None, - }); - - log::debug!("Conv1d output rank for {}: {}", node.name, tensor.rank); - } else { - panic!("Only tensor input is valid"); - } -} - -/// Update output rank for Conv2d (same as input). -fn conv2d_update_outputs(node: &mut Node) { - log::debug!("Conv2d rank inference for node {}", node.name); - - if let ArgType::Tensor(tensor) = &node.inputs[0].ty { - log::debug!("Conv2d input rank for {}: {}", node.name, tensor.rank); - - node.outputs[0].ty = ArgType::Tensor(TensorType { - elem_type: tensor.elem_type.clone(), - rank: tensor.rank, - static_shape: None, - }); - - log::debug!("Conv2d output rank for {}: {}", node.name, tensor.rank); - } else { - panic!("Only tensor input is valid"); - } -} - -/// Update output rank for ConvTranspose1d (same as input). -fn conv_transpose1d_update_outputs(node: &mut Node) { - log::debug!("ConvTranspose1d rank inference for node {}", node.name); - - if let ArgType::Tensor(tensor) = &node.inputs[0].ty { - log::debug!( - "ConvTranspose1d input rank for {}: {}", - node.name, - tensor.rank - ); - - node.outputs[0].ty = ArgType::Tensor(TensorType { - elem_type: tensor.elem_type.clone(), - rank: tensor.rank, - static_shape: None, - }); - - log::debug!( - "ConvTranspose1d output rank for {}: {}", - node.name, - tensor.rank - ); - } else { - panic!("Only tensor input is valid"); - } -} - -/// Update output rank for ConvTranspose2d (same as input). -fn conv_transpose2d_update_outputs(node: &mut Node) { - log::debug!("ConvTranspose2d rank inference for node {}", node.name); - - if let ArgType::Tensor(tensor) = &node.inputs[0].ty { - log::debug!( - "ConvTranspose2d input rank for {}: {}", - node.name, - tensor.rank - ); - - node.outputs[0].ty = ArgType::Tensor(TensorType { - elem_type: tensor.elem_type.clone(), - rank: tensor.rank, - static_shape: None, - }); - - log::debug!( - "ConvTranspose2d output rank for {}: {}", - node.name, - tensor.rank - ); - } else { - panic!("Only tensor input is valid"); - } -} - -/// Update output rank for MatMul based on input ranks. -fn matmul_update_outputs(node: &mut Node) { - log::debug!("MatMul rank inference for node {}", node.name); - - match (&node.inputs[0].ty, &node.inputs[1].ty) { - (ArgType::Tensor(a), ArgType::Tensor(b)) => { - log::debug!( - "MatMul input ranks for {}: a.rank={}, b.rank={}", - node.name, - a.rank, - b.rank - ); - - let mut out_rank = max(a.rank, b.rank); - if (a.rank >= 2 && b.rank == 1) || (a.rank == 1 && b.rank >= 2) { - out_rank -= 1; - log::debug!( - "MatMul special case for node {}: reducing output rank", - node.name - ); - } - - node.outputs[0].ty = ArgType::Tensor(TensorType { - elem_type: a.elem_type.clone(), - rank: out_rank, - static_shape: None, - }); - - log::debug!("MatMul output rank for {}: {}", node.name, out_rank); - } - _ => panic!("Only tensor inputs are valid"), - } -} - -/// Update output rank for Range (always rank 1). -fn range_update_outputs(node: &mut Node) { - log::debug!("Range rank inference for node {}", node.name); - - if node.inputs.len() != 3 { - panic!("Range: expected 3 inputs, found {}", node.inputs.len()); - } - log::debug!( - "Range operation always produces rank 1 tensor for {}", - node.name - ); - - node.outputs[0].ty = ArgType::Tensor(TensorType { - elem_type: ElementType::Int64, - rank: 1, - static_shape: None, - }); - - log::debug!("Range output rank for {}: 1", node.name); -} - -/// Update output rank for ReduceMax based on axes. -fn reduce_max_update_outputs(node: &mut Node) { - log::debug!("ReduceMax rank inference for node {}", node.name); - - if node.inputs.len() != 1 { - panic!("ReduceMax: multiple inputs are not supported"); - } - let tensor = match &node.inputs[0].ty { - ArgType::Tensor(tensor) => tensor, - _ => panic!("Only tensor input is valid"), - }; - log::debug!("ReduceMax input rank for {}: {}", node.name, tensor.rank); - - let dim_only = match node.attrs.get("axes") { - Some(value) => match &value { - AttributeValue::Int64(_) => true, - AttributeValue::Int64s(ints) => ints.len() == 1, - _ => false, - }, - None => false, - }; - - let output_rank = if dim_only { tensor.rank } else { 1 }; - log::debug!("ReduceMax output rank for {}: {}", node.name, output_rank); - - node.outputs[0].ty = ArgType::Tensor(TensorType { - elem_type: tensor.elem_type.clone(), - rank: output_rank, - static_shape: None, - }); -} - -/// Update output rank for ReduceMin based on axes. -fn reduce_min_update_outputs(node: &mut Node) { - log::debug!("ReduceMin rank inference for node {}", node.name); - - if node.inputs.len() != 1 { - panic!("ReduceMin: multiple inputs are not supported"); - } - let tensor = match &node.inputs[0].ty { - ArgType::Tensor(tensor) => tensor, - _ => panic!("Only tensor input is valid"), - }; - log::debug!("ReduceMin input rank for {}: {}", node.name, tensor.rank); - - let dim_only = match node.attrs.get("axes") { - Some(value) => match &value { - AttributeValue::Int64(_) => true, - AttributeValue::Int64s(ints) => ints.len() == 1, - _ => false, - }, - None => false, - }; - - let output_rank = if dim_only { tensor.rank } else { 1 }; - log::debug!("ReduceMin output rank for {}: {}", node.name, output_rank); - - node.outputs[0].ty = ArgType::Tensor(TensorType { - elem_type: tensor.elem_type.clone(), - rank: output_rank, - static_shape: None, - }); -} - -/// Update output rank for ReduceProd based on axes. -fn reduce_prod_update_outputs(node: &mut Node) { - log::debug!("ReduceProd rank inference for node {}", node.name); - - if node.inputs.len() != 1 { - panic!("ReduceProd: multiple inputs are not supported"); - } - let tensor = match &node.inputs[0].ty { - ArgType::Tensor(tensor) => tensor, - _ => panic!("Only tensor input is valid"), - }; - log::debug!("ReduceProd input rank for {}: {}", node.name, tensor.rank); - - let dim_only = match node.attrs.get("axes") { - Some(value) => match &value { - AttributeValue::Int64(_) => true, - AttributeValue::Int64s(ints) => ints.len() == 1, - _ => false, - }, - None => false, - }; - - let output_rank = if dim_only { tensor.rank } else { 1 }; - log::debug!("ReduceProd output rank for {}: {}", node.name, output_rank); - - node.outputs[0].ty = ArgType::Tensor(TensorType { - elem_type: tensor.elem_type.clone(), - rank: output_rank, - static_shape: None, - }); -} - -/// Update output rank for ReduceSum based on axes. -fn reduce_sum_update_outputs(node: &mut Node) { - log::debug!("ReduceSum rank inference for node {}", node.name); - - let tensor = match &node.inputs[0].ty { - ArgType::Tensor(tensor) => tensor, - _ => panic!("Only tensor input is valid"), - }; - log::debug!("ReduceSum input rank for {}: {}", node.name, tensor.rank); - - let dim_only = match node.attrs.get("axes") { - Some(value) => match &value { - AttributeValue::Int64(_) => true, - AttributeValue::Int64s(ints) => ints.len() == 1, - _ => false, - }, - None => false, - } || match node.inputs.get(1).and_then(|arg| arg.value.as_ref()) { - Some(value) => match &value.data { - Data::Int64(_) => true, - Data::Int64s(ints) => ints.len() == 1, - _ => false, - }, - None => false, - }; - - let output_rank = if dim_only { tensor.rank } else { 1 }; - log::debug!("ReduceSum output rank for {}: {}", node.name, output_rank); - - node.outputs[0].ty = ArgType::Tensor(TensorType { - elem_type: tensor.elem_type.clone(), - rank: output_rank, - static_shape: None, - }); -} - -/// Update output rank for Where to max input rank. -fn where_update_outputs(node: &mut Node) { - log::debug!("Where rank inference for node {}", node.name); - - let condition = &node.inputs[0].ty; - let x = &node.inputs[1].ty; - let y = &node.inputs[2].ty; - let elem_type = x.elem_type().clone(); - assert_eq!( - *condition.elem_type(), - ElementType::Bool, - "Where condition must be boolean!" - ); - assert_eq!( - elem_type, - *y.elem_type(), - "Where x and y have different element types!" - ); - - log::debug!( - "Where input ranks for {}: condition={}, x={}, y={}", - node.name, - condition.rank(), - x.rank(), - y.rank() - ); - - let output_rank = max(condition.rank(), max(x.rank(), y.rank())); - log::debug!("Where output rank for {}: {}", node.name, output_rank); - - if output_rank == 0 { - node.outputs[0].ty = ArgType::Scalar(elem_type); - log::debug!("Where result for {} is scalar", node.name); - } else { - node.outputs[0].ty = ArgType::Tensor(TensorType { - elem_type, - rank: output_rank, - static_shape: None, - }); - log::debug!( - "Where result for {} is tensor with rank {}", - node.name, - output_rank - ); - } -} - -/// Update output rank for Gather based on input and indices ranks. -fn gather_update_outputs(node: &mut Node) { - log::debug!("Gather rank inference for node {}", node.name); - - if node.inputs.len() != 2 { - panic!("Gather requires two inputs: data and indices"); - } - - let indices_rank = match &node.inputs[1].ty { - ArgType::Tensor(tensor) => tensor.rank, - ArgType::Scalar(_) => 0, - _ => panic!("Only tensor indices is valid, got {:?}", node.inputs[1].ty), - }; - log::debug!("Gather indices rank for {}: {}", node.name, indices_rank); - - match &node.inputs[0].ty { - ArgType::Tensor(input_tensor) => { - log::debug!( - "Gather input tensor rank for {}: {}", - node.name, - input_tensor.rank - ); - // Output of rank q+(r-1), where q is rank of indices tensor and r is rank of input - let output_rank = indices_rank + input_tensor.rank - 1; - log::debug!("Gather output rank for {}: {}", node.name, output_rank); - - if output_rank == 0 { - node.outputs[0].ty = ArgType::Scalar(input_tensor.elem_type.clone()); - log::debug!("Gather result for {} is scalar", node.name); - } else { - node.outputs[0].ty = ArgType::Tensor(TensorType { - elem_type: input_tensor.elem_type.clone(), - rank: output_rank, - static_shape: None, - }); - log::debug!( - "Gather result for {} is tensor with rank {}", - node.name, - output_rank - ); - } - } - ArgType::Shape(_) => { - log::debug!("Gather input is shape for {}", node.name); - let shape_rank = 1; - // Output of rank q+(r-1), where q is rank of indices tensor and r is rank of input - let output_rank = indices_rank + shape_rank - 1; - log::debug!( - "Gather output rank for {} with shape input: {}", - node.name, - output_rank - ); - - if output_rank == 0 { - node.outputs[0].ty = ArgType::Scalar(ElementType::Int64); - log::debug!("Gather result for {} is scalar (from shape)", node.name); - } else { - node.outputs[0].ty = ArgType::Tensor(TensorType { - elem_type: ElementType::Int64, - rank: output_rank, - static_shape: None, - }); - log::debug!( - "Gather result for {} is tensor with rank {} (from shape)", - node.name, - output_rank - ); - } - } - ty => panic!("Only tensor/shape input is valid but received: {:?}", ty), - } -} - -/// Update output rank for Split (same as input). -fn split_update_outputs(node: &mut Node) { - log::debug!("Split rank inference for node {}", node.name); - - let tensor = match &node.inputs[0].ty { - ArgType::Tensor(tensor) => tensor, - _ => panic!("Split: Input must be a tensor"), - }; - log::debug!("Split input rank for {}: {}", node.name, tensor.rank); - log::debug!( - "Split will generate {} outputs for {}", - node.outputs.len(), - node.name - ); - - for (i, output_arg) in node.outputs.iter_mut().enumerate() { - output_arg.ty = ArgType::Tensor(TensorType { - elem_type: tensor.elem_type.clone(), - rank: tensor.rank, - static_shape: None, - }); - log::debug!("Split output {} rank for {}: {}", i, node.name, tensor.rank); - } -} - -/// Update output rank for OneHot (input rank + 1). -fn one_hot_output_shape(node: &mut Node) { - log::debug!("OneHot rank inference for node {}", node.name); - - let input_rank = match &node.inputs[0].ty { - ArgType::Tensor(tensor) => tensor.rank, - _ => panic!("OneHot: invalid input type"), - }; - log::debug!("OneHot input rank for {}: {}", node.name, input_rank); - - let output_rank = input_rank + 1; - log::debug!("OneHot output rank for {}: {}", node.name, output_rank); - - node.outputs[0].ty = ArgType::Tensor(TensorType { - elem_type: node.outputs[0].ty.elem_type().clone(), - rank: output_rank, - static_shape: None, - }); -} - -fn gemm_output_shape(node: &mut Node) { - log::debug!("Gemm rank inference for node {}", node.name); - - let a_rank = match &node.inputs[0].ty { - ArgType::Tensor(tensor) => tensor.rank, - _ => panic!("Input A should be a tensor!"), - }; - let b_rank = match &node.inputs[1].ty { - ArgType::Tensor(tensor) => tensor.rank, - _ => panic!("Input B should be a tensor!"), - }; - - log::debug!( - "Gemm input ranks for {}: a_rank={}, b_rank={}", - node.name, - a_rank, - b_rank - ); - - let output_rank = max(a_rank, b_rank); - log::debug!("Gemm output rank for {}: {}", node.name, output_rank); - - node.outputs[0].ty = ArgType::Tensor(TensorType { - rank: output_rank, - static_shape: None, - elem_type: match &node.inputs[0].ty { - ArgType::Tensor(t) => t.elem_type.clone(), - _ => panic!("Unexpected type for input A"), - }, - }); -} diff --git a/crates/onnx-ir/src/util.rs b/crates/onnx-ir/src/util.rs index 6843b3c9be..fe63f651a1 100644 --- a/crates/onnx-ir/src/util.rs +++ b/crates/onnx-ir/src/util.rs @@ -1,4 +1,5 @@ -use crate::ir::{ArgType, Node}; +use crate::ir::{ArgType, Node, TensorType}; + use crate::protos::OperatorSetIdProto; pub fn shape_config(curr: &Node) -> (usize, usize) { @@ -81,3 +82,190 @@ pub fn verify_opsets(opsets: &[OperatorSetIdProto], min_version: i64) -> bool { } true } + +/// Preserve input rank for operations like Relu, Sigmoid, etc. +pub fn same_as_input(node: &mut Node) { + log::debug!("Copying input type to output for node {}", node.name); + + if let ArgType::Tensor(tensor) = &node.inputs[0].ty { + log::debug!("Input rank for {}: {}", node.name, tensor.rank); + } else if let ArgType::Scalar(_) = &node.inputs[0].ty { + log::debug!("Input is scalar for {}", node.name); + } + + node.outputs[0].ty = node.inputs[0].ty.clone(); + log::debug!("Output type is same as input for {}", node.name); +} + +/// Update output rank for broadcasting operations (e.g., Add, Sub) to max input rank. +pub fn same_as_input_broadcast(node: &mut Node) { + log::debug!("Broadcasting operation for node {}", node.name); + + let max_rank = node.inputs.iter().fold(0, |acc, input| match &input.ty { + ArgType::Tensor(tensor) => acc.max(tensor.rank), + ArgType::Scalar(_) => acc, + _ => panic!("Unsupported input type for broadcasting operation"), + }); + + log::debug!("Max rank for broadcasting node {}: {}", node.name, max_rank); + + if max_rank == 0 { + node.outputs[0].ty = ArgType::Scalar(node.inputs[0].ty.elem_type().clone()); + log::debug!("Scalar result for node {}", node.name); + } else { + let elem_type = node + .inputs + .iter() + .find_map(|input| match &input.ty { + ArgType::Tensor(tensor) => Some(tensor.elem_type.clone()), + _ => None, + }) + .unwrap_or_else(|| node.inputs[0].ty.elem_type().clone()); + + node.outputs[0].ty = ArgType::Tensor(TensorType { + elem_type, + rank: max_rank, + static_shape: None, + }); + log::debug!( + "Tensor result for node {} with rank {}", + node.name, + max_rank + ); + } +} + +/// Temporary stub preserves input type for unhandled operations. +pub fn temporary_pass_through_stub(node: &mut Node) { + log::warn!( + "Must implement rank inference for node type {:?} (name: {})", + node.node_type, + node.name + ); + + if let Some(input_rank) = node.inputs.first().map(|input| match &input.ty { + ArgType::Tensor(tensor) => tensor.rank, + ArgType::Scalar(_) => 0, + _ => 0, + }) { + log::debug!( + "Passing through input rank {} for unhandled node {}", + input_rank, + node.name + ); + } + + node.outputs[0].ty = node.inputs[0].ty.clone(); + log::debug!( + "Using pass-through inference for unhandled node type {:?} ({})", + node.node_type, + node.name + ); +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::{Argument, ElementType, NodeType}; + use std::collections::HashMap; + + fn create_test_node(op_type: NodeType, input_ranks: Vec) -> Node { + let mut inputs = Vec::new(); + + for (i, rank) in input_ranks.iter().enumerate() { + inputs.push(Argument { + name: format!("input_{}", i), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: *rank, + static_shape: None, + }), + value: None, + passed: true, + }); + } + + let outputs = vec![Argument { + name: "output".to_string(), + ty: ArgType::Tensor(TensorType { + elem_type: ElementType::Float32, + rank: 0, // Will be updated + static_shape: None, + }), + value: None, + passed: true, + }]; + + Node { + node_type: op_type.clone(), + name: format!("test_{:?}", op_type).to_lowercase(), + inputs, + outputs, + attrs: HashMap::new(), + } + } + + #[test] + fn test_same_as_input() { + let mut node = create_test_node(NodeType::Relu, vec![3]); + same_as_input(&mut node); + + match &node.outputs[0].ty { + ArgType::Tensor(tensor) => { + assert_eq!(tensor.elem_type, ElementType::Float32); + assert_eq!(tensor.rank, 3); + } + _ => panic!("Expected tensor output"), + } + } + + #[test] + fn test_same_as_input_broadcast_max_rank() { + let mut node = create_test_node(NodeType::Add, vec![2, 4, 3]); + same_as_input_broadcast(&mut node); + + match &node.outputs[0].ty { + ArgType::Tensor(tensor) => { + assert_eq!(tensor.elem_type, ElementType::Float32); + assert_eq!(tensor.rank, 4); // max(2, 4, 3) = 4 + } + _ => panic!("Expected tensor output"), + } + } + + #[test] + fn test_same_as_input_broadcast_with_scalar() { + let mut node = create_test_node(NodeType::Add, vec![3]); + // Add a scalar input + node.inputs.push(Argument { + name: "scalar_input".to_string(), + ty: ArgType::Scalar(ElementType::Float32), + value: None, + passed: true, + }); + + same_as_input_broadcast(&mut node); + + match &node.outputs[0].ty { + ArgType::Tensor(tensor) => { + assert_eq!(tensor.elem_type, ElementType::Float32); + assert_eq!(tensor.rank, 3); // Scalar doesn't affect rank + } + _ => panic!("Expected tensor output"), + } + } + + #[test] + fn test_temporary_pass_through_stub() { + let mut node = create_test_node(NodeType::Identity, vec![5]); + temporary_pass_through_stub(&mut node); + + match &node.outputs[0].ty { + ArgType::Tensor(tensor) => { + assert_eq!(tensor.elem_type, ElementType::Float32); + assert_eq!(tensor.rank, 5); + } + _ => panic!("Expected tensor output"), + } + } +}