Skip to content

Commit bfaa406

Browse files
committed
review feedback
1 parent a8bb1f5 commit bfaa406

File tree

5 files changed

+20
-30
lines changed

5 files changed

+20
-30
lines changed

onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc

+4-4
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@ namespace transformers {
2020
2121
Inputs:
2222
input_ids: int32 (B, 1)
23-
encoder_input_ids: int32 (B, encode_sequence_length) (optional)
23+
encoder_input_ids: int32 (B, encode_sequence_length) (optional for old format; removed in new format)
2424
encoder_attention_mask: int32 (B, encode_sequence_length)
25-
encoder_hidden_states: (B, encode_sequence_length, encoder_hidden_size) (optional)
25+
encoder_hidden_states: (B, encode_sequence_length, encoder_hidden_size) (optional for old format; removed in new format)
2626
2727
past_key_self_0: (B, num_heads, past_decode_sequence_length, head_size)
2828
past_value_self_0: (B, num_heads, past_decode_sequence_length, head_size)
@@ -147,7 +147,8 @@ Status T5DecoderSubgraph::Validate(const std::vector<const NodeArg*>& subgraph_i
147147
// decoder_feeds: input_ids, encoder_attention_mask,
148148
// present_key_self_0, present_value_self_0, ...,
149149
// present_key_cross_0, present_value_cross_0, ...
150-
150+
// past_seq_len (optional), num_beams (optional), cache_indirection (optional)
151+
//
151152
// Old format:
152153
// encoder feeds: encoder_input_ids, encoder_attention_mask, decoder_input_ids (with start tokens)
153154
// encoder fetches: logits, encoder_hidden_states,
@@ -157,7 +158,6 @@ Status T5DecoderSubgraph::Validate(const std::vector<const NodeArg*>& subgraph_i
157158
// present_key_self_0, present_value_self_0, ...,
158159
// present_key_cross_0, present_value_cross_0, ...
159160
// past_seq_len (optional), num_beams (optional), cache_indirection (optional)
160-
161161
Status T5DecoderSubgraph::CreateInitialFeeds(
162162
AllocatorPtr cpu_allocator,
163163
gsl::span<const int32_t> beam_next_tokens,

onnxruntime/python/tools/transformers/models/t5/convert_to_onnx.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ def export_onnx_models(
212212
else:
213213
logger.info(f"Skip exporting: existed ONNX model {onnx_path}")
214214

215-
# Optimize ONNX graph. Note that we have not implemented graph optimization for T5 yet.
215+
# Optimize ONNX graph.
216216
if optimize_onnx or precision != Precision.FLOAT32:
217217
onnx_shape_path = None
218218
if shape_infer_before_optimization:

onnxruntime/python/tools/transformers/onnx_model_t5.py

+12-12
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ def fuse_t5_encoder(self, softmax_node, input_name_to_nodes, output_name_to_node
237237
)
238238
if qkv_nodes is None:
239239
return False
240-
matmul_qkv, _transpose_qkv, reshape_qkv = qkv_nodes
240+
matmul_qkv, _, reshape_qkv = qkv_nodes
241241

242242
qkv_shape_nodes = self.model.match_parent_path(
243243
reshape_qkv,
@@ -298,7 +298,7 @@ def fuse_t5_encoder(self, softmax_node, input_name_to_nodes, output_name_to_node
298298
output_name_to_node,
299299
)
300300
if mask_nodes is None:
301-
return
301+
return False
302302
mul_node = mask_nodes[2]
303303

304304
_, mul_val = self.model.get_constant_input(mul_node)
@@ -357,7 +357,7 @@ def fuse_t5_encoder(self, softmax_node, input_name_to_nodes, output_name_to_node
357357
)
358358
if k_nodes is None:
359359
return False
360-
_, reshape_k, matmul_k = k_nodes
360+
_, _, matmul_k = k_nodes
361361
# todo: check reshape_k parent nodes
362362

363363
q_nodes = self.model.match_parent_path(
@@ -368,7 +368,7 @@ def fuse_t5_encoder(self, softmax_node, input_name_to_nodes, output_name_to_node
368368
if q_nodes is None:
369369
return False
370370

371-
transpose_q, reshape_q, matmul_q = q_nodes
371+
_, reshape_q, matmul_q = q_nodes
372372
# todo: check reshape_q parent nodes
373373

374374
if matmul_q.input[0] != input_shape_node.input[0]:
@@ -690,6 +690,7 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node):
690690

691691
gather = compute_bias_nodes[5]
692692
where = compute_bias_nodes[-1]
693+
slice = compute_bias_nodes[2]
693694
unsqueeze = compute_bias_nodes[3]
694695

695696
# Current fusion will not remove the node until the graph is processed.
@@ -790,10 +791,8 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node):
790791
# Unsqueeze(axes=0) Cast(to=int64)
791792
# \ /
792793
# Sub
793-
#
794-
# Founatutionally, there is still Slice to get last seq_len rows so end result is same.
795-
#
796-
# But need to be careful that the shape of some intermediate nodes are changed.
794+
# Currently, there is still Slice to get last seq_len rows so end result is same.
795+
# But need to be careful that the shape of bias tensor is changed before Slice.
797796
#
798797
# RelativePositionBias operator requires query_length == key_length so we shall pass in total_seq_len.
799798
# Here we get the end value of the Range node as length to pass to the RelativePositionBias node.
@@ -802,20 +801,21 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node):
802801
# only compute seq_len rows, then we can remove the Slice after the RelativePositionBias node.
803802
inputs = [bias_table.name, range_node.input[1], range_node.input[1]]
804803

805-
outputs = [unsqueeze.output[0]]
804+
# Use a new tensor name since the shape might be different as mentioned above.
805+
bias_output = node_name + "_rel_pos_bias"
806+
slice.input[0] = bias_output
807+
806808
rpb_node = helper.make_node(
807809
"RelativePositionBias",
808810
inputs=inputs,
809-
outputs=outputs,
811+
outputs=[bias_output],
810812
name=node_name,
811813
)
812814
rpb_node.domain = "com.microsoft"
813815
rpb_node.attribute.extend([helper.make_attribute("max_distance", max_distance)])
814816
rpb_node.attribute.extend([helper.make_attribute("is_bidirectional", is_bidirectional)])
815817
self.node_name_to_graph_name[rpb_node.name] = self.this_graph_name
816818
self.nodes_to_add.append(rpb_node)
817-
818-
self.nodes_to_remove.append(unsqueeze)
819819
self.prune_graph = True
820820

821821

onnxruntime/test/python/transformers/test_generation.py

+3-11
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ def get_tiny_t5_model_dir():
196196

197197

198198
class TestBeamSearchT5(unittest.TestCase):
199-
"""Test BeamSearch for T5 model"""
199+
"""Test BeamSearch for T5 model with fp32 in CPU"""
200200

201201
def setUp(self):
202202
tiny_model_dir = get_tiny_t5_model_dir()
@@ -215,8 +215,6 @@ def setUp(self):
215215
"--repetition_penalty 2.0",
216216
]
217217

218-
self.enable_cuda = torch.cuda.is_available() and "CUDAExecutionProvider" in get_available_providers()
219-
220218
export_t5_onnx_models(
221219
self.model_name,
222220
os.path.join(".", "cache_models"),
@@ -263,13 +261,6 @@ def run_beam_search(self, extra_arguments: str):
263261
result = run(arguments)
264262
self.assertTrue(result["parity"], f"ORT and PyTorch result is different on CPU for arguments {arguments}")
265263

266-
# Test GPU
267-
if self.enable_cuda:
268-
if "--use_gpu" not in arguments:
269-
arguments.append("--use_gpu")
270-
result = run(arguments)
271-
self.assertTrue(result["parity"], f"ORT and PyTorch result is different on GPU for arguments {arguments}")
272-
273264
os.remove(self.beam_search_onnx_path)
274265

275266
def test_return_sequences(self):
@@ -333,6 +324,7 @@ def check_encoder_fusion(self):
333324
onnx_model = OnnxModel(model)
334325
op_counters = onnx_model.get_operator_statistics()
335326
print("encoder ops", op_counters)
327+
336328
expected_node_count = {
337329
"RelativePositionBias": 1,
338330
"SimplifiedLayerNormalization": 5 if use_tiny_model else 13,
@@ -351,7 +343,7 @@ def check_decoder_fusion(self):
351343

352344
onnx_model = OnnxModel(model)
353345
op_counters = onnx_model.get_operator_statistics()
354-
print("decoder opators", op_counters)
346+
print("decoder ops", op_counters)
355347

356348
expected_node_count = {
357349
"RelativePositionBias": 1,

onnxruntime/test/testdata/transformers/tiny_t5/tiny_t5.py

-2
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,6 @@
1414
save_directory = "tiny_t5"
1515
model_name = "google-t5/t5-small"
1616

17-
model = T5ForConditionalGeneration.from_pretrained(model_name)
18-
1917
config = T5Config.from_pretrained(model_name)
2018

2119
config.num_heads = 2

0 commit comments

Comments
 (0)