Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 63 additions & 72 deletions src/finn/transformation/streamline/extract_norm_scale_bias.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,17 @@
# MIT license as part of project Brainsmith.
# All other copyright is held by AMD and is provided under BSD-3-Clause license.
#
# Note: This transform was originally written by Thomas Keller (ExpandNorms)
# and was adjusted.
# Note: This transform is inspired by a transformation from Thomas Keller (ExpandNorms)
# and ExtractQuantScaleZeroPt from qonnx.
#
############################################################################

import numpy as np
from onnx import TensorProto
from onnx import helper as oh
from qonnx.transformation.base import Transformation
from qonnx.transformation.infer_datatypes import InferDataTypes
from qonnx.transformation.infer_shapes import InferShapes
from qonnx.transformation.general import GiveUniqueParameterTensors, SortGraph
from qonnx.transformation.remove import RemoveIdentityOps


class ExtractNormScaleBias(Transformation):
Expand All @@ -30,85 +30,76 @@ def __init__(self):

def apply(self, model):
graph = model.graph
node_ind = 0
graph_modified = False

for node in graph.node:
node_ind += 1
if node.op_type == "LayerNormalization":
scale = model.get_initializer(node.input[1])
ln_node = node
input_ln = node.input[0]
scale_tensor = node.input[1]
# bias input is optional input
if len(node.input) > 2:
bias = model.get_initializer(node.input[2])
bias_tensor = node.input[2]
bias = model.get_initializer(bias_tensor)
else:
bias = None
scale_is_one = (scale == 1).all()
bias_is_zero = not np.any(bias)
if scale_is_one and (bias_is_zero or bias is None):
scale = model.get_initializer(scale_tensor)
extract_scale = False
extract_bias = False
if (scale != 1).any():
extract_scale = True
if bias is not None and np.any(bias):
extract_bias = True
if (not extract_scale) and (not extract_bias):
continue
act_shape = model.get_tensor_shape(node.input[0])
act_out = node.output[0]
if not scale_is_one:
# extract scale into separate Mul node
scale_dt = model.get_tensor_datatype(node.input[1])
# Create new tensors
act_shape = model.get_tensor_shape(input_ln)
last_node = ln_node
final_output = ln_node.output[0]
if extract_scale:
# create new Mul node that applies the scale
# Create new tensor
scale_act_in_name = model.make_new_valueinfo_name()
scale_act_in = oh.make_tensor_value_info(
model.make_new_valueinfo_name(), TensorProto.FLOAT, act_shape
)
scale_value = oh.make_tensor_value_info(
model.make_new_valueinfo_name(), TensorProto.FLOAT, [act_shape[-1]]
scale_act_in_name, TensorProto.FLOAT, act_shape
)
last_node.output[0] = scale_act_in_name
graph.value_info.append(scale_act_in)
graph.value_info.append(scale_value)

# Update previous output tensor
node.output[0] = scale_act_in.name
# Create Mul node to replace scale
mul_node = oh.make_node("Mul", [scale_act_in.name, scale_value.name], [act_out])

# set scale to all ones in LayerNormalization
model.set_initializer(node.input[1], np.ones(act_shape[-1], dtype=np.float32))

graph_modified = True

if not bias_is_zero or bias is not None:
# extract bias into separate Add node
bias_dt = model.get_tensor_datatype(node.input[2])
# Create new input tensor
bias_act_in = oh.make_tensor_value_info(
model.make_new_valueinfo_name(), TensorProto.FLOAT, act_shape
scale_node = oh.make_node(
"Mul", [scale_act_in_name, scale_tensor], [final_output]
)
bias_value = oh.make_tensor_value_info(
model.make_new_valueinfo_name(), TensorProto.FLOAT, [act_shape[-1]]
graph.node.append(scale_node)
# important: when tracking a pointer to newly added nodes,
# ensure the item from the container is used, and not the
# make_node result -- those are different objects
# e.g. if we use last_node = scale_node below,
# this will point to the wrong object and cause bugs later
last_node = graph.node[-1]
# remove scale from LayerNorm node
new_scale_name = model.make_new_valueinfo_name()
model.set_initializer(new_scale_name, np.ones(act_shape[-1], dtype=np.float32))
ln_node.input[1] = new_scale_name
if extract_bias:
# create new Add node that applies bias
# create new tensor
bias_act_in_name = model.make_new_valueinfo_name()
bias_act_in = oh.make_tensor_value_info(
bias_act_in_name, TensorProto.FLOAT, act_shape
)
graph.value_info.append(bias_act_in)
graph.value_info.append(bias_value)
# Update previous output tensor
if not scale_is_one:
mul_node.output[0] = bias_act_in.name
else:
node.output[0] = bias_act_in.name

# Create Add node to replace bias
add_node = oh.make_node("Add", [bias_act_in.name, bias_value.name], [act_out])

# set bias to all zeros in LayerNormalization
model.set_initializer(node.input[2], np.zeros(act_shape[-1], dtype=np.float32))

graph_modified = True
bias_node = oh.make_node("Add", [bias_act_in_name, bias_tensor], [final_output])
last_node.output[0] = bias_act_in_name
graph.node.append(bias_node)
# remove bias from LayerNorm node
new_bias_name = model.make_new_valueinfo_name()
model.set_initializer(new_bias_name, np.zeros(act_shape[-1], dtype=np.float32))
ln_node.input[2] = new_bias_name

# insert new nodes
insert_point = node_ind
if not scale_is_one:
insert_point += 1
graph.node.insert(insert_point, mul_node)
model.set_initializer(mul_node.input[1], scale)
model.set_tensor_datatype(mul_node.input[1], scale_dt)
if not bias_is_zero or bias is not None:
insert_point += 1
graph.node.insert(insert_point, add_node)
model.set_initializer(add_node.input[1], bias)
model.set_tensor_datatype(add_node.input[1], bias_dt)
model = model.transform(InferShapes())
model = model.transform(InferDataTypes())
if extract_scale or extract_bias:
# since we used append() for new nodes, need to call
# SortGraph to ensure correct (topological) order
model = model.transform(SortGraph())
# Remove potential unity multiplications from alpha and beta attributes
model = model.transform(RemoveIdentityOps())
# Ensure unique parameter tensors
model = model.transform(GiveUniqueParameterTensors())
return model, True

return (model, graph_modified)
return model, False
40 changes: 40 additions & 0 deletions tests/fpgadataflow/test_fpgadataflow_layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from qonnx.transformation.general import GiveUniqueNodeNames
from qonnx.transformation.infer_datatypes import InferDataTypes
from qonnx.transformation.infer_shapes import InferShapes
from qonnx.transformation.merge_onnx_models import MergeONNXModels
from qonnx.util.basic import gen_finn_dt_tensor, qonnx_make_model

import finn.core.onnx_exec as oxe
Expand Down Expand Up @@ -145,3 +146,42 @@ def test_fpgadataflow_layernorm(idt, ishape, simd, has_scale, has_bias, sim_styl
exp_cycles = exp_cycles_dict[model.graph.node[0].name]
assert np.isclose(exp_cycles, cycles_rtlsim, atol=10)
assert exp_cycles != 0


@pytest.mark.transform
@pytest.mark.parametrize("idt", [DataType["FLOAT32"]])
@pytest.mark.parametrize("ishape", [[1, 16, 48], [1, 32]])
@pytest.mark.parametrize("has_scale", [True, False])
@pytest.mark.parametrize("has_bias", [True, False])
def test_extract_norm_scale_bias(idt, ishape, has_scale, has_bias):
epsilon = 9.999999960041972e-13
model1 = create_layernorm_model(idt, ishape, has_scale, has_bias, epsilon)
model2 = create_layernorm_model(idt, ishape, has_scale, has_bias, epsilon)
model3 = create_layernorm_model(idt, ishape, has_scale, has_bias, epsilon)

model = model1.transform(MergeONNXModels(model2))
model = model.transform(MergeONNXModels(model3))

assert len(model.get_nodes_by_op_type("LayerNormalization")) == 3

# reference calculation
input = gen_finn_dt_tensor(DataType["FLOAT32"], ishape)
input_t = {model.graph.input[0].name: input}

y_ref = oxe.execute_onnx(model, input_t)[model.graph.output[0].name]

model = model.transform(InferShapes())
model = model.transform(InferDataTypes())

model = model.transform(ExtractNormScaleBias())

assert len(model.get_nodes_by_op_type("LayerNormalization")) == 3
if has_bias:
assert len(model.get_nodes_by_op_type("Add")) == 3
if has_scale:
assert len(model.get_nodes_by_op_type("Mul")) == 3

input_t = {model.graph.input[0].name: input}

y_out = oxe.execute_onnx(model, input_t)[model.graph.output[0].name]
assert (y_ref == y_out).all()