Skip to content

Commit 87a875f

Browse files
committed
Prevent cross-EP Cast fusion in RemoveDuplicateCastTransformer
1 parent 0df5dbc commit 87a875f

File tree

2 files changed

+82
-5
lines changed

2 files changed

+82
-5
lines changed

onnxruntime/core/optimizer/insert_cast_transformer.cc

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -433,15 +433,34 @@ class RemoveDuplicateCastTransformer : public GraphTransformer {
433433

434434
// If all the child nodes are either removed or another Cast node and we're not providing graph output,
435435
// we can remove this node. Connect those remaining child Cast nodes to current Cast node's input.
436+
//
437+
// However, we must NOT do this if any kept Cast child is on a different EP than the current node.
438+
// Fusing across EP boundaries can produce a node whose input type is not supported by its EP.
439+
// For example, Cast(int64->float, CPU) -> Cast(float->float16, WebGPU) would become
440+
// Cast(int64->float16, WebGPU), but WebGPU doesn't support int64 inputs.
441+
// See: https://github.com/microsoft/onnxruntime/issues/27291
436442
if (num_children > 0 && nodes_to_remove.size() + cast_nodes_to_keep.size() == num_children &&
437443
graph_outputs.find(node.OutputDefs()[0]) == graph_outputs_end) {
438-
for (auto& n : cast_nodes_to_keep) {
439-
Node& cast_node_to_keep = n;
440-
graph.SetNodeArgType(*cast_node_to_keep.MutableInputDefs()[0], *node.InputDefs()[0]->TypeAsProto());
444+
// Check that all kept Cast children are on the same EP as the current node.
445+
bool cross_ep = false;
446+
const auto& current_ep = node.GetExecutionProviderType();
447+
for (const auto& n : cast_nodes_to_keep) {
448+
const Node& kept_node = n;
449+
if (kept_node.GetExecutionProviderType() != current_ep) {
450+
cross_ep = true;
451+
break;
452+
}
441453
}
442454

443-
removed = graph_utils::RemoveNode(graph, node);
444-
modified = true;
455+
if (!cross_ep) {
456+
for (auto& n : cast_nodes_to_keep) {
457+
Node& cast_node_to_keep = n;
458+
graph.SetNodeArgType(*cast_node_to_keep.MutableInputDefs()[0], *node.InputDefs()[0]->TypeAsProto());
459+
}
460+
461+
removed = graph_utils::RemoveNode(graph, node);
462+
modified = true;
463+
}
445464
}
446465
}
447466

onnxruntime/test/framework/insert_cast_transformer_test.cc

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -371,5 +371,63 @@ TEST(TransformerTest, IsIsolatedFp16NodeOnCpuTest) {
371371
EXPECT_EQ(ops["Cast"], 4);
372372
}
373373

374+
// Verify that RemoveDuplicateCastTransformer does not fuse consecutive Cast nodes
375+
// that are assigned to different execution providers.
376+
// Regression test for https://github.com/microsoft/onnxruntime/issues/27291
377+
TEST(TransformerTest, CrossEpCastNodesNotFused) {
378+
auto model = std::make_shared<onnxruntime::Model>("test", false, DefaultLoggingManager().DefaultLogger());
379+
onnxruntime::Graph& graph = model->MainGraph();
380+
381+
// Build: X(int64) -> Cast(int64->float32) -> Cast(float32->float16) -> Y(float16)
382+
TypeProto tensor_int64;
383+
tensor_int64.mutable_tensor_type()->set_elem_type(TensorProto_DataType_INT64);
384+
TypeProto tensor_float32;
385+
tensor_float32.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT);
386+
TypeProto tensor_float16;
387+
tensor_float16.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT16);
388+
389+
onnxruntime::NodeArg x_def("X", &tensor_int64);
390+
onnxruntime::NodeArg mid_def("mid", &tensor_float32);
391+
onnxruntime::NodeArg y_def("Y", &tensor_float16);
392+
393+
NodeAttributes cast1_attrs = {
394+
{"to", utils::MakeAttribute("to",
395+
static_cast<int64_t>(TensorProto_DataType_FLOAT))}};
396+
NodeAttributes cast2_attrs = {
397+
{"to", utils::MakeAttribute("to",
398+
static_cast<int64_t>(TensorProto_DataType_FLOAT16))}};
399+
400+
// Cast_1 on CPU EP, Cast_2 on WebGPU EP.
401+
auto& cast1 = graph.AddNode("Cast_1", "Cast", "int64 to float32",
402+
ArgMap{&x_def}, ArgMap{&mid_def}, &cast1_attrs);
403+
cast1.SetExecutionProviderType(onnxruntime::kCpuExecutionProvider);
404+
405+
auto& cast2 = graph.AddNode("Cast_2", "Cast", "float32 to float16",
406+
ArgMap{&mid_def}, ArgMap{&y_def}, &cast2_attrs);
407+
cast2.SetExecutionProviderType(onnxruntime::kWebGpuExecutionProvider);
408+
409+
auto status = graph.Resolve();
410+
ASSERT_TRUE(status.IsOK()) << status.ErrorMessage();
411+
412+
// Run InsertCastTransformer (which internally runs RemoveDuplicateCastTransformer)
413+
InsertCastTransformer transformer("Test", DefaultCpuExecutionProvider()->GetKernelRegistry().get());
414+
415+
bool modified = false;
416+
status = transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger());
417+
EXPECT_TRUE(status.IsOK()) << status.ErrorMessage();
418+
status = graph.Resolve();
419+
EXPECT_TRUE(status.IsOK()) << status.ErrorMessage();
420+
421+
// Both Cast nodes must survive — they should NOT be fused across EP boundaries.
422+
std::map<std::string, int> op_counts = CountOpsInGraph(graph);
423+
EXPECT_EQ(op_counts["Cast"], 2) << "Cast nodes on different EPs must not be fused";
424+
425+
// Verify Cast_2's input is still float32 (not changed to int64)
426+
const auto* cast2_input_type = cast2.InputDefs()[0]->TypeAsProto();
427+
ASSERT_NE(cast2_input_type, nullptr);
428+
EXPECT_EQ(cast2_input_type->tensor_type().elem_type(), TensorProto_DataType_FLOAT)
429+
<< "Cast_2 input should remain float32, not be changed to int64";
430+
}
431+
374432
} // namespace test
375433
} // namespace onnxruntime

0 commit comments

Comments
 (0)