Description
Describe the issue
When trying to generate artifacts for a model using GRU and Squeeze, it seems that there is a misinterpretation of the GRU output shape during shape inference when layout == 1 and direction == 'forward'.
With layout == 1, the expected shapes are:
X.shape: [batch, seq_length, input_size]
Y.shape: [batch, seq_length, num_directions, hidden_size]
After my GRU operation, I want to squeeze the num_directions axis. However, this results in an error stating that dimension 2 must be 1 instead of 5, even though axis 2 corresponds to num_directions, which should always be 1 when direction == 'forward'.
Error message:
[ONNXRuntimeError] : 1 : FAIL : Node () Op (Squeeze) [ShapeInferenceError] Dimension of input 2 must be 1 instead of 5.
I noticed that layout == 1 is not supported for CPU execution, so I attempted to use layout == 0 instead.
However, this resulted in another error:
10 : INVALID_GRAPH : This is an invalid model. Type Error: Type 'tensor(float)' of input parameter (gru_output) of operator (GRUGrad) in node (_training_Grad/GRUGrad_0) is invalid.
Is there a known issue with shape inference in this case, or am I missing something?
To reproduce
First Error with Squeeze()
import onnx
import numpy as np
from onnx import helper, TensorProto
from onnxruntime.training import artifacts
def create_gru_squeeze_model():
# Define tensor names
input_name = "input"
gru_output_name = "gru_output"
squeeze_output_name = "squeeze_output"
final_output_name = "output"
# Define input shape
input_tensor = helper.make_tensor_value_info(input_name, TensorProto.FLOAT, ["batch", 5, 3])
output_tensor = helper.make_tensor_value_info(squeeze_output_name, TensorProto.FLOAT, ["batch", 5, 7])
# Create GRU weights (Initializers)
W = np.random.randn(1, 3 * 7, 3).astype(np.float32) # [3*hidden_size, input_size]
R = np.random.randn(1, 3 * 7, 7).astype(np.float32) # [3*hidden_size, hidden_size]
B = np.random.randn(1, 2 * 3 * 7).astype(np.float32) # [2*3*hidden_size]
W_initializer = helper.make_tensor("W", TensorProto.FLOAT, W.shape, W.flatten())
R_initializer = helper.make_tensor("R", TensorProto.FLOAT, R.shape, R.flatten())
B_initializer = helper.make_tensor("B", TensorProto.FLOAT, B.shape, B.flatten())
# Create GRU node
gru_node = helper.make_node(
"GRU",
inputs=[input_name, "W", "R", "B"],
outputs=[gru_output_name],
direction="forward",
hidden_size=7,
activations=["Tanh", "Sigmoid"],
layout=1,
linear_before_reset=1
)
# Create initializer for Squeeze (axis = 2)
squeeze_axes = helper.make_tensor("squeeze_axes", TensorProto.INT64, [1], [2])
# Create Squeeze node
squeeze_node = helper.make_node(
"Squeeze",
inputs=[gru_output_name, "squeeze_axes"],
outputs=[squeeze_output_name]
)
# Create the graph
graph = helper.make_graph(
[gru_node, squeeze_node],
"GRU_Squeeze_Model",
[input_tensor],
[output_tensor],
[W_initializer, R_initializer, B_initializer, squeeze_axes]
)
# Create the ONNX model with opset version 17
model = helper.make_model(graph, producer_name="onnx-gru-squeeze-generator", opset_imports=[helper.make_opsetid("", 17)])
inferred_model = onnx.shape_inference.infer_shapes(model)
onnx.checker.check_model(inferred_model)
# Save the model
onnx.save(inferred_model, "gru_squeeze.onnx")
print("gru_squeeze.onnx' generated successfully.")
# Generate the model
create_gru_squeeze_model()
# Load the forward-only ONNX model
model = onnx.load("gru_squeeze.onnx")
# Generate the training artifacts
artifacts.generate_artifacts(model,
requires_grad=["W", "R"],
frozen_params=["B"],
loss=artifacts.LossType.MSELoss,
optimizer=artifacts.OptimType.AdamW,
artifact_directory="./artefact_test")
Second Error with GRUGrad
import onnx
import numpy as np
from onnx import helper, TensorProto
from onnxruntime.training import artifacts
def create_gru_squeeze_model():
# Define tensor names
input_name = "input"
gru_output_name = "gru_output"
squeeze_output_name = "squeeze_output"
final_output_name = "output"
# Define input shape
input_tensor = helper.make_tensor_value_info(input_name, TensorProto.FLOAT, [5, "batch", 3])
output_tensor = helper.make_tensor_value_info(squeeze_output_name, TensorProto.FLOAT, [5, "batch", 7])
# Create GRU weights (Initializers)
W = np.random.randn(1, 3 * 7, 3).astype(np.float32) # [3*hidden_size, input_size]
R = np.random.randn(1, 3 * 7, 7).astype(np.float32) # [3*hidden_size, hidden_size]
B = np.random.randn(1, 2 * 3 * 7).astype(np.float32) # [2*3*hidden_size]
W_initializer = helper.make_tensor("W", TensorProto.FLOAT, W.shape, W.flatten())
R_initializer = helper.make_tensor("R", TensorProto.FLOAT, R.shape, R.flatten())
B_initializer = helper.make_tensor("B", TensorProto.FLOAT, B.shape, B.flatten())
# Create GRU node
gru_node = helper.make_node(
"GRU",
inputs=[input_name, "W", "R", "B"],
outputs=[gru_output_name],
direction="forward",
hidden_size=7,
activations=["Tanh", "Sigmoid"],
layout=0,
linear_before_reset=1
)
# Create initializer for Squeeze (axis = 2)
squeeze_axes = helper.make_tensor("squeeze_axes", TensorProto.INT64, [1], [1])
# Create Squeeze node
squeeze_node = helper.make_node(
"Squeeze",
inputs=[gru_output_name, "squeeze_axes"],
outputs=[squeeze_output_name]
)
# Create the graph
graph = helper.make_graph(
[gru_node, squeeze_node],
"GRU_Squeeze_Model",
[input_tensor],
[output_tensor],
[W_initializer, R_initializer, B_initializer, squeeze_axes]
)
# Create the ONNX model with opset version 17
model = helper.make_model(graph, producer_name="onnx-gru-squeeze-generator", opset_imports=[helper.make_opsetid("", 17)])
inferred_model = onnx.shape_inference.infer_shapes(model)
#onnx.checker.check_model(inferred_model)
# Save the model
onnx.save(inferred_model, "gru_squeeze_layout_0.onnx")
print("gru_squeeze_layout_0.onnx' generated successfully.")
# Generate the model
create_gru_squeeze_model()
# Load the forward-only ONNX model
model = onnx.load("gru_squeeze_layout_0.onnx")
# Generate the training artifacts
artifacts.generate_artifacts(model,
requires_grad=["W", "R"],
frozen_params=["B"],
loss=artifacts.LossType.MSELoss,
optimizer=artifacts.OptimType.AdamW,
artifact_directory="./artefact_test_layout_0")
Urgency
This is not urgent, but I wanted to report the issue in case it helps improve the model handling.
ONNX Runtime Installation
Built from Source
ONNX Runtime Version or Commit ID
1.21.0
PyTorch Version
2.6.0
Execution Provider
Default CPU
Execution Provider Library Version
No response