Skip to content

Commit 5bc10a3

Browse files
authored
[webgpu] Optimize AttentionPrepare (#26850)
This pull request refactors and streamlines the computation of Q, K, V tensors in the WebGPU BERT Attention operator. The main changes include removing a custom QKV preparation kernel in favor of a more modular approach using a MatMul operation followed by a dedicated split kernel, and generalizing the QKV splitting logic for broader reuse. This improves maintainability, code reuse, and performance since we have done many optimization on MatMul op. With this change, PrepareQKV becomes 128.88 ms from 751.67 ms in phi4-vision model. Before Kernel | Time (ms) | Percentage (%) -- | -- | -- Attention\|AttentionPrepare | 751.67 | 49.91 After Kernel | Time (ms) | Percentage (%) -- | -- | -- Attention\|MatMul | 120.87 | 19.77 Attention\|SplitPackedQKV | 1.94 | 0.32
1 parent 63b5cef commit 5bc10a3

File tree

5 files changed

+102
-159
lines changed

5 files changed

+102
-159
lines changed

onnxruntime/contrib_ops/webgpu/bert/attention.cc

Lines changed: 87 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
#include "contrib_ops/webgpu/bert/multihead_attention.h"
99
#include "contrib_ops/webgpu/webgpu_contrib_kernels.h"
1010
#include "core/providers/webgpu/webgpu_supported_types.h"
11+
#include "core/providers/webgpu/webgpu_utils.h"
12+
#include "core/providers/webgpu/math/matmul.h"
1113
using namespace onnxruntime::webgpu;
1214
using namespace ::onnxruntime::common;
1315
using namespace ONNX_NAMESPACE;
@@ -70,6 +72,50 @@ Status TransferBSDToBNSH(onnxruntime::webgpu::ComputeContext& context, int num_h
7072
return context.RunProgram(program);
7173
};
7274

75+
Status SplitPackedQKVProgram::GenerateShaderCode(ShaderHelper& sh) const {
76+
// Inputs: packed_qkv [B, S, D], outputs: Q, K, V [B, S, D]
77+
const auto& packed_qkv = sh.AddInput("packed_qkv", ShaderUsage::UseOffsetToIndices | ShaderUsage::UseUniform);
78+
const auto& query = sh.AddOutput("query", ShaderUsage::UseSetByIndices | ShaderUsage::UseUniform);
79+
const auto& key = sh.AddOutput("key", ShaderUsage::UseSetByIndices | ShaderUsage::UseUniform);
80+
const auto& value = sh.AddOutput("val", ShaderUsage::UseSetByIndices | ShaderUsage::UseUniform);
81+
sh.MainFunctionBody()
82+
<< sh.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.input_size")
83+
<< " let packed_qkv_indices = " << packed_qkv.OffsetToIndices("global_idx") << ";\n"
84+
<< " let batch = packed_qkv_indices[0];\n"
85+
<< " let seq = packed_qkv_indices[1];\n"
86+
<< " let d = packed_qkv_indices[2];\n"
87+
<< " let input_data = " << packed_qkv.GetByOffset("global_idx") << ";\n"
88+
<< " if (d < uniforms.hidden_size) {\n"
89+
<< " " << query.SetByIndices("vec3<u32>(batch, seq, d)", "input_data") << ";\n"
90+
<< " } else if (d < (uniforms.hidden_size + uniforms.kv_hidden_size)) {\n"
91+
<< " let kd = d - uniforms.hidden_size;\n"
92+
<< " " << key.SetByIndices("vec3<u32>(batch, seq, kd)", "input_data") << ";\n"
93+
<< " } else {\n"
94+
<< " let vd = d - uniforms.hidden_size - uniforms.kv_hidden_size;\n"
95+
<< " " << value.SetByIndices("vec3<u32>(batch, seq, vd)", "input_data") << ";\n"
96+
<< " }\n";
97+
return Status::OK();
98+
}
99+
100+
Status SplitPackedQKV(onnxruntime::webgpu::ComputeContext& context, const WebgpuAttentionParameters& params,
101+
const Tensor* packedQKV, Tensor* query, Tensor* key, Tensor* val, int kv_hidden_size) {
102+
// Output Q, K, V in BSD format
103+
const int components = std::min({GetMaxComponents(params.hidden_size_), GetMaxComponents(kv_hidden_size), GetMaxComponents(params.v_hidden_size_)});
104+
SplitPackedQKVProgram program;
105+
auto input_size = packedQKV->Shape().Size();
106+
const uint32_t vectorized_input_size = static_cast<uint32_t>(input_size / components);
107+
program
108+
.AddInput({packedQKV, ProgramTensorMetadataDependency::TypeAndRank, components})
109+
.AddOutputs({{query, ProgramTensorMetadataDependency::TypeAndRank, components}, {key, ProgramTensorMetadataDependency::TypeAndRank, components}, {val, ProgramTensorMetadataDependency::TypeAndRank, components}})
110+
.AddUniformVariables({
111+
{vectorized_input_size},
112+
{static_cast<uint32_t>(params.hidden_size_ / components)},
113+
{static_cast<uint32_t>(kv_hidden_size / components)},
114+
})
115+
.SetDispatchGroupSize((vectorized_input_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE);
116+
return context.RunProgram(program);
117+
}
118+
73119
void InitVarStub(std::ostringstream& ss, bool has_seqlen_k) {
74120
if (has_seqlen_k) {
75121
ss << "total_sequence_length = u32(seqlen_k[batch_idx]) + 1;\n";
@@ -594,113 +640,26 @@ Attention::Attention(const OpKernelInfo& info)
594640
onnxruntime::contrib::AttentionBase(info, false) {
595641
}
596642

597-
// QKV preparation program - computes Q, K, V from input, weights, and bias
598-
class AttentionPrepareProgram final : public Program<AttentionPrepareProgram> {
599-
public:
600-
AttentionPrepareProgram() : Program{"AttentionPrepare"} {}
601-
602-
Status GenerateShaderCode(ShaderHelper& shader) const override {
603-
shader.AddInput("input", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
604-
shader.AddInput("weight", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
605-
shader.AddInput("bias", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
606-
shader.AddOutput("output_q", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
607-
shader.AddOutput("output_k", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
608-
shader.AddOutput("output_v", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
609-
610-
constexpr int TILE_SIZE = 12;
611-
612-
shader.AdditionalImplementation() << "const TILE_SIZE = " << TILE_SIZE << "u;\n"
613-
<< "var<workgroup> tileInput: array<input_value_t, " << TILE_SIZE * TILE_SIZE << ">;\n"
614-
<< "var<workgroup> tileWeightQ: array<input_value_t, " << TILE_SIZE * TILE_SIZE << ">;\n"
615-
<< "var<workgroup> tileWeightK: array<input_value_t, " << TILE_SIZE * TILE_SIZE << ">;\n"
616-
<< "var<workgroup> tileWeightV: array<input_value_t, " << TILE_SIZE * TILE_SIZE << ">;\n";
617-
618-
shader.MainFunctionBody() //<< shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.M * uniforms.N")
619-
<< "let batchIndex = workgroup_id.z / uniforms.num_heads;\n"
620-
<< "let headNumber = workgroup_id.z % uniforms.num_heads;\n"
621-
<< "let m = global_id.y;\n"
622-
<< "let n = global_id.x;\n"
623-
<< "let inputOffset = batchIndex * (uniforms.M * uniforms.K) + m * uniforms.K;\n"
624-
<< "let biasOffsetQ = headNumber * uniforms.head_size;\n"
625-
<< "let biasOffsetK = uniforms.hidden_size + biasOffsetQ;\n"
626-
<< "let biasOffsetV = uniforms.hidden_size + biasOffsetK;\n"
627-
<< "var valueQ = input_value_t(0);\n"
628-
<< "var valueK = input_value_t(0);\n"
629-
<< "var valueV = input_value_t(0);\n"
630-
<< "for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) {\n"
631-
<< " if (m < uniforms.M && w + local_id.x < uniforms.K) {\n"
632-
<< " tileInput[TILE_SIZE * local_id.y + local_id.x] = input[inputOffset + w + local_id.x];\n"
633-
<< " }\n"
634-
<< " if (n < uniforms.N && w + local_id.y < uniforms.K) {\n"
635-
<< " let offset = n + (w + local_id.y) * uniforms.ldb;\n"
636-
<< " tileWeightQ[TILE_SIZE * local_id.y + local_id.x] = weight[biasOffsetQ + offset];\n"
637-
<< " tileWeightK[TILE_SIZE * local_id.y + local_id.x] = weight[biasOffsetK + offset];\n"
638-
<< " tileWeightV[TILE_SIZE * local_id.y + local_id.x] = weight[biasOffsetV + offset];\n"
639-
<< " }\n"
640-
<< " workgroupBarrier();\n"
641-
<< " for (var k: u32 = 0u; k<TILE_SIZE && w+k < uniforms.K; k++) {\n"
642-
<< " let inputTileOffset = TILE_SIZE * local_id.y + k;\n"
643-
<< " let weightTileOffset = TILE_SIZE * k + local_id.x;\n"
644-
<< " valueQ += tileInput[inputTileOffset] * tileWeightQ[weightTileOffset];\n"
645-
<< " valueK += tileInput[inputTileOffset] * tileWeightK[weightTileOffset];\n"
646-
<< " valueV += tileInput[inputTileOffset] * tileWeightV[weightTileOffset];\n"
647-
<< " }\n"
648-
<< " workgroupBarrier();\n"
649-
<< "}\n"
650-
<< "let headOffset = (m * uniforms.N + n) % uniforms.head_size;\n"
651-
<< "valueQ += bias[headOffset + biasOffsetQ];\n"
652-
<< "valueK += bias[headOffset + biasOffsetK];\n"
653-
<< "valueV += bias[headOffset + biasOffsetV];\n"
654-
<< "let offset = workgroup_id.z * uniforms.M * uniforms.N;\n"
655-
<< "if (m < uniforms.M && n < uniforms.N) {\n"
656-
<< " let outputIdx = offset + m * uniforms.N + n;\n"
657-
<< " output_q[outputIdx] = valueQ;\n"
658-
<< " output_k[outputIdx] = valueK;\n"
659-
<< " output_v[outputIdx] = valueV;\n"
660-
<< "}\n";
661-
662-
return Status::OK();
663-
}
664-
665-
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"M", ProgramUniformVariableDataType::Uint32},
666-
{"K", ProgramUniformVariableDataType::Uint32},
667-
{"N", ProgramUniformVariableDataType::Uint32},
668-
{"num_heads", ProgramUniformVariableDataType::Uint32},
669-
{"head_size", ProgramUniformVariableDataType::Uint32},
670-
{"hidden_size", ProgramUniformVariableDataType::Uint32},
671-
{"ldb", ProgramUniformVariableDataType::Uint32});
672-
};
673-
674643
Status PrepareQKV(onnxruntime::webgpu::ComputeContext& context, const WebgpuAttentionParameters& parameters,
675644
const Tensor* input, const Tensor* weights, const Tensor* bias,
676645
Tensor* q, Tensor* k, Tensor* v) {
677-
constexpr int TILE_SIZE = 12;
678-
const int M = parameters.sequence_length_;
679-
const int K = parameters.input_hidden_size_;
680-
const int N = parameters.head_size_;
681-
682-
const uint32_t dispatch_x = (parameters.head_size_ + TILE_SIZE - 1) / TILE_SIZE;
683-
const uint32_t dispatch_y = (parameters.sequence_length_ + TILE_SIZE - 1) / TILE_SIZE;
684-
const uint32_t dispatch_z = parameters.batch_size_ * parameters.num_heads_;
685-
686-
AttentionPrepareProgram program{};
687-
program.AddInputs({{input, ProgramTensorMetadataDependency::TypeAndRank},
688-
{weights, ProgramTensorMetadataDependency::TypeAndRank},
689-
{bias, ProgramTensorMetadataDependency::TypeAndRank}})
690-
.AddOutputs({{q, ProgramTensorMetadataDependency::TypeAndRank},
691-
{k, ProgramTensorMetadataDependency::TypeAndRank},
692-
{v, ProgramTensorMetadataDependency::TypeAndRank}})
693-
.SetDispatchGroupSize(dispatch_x, dispatch_y, dispatch_z)
694-
.SetWorkgroupSize(TILE_SIZE, TILE_SIZE)
695-
.AddUniformVariables({{static_cast<uint32_t>(M)},
696-
{static_cast<uint32_t>(K)},
697-
{static_cast<uint32_t>(N)},
698-
{static_cast<uint32_t>(parameters.num_heads_)},
699-
{static_cast<uint32_t>(parameters.head_size_)},
700-
{static_cast<uint32_t>(parameters.hidden_size_)},
701-
{static_cast<uint32_t>(parameters.hidden_size_ + parameters.hidden_size_ + parameters.v_hidden_size_)}});
646+
// Use MatMul to compute packed QKV output: input * weights + bias
647+
// Then use SplitPackedQKV to split into Q, K, V in BSD format
648+
// Returns Q, K, V in BSD format
702649

703-
return context.RunProgram(program);
650+
// Create packed QKV tensor with shape [batch_size, sequence_length, hidden_size + hidden_size + v_hidden_size]
651+
const int64_t packed_qkv_size = parameters.hidden_size_ + parameters.hidden_size_ + parameters.v_hidden_size_;
652+
TensorShapeVector packed_qkv_shape({parameters.batch_size_, parameters.sequence_length_, packed_qkv_size});
653+
Tensor packed_qkv = context.CreateGPUTensor(input->DataType(), TensorShape(packed_qkv_shape));
654+
655+
// Prepare inputs for MatMul
656+
std::vector<const Tensor*> matmul_inputs = {input, weights, bias};
657+
658+
// Call MatMul: packed_qkv = input * weights + bias
659+
ORT_RETURN_IF_ERROR(onnxruntime::webgpu::ComputeMatMul(&context, Activation(), matmul_inputs, &packed_qkv, true));
660+
661+
// Output Q, K, V in BSD format
662+
return SplitPackedQKV(context, parameters, &packed_qkv, q, k, v, parameters.hidden_size_);
704663
}
705664

706665
Status Attention::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const {
@@ -755,15 +714,16 @@ Status Attention::ComputeInternal(onnxruntime::webgpu::ComputeContext& context)
755714
ORT_NOT_IMPLEMENTED("present tensor not implemented for webgpu Attention");
756715
}
757716

758-
// Create Q, K, V tensors by computing input * weights + bias
759-
TensorShapeVector qkv_shape({parameters.batch_size_, parameters.num_heads_,
760-
parameters.sequence_length_, parameters.head_size_});
761-
Tensor Q = context.CreateGPUTensor(input->DataType(), qkv_shape);
762-
Tensor K = context.CreateGPUTensor(input->DataType(), qkv_shape);
763-
Tensor V = context.CreateGPUTensor(input->DataType(), qkv_shape);
717+
// Create Q, K, V tensors in BSD format from input * weights + bias
718+
TensorShapeVector qkv_bsd_shape({parameters.batch_size_, parameters.sequence_length_, parameters.hidden_size_});
719+
TensorShapeVector v_bsd_shape({parameters.batch_size_, parameters.sequence_length_, parameters.v_hidden_size_});
720+
Tensor Q_bsd = context.CreateGPUTensor(input->DataType(), TensorShape(qkv_bsd_shape));
721+
Tensor K_bsd = context.CreateGPUTensor(input->DataType(), TensorShape(qkv_bsd_shape));
722+
Tensor V_bsd = context.CreateGPUTensor(input->DataType(), TensorShape(v_bsd_shape));
764723

765-
// Compute Q, K, V from input, weights, and bias
766-
ORT_RETURN_IF_ERROR(PrepareQKV(context, parameters, input, weights, bias, &Q, &K, &V));
724+
// Compute Q, K, V from input, weights, and bias (returns BSD format)
725+
ORT_RETURN_IF_ERROR(PrepareQKV(context, parameters, input, weights, bias, &Q_bsd, &K_bsd, &V_bsd));
726+
parameters.qkv_format_ = Q_K_V_BSNH;
767727

768728
// Check if we can use flash attention
769729
// For Attention operator, we need to create present_key and present_value tensors for flash attention
@@ -774,10 +734,25 @@ Status Attention::ComputeInternal(onnxruntime::webgpu::ComputeContext& context)
774734
Tensor present_value = context.CreateGPUTensor(input->DataType(), present_kv_shape);
775735

776736
if (CanApplyFlashAttention(nullptr, &present_key, &present_value, parameters, context)) {
777-
return ApplyFlashAttention(&Q, &K, &V, attention_bias, output, nullptr, &present_key, nullptr, &present_value,
737+
// FlashAttention supports Q_K_V_BSNH format directly
738+
return ApplyFlashAttention(&Q_bsd, &K_bsd, &V_bsd, attention_bias, output, nullptr, &present_key, nullptr, &present_value,
778739
parameters, context, nullptr);
779740
}
780741

742+
// For non-flash attention path, convert BSD to BNSH format
743+
TensorShapeVector qkv_bnsh_shape({parameters.batch_size_, parameters.num_heads_, parameters.sequence_length_, parameters.head_size_});
744+
TensorShapeVector v_bnsh_shape({parameters.batch_size_, parameters.num_heads_, parameters.sequence_length_, parameters.v_head_size_});
745+
Tensor Q = context.CreateGPUTensor(input->DataType(), TensorShape(qkv_bnsh_shape));
746+
Tensor K = context.CreateGPUTensor(input->DataType(), TensorShape(qkv_bnsh_shape));
747+
Tensor V = context.CreateGPUTensor(input->DataType(), TensorShape(v_bnsh_shape));
748+
749+
ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.num_heads_, parameters.sequence_length_,
750+
parameters.head_size_, &Q_bsd, nullptr, 0, &Q));
751+
ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.num_heads_, parameters.sequence_length_,
752+
parameters.head_size_, &K_bsd, nullptr, 0, &K));
753+
ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.num_heads_, parameters.sequence_length_,
754+
parameters.v_head_size_, &V_bsd, nullptr, 0, &V));
755+
781756
// Apply the actual attention computation
782757
return ApplyAttention(&Q, &K, &V, attention_bias, nullptr, nullptr, output, /* present_key */ nullptr,
783758
/* present_value */ nullptr, /* output_qk */ nullptr, parameters, context, nullptr, nullptr, -1);

onnxruntime/contrib_ops/webgpu/bert/attention.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,17 @@ class TransferBSDToBNSHProgram final : public Program<TransferBSDToBNSHProgram>
3232
bool has_bias_;
3333
};
3434

35+
class SplitPackedQKVProgram final : public Program<SplitPackedQKVProgram> {
36+
public:
37+
SplitPackedQKVProgram() : Program{"SplitPackedQKV"} {}
38+
39+
Status GenerateShaderCode(ShaderHelper& sh) const override;
40+
41+
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"input_size", ProgramUniformVariableDataType::Uint32},
42+
{"hidden_size", ProgramUniformVariableDataType::Uint32},
43+
{"kv_hidden_size", ProgramUniformVariableDataType::Uint32});
44+
};
45+
3546
class AttentionProbsProgram final : public Program<AttentionProbsProgram> {
3647
public:
3748
AttentionProbsProgram(const std::string& kernel_name, bool feed_past_key, bool has_present_key,

onnxruntime/contrib_ops/webgpu/bert/attention_common.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,9 @@ struct WebgpuAttentionParameters {
122122
Status TransferBSDToBNSH(onnxruntime::webgpu::ComputeContext& context, int num_heads, int sequence_length,
123123
int head_size, const Tensor* input_tensor, const Tensor* bias, int bias_offset, Tensor* output_tensor);
124124

125+
Status SplitPackedQKV(onnxruntime::webgpu::ComputeContext& context, const WebgpuAttentionParameters& params,
126+
const Tensor* packedQKV, Tensor* query, Tensor* key, Tensor* val, int kv_hidden_size);
127+
125128
Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attention_bias,
126129
const Tensor* past_key, const Tensor* past_value, Tensor* output, Tensor* present_key, Tensor* present_value,
127130
Tensor* output_qk, WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context,

0 commit comments

Comments
 (0)