Skip to content

Commit 47a0077

Browse files
Fix attention fusion in conformer encoder (#23711)
### Description This PR updates the attention fusion for conformer-encoder models. It is a follow-up to [this PR](#23528). ### Motivation and Context Subsequent modeling code updates have changed (and will continue to change) the graph fusions. However, the three ending attention mask nodes (`Cast --> Unsqueeze --> Equal`) will remain. Thus, the attention fusion should work regardless of any future modeling code changes when handling the attention mask.
1 parent c7aa9a7 commit 47a0077

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

onnxruntime/python/tools/transformers/fusion_conformer_attention.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,11 +79,11 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
7979
where_qk = qk_nodes[2]
8080
mask_nodes = self.model.match_parent_path(
8181
where_qk,
82-
["Equal", "Unsqueeze", "Cast", "Expand"],
83-
[0, 0, 0, 0],
82+
["Equal", "Unsqueeze", "Cast"],
83+
[0, 0, 0],
8484
)
8585
if mask_nodes is not None:
86-
attn_mask = mask_nodes[-2].output[0]
86+
attn_mask = mask_nodes[-1].output[0]
8787

8888
add_qk, matmul_qk = qk_nodes[-2], qk_nodes[-1]
8989

0 commit comments

Comments
 (0)