Skip to content

Commit d84314c

Browse files
authored
Update T5 Onnx Export and Optimization (#23949)
Previously, the encoder onnx model adds extra initialization for decoder to generate kv cache from prompt. It is not necessary. Here we redesign onnx export for T5 model to output two separate models for encode and decoder. Move Linear that generates cross features based on encoder_hidden_states to encoder onnx model. In this way, the encoder does not need output encoder_hidden_states, and only need output the features for cross attention used in decoder. Major changes: -[x] update t5 onnx export script -[x] update convert_generation script -[x] update beam search to support changes of inputs and outputs (detail can be found below). -[x] add a tiny t5 model, and enable the generation test for T5 in Linux CI pipelines. Example change in inputs and outputs for one layer model: **Encoder Inputs**: - encoder_input_ids: int32 (B, encode_sequence_length) - encoder_attention_mask: int32 (B, encode_sequence_length) - ~~decoder_input_ids: int32 (B, 1)~~ **Encoder Outputs**: - ~~logits: (B, 1, vocab_size)~~ - ~~encoder_hidden_states: (B, encode_sequence_length, encoder_hidden_size)~~ - ~~present_key_self_0: (B, num_heads, 1, head_size)~~ - ~~present_value_self_0: (B, num_heads, 1, head_size)~~ - present_key_cross_0: (B, num_heads, encode_sequence_length, head_size) - present_value_cross_0: (B, num_heads, encode_sequence_length, head_size) **Decoder Inputs**: - input_ids: int32 (B, 1) - ~~encoder_input_ids: int32 (B, encode_sequence_length) (optional for old format; removed in new format)~~ - encoder_attention_mask: int32 (B, encode_sequence_length) - ~~encoder_hidden_states: (B, encode_sequence_length, encoder_hidden_size) (optional for old format; removed in new format)~~ - past_key_self_0: (B, num_heads, past_decode_sequence_length, head_size) - past_value_self_0: (B, num_heads, past_decode_sequence_length, head_size) - past_key_cross_0: (B, num_heads, encode_sequence_length, head_size) - past_value_cross_0: (B, num_heads, encode_sequence_length, head_size) **Decoder Outputs**: - logits: (B, 1, vocab_size) - present_key_self_0: (B, num_heads, past_decode_sequence_length + 1, head_size) - present_value_self_0: (B, num_heads, past_decode_sequence_length + 1, head_size) Known issues: - Some postprocessing (like converting to use decoder masked MHA, past and present buffer sharing) is not done. Could be a future work item to integrate with onnxruntime-genai. ### Motivation and Context Make the encoder onnx model simpler and more efficient in inference (no need to output encoder_hidden_states).
1 parent 4959468 commit d84314c

27 files changed

+2578
-753
lines changed

onnxruntime/contrib_ops/cpu/transformers/beam_search.cc

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -139,13 +139,19 @@ Status BeamSearch::SetupSubgraphExecutionInfo(const SessionState& session_state,
139139
ORT_RETURN_IF_ERROR(t5_encoder_subgraph_->Setup(session_state, subgraph_session_state));
140140
encoder_feeds_fetches_manager_ = t5_encoder_subgraph_->GetFeedsFetchesManager();
141141

142-
if (parameters_->decoder_start_token_id < 0) {
143-
ORT_RETURN_IF(t5_encoder_subgraph_->num_subgraph_inputs != 2,
144-
"Encoder subgraph shall have 2 inputs when decoder_start_token_id attribute is empty");
142+
if (!t5_encoder_subgraph_->HasLogitsOutput()) {
143+
// New format requires start token id.
144+
ORT_ENFORCE(parameters_->decoder_start_token_id >= 0);
145145
} else {
146-
ORT_RETURN_IF(t5_encoder_subgraph_->num_subgraph_inputs != 3,
147-
"Encoder subgraph shall have 3 inputs when decoder_start_token_id attribute is available");
146+
if (parameters_->decoder_start_token_id < 0) {
147+
ORT_RETURN_IF(t5_encoder_subgraph_->num_subgraph_inputs != 2,
148+
"Encoder subgraph shall have 2 inputs when decoder_start_token_id attribute is empty");
149+
} else {
150+
ORT_RETURN_IF(t5_encoder_subgraph_->num_subgraph_inputs != 3,
151+
"Encoder subgraph shall have 3 inputs when decoder_start_token_id attribute is available");
152+
}
148153
}
154+
149155
} else if (attribute_name == "decoder") {
150156
ORT_ENFORCE(t5_decoder_subgraph_ == nullptr,
151157
"SetupSubgraphExecutionInfo should only be called once for each subgraph.");

onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h

Lines changed: 47 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,13 @@ class BeamSearchT5 : public BeamSearchBase<T> {
5151
expand_buffer_int32_func_(expand_buffer_int32_func),
5252
expand_buffer_float_func_(expand_buffer_float_func),
5353
expand_buffer_float16_func_(expand_buffer_float16_func),
54-
create_beam_scorer_func_(create_beam_scorer_func) {}
54+
create_beam_scorer_func_(create_beam_scorer_func) {
55+
// When decoder uses encoder_hidden_state, make sure the encoder outputs it.
56+
if (decoder_subgraph_.UseEncoderHiddenState()) {
57+
ORT_ENFORCE(encoder_subgraph_.subgraph_output_names[1] == "encoder_hidden_states");
58+
}
59+
ORT_ENFORCE(encoder_subgraph_.num_layers == decoder_subgraph_.num_layers);
60+
}
5561

5662
#ifdef USE_CUDA
5763
Status InitializeCuda(
@@ -160,7 +166,7 @@ Status BeamSearchT5<T>::Execute(const FeedsFetchesManager& encoder_feeds_fetches
160166
this->create_encoder_inputs_func_,
161167
this->add_to_feeds_func_,
162168
buffer,
163-
decoder_input_ids,
169+
decoder_input_ids, // new format does not use decoder_input_ids in encoder, it is still initialized here when decoder_start_token_id >= 0.
164170
this->ort_stream_));
165171

166172
#ifdef DEBUG_NODE_INPUTS_OUTPUTS
@@ -233,35 +239,47 @@ Status BeamSearchT5<T>::Execute(const FeedsFetchesManager& encoder_feeds_fetches
233239

234240
std::vector<OrtValue> decoder_fetches;
235241

236-
if (current_length + 1 < parameters->max_length) {
242+
// When encoder outputs logits (in old format), we need get the next token from logits.
243+
if (current_length + 1 < parameters->max_length && encoder_subgraph_.HasLogitsOutput()) {
237244
++iteration_counter;
238-
ORT_RETURN_IF_ERROR(this->GenerateNextToken(encoder_fetches[0],
245+
const OrtValue& logits = encoder_fetches[0];
246+
ORT_RETURN_IF_ERROR(this->GenerateNextToken(logits,
239247
beam_next_tokens,
240248
beam_state,
241249
cpu_state,
242250
iteration_counter));
243251
++current_length; // Increase sequence length after a new token is generated.
252+
}
244253

245-
ORT_RETURN_IF_ERROR(decoder_subgraph_.CreateInitialFeeds(this->cpu_allocator_,
246-
ReinterpretAsSpan<const int32_t>(beam_next_tokens),
247-
this->implicit_inputs_,
248-
encoder_feeds,
249-
encoder_fetches,
250-
decoder_feeds,
251-
this->device_copy_int32_func_,
252-
this->expand_buffer_int32_func_,
253-
this->expand_buffer_float_func_,
254-
this->expand_buffer_float16_func_,
255-
parameters->num_beams,
256-
this->ort_stream_,
257-
decoder_subgraph_.UseSequenceAsInputIds(),
258-
current_length,
259-
cpu_state.sequences,
260-
parameters->max_length,
261-
decoder_subgraph_.has_decoder_masked_attention_,
262-
this->cuda_device_prop_ != nullptr));
254+
if (current_length < parameters->max_length) {
255+
// when no logits, copy sequence (filled with start token IDs) to input_ids for decoder.
256+
bool copy_sequence_to_input_ids = decoder_subgraph_.UseSequenceAsInputIds() || !encoder_subgraph_.HasLogitsOutput();
257+
if (copy_sequence_to_input_ids) {
258+
ORT_ENFORCE(current_length == cpu_state.sequences.GetSequenceLength());
259+
}
260+
261+
// Generate inputs for next decoder subgraph call.
262+
ORT_RETURN_IF_ERROR(decoder_subgraph_.CreateInitialFeeds(
263+
this->cpu_allocator_,
264+
ReinterpretAsSpan<const int32_t>(beam_next_tokens),
265+
this->implicit_inputs_,
266+
encoder_feeds,
267+
encoder_fetches,
268+
decoder_feeds,
269+
this->device_copy_int32_func_,
270+
this->expand_buffer_int32_func_,
271+
this->expand_buffer_float_func_,
272+
this->expand_buffer_float16_func_,
273+
parameters->num_beams,
274+
this->ort_stream_,
275+
copy_sequence_to_input_ids,
276+
cpu_state.sequences,
277+
parameters->max_length,
278+
decoder_subgraph_.has_decoder_masked_attention_,
279+
this->cuda_device_prop_ != nullptr));
263280

264281
if (decoder_subgraph_.past_present_share_buffer_) {
282+
// Configure buffer sharing of past and present kv cache.
265283
decoder_fetches.reserve(static_cast<size_t>(decoder_subgraph_.GetFirstPresentOutputIndex()) +
266284
2 * static_cast<size_t>(decoder_subgraph_.num_layers));
267285
decoder_fetches.resize(decoder_subgraph_.GetFirstPresentOutputIndex(), OrtValue());
@@ -299,14 +317,19 @@ Status BeamSearchT5<T>::Execute(const FeedsFetchesManager& encoder_feeds_fetches
299317

300318
while (current_length < parameters->max_length) {
301319
iteration_counter++;
320+
302321
#ifdef DEBUG_GENERATION
303-
auto cur_len = std::to_string(current_length);
304-
dumper->Print("***CurrentLength", cur_len, true);
322+
dumper->Print(::onnxruntime::MakeString("Iteration=", iteration_counter,
323+
", CurrentLength=", current_length,
324+
", num_layers=", decoder_subgraph_.num_layers,
325+
", decoder_feeds=", decoder_feeds.size(),
326+
", start_token_id=", parameters->decoder_start_token_id));
305327

306328
for (int i = 0; i < decoder_subgraph_.GetFirstPastInputIndex(); i++) {
307329
dumper->Print("decoder_feeds", i, true);
308330
dumper->Print("", decoder_feeds[i]);
309331
}
332+
310333
for (int i = 0; i < decoder_subgraph_.num_layers; i++) {
311334
int self_key_idx = decoder_subgraph_.GetFirstPastInputIndex() + 2 * i;
312335
int self_value_idx = self_key_idx + 1;

onnxruntime/contrib_ops/cpu/transformers/subgraph_base.cc

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,9 @@ Subgraph::Subgraph(
3636
auto& subgraph_inputs = subgraph.GetInputs();
3737
auto& subgraph_outputs = subgraph.GetOutputs();
3838

39-
// inputs: input_ids, position_ids, attention_mask, past_0, past_1, ...
40-
// outputs: logits, present_0, present_1, ...
4139
num_subgraph_inputs = static_cast<int>(subgraph_inputs.size());
4240
num_subgraph_outputs = static_cast<int>(subgraph_outputs.size());
4341

44-
// CheckSubgraph will verify inputs and outputs later.
4542
subgraph_input_names.reserve(num_subgraph_inputs);
4643
for (int i = 0; i < num_subgraph_inputs; ++i) {
4744
subgraph_input_names.push_back(subgraph_inputs[i]->Name());
@@ -68,10 +65,9 @@ Status Subgraph::Setup(const SessionState& session_state,
6865
InlinedVector<std::string_view> feed_names;
6966
feed_names.reserve(static_cast<size_t>(num_subgraph_inputs) + static_cast<size_t>(num_implicit_inputs));
7067

71-
// Use the first output (logits) to find device location.
68+
// Use the first output to find device location.
7269
const OrtDevice& default_location = utils::FindDeviceForValue(subgraph_session_state, subgraph_output_names[0]);
7370

74-
// The position_ids, attention_mask, past_0, ... are created by this operator so the name doesn't matter.
7571
feed_names.insert(feed_names.end(), subgraph_input_names.begin(), subgraph_input_names.end());
7672

7773
const auto& subgraph_map = subgraph_session_state.GetOrtValueNameIdxMap();
@@ -174,13 +170,15 @@ Status Subgraph::GetParameters(const ONNX_NAMESPACE::TensorShapeProto* past_shap
174170
}
175171

176172
// Logits shape is like (batch_size, seq_len, vocabulary_size)
177-
ORT_RETURN_IF(logits_shape->dim_size() != 3,
178-
"subgraph logits output is expected to have 3 dimension, got ", logits_shape->dim_size());
173+
if (logits_shape != nullptr) {
174+
ORT_RETURN_IF(logits_shape->dim_size() != 3,
175+
"subgraph logits output is expected to have 3 dimension, got ", logits_shape->dim_size());
179176

180-
ORT_RETURN_IF(!logits_shape->dim(2).has_dim_value() || logits_shape->dim(2).dim_value() <= 0,
181-
"subgraph past state dimension 2 shall have a positive value for vocabulary size");
177+
ORT_RETURN_IF(!logits_shape->dim(2).has_dim_value() || logits_shape->dim(2).dim_value() <= 0,
178+
"subgraph past state dimension 2 shall have a positive value for vocabulary size");
182179

183-
this->vocab_size = static_cast<int>(logits_shape->dim(2).dim_value());
180+
this->vocab_size = static_cast<int>(logits_shape->dim(2).dim_value());
181+
}
184182

185183
return Status::OK();
186184
}

0 commit comments

Comments
 (0)