@@ -259,3 +259,40 @@ def test_modelwrapper_set_opset_import():
259259 # Test setting ONNX main domain
260260 model .set_opset_import ("" , 13 )
261261 assert model .get_opset_imports () == {"" : 13 , "qonnx.custom_op.general" : 2 }
262+
263+
264+ def test_modelwrapper_get_global_io ():
265+ # Create a simple model with single input and output
266+ in1 = onnx .helper .make_tensor_value_info ("global_in" , onnx .TensorProto .FLOAT , [1 , 4 ])
267+ out1 = onnx .helper .make_tensor_value_info ("global_out" , onnx .TensorProto .FLOAT , [1 , 4 ])
268+ node = onnx .helper .make_node ("Neg" , inputs = ["global_in" ], outputs = ["global_out" ])
269+ graph = onnx .helper .make_graph (
270+ nodes = [node ],
271+ name = "simple_graph" ,
272+ inputs = [in1 ],
273+ outputs = [out1 ],
274+ )
275+ onnx_model = qonnx_make_model (graph , producer_name = "global-io-test-model" )
276+ model = ModelWrapper (onnx_model )
277+
278+ # Test get_first_global_in
279+ assert model .get_first_global_in () == "global_in"
280+
281+ # Test get_first_global_out
282+ assert model .get_first_global_out () == "global_out"
283+
284+ # Test with multi-input model (should still return first input)
285+ in2 = onnx .helper .make_tensor_value_info ("second_in" , onnx .TensorProto .FLOAT , [1 , 4 ])
286+ add_node = onnx .helper .make_node ("Add" , inputs = ["global_in" , "second_in" ], outputs = ["global_out" ])
287+ graph_multi = onnx .helper .make_graph (
288+ nodes = [add_node ],
289+ name = "multi_input_graph" ,
290+ inputs = [in1 , in2 ],
291+ outputs = [out1 ],
292+ )
293+ onnx_model_multi = qonnx_make_model (graph_multi , producer_name = "global-io-multi-test" )
294+ model_multi = ModelWrapper (onnx_model_multi )
295+
296+ # Should still return first input/output
297+ assert model_multi .get_first_global_in () == "global_in"
298+ assert model_multi .get_first_global_out () == "global_out"
0 commit comments