Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/qonnx/transformation/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,16 @@ def apply(self, model):
for node_idx, n in enumerate(node_list):
node_pred = model.find_direct_predecessors(n)
if node_pred is None:
if len(n.input) > 0:
# check if node inputs are connected to graph inputs or initializers
# if so, we can keep the node in the graph
for name in n.input:
if util.get_by_name(model.graph.initializer, name) or \
util.get_by_name(model.graph.input, name):
# this node is connected to graph inputs or initializers
# so we can keep it in the graph
graph_dependencies[node_idx] = set()
break
# Will also eliminate nodes that are floating around for some reason
continue

Expand Down
31 changes: 31 additions & 0 deletions tests/transformation/test_sort_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,3 +166,34 @@ def test_sort_nonlinear_graph():
# import matplotlib.pyplot as plt
# plt.plot(sizes,times,"--o")
# plt.grid(True)

def test_sort_graph_node_only_connected_to_graphio():
"""
Test that SortGraph does not remove nodes that are only connected to graph inputs/outputs.
Occurs when the graph has more than one node.
"""
ch = 2
ifmdim = 16
input_shape = (1, ch, ifmdim, ifmdim)

top_in0 = helper.make_tensor_value_info("top_in0", TensorProto.FLOAT, input_shape)
top_in1 = helper.make_tensor_value_info("top_in1", TensorProto.FLOAT, input_shape)
top_out0 = helper.make_tensor_value_info("top_out0", TensorProto.FLOAT, input_shape)
top_out1 = helper.make_tensor_value_info("top_out1", TensorProto.FLOAT, input_shape)

modelproto = qonnx_make_model(
helper.make_graph(
name="test",
inputs=[top_in0, top_in1],
outputs=[top_out0, top_out1],
nodes=[
helper.make_node("Identity", ["top_in0"], ["top_out0"], name="id0"),
helper.make_node("Identity", ["top_in1"], ["top_out1"], name="id1"),
],
)
)
model = ModelWrapper(modelproto)
model = model.transform(SortGraph())

# Ensure that sort did not remove the Identity nodes
assert len(model.graph.node) == 2, "SortGraph removed nodes connected only to graph inputs/outputs."
Loading