Skip to content

Commit d8d1802

Browse files
committed
fix(frontend): non fusable graph is not an error
reported in https://community.zama.ai/t/implementing-comparison-strategies/3469/5
1 parent ecb729b commit d8d1802

File tree

2 files changed

+15
-9
lines changed

2 files changed

+15
-9
lines changed

frontends/concrete-python/concrete/fhe/compilation/utils.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -248,12 +248,15 @@ def fuse(graph: Graph, artifacts: Optional["FunctionDebugArtifacts"] = None):
248248
all_nodes, start_nodes, terminal_node = subgraph_to_fuse
249249
processed_terminal_nodes.add(terminal_node)
250250

251-
conversion_result = convert_subgraph_to_subgraph_node(
252-
graph,
253-
all_nodes,
254-
start_nodes,
255-
terminal_node,
256-
)
251+
try:
252+
conversion_result = convert_subgraph_to_subgraph_node(
253+
graph,
254+
all_nodes,
255+
start_nodes,
256+
terminal_node,
257+
)
258+
except NotFusable:
259+
conversion_result = None
257260
if conversion_result is None:
258261
continue
259262

@@ -769,7 +772,7 @@ def convert_subgraph_to_subgraph_node(
769772
if terminal_node.properties["name"] == "where":
770773
return None
771774

772-
raise RuntimeError(
775+
raise NotFusable(
773776
"A subgraph within the function you are trying to compile cannot be fused "
774777
"because it has multiple input nodes\n\n"
775778
+ graph.format(highlighted_nodes=base_highlighted_nodes, show_bounds=False)
@@ -826,6 +829,8 @@ def convert_subgraph_to_subgraph_node(
826829

827830
return subgraph_node, variable_input_node
828831

832+
class NotFusable(RuntimeError):
833+
pass
829834

830835
def check_subgraph_fusibility(
831836
graph: Graph,
@@ -868,7 +873,7 @@ def check_subgraph_fusibility(
868873

869874
if not node.is_fusable:
870875
base_highlighted_nodes[node] = ["this node is not fusable", node.location]
871-
raise RuntimeError(
876+
raise NotFusable(
872877
"A subgraph within the function you are trying to compile cannot be fused "
873878
"because of a node, which is marked explicitly as non-fusable\n\n"
874879
+ graph.format(highlighted_nodes=base_highlighted_nodes, show_bounds=False)
@@ -879,7 +884,7 @@ def check_subgraph_fusibility(
879884
"this node has a different shape than the input node",
880885
node.location,
881886
]
882-
raise RuntimeError(
887+
raise NotFusable(
883888
"A subgraph within the function you are trying to compile cannot be fused "
884889
"because of a node, which is has a different shape than the input node\n\n"
885890
+ graph.format(highlighted_nodes=base_highlighted_nodes, show_bounds=False)

frontends/concrete-python/tests/execution/test_comparison.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
("<=", lambda x, y: x <= y),
6565
(">", lambda x, y: x > y),
6666
(">=", lambda x, y: x >= y),
67+
("mixed", lambda x, y: fhe.if_then_else(x < y, x, y))
6768
]
6869
),
6970
# bit widths

0 commit comments

Comments
 (0)