Skip to content

Commit 2199cc6

Browse files
authored
Merge pull request #225 from rothej/fix/modelwrapper
Add modelwrapper methods to get network global input and output tensor names
2 parents 3d87295 + 09d71ef commit 2199cc6

File tree

2 files changed

+45
-0
lines changed

2 files changed

+45
-0
lines changed

src/qonnx/core/modelwrapper.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -543,6 +543,14 @@ def get_all_tensor_names(self):
543543
names += [x.name for x in graph.output]
544544
return names
545545

546+
def get_first_global_in(self):
547+
"""Return the name of the first global input tensor."""
548+
return self.graph.input[0].name
549+
550+
def get_first_global_out(self):
551+
"""Return the name of the first global output tensor."""
552+
return self.graph.output[0].name
553+
546554
def make_new_valueinfo_name(self):
547555
"""Returns a name that can be used for a new value_info."""
548556
names = self.get_all_tensor_names()

tests/core/test_modelwrapper.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,3 +259,40 @@ def test_modelwrapper_set_opset_import():
259259
# Test setting ONNX main domain
260260
model.set_opset_import("", 13)
261261
assert model.get_opset_imports() == {"": 13, "qonnx.custom_op.general": 2}
262+
263+
264+
def test_modelwrapper_get_global_io():
265+
# Create a simple model with single input and output
266+
in1 = onnx.helper.make_tensor_value_info("global_in", onnx.TensorProto.FLOAT, [1, 4])
267+
out1 = onnx.helper.make_tensor_value_info("global_out", onnx.TensorProto.FLOAT, [1, 4])
268+
node = onnx.helper.make_node("Neg", inputs=["global_in"], outputs=["global_out"])
269+
graph = onnx.helper.make_graph(
270+
nodes=[node],
271+
name="simple_graph",
272+
inputs=[in1],
273+
outputs=[out1],
274+
)
275+
onnx_model = qonnx_make_model(graph, producer_name="global-io-test-model")
276+
model = ModelWrapper(onnx_model)
277+
278+
# Test get_first_global_in
279+
assert model.get_first_global_in() == "global_in"
280+
281+
# Test get_first_global_out
282+
assert model.get_first_global_out() == "global_out"
283+
284+
# Test with multi-input model (should still return first input)
285+
in2 = onnx.helper.make_tensor_value_info("second_in", onnx.TensorProto.FLOAT, [1, 4])
286+
add_node = onnx.helper.make_node("Add", inputs=["global_in", "second_in"], outputs=["global_out"])
287+
graph_multi = onnx.helper.make_graph(
288+
nodes=[add_node],
289+
name="multi_input_graph",
290+
inputs=[in1, in2],
291+
outputs=[out1],
292+
)
293+
onnx_model_multi = qonnx_make_model(graph_multi, producer_name="global-io-multi-test")
294+
model_multi = ModelWrapper(onnx_model_multi)
295+
296+
# Should still return first input/output
297+
assert model_multi.get_first_global_in() == "global_in"
298+
assert model_multi.get_first_global_out() == "global_out"

0 commit comments

Comments
 (0)