88# MIT license as part of project Brainsmith.
99# All other copyright is held by AMD and is provided under BSD-3-Clause license.
1010#
11- # Note: This transform was originally written by Thomas Keller (ExpandNorms)
12- # and was adjusted .
11+ # Note: This transform is inspired by a transformation from Thomas Keller (ExpandNorms)
12+ # and ExtractQuantScaleZeroPt from qonnx .
1313#
1414############################################################################
1515
1616import numpy as np
1717from onnx import TensorProto
1818from onnx import helper as oh
1919from qonnx .transformation .base import Transformation
20- from qonnx .transformation .infer_datatypes import InferDataTypes
21- from qonnx .transformation .infer_shapes import InferShapes
20+ from qonnx .transformation .general import GiveUniqueParameterTensors , SortGraph
21+ from qonnx .transformation .remove import RemoveIdentityOps
2222
2323
2424class ExtractNormScaleBias (Transformation ):
@@ -30,85 +30,76 @@ def __init__(self):
3030
3131 def apply (self , model ):
3232 graph = model .graph
33- node_ind = 0
34- graph_modified = False
35-
3633 for node in graph .node :
37- node_ind += 1
3834 if node .op_type == "LayerNormalization" :
39- scale = model .get_initializer (node .input [1 ])
35+ ln_node = node
36+ input_ln = node .input [0 ]
37+ scale_tensor = node .input [1 ]
38+ # bias input is optional input
4039 if len (node .input ) > 2 :
41- bias = model .get_initializer (node .input [2 ])
40+ bias_tensor = node .input [2 ]
41+ bias = model .get_initializer (bias_tensor )
4242 else :
4343 bias = None
44- scale_is_one = (scale == 1 ).all ()
45- bias_is_zero = not np .any (bias )
46- if scale_is_one and (bias_is_zero or bias is None ):
44+ scale = model .get_initializer (scale_tensor )
45+ extract_scale = False
46+ extract_bias = False
47+ if (scale != 1 ).any ():
48+ extract_scale = True
49+ if bias is not None and np .any (bias ):
50+ extract_bias = True
51+ if (not extract_scale ) and (not extract_bias ):
4752 continue
48- act_shape = model .get_tensor_shape (node .input [0 ])
49- act_out = node .output [0 ]
50- if not scale_is_one :
51- # extract scale into separate Mul node
52- scale_dt = model .get_tensor_datatype (node .input [1 ])
53- # Create new tensors
53+ act_shape = model .get_tensor_shape (input_ln )
54+ last_node = ln_node
55+ final_output = ln_node .output [0 ]
56+ if extract_scale :
57+ # create new Mul node that applies the scale
58+ # Create new tensor
59+ scale_act_in_name = model .make_new_valueinfo_name ()
5460 scale_act_in = oh .make_tensor_value_info (
55- model .make_new_valueinfo_name (), TensorProto .FLOAT , act_shape
56- )
57- scale_value = oh .make_tensor_value_info (
58- model .make_new_valueinfo_name (), TensorProto .FLOAT , [act_shape [- 1 ]]
61+ scale_act_in_name , TensorProto .FLOAT , act_shape
5962 )
63+ last_node .output [0 ] = scale_act_in_name
6064 graph .value_info .append (scale_act_in )
61- graph .value_info .append (scale_value )
62-
63- # Update previous output tensor
64- node .output [0 ] = scale_act_in .name
65- # Create Mul node to replace scale
66- mul_node = oh .make_node ("Mul" , [scale_act_in .name , scale_value .name ], [act_out ])
67-
68- # set scale to all ones in LayerNormalization
69- model .set_initializer (node .input [1 ], np .ones (act_shape [- 1 ], dtype = np .float32 ))
70-
71- graph_modified = True
72-
73- if not bias_is_zero or bias is not None :
74- # extract bias into separate Add node
75- bias_dt = model .get_tensor_datatype (node .input [2 ])
76- # Create new input tensor
77- bias_act_in = oh .make_tensor_value_info (
78- model .make_new_valueinfo_name (), TensorProto .FLOAT , act_shape
65+ scale_node = oh .make_node (
66+ "Mul" , [scale_act_in_name , scale_tensor ], [final_output ]
7967 )
80- bias_value = oh .make_tensor_value_info (
81- model .make_new_valueinfo_name (), TensorProto .FLOAT , [act_shape [- 1 ]]
68+ graph .node .append (scale_node )
69+ # important: when tracking a pointer to newly added nodes,
70+ # ensure the item from the container is used, and not the
71+ # make_node result -- those are different objects
72+ # e.g. if we use last_node = scale_node below,
73+ # this will point to the wrong object and cause bugs later
74+ last_node = graph .node [- 1 ]
75+ # remove scale from LayerNorm node
76+ new_scale_name = model .make_new_valueinfo_name ()
77+ model .set_initializer (new_scale_name , np .ones (act_shape [- 1 ], dtype = np .float32 ))
78+ ln_node .input [1 ] = new_scale_name
79+ if extract_bias :
80+ # create new Add node that applies bias
81+ # create new tensor
82+ bias_act_in_name = model .make_new_valueinfo_name ()
83+ bias_act_in = oh .make_tensor_value_info (
84+ bias_act_in_name , TensorProto .FLOAT , act_shape
8285 )
8386 graph .value_info .append (bias_act_in )
84- graph .value_info .append (bias_value )
85- # Update previous output tensor
86- if not scale_is_one :
87- mul_node .output [0 ] = bias_act_in .name
88- else :
89- node .output [0 ] = bias_act_in .name
90-
91- # Create Add node to replace bias
92- add_node = oh .make_node ("Add" , [bias_act_in .name , bias_value .name ], [act_out ])
93-
94- # set bias to all zeros in LayerNormalization
95- model .set_initializer (node .input [2 ], np .zeros (act_shape [- 1 ], dtype = np .float32 ))
96-
97- graph_modified = True
87+ bias_node = oh .make_node ("Add" , [bias_act_in_name , bias_tensor ], [final_output ])
88+ last_node .output [0 ] = bias_act_in_name
89+ graph .node .append (bias_node )
90+ # remove bias from LayerNorm node
91+ new_bias_name = model .make_new_valueinfo_name ()
92+ model .set_initializer (new_bias_name , np .zeros (act_shape [- 1 ], dtype = np .float32 ))
93+ ln_node .input [2 ] = new_bias_name
9894
99- # insert new nodes
100- insert_point = node_ind
101- if not scale_is_one :
102- insert_point += 1
103- graph .node .insert (insert_point , mul_node )
104- model .set_initializer (mul_node .input [1 ], scale )
105- model .set_tensor_datatype (mul_node .input [1 ], scale_dt )
106- if not bias_is_zero or bias is not None :
107- insert_point += 1
108- graph .node .insert (insert_point , add_node )
109- model .set_initializer (add_node .input [1 ], bias )
110- model .set_tensor_datatype (add_node .input [1 ], bias_dt )
111- model = model .transform (InferShapes ())
112- model = model .transform (InferDataTypes ())
95+ if extract_scale or extract_bias :
96+ # since we used append() for new nodes, need to call
97+ # SortGraph to ensure correct (topological) order
98+ model = model .transform (SortGraph ())
99+ # Remove potential unity multiplications from alpha and beta attributes
100+ model = model .transform (RemoveIdentityOps ())
101+ # Ensure unique parameter tensors
102+ model = model .transform (GiveUniqueParameterTensors ())
103+ return model , True
113104
114- return ( model , graph_modified )
105+ return model , False
0 commit comments