Open
Description
Describe the issue
In ort training, we do graph transformation/optimization for forward_model_ and invoke 'Model::Load(forward_model_->ToProto(), gradient_model_, nullptr, *logger_)' to get the gradient_model.
If shapes of args are changed in the transformation/optimization step, the change will not be seen in gradient_model_.graph.
This is because shape change is saved in Graph::node_args_ which will not sync when serialize graph to proto.
To reproduce
the onnx file can be created by code:
onnxruntime::Model original_model("test", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(),
{{kOnnxDomain, 12}}, {}, DefaultLoggingManager().DefaultLogger());
onnxruntime::Graph& graph = original_model.MainGraph();
TypeProto tensor_float;
tensor_float.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT);
onnxruntime::NodeArg input_def("X", &tensor_float), inter_def("Y", &tensor_float), output_def("Z", &tensor_float);
onnxruntime::Node* node = &graph.AddNode("node1", "Identity", "Identity operator", ArgMap{&input_def}, ArgMap{&inter_def});
node->SetExecutionProviderType(kCpuExecutionProvider);
onnxruntime::Node* node2 = &graph.AddNode("node2", "Identity", "Identity operator", ArgMap{&inter_def}, ArgMap{&output_def});
node2->SetExecutionProviderType(kCpuExecutionProvider);
ASSERT_STATUS_OK(Model::Save(original_model, "./test.onnx"));
and test code to reproduce the issue:
constexpr const ORTCHAR_T* model_uri = "test.onnx";
std::shared_ptr<Model> forward_model;
ASSERT_STATUS_OK(Model::Load(model_uri, forward_model, nullptr, *logger_));
Graph& graph = forward_model->MainGraph();
NodeArg* arg = graph.GetNodeArg("Y");
ASSERT_TRUE(arg->Shape() == nullptr);
onnx::TensorShapeProto new_shape;
new_shape.add_dim()->set_dim_value(2);
new_shape.add_dim()->set_dim_value(3);
arg->SetShape(new_shape); // set the shape, but it's not sync and can not be seen in gradient_model
// This code snippet is to sync the shape change,
// If uncommented, the shape change can be seen in the gradient graph and this test succeed.
// but this is not done in current ort code.
// graph.Set_is_loaded_from_model_file(false);
// graph.SetGraphResolveNeeded();
// Graph::ResolveOptions resolve_options;
// ASSERT_STATUS_OK(graph.Resolve(resolve_options));
std::shared_ptr<Model> gradient_model;
ASSERT_STATUS_OK(Model::Load(forward_model->ToProto(), gradient_model, nullptr, *logger_));
Graph& gradient_graph = gradient_model->MainGraph();
NodeArg* gradient_arg = gradient_graph.GetNodeArg("Y");
ASSERT_TRUE(gradient_arg->Shape() != nullptr); // The shape change is not seen in the gradient graph and this will fail
Urgency
No response
ONNX Runtime Installation
Built from Source
ONNX Runtime Version or Commit ID
PyTorch Version
2.2.0
Execution Provider
Default CPU
Execution Provider Library Version
No response