Skip to content

Commit a401588

Browse files
authored
Merge pull request #1503 from Xilinx/fix/extract_norm
Fix ExtractNormScaleBias transform
2 parents 46bbe99 + 4b418a0 commit a401588

File tree

2 files changed

+103
-72
lines changed

2 files changed

+103
-72
lines changed

src/finn/transformation/streamline/extract_norm_scale_bias.py

Lines changed: 63 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,17 @@
88
# MIT license as part of project Brainsmith.
99
# All other copyright is held by AMD and is provided under BSD-3-Clause license.
1010
#
11-
# Note: This transform was originally written by Thomas Keller (ExpandNorms)
12-
# and was adjusted.
11+
# Note: This transform is inspired by a transformation from Thomas Keller (ExpandNorms)
12+
# and ExtractQuantScaleZeroPt from qonnx.
1313
#
1414
############################################################################
1515

1616
import numpy as np
1717
from onnx import TensorProto
1818
from onnx import helper as oh
1919
from qonnx.transformation.base import Transformation
20-
from qonnx.transformation.infer_datatypes import InferDataTypes
21-
from qonnx.transformation.infer_shapes import InferShapes
20+
from qonnx.transformation.general import GiveUniqueParameterTensors, SortGraph
21+
from qonnx.transformation.remove import RemoveIdentityOps
2222

2323

2424
class ExtractNormScaleBias(Transformation):
@@ -30,85 +30,76 @@ def __init__(self):
3030

3131
def apply(self, model):
3232
graph = model.graph
33-
node_ind = 0
34-
graph_modified = False
35-
3633
for node in graph.node:
37-
node_ind += 1
3834
if node.op_type == "LayerNormalization":
39-
scale = model.get_initializer(node.input[1])
35+
ln_node = node
36+
input_ln = node.input[0]
37+
scale_tensor = node.input[1]
38+
# bias input is optional input
4039
if len(node.input) > 2:
41-
bias = model.get_initializer(node.input[2])
40+
bias_tensor = node.input[2]
41+
bias = model.get_initializer(bias_tensor)
4242
else:
4343
bias = None
44-
scale_is_one = (scale == 1).all()
45-
bias_is_zero = not np.any(bias)
46-
if scale_is_one and (bias_is_zero or bias is None):
44+
scale = model.get_initializer(scale_tensor)
45+
extract_scale = False
46+
extract_bias = False
47+
if (scale != 1).any():
48+
extract_scale = True
49+
if bias is not None and np.any(bias):
50+
extract_bias = True
51+
if (not extract_scale) and (not extract_bias):
4752
continue
48-
act_shape = model.get_tensor_shape(node.input[0])
49-
act_out = node.output[0]
50-
if not scale_is_one:
51-
# extract scale into separate Mul node
52-
scale_dt = model.get_tensor_datatype(node.input[1])
53-
# Create new tensors
53+
act_shape = model.get_tensor_shape(input_ln)
54+
last_node = ln_node
55+
final_output = ln_node.output[0]
56+
if extract_scale:
57+
# create new Mul node that applies the scale
58+
# Create new tensor
59+
scale_act_in_name = model.make_new_valueinfo_name()
5460
scale_act_in = oh.make_tensor_value_info(
55-
model.make_new_valueinfo_name(), TensorProto.FLOAT, act_shape
56-
)
57-
scale_value = oh.make_tensor_value_info(
58-
model.make_new_valueinfo_name(), TensorProto.FLOAT, [act_shape[-1]]
61+
scale_act_in_name, TensorProto.FLOAT, act_shape
5962
)
63+
last_node.output[0] = scale_act_in_name
6064
graph.value_info.append(scale_act_in)
61-
graph.value_info.append(scale_value)
62-
63-
# Update previous output tensor
64-
node.output[0] = scale_act_in.name
65-
# Create Mul node to replace scale
66-
mul_node = oh.make_node("Mul", [scale_act_in.name, scale_value.name], [act_out])
67-
68-
# set scale to all ones in LayerNormalization
69-
model.set_initializer(node.input[1], np.ones(act_shape[-1], dtype=np.float32))
70-
71-
graph_modified = True
72-
73-
if not bias_is_zero or bias is not None:
74-
# extract bias into separate Add node
75-
bias_dt = model.get_tensor_datatype(node.input[2])
76-
# Create new input tensor
77-
bias_act_in = oh.make_tensor_value_info(
78-
model.make_new_valueinfo_name(), TensorProto.FLOAT, act_shape
65+
scale_node = oh.make_node(
66+
"Mul", [scale_act_in_name, scale_tensor], [final_output]
7967
)
80-
bias_value = oh.make_tensor_value_info(
81-
model.make_new_valueinfo_name(), TensorProto.FLOAT, [act_shape[-1]]
68+
graph.node.append(scale_node)
69+
# important: when tracking a pointer to newly added nodes,
70+
# ensure the item from the container is used, and not the
71+
# make_node result -- those are different objects
72+
# e.g. if we use last_node = scale_node below,
73+
# this will point to the wrong object and cause bugs later
74+
last_node = graph.node[-1]
75+
# remove scale from LayerNorm node
76+
new_scale_name = model.make_new_valueinfo_name()
77+
model.set_initializer(new_scale_name, np.ones(act_shape[-1], dtype=np.float32))
78+
ln_node.input[1] = new_scale_name
79+
if extract_bias:
80+
# create new Add node that applies bias
81+
# create new tensor
82+
bias_act_in_name = model.make_new_valueinfo_name()
83+
bias_act_in = oh.make_tensor_value_info(
84+
bias_act_in_name, TensorProto.FLOAT, act_shape
8285
)
8386
graph.value_info.append(bias_act_in)
84-
graph.value_info.append(bias_value)
85-
# Update previous output tensor
86-
if not scale_is_one:
87-
mul_node.output[0] = bias_act_in.name
88-
else:
89-
node.output[0] = bias_act_in.name
90-
91-
# Create Add node to replace bias
92-
add_node = oh.make_node("Add", [bias_act_in.name, bias_value.name], [act_out])
93-
94-
# set bias to all zeros in LayerNormalization
95-
model.set_initializer(node.input[2], np.zeros(act_shape[-1], dtype=np.float32))
96-
97-
graph_modified = True
87+
bias_node = oh.make_node("Add", [bias_act_in_name, bias_tensor], [final_output])
88+
last_node.output[0] = bias_act_in_name
89+
graph.node.append(bias_node)
90+
# remove bias from LayerNorm node
91+
new_bias_name = model.make_new_valueinfo_name()
92+
model.set_initializer(new_bias_name, np.zeros(act_shape[-1], dtype=np.float32))
93+
ln_node.input[2] = new_bias_name
9894

99-
# insert new nodes
100-
insert_point = node_ind
101-
if not scale_is_one:
102-
insert_point += 1
103-
graph.node.insert(insert_point, mul_node)
104-
model.set_initializer(mul_node.input[1], scale)
105-
model.set_tensor_datatype(mul_node.input[1], scale_dt)
106-
if not bias_is_zero or bias is not None:
107-
insert_point += 1
108-
graph.node.insert(insert_point, add_node)
109-
model.set_initializer(add_node.input[1], bias)
110-
model.set_tensor_datatype(add_node.input[1], bias_dt)
111-
model = model.transform(InferShapes())
112-
model = model.transform(InferDataTypes())
95+
if extract_scale or extract_bias:
96+
# since we used append() for new nodes, need to call
97+
# SortGraph to ensure correct (topological) order
98+
model = model.transform(SortGraph())
99+
# Remove potential unity multiplications from alpha and beta attributes
100+
model = model.transform(RemoveIdentityOps())
101+
# Ensure unique parameter tensors
102+
model = model.transform(GiveUniqueParameterTensors())
103+
return model, True
113104

114-
return (model, graph_modified)
105+
return model, False

tests/fpgadataflow/test_fpgadataflow_layernorm.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from qonnx.transformation.general import GiveUniqueNodeNames
2121
from qonnx.transformation.infer_datatypes import InferDataTypes
2222
from qonnx.transformation.infer_shapes import InferShapes
23+
from qonnx.transformation.merge_onnx_models import MergeONNXModels
2324
from qonnx.util.basic import gen_finn_dt_tensor, qonnx_make_model
2425

2526
import finn.core.onnx_exec as oxe
@@ -145,3 +146,42 @@ def test_fpgadataflow_layernorm(idt, ishape, simd, has_scale, has_bias, sim_styl
145146
exp_cycles = exp_cycles_dict[model.graph.node[0].name]
146147
assert np.isclose(exp_cycles, cycles_rtlsim, atol=10)
147148
assert exp_cycles != 0
149+
150+
151+
@pytest.mark.transform
152+
@pytest.mark.parametrize("idt", [DataType["FLOAT32"]])
153+
@pytest.mark.parametrize("ishape", [[1, 16, 48], [1, 32]])
154+
@pytest.mark.parametrize("has_scale", [True, False])
155+
@pytest.mark.parametrize("has_bias", [True, False])
156+
def test_extract_norm_scale_bias(idt, ishape, has_scale, has_bias):
157+
epsilon = 9.999999960041972e-13
158+
model1 = create_layernorm_model(idt, ishape, has_scale, has_bias, epsilon)
159+
model2 = create_layernorm_model(idt, ishape, has_scale, has_bias, epsilon)
160+
model3 = create_layernorm_model(idt, ishape, has_scale, has_bias, epsilon)
161+
162+
model = model1.transform(MergeONNXModels(model2))
163+
model = model.transform(MergeONNXModels(model3))
164+
165+
assert len(model.get_nodes_by_op_type("LayerNormalization")) == 3
166+
167+
# reference calculation
168+
input = gen_finn_dt_tensor(DataType["FLOAT32"], ishape)
169+
input_t = {model.graph.input[0].name: input}
170+
171+
y_ref = oxe.execute_onnx(model, input_t)[model.graph.output[0].name]
172+
173+
model = model.transform(InferShapes())
174+
model = model.transform(InferDataTypes())
175+
176+
model = model.transform(ExtractNormScaleBias())
177+
178+
assert len(model.get_nodes_by_op_type("LayerNormalization")) == 3
179+
if has_bias:
180+
assert len(model.get_nodes_by_op_type("Add")) == 3
181+
if has_scale:
182+
assert len(model.get_nodes_by_op_type("Mul")) == 3
183+
184+
input_t = {model.graph.input[0].name: input}
185+
186+
y_out = oxe.execute_onnx(model, input_t)[model.graph.output[0].name]
187+
assert (y_ref == y_out).all()

0 commit comments

Comments
 (0)