@@ -520,8 +520,8 @@ Status ComputeVxAttentionScore(onnxruntime::webgpu::ComputeContext& context, int
520520
521521Status ApplyAttention (const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attention_bias,
522522 const Tensor* past_key, const Tensor* past_value, Tensor* output, Tensor* present_key, Tensor* present_value,
523- WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* head_sink ,
524- const Tensor* seqlen_k, int local_window_size) {
523+ Tensor* output_qk, WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context,
524+ const Tensor* head_sink, const Tensor* seqlen_k, int local_window_size) {
525525 const int output_count = std::min ({context.OutputCount (), 1 + (past_key != nullptr ? 1 : 0 ) + (past_value != nullptr ? 1 : 0 )});
526526 const int past_sequence_length = output_count > 1 ? parameters.past_sequence_length_ : 0 ;
527527 const int total_sequence_length =
@@ -534,6 +534,11 @@ Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const T
534534 ORT_RETURN_IF_ERROR (ComputeAttentionProbs (context, output_count, Q, K, past_key, attention_bias, &probs, present_key,
535535 parameters, past_sequence_length, total_sequence_length, seqlen_k));
536536
537+ if (output_qk != nullptr ) {
538+ // Copy the attention scores (scaled Q*K^T) to output_qk
539+ ORT_RETURN_IF_ERROR (context.CopyTensor (probs, *output_qk));
540+ }
541+
537542 ORT_RETURN_IF_ERROR (ComputeInPlaceSoftmax (context, &probs,
538543 parameters.batch_size_ , parameters.num_heads_ , parameters.past_sequence_length_ , parameters.sequence_length_ , total_sequence_length, seqlen_k, parameters.is_first_prompt_ , parameters.use_smooth_softmax_ , head_sink, local_window_size));
539544
@@ -730,7 +735,7 @@ Status Attention::ComputeInternal(onnxruntime::webgpu::ComputeContext& context)
730735
731736 // Apply the actual attention computation
732737 return ApplyAttention (&Q, &K, &V, attention_bias, nullptr , nullptr , output, /* present_key */ nullptr ,
733- /* present_value */ nullptr , parameters, context, nullptr , nullptr , -1 );
738+ /* present_value */ nullptr , /* output_qk */ nullptr , parameters, context, nullptr , nullptr , -1 );
734739}
735740
736741} // namespace webgpu
0 commit comments