Skip to content

Commit 239df8b

Browse files
committed
Add CUDA implementation for attn_probs
1 parent 3147d51 commit 239df8b

File tree

7 files changed

+150
-59
lines changed

7 files changed

+150
-59
lines changed

onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ Status MultiHeadAttention<T>::Compute(OpKernelContext* context) const {
161161
past_value == nullptr &&
162162
present_k == nullptr &&
163163
present_v == nullptr &&
164-
attn_probs == nullptr && // TODO: can we support it?
164+
attn_probs == nullptr && // TODO: can we support it?
165165
l2_cache_size_ > 0) {
166166
MlasFlashAttentionThreadedArgs args;
167167
args.batch_size = batch_size;

onnxruntime/contrib_ops/cuda/bert/attention_impl.cu

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -590,15 +590,22 @@ Status UnfusedAttention(
590590

591591
DUMP_TENSOR_D("QK", data.scratch, batch_size, num_heads, sequence_length, total_sequence_length);
592592

593-
constexpr size_t element_size = sizeof(T);
594-
const size_t bytes = GetAttentionScratchSize(element_size, batch_size, num_heads,
595-
sequence_length, total_sequence_length);
596-
T* scratch2 = data.scratch + (bytes / element_size);
593+
T* softmax_storage;
594+
if (data.attn_probs == nullptr) {
595+
constexpr size_t element_size = sizeof(T);
596+
const size_t bytes = GetAttentionScratchSize(element_size, batch_size, num_heads,
597+
sequence_length, total_sequence_length);
598+
T* scratch2 = data.scratch + (bytes / element_size);
599+
softmax_storage = scratch2;
600+
}
601+
else {
602+
softmax_storage = data.attn_probs;
603+
}
597604

598605
const bool broadcast_attn_bias_dim_0 = parameters.broadcast_attn_bias_dim_0;
599606
const bool broadcast_attn_bias_dim_1 = parameters.broadcast_attn_bias_dim_1;
600607

601-
// Apply softmax and store result R to scratch2: BxNxSxT
608+
// Apply softmax and store result R to softmax_storage: BxNxSxT
602609
if (use_raw_attention_mask) { // 2d, 3d or 4d attention mask
603610
const int mask_dimension = static_cast<int>(mask_index_dims.size());
604611

@@ -612,7 +619,7 @@ Status UnfusedAttention(
612619
ComputeSoftmaxWithRawMask<T>(
613620
ort_stream, total_sequence_length, sequence_length, batch_size, num_heads,
614621
mask_index, nullptr, data.attention_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1,
615-
data.scratch, scratch2, parameters.is_unidirectional, scale, mask_dimension,
622+
data.scratch, softmax_storage, parameters.is_unidirectional, scale, mask_dimension,
616623
parameters.max_sequence_length, use_persistent_softmax, persistent_softmax_workspace,
617624
parameters.mask_filter_value));
618625
} else if (nullptr != mask_index) { // 1d mask index
@@ -622,24 +629,24 @@ Status UnfusedAttention(
622629
ORT_RETURN_IF_ERROR(ComputeSoftmaxWithMask1D<T>(
623630
stream, total_sequence_length, sequence_length, batch_size, num_heads,
624631
mask_index, mask_start, data.attention_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1,
625-
data.scratch, scratch2, parameters.is_unidirectional));
632+
data.scratch, softmax_storage, parameters.is_unidirectional));
626633
} else { // no mask
627634
ORT_RETURN_IF_ERROR(
628635
ComputeSoftmax<T>(
629636
stream, total_sequence_length, sequence_length, batch_size, num_heads,
630637
data.attention_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1,
631-
data.scratch, scratch2, parameters.is_unidirectional));
638+
data.scratch, softmax_storage, parameters.is_unidirectional));
632639
}
633640

634-
DUMP_TENSOR_D("Softmax", scratch2, batch_size, num_heads, sequence_length, total_sequence_length);
641+
DUMP_TENSOR_D("Softmax", softmax_storage, batch_size, num_heads, sequence_length, total_sequence_length);
635642

636643
// compute R*V (as V*R), and store in temp_output (space used by Q): BxNxSxH_v
637644
T* temp_output = data.q;
638645
CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper(
639646
cublas, CUBLAS_OP_N, CUBLAS_OP_N,
640647
v_head_size, sequence_length, total_sequence_length,
641648
&one, data.v, v_head_size, present_size_per_batch_v,
642-
scratch2, total_sequence_length, sequence_length * total_sequence_length,
649+
softmax_storage, total_sequence_length, sequence_length * total_sequence_length,
643650
&zero, temp_output, v_head_size, sequence_length * v_head_size, batches, device_prop, parameters.use_tf32));
644651

645652
// Temp_output is BxNxSxH_v, transpose to output BxSxNxH_v

onnxruntime/contrib_ops/cuda/bert/attention_impl.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ struct AttentionData {
8181
T* present = nullptr;
8282
T* present_key = nullptr;
8383
T* present_value = nullptr;
84+
T* attn_probs = nullptr;
8485

8586
void* fused_runner = nullptr;
8687
const void* fused_cross_attention_kernel = nullptr;

onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,13 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
113113
output_shape[2] = static_cast<int64_t>(parameters.v_hidden_size);
114114
Tensor* output = context->Output(0, output_shape);
115115

116+
TensorShapeVector attn_probs_shape(4);
117+
attn_probs_shape[0] = static_cast<int64_t>(parameters.batch_size);
118+
attn_probs_shape[1] = static_cast<int64_t>(parameters.num_heads);
119+
attn_probs_shape[2] = static_cast<int64_t>(sequence_length);
120+
attn_probs_shape[3] = static_cast<int64_t>(parameters.total_sequence_length);
121+
Tensor* attn_probs = context->Output(3, attn_probs_shape);
122+
116123
std::vector<int64_t> present_dims{
117124
parameters.batch_size, parameters.num_heads, parameters.total_sequence_length, parameters.head_size};
118125
TensorShape present_shape(present_dims);
@@ -172,6 +179,7 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
172179
parameters.past_sequence_length > 0 &&
173180
nullptr == attention_bias &&
174181
nullptr == key_padding_mask &&
182+
nullptr == attn_probs && // TODO: support attn_probs
175183
parameters.head_size == parameters.v_head_size &&
176184
onnxruntime::lean::is_supported(device_prop,
177185
parameters.head_size,
@@ -216,6 +224,7 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
216224
!disable_flash_attention_ &&
217225
nullptr == attention_bias &&
218226
nullptr == key_padding_mask &&
227+
nullptr == attn_probs && // TODO: support attn_probs
219228
parameters.head_size == parameters.v_head_size &&
220229
onnxruntime::flash::is_supported(device_prop,
221230
parameters.head_size,
@@ -280,7 +289,7 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
280289
!is_unidirectional_ &&
281290
nullptr == key_padding_mask &&
282291
nullptr == attention_bias &&
283-
nullptr == past_key && nullptr == present_key &&
292+
nullptr == past_key && nullptr == present_key && nullptr == attn_probs &&
284293
(parameters.qkv_format == Q_K_V_BSNH || (parameters.qkv_format == Q_KV_BSNH_BSN2H && bias == nullptr)) &&
285294
parameters.hidden_size == parameters.v_hidden_size &&
286295
has_fused_cross_attention_kernel(sm, parameters.head_size, parameters.kv_sequence_length);
@@ -305,7 +314,7 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
305314
!is_unidirectional_ &&
306315
nullptr == attention_bias &&
307316
(parameters.qkv_format == Q_K_V_BSNH || parameters.qkv_format == QKV_BSN3H) &&
308-
nullptr == past_key && nullptr == present_key &&
317+
nullptr == past_key && nullptr == present_key && nullptr == attn_probs &&
309318
is_mask_none_or_1d_k_len &&
310319
parameters.hidden_size == parameters.v_hidden_size &&
311320
parameters.sequence_length == parameters.kv_sequence_length && // self attention only for fused runner
@@ -339,6 +348,7 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
339348
kernel_type == AttentionKernelType::AttentionKernel_Default &&
340349
!disable_memory_efficient_attention_ &&
341350
is_long_sequence &&
351+
nullptr == attn_probs && // TODO: support attn_probs
342352
// Check whether the attention bias alignment is good for memory efficient attention.
343353
(attention_bias == nullptr || parameters.sequence_length % (4 * sizeof(T)) == 0) &&
344354
(nullptr == key_padding_mask || parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START) &&
@@ -369,6 +379,7 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
369379
data.output = reinterpret_cast<CudaT*>(output->MutableData<T>());
370380
data.present_key = (nullptr == present_key) ? nullptr : reinterpret_cast<CudaT*>(present_key->MutableData<T>());
371381
data.present_value = (nullptr == present_value) ? nullptr : reinterpret_cast<CudaT*>(present_value->MutableData<T>());
382+
data.attn_probs = (nullptr == attn_probs) ? nullptr : reinterpret_cast<CudaT*>(attn_probs->MutableData<T>());
372383
data.fused_runner = reinterpret_cast<void*>(fused_runner);
373384
data.fused_cross_attention_kernel = fused_cross_attention_kernel;
374385
data.use_flash_attention = use_flash_attention;

0 commit comments

Comments
 (0)