Skip to content

[Training] Shape change is not sync when serialize graph to proto #19741

Open
@guyang3532

Description

@guyang3532

Describe the issue

In ort training, we do graph transformation/optimization for forward_model_ and invoke 'Model::Load(forward_model_->ToProto(), gradient_model_, nullptr, *logger_)' to get the gradient_model.
If shapes of args are changed in the transformation/optimization step, the change will not be seen in gradient_model_.graph.
This is because shape change is saved in Graph::node_args_ which will not sync when serialize graph to proto.

To reproduce

I construct a graph as:
image

the onnx file can be created by code:

  onnxruntime::Model original_model("test", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(),
                           {{kOnnxDomain, 12}}, {}, DefaultLoggingManager().DefaultLogger());
  onnxruntime::Graph& graph = original_model.MainGraph();
  TypeProto tensor_float;
  tensor_float.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT);
  onnxruntime::NodeArg input_def("X", &tensor_float), inter_def("Y", &tensor_float), output_def("Z", &tensor_float);

  onnxruntime::Node* node = &graph.AddNode("node1", "Identity", "Identity operator", ArgMap{&input_def}, ArgMap{&inter_def});
  node->SetExecutionProviderType(kCpuExecutionProvider);
  onnxruntime::Node* node2 = &graph.AddNode("node2", "Identity", "Identity operator", ArgMap{&inter_def}, ArgMap{&output_def});
  node2->SetExecutionProviderType(kCpuExecutionProvider);
  ASSERT_STATUS_OK(Model::Save(original_model, "./test.onnx"));

and test code to reproduce the issue:

  constexpr const ORTCHAR_T* model_uri = "test.onnx";
  std::shared_ptr<Model> forward_model;
  ASSERT_STATUS_OK(Model::Load(model_uri, forward_model, nullptr, *logger_));
  Graph& graph = forward_model->MainGraph();
  NodeArg* arg = graph.GetNodeArg("Y");
  ASSERT_TRUE(arg->Shape() == nullptr);
  onnx::TensorShapeProto new_shape;
  new_shape.add_dim()->set_dim_value(2);
  new_shape.add_dim()->set_dim_value(3);
  arg->SetShape(new_shape); // set the shape, but it's not sync and can not be seen in gradient_model

  // This code snippet is to sync the shape change,
  // If uncommented, the shape change can be seen in the gradient graph and this test succeed.
  // but this is not done in current ort code.
  // graph.Set_is_loaded_from_model_file(false);
  // graph.SetGraphResolveNeeded();
  // Graph::ResolveOptions resolve_options;
  // ASSERT_STATUS_OK(graph.Resolve(resolve_options));

  std::shared_ptr<Model> gradient_model;
  ASSERT_STATUS_OK(Model::Load(forward_model->ToProto(), gradient_model, nullptr, *logger_));
  Graph& gradient_graph = gradient_model->MainGraph();
  NodeArg* gradient_arg = gradient_graph.GetNodeArg("Y");
  ASSERT_TRUE(gradient_arg->Shape() != nullptr); // The shape change is not seen in the gradient graph and this will fail

Urgency

No response

ONNX Runtime Installation

Built from Source

ONNX Runtime Version or Commit ID

ed550b5

PyTorch Version

2.2.0

Execution Provider

Default CPU

Execution Provider Library Version

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    staleissues that have not been addressed in a while; categorized by a bottrainingissues related to ONNX Runtime training; typically submitted using template

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions