@@ -210,7 +210,7 @@ def snake_case_to_pascal_case(s: str) -> str:
210
210
211
211
class ONNXOperatorReader :
212
212
"""
213
- Utiliy for extracting attribute and input values from an ONNX operator.
213
+ Utility for extracting attribute and input values from an ONNX operator.
214
214
215
215
This keeps track of which attributes have been read so that we can warn about
216
216
any unhandled ones.
@@ -501,10 +501,29 @@ def noop_add_node(node: Node) -> int:
501
501
502
502
output_name = onnx_op .output [0 ]
503
503
504
- tensor = ONNXOperatorReader (
505
- onnx_op , input_indexes = [], add_node = noop_add_node
506
- ).require_attr ("value" , "tensor" )
507
- const_node = constant_node_from_onnx_initializer (tensor , output_name )
504
+ attrs = ONNXOperatorReader (onnx_op , input_indexes = [], add_node = noop_add_node )
505
+ if (tensor := attrs .get_attr ("value" , "tensor" , None )) is not None :
506
+ const_node = constant_node_from_onnx_initializer (tensor , output_name )
507
+ else :
508
+ if (int_ := attrs .get_attr ("value_int" , "int" , None )) is not None :
509
+ shape = []
510
+ data = np .array (int_ ).astype (np .int32 )
511
+ elif (ints := attrs .get_attr ("value_ints" , "ints" , None )) is not None :
512
+ shape = [len (ints )]
513
+ data = np .array (ints ).astype (np .int32 )
514
+ elif (float_ := attrs .get_attr ("value_float" , "float" , None )) is not None :
515
+ shape = []
516
+ data = np .array (float_ ).astype (np .float32 )
517
+ elif (floats := attrs .get_attr ("value_floats" , "floats" , None )) is not None :
518
+ shape = [len (floats )]
519
+ data = np .array (floats ).astype (np .float32 )
520
+ else :
521
+ # Unsupported attributes: value_string, value_strings
522
+ raise Exception (
523
+ f'Unable to get value from "Constant" operator "{ onnx_op .name } "'
524
+ )
525
+ const_node = ConstantNode ("dummy_name" , shape , data )
526
+
508
527
const_node .name = output_name
509
528
510
529
return const_node
0 commit comments