@@ -371,5 +371,63 @@ TEST(TransformerTest, IsIsolatedFp16NodeOnCpuTest) {
371371 EXPECT_EQ (ops[" Cast" ], 4 );
372372}
373373
374+ // Verify that RemoveDuplicateCastTransformer does not fuse consecutive Cast nodes
375+ // that are assigned to different execution providers.
376+ // Regression test for https://github.com/microsoft/onnxruntime/issues/27291
377+ TEST (TransformerTest, CrossEpCastNodesNotFused) {
378+ auto model = std::make_shared<onnxruntime::Model>(" test" , false , DefaultLoggingManager ().DefaultLogger ());
379+ onnxruntime::Graph& graph = model->MainGraph ();
380+
381+ // Build: X(int64) -> Cast(int64->float32) -> Cast(float32->float16) -> Y(float16)
382+ TypeProto tensor_int64;
383+ tensor_int64.mutable_tensor_type ()->set_elem_type (TensorProto_DataType_INT64);
384+ TypeProto tensor_float32;
385+ tensor_float32.mutable_tensor_type ()->set_elem_type (TensorProto_DataType_FLOAT);
386+ TypeProto tensor_float16;
387+ tensor_float16.mutable_tensor_type ()->set_elem_type (TensorProto_DataType_FLOAT16);
388+
389+ onnxruntime::NodeArg x_def (" X" , &tensor_int64);
390+ onnxruntime::NodeArg mid_def (" mid" , &tensor_float32);
391+ onnxruntime::NodeArg y_def (" Y" , &tensor_float16);
392+
393+ NodeAttributes cast1_attrs = {
394+ {" to" , utils::MakeAttribute (" to" ,
395+ static_cast <int64_t >(TensorProto_DataType_FLOAT))}};
396+ NodeAttributes cast2_attrs = {
397+ {" to" , utils::MakeAttribute (" to" ,
398+ static_cast <int64_t >(TensorProto_DataType_FLOAT16))}};
399+
400+ // Cast_1 on CPU EP, Cast_2 on WebGPU EP.
401+ auto & cast1 = graph.AddNode (" Cast_1" , " Cast" , " int64 to float32" ,
402+ ArgMap{&x_def}, ArgMap{&mid_def}, &cast1_attrs);
403+ cast1.SetExecutionProviderType (onnxruntime::kCpuExecutionProvider );
404+
405+ auto & cast2 = graph.AddNode (" Cast_2" , " Cast" , " float32 to float16" ,
406+ ArgMap{&mid_def}, ArgMap{&y_def}, &cast2_attrs);
407+ cast2.SetExecutionProviderType (onnxruntime::kWebGpuExecutionProvider );
408+
409+ auto status = graph.Resolve ();
410+ ASSERT_TRUE (status.IsOK ()) << status.ErrorMessage ();
411+
412+ // Run InsertCastTransformer (which internally runs RemoveDuplicateCastTransformer)
413+ InsertCastTransformer transformer (" Test" , DefaultCpuExecutionProvider ()->GetKernelRegistry ().get ());
414+
415+ bool modified = false ;
416+ status = transformer.Apply (graph, modified, DefaultLoggingManager ().DefaultLogger ());
417+ EXPECT_TRUE (status.IsOK ()) << status.ErrorMessage ();
418+ status = graph.Resolve ();
419+ EXPECT_TRUE (status.IsOK ()) << status.ErrorMessage ();
420+
421+ // Both Cast nodes must survive — they should NOT be fused across EP boundaries.
422+ std::map<std::string, int > op_counts = CountOpsInGraph (graph);
423+ EXPECT_EQ (op_counts[" Cast" ], 2 ) << " Cast nodes on different EPs must not be fused" ;
424+
425+ // Verify Cast_2's input is still float32 (not changed to int64)
426+ const auto * cast2_input_type = cast2.InputDefs ()[0 ]->TypeAsProto ();
427+ ASSERT_NE (cast2_input_type, nullptr );
428+ EXPECT_EQ (cast2_input_type->tensor_type ().elem_type (), TensorProto_DataType_FLOAT)
429+ << " Cast_2 input should remain float32, not be changed to int64" ;
430+ }
431+
374432} // namespace test
375433} // namespace onnxruntime
0 commit comments