@@ -590,15 +590,22 @@ Status UnfusedAttention(
590590
591591 DUMP_TENSOR_D (" QK" , data.scratch , batch_size, num_heads, sequence_length, total_sequence_length);
592592
593- constexpr size_t element_size = sizeof (T);
594- const size_t bytes = GetAttentionScratchSize (element_size, batch_size, num_heads,
595- sequence_length, total_sequence_length);
596- T* scratch2 = data.scratch + (bytes / element_size);
593+ T* softmax_storage;
594+ if (data.attn_probs == nullptr ) {
595+ constexpr size_t element_size = sizeof (T);
596+ const size_t bytes = GetAttentionScratchSize (element_size, batch_size, num_heads,
597+ sequence_length, total_sequence_length);
598+ T* scratch2 = data.scratch + (bytes / element_size);
599+ softmax_storage = scratch2;
600+ }
601+ else {
602+ softmax_storage = data.attn_probs ;
603+ }
597604
598605 const bool broadcast_attn_bias_dim_0 = parameters.broadcast_attn_bias_dim_0 ;
599606 const bool broadcast_attn_bias_dim_1 = parameters.broadcast_attn_bias_dim_1 ;
600607
601- // Apply softmax and store result R to scratch2 : BxNxSxT
608+ // Apply softmax and store result R to softmax_storage : BxNxSxT
602609 if (use_raw_attention_mask) { // 2d, 3d or 4d attention mask
603610 const int mask_dimension = static_cast <int >(mask_index_dims.size ());
604611
@@ -612,7 +619,7 @@ Status UnfusedAttention(
612619 ComputeSoftmaxWithRawMask<T>(
613620 ort_stream, total_sequence_length, sequence_length, batch_size, num_heads,
614621 mask_index, nullptr , data.attention_bias , broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1,
615- data.scratch , scratch2 , parameters.is_unidirectional , scale, mask_dimension,
622+ data.scratch , softmax_storage , parameters.is_unidirectional , scale, mask_dimension,
616623 parameters.max_sequence_length , use_persistent_softmax, persistent_softmax_workspace,
617624 parameters.mask_filter_value ));
618625 } else if (nullptr != mask_index) { // 1d mask index
@@ -622,24 +629,24 @@ Status UnfusedAttention(
622629 ORT_RETURN_IF_ERROR (ComputeSoftmaxWithMask1D<T>(
623630 stream, total_sequence_length, sequence_length, batch_size, num_heads,
624631 mask_index, mask_start, data.attention_bias , broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1,
625- data.scratch , scratch2 , parameters.is_unidirectional ));
632+ data.scratch , softmax_storage , parameters.is_unidirectional ));
626633 } else { // no mask
627634 ORT_RETURN_IF_ERROR (
628635 ComputeSoftmax<T>(
629636 stream, total_sequence_length, sequence_length, batch_size, num_heads,
630637 data.attention_bias , broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1,
631- data.scratch , scratch2 , parameters.is_unidirectional ));
638+ data.scratch , softmax_storage , parameters.is_unidirectional ));
632639 }
633640
634- DUMP_TENSOR_D (" Softmax" , scratch2 , batch_size, num_heads, sequence_length, total_sequence_length);
641+ DUMP_TENSOR_D (" Softmax" , softmax_storage , batch_size, num_heads, sequence_length, total_sequence_length);
635642
636643 // compute R*V (as V*R), and store in temp_output (space used by Q): BxNxSxH_v
637644 T* temp_output = data.q ;
638645 CUBLAS_RETURN_IF_ERROR (cublasGemmStridedBatchedHelper (
639646 cublas, CUBLAS_OP_N, CUBLAS_OP_N,
640647 v_head_size, sequence_length, total_sequence_length,
641648 &one, data.v , v_head_size, present_size_per_batch_v,
642- scratch2 , total_sequence_length, sequence_length * total_sequence_length,
649+ softmax_storage , total_sequence_length, sequence_length * total_sequence_length,
643650 &zero, temp_output, v_head_size, sequence_length * v_head_size, batches, device_prop, parameters.use_tf32 ));
644651
645652 // Temp_output is BxNxSxH_v, transpose to output BxSxNxH_v
0 commit comments