Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 59 additions & 18 deletions onnxruntime/python/tools/transformers/fusion_attention_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,10 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
if node_before_layer_norm is None:
continue
child = self.model.find_first_child_by_type(
node_before_layer_norm, "LayerNormalization", input_name_to_nodes, False
node_before_layer_norm,
"LayerNormalization",
input_name_to_nodes,
False,
)
if child is None:
continue
Expand All @@ -146,19 +149,26 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
qkv_nodes = self.model.match_parent_path(
normalize_node,
["Add", "MatMul", "Reshape", "Transpose", "MatMul"],
[1, 1, 0, 0, 0],
[1, None, 0, 0, 0],
)
if qkv_nodes is None:
logger.debug("fuse_attention: failed to match qkv path")
return

reshape_qkv, transpose_qkv, matmul_qkv = qkv_nodes[2], qkv_nodes[3], qkv_nodes[-1]
reshape_qkv, transpose_qkv, matmul_qkv = (
qkv_nodes[2],
qkv_nodes[3],
qkv_nodes[-1],
)

v_nodes = self.model.match_parent_path(
matmul_qkv, ["Reshape", "Transpose", "Reshape", "Add", "MatMul"], [1, 0, 0, 0, None]
matmul_qkv,
["Reshape", "Transpose", "Reshape", "Add", "MatMul"],
[1, 0, 0, 0, None],
)
if v_nodes is None:
v_nodes = self.model.match_parent_path(matmul_qkv, ["Transpose", "Reshape", "Add", "MatMul"], [1, 0, 0, 1])
v_nodes = self.model.match_parent_path(
matmul_qkv, ["Transpose", "Reshape", "Add", "MatMul"], [1, 0, 0, None]
)
if v_nodes is None:
logger.debug("fuse_attention: failed to match v path")
return
Expand All @@ -183,14 +193,27 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
if qk_nodes is None:
qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Add", "Mul", "MatMul"], [0, 0, 0, 0])
if qk_nodes is None:
qk_nodes = self.model.match_parent_path(
matmul_qkv, ["Cast", "Cast", "Softmax", "Add", "Mul", "MatMul"], [0, 0, 0, 0, 0, 0]
)
# If attention mask is not used, we can still match the qk path.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggest to change the condition so that layout is more friendly:
Before:

if qk_nodes is None:
    ...
else:
   add_mask = qk_nodes[1]

To

if qk_nodes is not None:
   add_mask = qk_nodes[1]
else:
   ...

Another possible change is to use match_parent_paths.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in #24280

qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Mul", "MatMul"], [0, 0, 0])
if qk_nodes is None:
logger.debug("fuse_attention: failed to match qk path")
return
else:
add_mask = qk_nodes[3]
# Cast nodes are added in the model for fp16.
qk_nodes = self.model.match_parent_path(
matmul_qkv,
["Cast", "Cast", "Softmax", "Add", "Mul", "MatMul"],
[0, 0, 0, 0, 0, 0],
)
if qk_nodes is None:
# If attention mask is not used, we can still match the qk path.
qk_nodes = self.model.match_parent_path(
matmul_qkv,
["Cast", "Cast", "Softmax", "Mul", "MatMul"],
[0, 0, 0, 0, 0],
)
if qk_nodes is None:
logger.debug("fuse_attention: failed to match qk path")
return
else:
add_mask = qk_nodes[3]
else:
add_mask = qk_nodes[1]
else:
Expand All @@ -201,10 +224,14 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
matmul_qk = qk_nodes[-1]

q_nodes = self.model.match_parent_path(
matmul_qk, ["Reshape", "Transpose", "Reshape", "Mul", "Add", "MatMul"], [0, 0, 0, 0, None, None]
matmul_qk,
["Reshape", "Transpose", "Reshape", "Mul", "Add", "MatMul"],
[0, 0, 0, 0, None, None],
)
if q_nodes is None:
q_nodes = self.model.match_parent_path(matmul_qk, ["Transpose", "Reshape", "Add", "MatMul"], [0, 0, 0, 1])
q_nodes = self.model.match_parent_path(
matmul_qk, ["Transpose", "Reshape", "Add", "MatMul"], [0, 0, 0, None]
)
if q_nodes is None:
logger.debug("fuse_attention: failed to match q path")
return
Expand All @@ -216,10 +243,14 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
add_q, matmul_q = q_nodes[-2], q_nodes[-1]

k_nodes = self.model.match_parent_path(
matmul_qk, ["Transpose", "Reshape", "Transpose", "Reshape", "Add", "MatMul"], [1, 0, 0, 0, 0, None]
matmul_qk,
["Transpose", "Reshape", "Transpose", "Reshape", "Add", "MatMul"],
[1, 0, 0, 0, 0, None],
)
if k_nodes is None:
k_nodes = self.model.match_parent_path(matmul_qk, ["Transpose", "Reshape", "Add", "MatMul"], [1, 0, 0, 1])
k_nodes = self.model.match_parent_path(
matmul_qk, ["Transpose", "Reshape", "Add", "MatMul"], [1, 0, 0, None]
)
if k_nodes is None:
logger.debug("fuse_attention: failed to match k path")
return
Expand All @@ -242,7 +273,17 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
# 4D Add after Q x K'
add_qk_nodes = self.model.match_parent_path(
add_mask,
["Where", "Sub", "Cast", "Expand", "Unsqueeze", "Unsqueeze", "Reshape", "Reshape", "Cast"],
[
"Where",
"Sub",
"Cast",
"Expand",
"Unsqueeze",
"Unsqueeze",
"Reshape",
"Reshape",
"Cast",
],
[1, 2, 1, 0, 0, 0, 0, 0, 0],
)
if add_qk_nodes is not None:
Expand Down
17 changes: 12 additions & 5 deletions onnxruntime/python/tools/transformers/fusion_fastgelu.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,13 +177,12 @@ def fuse_2(self, tanh_node, input_name_to_nodes: dict, output_name_to_node: dict
return
mul_after_mul_half = children[0]

# root_node could be None when root_input is graph input
root_node = self.model.get_parent(
mul_after_mul_half,
0 if mul_after_mul_half.input[1] == mul_half.output[0] else 1,
output_name_to_node,
)
if root_node is None:
return

mul_before_tanh = self.model.match_parent(tanh_node, "Mul", 0, output_name_to_node)
if mul_before_tanh is None:
Expand All @@ -197,7 +196,13 @@ def fuse_2(self, tanh_node, input_name_to_nodes: dict, output_name_to_node: dict
if add_before_tanh is None:
return

mul_after_pow = self.model.match_parent(add_before_tanh, "Mul", None, output_name_to_node, exclude=[root_node])
mul_after_pow = self.model.match_parent(
add_before_tanh,
"Mul",
None,
output_name_to_node,
exclude=[root_node] if root_node else [],
)
if mul_after_pow is None:
return

Expand All @@ -212,7 +217,9 @@ def fuse_2(self, tanh_node, input_name_to_nodes: dict, output_name_to_node: dict
if not self.model.has_constant_input(pow, 3.0):
return

if pow.input[0] != root_node.output[0]:
root_input = mul_after_mul_half.input[0 if mul_after_mul_half.input[1] == mul_half.output[0] else 1]

if pow.input[0] != root_input:
return

subgraph_nodes = [
Expand All @@ -236,7 +243,7 @@ def fuse_2(self, tanh_node, input_name_to_nodes: dict, output_name_to_node: dict
self.nodes_to_remove.extend(subgraph_nodes)
fused_node = helper.make_node(
"FastGelu",
inputs=[root_node.output[0]],
inputs=[root_input],
outputs=mul_after_mul_half.output,
name=self.model.create_node_name("FastGelu"),
)
Expand Down
Binary file not shown.
Binary file not shown.
70 changes: 41 additions & 29 deletions onnxruntime/test/python/transformers/test_gelu_fusions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import unittest

import torch
from parameterized import parameterized
from parity_utilities import find_transformers_source

if find_transformers_source():
Expand Down Expand Up @@ -43,16 +44,6 @@ def forward(self, x):
return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x * (1.0 + 0.044715 * x * x)))


test_cases = [
("huggingface", "Gelu", HuggingfaceGelu),
("huggingface", "FastGelu", HuggingfaceFastGelu),
("huggingface", "QuickGelu", HuggingfaceQuickGelu),
("huggingface", "FastGelu", HuggingfaceTorchGeluTanh),
("megatron", "Gelu", MegatronGelu),
("megatron", "FastGelu", MegatronFastGelu),
]


class TestGeluFusions(unittest.TestCase):
def verify_node_count(self, bert_model, expected_node_count, test_name):
for op_type, count in expected_node_count.items():
Expand All @@ -62,25 +53,46 @@ def verify_node_count(self, bert_model, expected_node_count, test_name):
print(f"{op}: {len(bert_model.get_nodes_by_op_type(op))} expected={counter}")
self.assertEqual(len(bert_model.get_nodes_by_op_type(op_type)), count)

def test_fusions(self):
for test_case in test_cases:
source, operator, model_class = test_case
model = model_class()
dummy_input = torch.ones(3, dtype=torch.float32)
test_name = f"{operator}_{source}"
onnx_path = f"{test_name}.onnx"
torch.onnx.export(
model,
(dummy_input),
onnx_path,
input_names=["input"],
output_names=["output"],
)
optimizer = optimize_model(onnx_path, "bert")
# optimizer.save_model_to_file(f"{operator}_{source}_opt.onnx")
os.remove(onnx_path)
expected_node_count = {operator: 1}
self.verify_node_count(optimizer, expected_node_count, test_name)
@parameterized.expand(
[
(("huggingface", "Gelu", HuggingfaceGelu), True),
(("huggingface", "FastGelu", HuggingfaceFastGelu), True),
(("huggingface", "QuickGelu", HuggingfaceQuickGelu), True),
(("huggingface", "FastGelu", HuggingfaceTorchGeluTanh), True),
(("megatron", "Gelu", MegatronGelu), True),
(("megatron", "FastGelu", MegatronFastGelu), True),
(("huggingface", "Gelu", HuggingfaceGelu), False),
(("huggingface", "FastGelu", HuggingfaceFastGelu), False),
(("huggingface", "QuickGelu", HuggingfaceQuickGelu), False),
(("huggingface", "FastGelu", HuggingfaceTorchGeluTanh), False),
(("megatron", "Gelu", MegatronGelu), False),
(("megatron", "FastGelu", MegatronFastGelu), False),
]
)
def test_fusions(self, test_case, dynamo):
source, operator, model_class = test_case
model = model_class()
dummy_input = torch.ones(3, dtype=torch.float32)
test_name = f"{operator}_{source}"
onnx_path = f"{test_name}.onnx"
torch.onnx.export(
model,
(dummy_input,),
onnx_path,
input_names=["input"],
output_names=["output"],
dynamo=dynamo,
)
optimizer = optimize_model(onnx_path, "bert")
# optimizer.save_model_to_file(f"{operator}_{source}_opt.onnx")
os.remove(onnx_path)
# Remove the associated .data file (dynamo)
data_path = onnx_path + ".data"
if os.path.exists(data_path):
os.remove(data_path)
expected_node_count = {operator: 1}

self.verify_node_count(optimizer, expected_node_count, test_name)


if __name__ == "__main__":
Expand Down
Loading
Loading