From 87a875fed5c9c2ab7c02df55ab0f209569b94009 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Mon, 16 Feb 2026 14:12:35 -0800 Subject: [PATCH] Prevent cross-EP Cast fusion in RemoveDuplicateCastTransformer --- .../core/optimizer/insert_cast_transformer.cc | 29 ++++++++-- .../framework/insert_cast_transformer_test.cc | 58 +++++++++++++++++++ 2 files changed, 82 insertions(+), 5 deletions(-) diff --git a/onnxruntime/core/optimizer/insert_cast_transformer.cc b/onnxruntime/core/optimizer/insert_cast_transformer.cc index b1665c7172549..807021e67dee5 100644 --- a/onnxruntime/core/optimizer/insert_cast_transformer.cc +++ b/onnxruntime/core/optimizer/insert_cast_transformer.cc @@ -433,15 +433,34 @@ class RemoveDuplicateCastTransformer : public GraphTransformer { // If all the child nodes are either removed or another Cast node and we're not providing graph output, // we can remove this node. Connect those remaining child Cast nodes to current Cast node's input. + // + // However, we must NOT do this if any kept Cast child is on a different EP than the current node. + // Fusing across EP boundaries can produce a node whose input type is not supported by its EP. + // For example, Cast(int64->float, CPU) -> Cast(float->float16, WebGPU) would become + // Cast(int64->float16, WebGPU), but WebGPU doesn't support int64 inputs. + // See: https://github.com/microsoft/onnxruntime/issues/27291 if (num_children > 0 && nodes_to_remove.size() + cast_nodes_to_keep.size() == num_children && graph_outputs.find(node.OutputDefs()[0]) == graph_outputs_end) { - for (auto& n : cast_nodes_to_keep) { - Node& cast_node_to_keep = n; - graph.SetNodeArgType(*cast_node_to_keep.MutableInputDefs()[0], *node.InputDefs()[0]->TypeAsProto()); + // Check that all kept Cast children are on the same EP as the current node. + bool cross_ep = false; + const auto& current_ep = node.GetExecutionProviderType(); + for (const auto& n : cast_nodes_to_keep) { + const Node& kept_node = n; + if (kept_node.GetExecutionProviderType() != current_ep) { + cross_ep = true; + break; + } } - removed = graph_utils::RemoveNode(graph, node); - modified = true; + if (!cross_ep) { + for (auto& n : cast_nodes_to_keep) { + Node& cast_node_to_keep = n; + graph.SetNodeArgType(*cast_node_to_keep.MutableInputDefs()[0], *node.InputDefs()[0]->TypeAsProto()); + } + + removed = graph_utils::RemoveNode(graph, node); + modified = true; + } } } diff --git a/onnxruntime/test/framework/insert_cast_transformer_test.cc b/onnxruntime/test/framework/insert_cast_transformer_test.cc index c4b0f3ffd15d9..b2a3f9ee329e5 100644 --- a/onnxruntime/test/framework/insert_cast_transformer_test.cc +++ b/onnxruntime/test/framework/insert_cast_transformer_test.cc @@ -371,5 +371,63 @@ TEST(TransformerTest, IsIsolatedFp16NodeOnCpuTest) { EXPECT_EQ(ops["Cast"], 4); } +// Verify that RemoveDuplicateCastTransformer does not fuse consecutive Cast nodes +// that are assigned to different execution providers. +// Regression test for https://github.com/microsoft/onnxruntime/issues/27291 +TEST(TransformerTest, CrossEpCastNodesNotFused) { + auto model = std::make_shared("test", false, DefaultLoggingManager().DefaultLogger()); + onnxruntime::Graph& graph = model->MainGraph(); + + // Build: X(int64) -> Cast(int64->float32) -> Cast(float32->float16) -> Y(float16) + TypeProto tensor_int64; + tensor_int64.mutable_tensor_type()->set_elem_type(TensorProto_DataType_INT64); + TypeProto tensor_float32; + tensor_float32.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT); + TypeProto tensor_float16; + tensor_float16.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT16); + + onnxruntime::NodeArg x_def("X", &tensor_int64); + onnxruntime::NodeArg mid_def("mid", &tensor_float32); + onnxruntime::NodeArg y_def("Y", &tensor_float16); + + NodeAttributes cast1_attrs = { + {"to", utils::MakeAttribute("to", + static_cast(TensorProto_DataType_FLOAT))}}; + NodeAttributes cast2_attrs = { + {"to", utils::MakeAttribute("to", + static_cast(TensorProto_DataType_FLOAT16))}}; + + // Cast_1 on CPU EP, Cast_2 on WebGPU EP. + auto& cast1 = graph.AddNode("Cast_1", "Cast", "int64 to float32", + ArgMap{&x_def}, ArgMap{&mid_def}, &cast1_attrs); + cast1.SetExecutionProviderType(onnxruntime::kCpuExecutionProvider); + + auto& cast2 = graph.AddNode("Cast_2", "Cast", "float32 to float16", + ArgMap{&mid_def}, ArgMap{&y_def}, &cast2_attrs); + cast2.SetExecutionProviderType(onnxruntime::kWebGpuExecutionProvider); + + auto status = graph.Resolve(); + ASSERT_TRUE(status.IsOK()) << status.ErrorMessage(); + + // Run InsertCastTransformer (which internally runs RemoveDuplicateCastTransformer) + InsertCastTransformer transformer("Test", DefaultCpuExecutionProvider()->GetKernelRegistry().get()); + + bool modified = false; + status = transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger()); + EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); + status = graph.Resolve(); + EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); + + // Both Cast nodes must survive — they should NOT be fused across EP boundaries. + std::map op_counts = CountOpsInGraph(graph); + EXPECT_EQ(op_counts["Cast"], 2) << "Cast nodes on different EPs must not be fused"; + + // Verify Cast_2's input is still float32 (not changed to int64) + const auto* cast2_input_type = cast2.InputDefs()[0]->TypeAsProto(); + ASSERT_NE(cast2_input_type, nullptr); + EXPECT_EQ(cast2_input_type->tensor_type().elem_type(), TensorProto_DataType_FLOAT) + << "Cast_2 input should remain float32, not be changed to int64"; +} + } // namespace test } // namespace onnxruntime