@@ -248,12 +248,15 @@ def fuse(graph: Graph, artifacts: Optional["FunctionDebugArtifacts"] = None):
248248 all_nodes , start_nodes , terminal_node = subgraph_to_fuse
249249 processed_terminal_nodes .add (terminal_node )
250250
251- conversion_result = convert_subgraph_to_subgraph_node (
252- graph ,
253- all_nodes ,
254- start_nodes ,
255- terminal_node ,
256- )
251+ try :
252+ conversion_result = convert_subgraph_to_subgraph_node (
253+ graph ,
254+ all_nodes ,
255+ start_nodes ,
256+ terminal_node ,
257+ )
258+ except NotFusable :
259+ conversion_result = None
257260 if conversion_result is None :
258261 continue
259262
@@ -769,7 +772,7 @@ def convert_subgraph_to_subgraph_node(
769772 if terminal_node .properties ["name" ] == "where" :
770773 return None
771774
772- raise RuntimeError (
775+ raise NotFusable (
773776 "A subgraph within the function you are trying to compile cannot be fused "
774777 "because it has multiple input nodes\n \n "
775778 + graph .format (highlighted_nodes = base_highlighted_nodes , show_bounds = False )
@@ -826,6 +829,8 @@ def convert_subgraph_to_subgraph_node(
826829
827830 return subgraph_node , variable_input_node
828831
832+ class NotFusable (RuntimeError ):
833+ pass
829834
830835def check_subgraph_fusibility (
831836 graph : Graph ,
@@ -868,7 +873,7 @@ def check_subgraph_fusibility(
868873
869874 if not node .is_fusable :
870875 base_highlighted_nodes [node ] = ["this node is not fusable" , node .location ]
871- raise RuntimeError (
876+ raise NotFusable (
872877 "A subgraph within the function you are trying to compile cannot be fused "
873878 "because of a node, which is marked explicitly as non-fusable\n \n "
874879 + graph .format (highlighted_nodes = base_highlighted_nodes , show_bounds = False )
@@ -879,7 +884,7 @@ def check_subgraph_fusibility(
879884 "this node has a different shape than the input node" ,
880885 node .location ,
881886 ]
882- raise RuntimeError (
887+ raise NotFusable (
883888 "A subgraph within the function you are trying to compile cannot be fused "
884889 "because of a node, which is has a different shape than the input node\n \n "
885890 + graph .format (highlighted_nodes = base_highlighted_nodes , show_bounds = False )
0 commit comments