Skip to content

Commit 092ebff

Browse files
committed
fixes #2538
1 parent ce4c8b8 commit 092ebff

File tree

2 files changed

+62
-9
lines changed

2 files changed

+62
-9
lines changed

coremltools/converters/mil/frontend/torch/test/test_passes.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
flatten_graph_input_values,
1919
flatten_graph_output_values,
2020
transform_inplace_ops,
21+
remove_getattr_nodes
2122
)
2223
import coremltools as ct
2324

@@ -405,3 +406,52 @@ def forward(self, x):
405406
y_cm = ct_model.predict({'x': x})['y']
406407

407408
assert((y_cm == np.zeros(shape)).all())
409+
410+
411+
@staticmethod
412+
def test_remove_getattr_nodes_immediate_output():
413+
graph_nodes = [
414+
InternalTorchIRNode(
415+
inputs=["self"],
416+
attr={"name": "const_out2", "value": None},
417+
outputs=["const_out2"],
418+
kind="getattr",
419+
),
420+
InternalTorchIRNode(
421+
inputs=["self"],
422+
attr={"name": "const_out1", "value": None},
423+
outputs=["const_out1"],
424+
kind="getattr",
425+
),
426+
InternalTorchIRNode(
427+
inputs=["const_out1", "const_out2"],
428+
attr={"value": None},
429+
outputs=["3"],
430+
kind="tupleconstruct",
431+
),
432+
]
433+
const2 = torch.tensor([5., 6., 7., 8.])
434+
const1 = torch.tensor([1., 2., 3., 4.])
435+
graph_params = {'const_out2': const2,
436+
'const_out1': const1}
437+
graph_inputs = []
438+
graph_outputs = ["const_out1", "const_out2"]
439+
440+
graph = InternalTorchIRGraph(
441+
nodes=graph_nodes,
442+
params=graph_params,
443+
inputs=graph_inputs,
444+
outputs=graph_outputs,
445+
)
446+
447+
for node in graph.nodes:
448+
node.parent = graph
449+
450+
remove_getattr_nodes(graph)
451+
452+
np.testing.assert_equal(graph.nodes[0].kind, "constant")
453+
np.testing.assert_equal(graph.nodes[1].kind, "constant")
454+
np.testing.assert_equal(graph.nodes[2].kind, "tupleconstruct")
455+
np.testing.assert_allclose(graph.nodes[0].attr["value"], const2)
456+
np.testing.assert_allclose(graph.nodes[1].attr["value"], const1)
457+

coremltools/converters/mil/frontend/torch/torchir_passes.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -231,10 +231,9 @@ def forward(self, x):
231231

232232
def remove_getattr_nodes(graph: InternalTorchIRGraph) -> None:
233233
"""
234-
Remove the getattr nodes in the graph
234+
Remove the getattr nodes in the graph that are not output nodes
235235
"""
236236

237-
getattr_nodes = []
238237
new_nodes = []
239238

240239
for node in graph.nodes:
@@ -243,16 +242,20 @@ def remove_getattr_nodes(graph: InternalTorchIRGraph) -> None:
243242
remove_getattr_nodes(block)
244243

245244
if node.kind == "getattr":
246-
getattr_nodes.append(node)
245+
if node.name in graph.outputs:
246+
# create and add new constant node
247+
new_nodes.append(
248+
InternalTorchIRNode(
249+
inputs=[],
250+
outputs=node.outputs,
251+
kind="constant",
252+
name="internal_immediate_output_attr",
253+
attr={"value": node.parent.params[node.name]}
254+
)
255+
)
247256
else:
248257
new_nodes.append(node)
249258

250-
# check the getattr nodes not in the outputs
251-
for node in getattr_nodes:
252-
if node.name in graph.outputs:
253-
raise RuntimeError("{} should not be in the graph outputs.".format(node.name))
254-
255-
# remove the getattr nodes
256259
graph.nodes = new_nodes
257260

258261

0 commit comments

Comments
 (0)