Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 24 additions & 5 deletions onnxruntime/core/optimizer/insert_cast_transformer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -433,15 +433,34 @@ class RemoveDuplicateCastTransformer : public GraphTransformer {

// If all the child nodes are either removed or another Cast node and we're not providing graph output,
// we can remove this node. Connect those remaining child Cast nodes to current Cast node's input.
//
// However, we must NOT do this if any kept Cast child is on a different EP than the current node.
// Fusing across EP boundaries can produce a node whose input type is not supported by its EP.
// For example, Cast(int64->float, CPU) -> Cast(float->float16, WebGPU) would become
// Cast(int64->float16, WebGPU), but WebGPU doesn't support int64 inputs.
// See: https://github.com/microsoft/onnxruntime/issues/27291
if (num_children > 0 && nodes_to_remove.size() + cast_nodes_to_keep.size() == num_children &&
graph_outputs.find(node.OutputDefs()[0]) == graph_outputs_end) {
for (auto& n : cast_nodes_to_keep) {
Node& cast_node_to_keep = n;
graph.SetNodeArgType(*cast_node_to_keep.MutableInputDefs()[0], *node.InputDefs()[0]->TypeAsProto());
// Check that all kept Cast children are on the same EP as the current node.
bool cross_ep = false;
const auto& current_ep = node.GetExecutionProviderType();
for (const auto& n : cast_nodes_to_keep) {
const Node& kept_node = n;
if (kept_node.GetExecutionProviderType() != current_ep) {
cross_ep = true;
break;
}
}

removed = graph_utils::RemoveNode(graph, node);
modified = true;
if (!cross_ep) {
for (auto& n : cast_nodes_to_keep) {
Node& cast_node_to_keep = n;
graph.SetNodeArgType(*cast_node_to_keep.MutableInputDefs()[0], *node.InputDefs()[0]->TypeAsProto());
}

removed = graph_utils::RemoveNode(graph, node);
modified = true;
}
}
}

Expand Down
58 changes: 58 additions & 0 deletions onnxruntime/test/framework/insert_cast_transformer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -371,5 +371,63 @@ TEST(TransformerTest, IsIsolatedFp16NodeOnCpuTest) {
EXPECT_EQ(ops["Cast"], 4);
}

// Verify that RemoveDuplicateCastTransformer does not fuse consecutive Cast nodes
// that are assigned to different execution providers.
// Regression test for https://github.com/microsoft/onnxruntime/issues/27291
TEST(TransformerTest, CrossEpCastNodesNotFused) {
auto model = std::make_shared<onnxruntime::Model>("test", false, DefaultLoggingManager().DefaultLogger());
onnxruntime::Graph& graph = model->MainGraph();

// Build: X(int64) -> Cast(int64->float32) -> Cast(float32->float16) -> Y(float16)
TypeProto tensor_int64;
tensor_int64.mutable_tensor_type()->set_elem_type(TensorProto_DataType_INT64);
TypeProto tensor_float32;
tensor_float32.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT);
TypeProto tensor_float16;
tensor_float16.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT16);

onnxruntime::NodeArg x_def("X", &tensor_int64);
onnxruntime::NodeArg mid_def("mid", &tensor_float32);
onnxruntime::NodeArg y_def("Y", &tensor_float16);

NodeAttributes cast1_attrs = {
{"to", utils::MakeAttribute("to",
static_cast<int64_t>(TensorProto_DataType_FLOAT))}};
NodeAttributes cast2_attrs = {
{"to", utils::MakeAttribute("to",
static_cast<int64_t>(TensorProto_DataType_FLOAT16))}};

// Cast_1 on CPU EP, Cast_2 on WebGPU EP.
auto& cast1 = graph.AddNode("Cast_1", "Cast", "int64 to float32",
ArgMap{&x_def}, ArgMap{&mid_def}, &cast1_attrs);
cast1.SetExecutionProviderType(onnxruntime::kCpuExecutionProvider);

auto& cast2 = graph.AddNode("Cast_2", "Cast", "float32 to float16",
ArgMap{&mid_def}, ArgMap{&y_def}, &cast2_attrs);
cast2.SetExecutionProviderType(onnxruntime::kWebGpuExecutionProvider);

auto status = graph.Resolve();
ASSERT_TRUE(status.IsOK()) << status.ErrorMessage();

// Run InsertCastTransformer (which internally runs RemoveDuplicateCastTransformer)
InsertCastTransformer transformer("Test", DefaultCpuExecutionProvider()->GetKernelRegistry().get());

bool modified = false;
status = transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger());
EXPECT_TRUE(status.IsOK()) << status.ErrorMessage();
status = graph.Resolve();
EXPECT_TRUE(status.IsOK()) << status.ErrorMessage();

// Both Cast nodes must survive — they should NOT be fused across EP boundaries.
std::map<std::string, int> op_counts = CountOpsInGraph(graph);
EXPECT_EQ(op_counts["Cast"], 2) << "Cast nodes on different EPs must not be fused";

// Verify Cast_2's input is still float32 (not changed to int64)
const auto* cast2_input_type = cast2.InputDefs()[0]->TypeAsProto();
ASSERT_NE(cast2_input_type, nullptr);
EXPECT_EQ(cast2_input_type->tensor_type().elem_type(), TensorProto_DataType_FLOAT)
<< "Cast_2 input should remain float32, not be changed to int64";
}

} // namespace test
} // namespace onnxruntime
Loading