@@ -237,7 +237,7 @@ def fuse_t5_encoder(self, softmax_node, input_name_to_nodes, output_name_to_node
237
237
)
238
238
if qkv_nodes is None :
239
239
return False
240
- matmul_qkv , _transpose_qkv , reshape_qkv = qkv_nodes
240
+ matmul_qkv , _ , reshape_qkv = qkv_nodes
241
241
242
242
qkv_shape_nodes = self .model .match_parent_path (
243
243
reshape_qkv ,
@@ -298,7 +298,7 @@ def fuse_t5_encoder(self, softmax_node, input_name_to_nodes, output_name_to_node
298
298
output_name_to_node ,
299
299
)
300
300
if mask_nodes is None :
301
- return
301
+ return False
302
302
mul_node = mask_nodes [2 ]
303
303
304
304
_ , mul_val = self .model .get_constant_input (mul_node )
@@ -357,7 +357,7 @@ def fuse_t5_encoder(self, softmax_node, input_name_to_nodes, output_name_to_node
357
357
)
358
358
if k_nodes is None :
359
359
return False
360
- _ , reshape_k , matmul_k = k_nodes
360
+ _ , _ , matmul_k = k_nodes
361
361
# todo: check reshape_k parent nodes
362
362
363
363
q_nodes = self .model .match_parent_path (
@@ -368,7 +368,7 @@ def fuse_t5_encoder(self, softmax_node, input_name_to_nodes, output_name_to_node
368
368
if q_nodes is None :
369
369
return False
370
370
371
- transpose_q , reshape_q , matmul_q = q_nodes
371
+ _ , reshape_q , matmul_q = q_nodes
372
372
# todo: check reshape_q parent nodes
373
373
374
374
if matmul_q .input [0 ] != input_shape_node .input [0 ]:
@@ -690,6 +690,7 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node):
690
690
691
691
gather = compute_bias_nodes [5 ]
692
692
where = compute_bias_nodes [- 1 ]
693
+ slice = compute_bias_nodes [2 ]
693
694
unsqueeze = compute_bias_nodes [3 ]
694
695
695
696
# Current fusion will not remove the node until the graph is processed.
@@ -790,10 +791,8 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node):
790
791
# Unsqueeze(axes=0) Cast(to=int64)
791
792
# \ /
792
793
# Sub
793
- #
794
- # Founatutionally, there is still Slice to get last seq_len rows so end result is same.
795
- #
796
- # But need to be careful that the shape of some intermediate nodes are changed.
794
+ # Currently, there is still Slice to get last seq_len rows so end result is same.
795
+ # But need to be careful that the shape of bias tensor is changed before Slice.
797
796
#
798
797
# RelativePositionBias operator requires query_length == key_length so we shall pass in total_seq_len.
799
798
# Here we get the end value of the Range node as length to pass to the RelativePositionBias node.
@@ -802,20 +801,21 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node):
802
801
# only compute seq_len rows, then we can remove the Slice after the RelativePositionBias node.
803
802
inputs = [bias_table .name , range_node .input [1 ], range_node .input [1 ]]
804
803
805
- outputs = [unsqueeze .output [0 ]]
804
+ # Use a new tensor name since the shape might be different as mentioned above.
805
+ bias_output = node_name + "_rel_pos_bias"
806
+ slice .input [0 ] = bias_output
807
+
806
808
rpb_node = helper .make_node (
807
809
"RelativePositionBias" ,
808
810
inputs = inputs ,
809
- outputs = outputs ,
811
+ outputs = [ bias_output ] ,
810
812
name = node_name ,
811
813
)
812
814
rpb_node .domain = "com.microsoft"
813
815
rpb_node .attribute .extend ([helper .make_attribute ("max_distance" , max_distance )])
814
816
rpb_node .attribute .extend ([helper .make_attribute ("is_bidirectional" , is_bidirectional )])
815
817
self .node_name_to_graph_name [rpb_node .name ] = self .this_graph_name
816
818
self .nodes_to_add .append (rpb_node )
817
-
818
- self .nodes_to_remove .append (unsqueeze )
819
819
self .prune_graph = True
820
820
821
821
0 commit comments