From fdd5ceb97de27d72ab12051cf3c31d6992295ceb Mon Sep 17 00:00:00 2001 From: Satya Kumar Jandhyala Date: Wed, 15 Jan 2025 12:01:32 -0800 Subject: [PATCH 01/23] Added GroupQuerryAttention do_rotary attribute. --- .../webgpu/bert/group_query_attention.cc | 89 ++++++++++++++++++- .../webgpu/bert/group_query_attention.h | 21 +++++ 2 files changed, 106 insertions(+), 4 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc index 31c8af9b4f922..9649677119f1c 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc @@ -5,6 +5,7 @@ #include "contrib_ops/webgpu/bert/attention_common.h" #include "contrib_ops/webgpu/bert/group_query_attention.h" #include "contrib_ops/webgpu/webgpu_contrib_kernels.h" +#include "contrib_ops/webgpu/bert/rotary_embedding.h" #include "core/providers/webgpu/webgpu_supported_types.h" @@ -29,13 +30,78 @@ ONNX_OPERATOR_KERNEL_EX( .InputMemoryType(OrtMemTypeCPUInput, 6), GroupQueryAttention); +Status GeneratePositionIDsProgram::GenerateShaderCode(ShaderHelper& sh) const { + sh.AddInput("seqlens", ShaderUsage::UseUniform); + sh.AddOutput("output", ShaderUsage::UseUniform); + sh.MainFunctionBody() << "let batch_idx = global_idx / uniforms.sequence_length;\n" + << "let sequence_idx = i32(global_idx % uniforms.sequence_length);\n" + << "var pos_id: u32 = 0u;\n" + << "if (is_first_prompt == 0) {\n" + << " let total_seqlen = ${seqLensInputHelper.getByOffset('batch_idx')} + 1;\n" + << " let past_seqlen = total_seqlen - i32(uniforms.sequence_length);\n" + << " if (past_seqlen + sequence_idx < total_seqlen) {\n" + << " pos_id = u32(past_seqlen + sequence_idx);\n" + << " } else {\n" + << " pos_id = 1u;\n" + << " }\n" + << "}\n" + << "output[global_idx] = pos_id;\n"; + return Status::OK(); +} + +Status GeneratePositionIDs(onnxruntime::webgpu::ComputeContext& context, const WebgpuAttentionParameters& params, const Tensor* seqlens, Tensor* output_tensor) { + GeneratePositionIDsProgram program(params); + program.AddInput(seqlens) + .AddOutput(output_tensor) + .AddUniformVariables({{static_cast(params.batch_size_)}, + {static_cast(params.sequence_length_)}, + {static_cast(params.num_heads_)}, + {static_cast(params.head_size_)}, + {static_cast(params.rotary_dim_)}, + {static_cast(params.rotary_interleaved_)}, + {static_cast(params.is_first_prompt_ ? 0 : 1)}, + {static_cast(params.total_sequence_length_)}}); + return context.RunProgram(program); +} + +Status RunRotaryEmbedding(onnxruntime::webgpu::ComputeContext& context, const WebgpuAttentionParameters& params, const Tensor* input, const Tensor* pos_ids, const Tensor* cos_cache, const Tensor* sin_cache, Tensor* output) { + const auto half_rotary_embedding_dim = gsl::narrow(cos_cache->Shape()[1]); + + const TensorShape global_shape({params.batch_size_, params.sequence_length_, params.num_heads_, params.head_size_ - half_rotary_embedding_dim}); + const auto rank = global_shape.NumDimensions(); + std::vector global_dims(rank); + std::vector global_strides(rank); + for (size_t j = 0; j < rank; ++j) { + global_dims[j] = gsl::narrow(global_shape[j]); + global_strides[j] = gsl::narrow(global_shape.SizeFromDimension(j + 1)); + } + const auto input_output_strides = std::vector({gsl::narrow(params.batch_size_), gsl::narrow(params.hidden_size_), gsl::narrow(params.head_size_), 1}); + const auto output_size = gsl::narrow(global_shape.Size()); + + RotaryEmbeddingProgram program(params.rotary_interleaved_); + program + .CacheHint(params.rotary_interleaved_) + .AddInputs({{input, ProgramTensorMetadataDependency::Rank}, + {pos_ids, ProgramTensorMetadataDependency::Rank}, + {cos_cache, ProgramTensorMetadataDependency::Rank}, + {sin_cache, ProgramTensorMetadataDependency::Rank}}) + .AddOutput(output) + .SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .AddUniformVariables({{params.scale_}, + {gsl::make_span(global_dims)}, + {gsl::make_span(global_strides)}, + {gsl::make_span(input_output_strides)}}) + .AddIndices(TensorShape{1, 1}); + return context.RunProgram(program); +} + Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const { const Tensor* query = context.Input(0); const Tensor* key = context.Input(1); const Tensor* value = context.Input(2); const Tensor* past_key = context.Input(3); const Tensor* past_value = context.Input(4); - const Tensor* seqlen_k = context.Input(5); + const Tensor* seqlens_k = context.Input(5); const Tensor* total_seqlen_tensor = context.Input(6); const Tensor* cos_cache = context.Input(7); const Tensor* sin_cache = context.Input(8); @@ -51,7 +117,7 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& ¶ms, num_heads_, kv_num_heads_, - seqlen_k, + seqlens_k, total_seqlen_tensor, scale_, softcap_)); @@ -59,6 +125,21 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& if (parameters.is_packed_qkv_) { ORT_NOT_IMPLEMENTED("Packed QKV of shape (B, L, N, 3, H) not implemented for webgpu-ep."); } + if (do_rotary_) { + Tensor q = context.CreateGPUTensor(query->DataType(), query->Shape()); + Tensor k = context.CreateGPUTensor(key->DataType(), key->Shape()); + TensorShape pos_ids_shape = parameters.is_first_prompt_ ? TensorShape({1}) : TensorShape({parameters.batch_size_ * parameters.sequence_length_}); + Tensor pos_ids = context.CreateGPUTensor(DataTypeImpl::GetType(), pos_ids_shape); + ORT_RETURN_IF_ERROR(GeneratePositionIDs(context, parameters, seqlens_k, &pos_ids)); + + ORT_RETURN_IF_ERROR(RunRotaryEmbedding(context, parameters, query, &pos_ids, cos_cache, sin_cache, &q)); + + ORT_RETURN_IF_ERROR(RunRotaryEmbedding(context, parameters, key, &pos_ids, cos_cache, sin_cache, &k)); + + query = &q; + key = &k; + } + TensorShapeVector output_shape(3); output_shape[0] = static_cast(parameters.batch_size_); output_shape[1] = static_cast(parameters.sequence_length_); @@ -82,7 +163,7 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& context, parameters.num_heads_, parameters.sequence_length_, parameters.head_size_, query, nullptr, 0, &Q)); if (parameters.qkv_format_ == Q_K_V_BSNH_BNSH_BNSH) { // key and value in BNSH format return ApplyAttention(&Q, key, value, nullptr, past_key, past_value, output, present_key, - present_value, parameters, context, seqlen_k); + present_value, parameters, context, seqlens_k); } TensorShapeVector k_new_dims({parameters.batch_size_, parameters.kv_num_heads_, @@ -99,7 +180,7 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.kv_num_heads_, parameters.kv_sequence_length_, parameters.v_head_size_, value, nullptr, 0, &V)); return ApplyAttention(&Q, &K, &V, nullptr, past_key, past_value, output, present_key, - present_value, parameters, context, seqlen_k); + present_value, parameters, context, seqlens_k); } } // namespace webgpu diff --git a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.h b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.h index 04969dc778927..a06bc1946d044 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.h @@ -14,6 +14,27 @@ namespace webgpu { using namespace onnxruntime::webgpu; +Status GeneratePositionIDs(onnxruntime::webgpu::ComputeContext& context, bool is_first_prompt, int batch_size, int sequence_length, int num_heads, int head_size, int rotary_embedding_dim, bool interleaved, int total_seqlen, const Tensor* seqlens, Tensor* output_tensor); + +class GeneratePositionIDsProgram final : public Program { + public: + GeneratePositionIDsProgram(const WebgpuAttentionParameters& params) : Program{"GeneratePositionIDs"}, params_(params) {} + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"batch_size", ProgramUniformVariableDataType::Uint32}, + {"sequence_length", ProgramUniformVariableDataType::Uint32}, + {"num_heads", ProgramUniformVariableDataType::Uint32}, + {"head_size", ProgramUniformVariableDataType::Uint32}, + {"rotary_embedding_dim", ProgramUniformVariableDataType::Uint32}, + {"interleaved", ProgramUniformVariableDataType::Uint32}, + {"position_ids_format", ProgramUniformVariableDataType::Uint32}, + {"total_seqlen", ProgramUniformVariableDataType::Uint32}); + + private: + const WebgpuAttentionParameters& params_; +}; + class GroupQueryAttention final : public WebGpuKernel { public: GroupQueryAttention(const OpKernelInfo& info) : WebGpuKernel(info) { From f6b022252d8118c14a6cb3a2d2db2213b6b0fd6f Mon Sep 17 00:00:00 2001 From: Satya Kumar Jandhyala Date: Wed, 15 Jan 2025 17:09:49 -0800 Subject: [PATCH 02/23] Added packed QKV and rotary embedding support for GQA --- .../contrib_ops/webgpu/bert/attention.cc | 44 ++++++--- .../contrib_ops/webgpu/bert/attention.h | 10 +- .../webgpu/bert/group_query_attention.cc | 93 +++++++++++-------- 3 files changed, 89 insertions(+), 58 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.cc b/onnxruntime/contrib_ops/webgpu/bert/attention.cc index 568e75b38a98f..c0ec6dce4a1f7 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.cc @@ -80,7 +80,9 @@ void InitVarStub(std::ostringstream& ss, const Tensor* seqlen_k) { Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const { shader.AddInput("q", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); - shader.AddInput("key", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); + if (!is_packed_qkv_) { + shader.AddInput("key", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); + } if (feed_past_key_) { shader.AddInput("past_key", ShaderUsage::UseUniform); } @@ -102,13 +104,21 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const { << "let m = workgroup_id.y * TILE_SIZE;\n" << "let n = workgroup_id.x * TILE_SIZE;\n" << "let batch_idx = workgroup_id.z / uniforms.num_heads;\n" - << "let qOffset = workgroup_id.z * uniforms.M * uniforms.K + m * uniforms.K;\n" << "let sequence_length = uniforms.M;\n" << "var total_sequence_length = uniforms.N;\n"; + if (is_packed_qkv_) { + shader.MainFunctionBody() << "let head_idx = workgroup_id.z % uniforms.num_heads;\n" + << "let kv_num_heads = uniforms.kv_num_heads /" << n_reps_ << ";\n" + << "let packed_batch_stride = (uniforms.num_heads + 2 * kv_num_heads) * uniforms.M * uniforms.K;\n" + << "let qOffset = batch_idx * packed_batch_stride + head_idx * uniforms.M * uniforms.K;\n" + << "let kOffset = batchIdx * packed_batch_stride + (uniforms.num_heads + kvHeadIdx) * uniforms.kv_sequence_length * uniforms.K;\n"; + } else { + shader.MainFunctionBody() << "let qOffset = workgroup_id.z * uniforms.M * uniforms.K + m * uniforms.K;\n" + << "let kOffset = (workgroup_id.z / " << n_reps_ << ") * uniforms.kv_sequence_length * uniforms.K;\n"; + } std::ostringstream oss; InitVarStub(oss, seqlen_k_); shader.MainFunctionBody() << oss.str(); - shader.MainFunctionBody() << "let kOffset = (workgroup_id.z / " << n_reps_ << ") * uniforms.kv_sequence_length * uniforms.K;\n"; if (has_present_key_) { shader.MainFunctionBody() << "let presentKeyOffset = (workgroup_id.z / " << n_reps_ << ") * uniforms.present_sequence_length * uniforms.K;\n"; } @@ -126,11 +136,11 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const { << " let pastKeyOffset = (workgroup_id.z / " << n_reps_ << ") * uniforms.past_sequence_length * uniforms.K;\n" << " tileK[idx] = " << (past_present_share_buffer_ ? "present_key" : "past_key") << "[pastKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x];\n" << " } else if (n + local_id.y - past_sequence_length < uniforms.kv_sequence_length) {\n" - << " tileK[idx] = key[kOffset + (n + local_id.y - past_sequence_length) * uniforms.K + w + local_id.x];\n" + << " tileK[idx] = " << (is_packed_qkv_ ? "q" : "key") << "[kOffset + (n + local_id.y - past_sequence_length) * uniforms.K + w + local_id.x];\n" << " }\n"; } else { shader.MainFunctionBody() << " if (n + local_id.y < uniforms.kv_sequence_length) {\n" - " tileK[idx] = key[kOffset + (n + local_id.y) * uniforms.K + w + local_id.x];\n" + " tileK[idx] = " << (is_packed_qkv_ ? "q" : "key") << "[kOffset + (n + local_id.y) * uniforms.K + w + local_id.x];\n" " }\n"; } @@ -181,9 +191,11 @@ Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int o const int components = parameters.head_size_ % 4 == 0 ? 4 : (parameters.head_size_ % 2 == 0 ? 2 : 1); AttentionProbsProgram program{"AttentionProbs", feed_past_key, has_present_key, has_attention_bias, tile_size, - components, parameters.is_first_prompt_, parameters.n_reps, seqlen_k, parameters.past_present_share_buffer_}; - program.AddInputs({{Q, ProgramTensorMetadataDependency::TypeAndRank, components}, - {K, ProgramTensorMetadataDependency::TypeAndRank, components}}); + components, parameters.is_first_prompt_, parameters.is_packed_qkv_, parameters.n_reps, seqlen_k, parameters.past_present_share_buffer_}; + program.AddInput({Q, ProgramTensorMetadataDependency::TypeAndRank, components}); + if (K != nullptr) { + program.AddInput({K, ProgramTensorMetadataDependency::TypeAndRank, components}); + } if (feed_past_key) { program.AddInput({past_key, ProgramTensorMetadataDependency::TypeAndRank, components}); } @@ -203,7 +215,7 @@ Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int o (parameters.sequence_length_ + tile_size - 1) / tile_size, parameters.batch_size_ * parameters.num_heads_) .SetWorkgroupSize(tile_size, tile_size) - .CacheHint(std::to_string(tile_size), parameters.past_present_share_buffer_, feed_past_key, has_present_key, has_attention_bias, seqlen_k != nullptr, components, parameters.is_first_prompt_) + .CacheHint(std::to_string(tile_size), parameters.past_present_share_buffer_, feed_past_key, has_present_key, has_attention_bias, seqlen_k != nullptr, components, parameters.is_first_prompt_, parameters.is_packed_qkv_) .AddUniformVariables({{static_cast(parameters.sequence_length_)}, {static_cast(vectorized_head_size)}, {static_cast(total_sequence_length)}, @@ -331,7 +343,13 @@ Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const { std::ostringstream oss; InitVarStub(oss, seqlen_k_); shader.MainFunctionBody() << oss.str(); - shader.MainFunctionBody() << "let vOffset = (workgroup_id.z / " << n_reps_ << ") * uniforms.N * uniforms.kv_sequence_length + n;\n"; + if (is_packed_qkv_) { + shader.MainFunctionBody() << "let kv_num_heads = uniforms.num_heads / " << n_reps_ << ";\n" + << "let packed_batch_stride = (uniforms.num_heads + 2 * kv_num_heads) * uniforms.M * uniforms.K;\n" + << "let vOffset = batch_idx * packed_batch_stride + (uniforms.num_heads + kv_num_heads + kvHeadIdx) * uniforms.N * uniforms.kv_sequence_length + n;\n"; + } else { + shader.MainFunctionBody() << "let vOffset = (workgroup_id.z / " << n_reps_ << ") * uniforms.N * uniforms.kv_sequence_length + n;\n"; + } if (has_present_value_) { shader.MainFunctionBody() << "let presentValueOffset = (workgroup_id.z / " << n_reps_ << ") * uniforms.N * uniforms.present_sequence_length + n;\n"; } @@ -399,7 +417,7 @@ Status ComputeVxAttentionScore(onnxruntime::webgpu::ComputeContext& context, int const bool has_present_value = output_count > 1 && past_value != nullptr; constexpr int tile_size = 12; - VxAttentionScoreProgram program{"VxAttentionScore", feed_past_value, has_present_value, tile_size, parameters.is_first_prompt_, parameters.n_reps, seqlen_k, parameters.past_present_share_buffer_}; + VxAttentionScoreProgram program{"VxAttentionScore", feed_past_value, has_present_value, tile_size, parameters.is_first_prompt_, parameters.is_packed_qkv_, parameters.n_reps, seqlen_k, parameters.past_present_share_buffer_}; program.AddInputs({{probs, ProgramTensorMetadataDependency::TypeAndRank}, {V, ProgramTensorMetadataDependency::TypeAndRank}}); if (feed_past_value) { @@ -416,7 +434,7 @@ Status ComputeVxAttentionScore(onnxruntime::webgpu::ComputeContext& context, int program.SetDispatchGroupSize((parameters.v_head_size_ + tile_size - 1) / tile_size, (parameters.sequence_length_ + tile_size - 1) / tile_size, parameters.batch_size_ * parameters.num_heads_) - .CacheHint(std::to_string(tile_size), parameters.past_present_share_buffer_, feed_past_value, has_present_value, seqlen_k != nullptr, parameters.is_first_prompt_) + .CacheHint(std::to_string(tile_size), parameters.past_present_share_buffer_, feed_past_value, has_present_value, seqlen_k != nullptr, parameters.is_first_prompt_, parameters.is_packed_qkv_) .SetWorkgroupSize(tile_size, tile_size) .AddUniformVariables({{static_cast(parameters.sequence_length_)}, {static_cast(total_sequence_length)}, @@ -451,7 +469,7 @@ Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const T ORT_RETURN_IF_ERROR(ComputeInPlaceSoftmax(context, &probs, parameters.batch_size_, parameters.num_heads_, parameters.past_sequence_length_, parameters.sequence_length_, total_sequence_length, seqlen_k, parameters.is_first_prompt_)); - ORT_RETURN_IF_ERROR(ComputeVxAttentionScore(context, output_count, &probs, V, past_value, output, present_value, + ORT_RETURN_IF_ERROR(ComputeVxAttentionScore(context, output_count, &probs, parameters.is_packed_qkv_ ? Q : V, past_value, output, present_value, parameters, past_sequence_length, total_sequence_length, seqlen_k)); return Status::OK(); diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.h b/onnxruntime/contrib_ops/webgpu/bert/attention.h index 164ea72b07d9d..3dcb339bd896d 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.h @@ -34,8 +34,8 @@ class TransferBSDToBNSHProgram final : public Program class AttentionProbsProgram final : public Program { public: AttentionProbsProgram(const std::string& kernel_name, bool feed_past_key, bool has_present_key, - bool has_attention_bias, int tile_size, int components, bool is_first_prompt, int n_reps = 1, const Tensor* seqlen_k = nullptr, bool past_present_share_buffer = false) - : Program{kernel_name}, feed_past_key_(feed_past_key), has_present_key_(has_present_key), has_attention_bias_(has_attention_bias), tile_size_(tile_size), components_(components), n_reps_(n_reps), seqlen_k_(seqlen_k), past_present_share_buffer_(past_present_share_buffer), is_first_prompt_(is_first_prompt) { + bool has_attention_bias, int tile_size, int components, bool is_first_prompt, bool is_packed_qkv, int n_reps = 1, const Tensor* seqlen_k = nullptr, bool past_present_share_buffer = false) + : Program{kernel_name}, feed_past_key_(feed_past_key), has_present_key_(has_present_key), has_attention_bias_(has_attention_bias), tile_size_(tile_size), components_(components), n_reps_(n_reps), seqlen_k_(seqlen_k), past_present_share_buffer_(past_present_share_buffer), is_first_prompt_(is_first_prompt), is_packed_qkv_(is_packed_qkv) { } Status GenerateShaderCode(ShaderHelper& sh) const override; @@ -64,6 +64,7 @@ class AttentionProbsProgram final : public Program { const Tensor* seqlen_k_; bool past_present_share_buffer_; bool is_first_prompt_; + bool is_packed_qkv_; }; class InPlaceSoftmaxProgram final : public Program { @@ -90,8 +91,8 @@ class InPlaceSoftmaxProgram final : public Program { class VxAttentionScoreProgram final : public Program { public: - VxAttentionScoreProgram(const std::string& kernel_name, bool feed_past_value, bool has_present_value, int tile_size, bool is_first_prompt, int n_reps = 1, const Tensor* seqlen_k = nullptr, bool past_present_share_buffer = false) - : Program{kernel_name}, feed_past_value_(feed_past_value), has_present_value_(has_present_value), tile_size_(tile_size), n_reps_(n_reps), seqlen_k_(seqlen_k), past_present_share_buffer_(past_present_share_buffer), is_first_prompt_(is_first_prompt) { + VxAttentionScoreProgram(const std::string& kernel_name, bool feed_past_value, bool has_present_value, int tile_size, bool is_first_prompt, bool is_packed_qkv, int n_reps = 1, const Tensor* seqlen_k = nullptr, bool past_present_share_buffer = false) + : Program{kernel_name}, feed_past_value_(feed_past_value), has_present_value_(has_present_value), tile_size_(tile_size), n_reps_(n_reps), seqlen_k_(seqlen_k), past_present_share_buffer_(past_present_share_buffer), is_first_prompt_(is_first_prompt), is_packed_qkv_(is_packed_qkv) { } Status GenerateShaderCode(ShaderHelper& sh) const override; @@ -118,6 +119,7 @@ class VxAttentionScoreProgram final : public Program { const Tensor* seqlen_k_; bool past_present_share_buffer_; bool is_first_prompt_; + bool is_packed_qkv_; }; } // namespace webgpu diff --git a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc index 9649677119f1c..819371d60bd04 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc @@ -64,10 +64,10 @@ Status GeneratePositionIDs(onnxruntime::webgpu::ComputeContext& context, const W return context.RunProgram(program); } -Status RunRotaryEmbedding(onnxruntime::webgpu::ComputeContext& context, const WebgpuAttentionParameters& params, const Tensor* input, const Tensor* pos_ids, const Tensor* cos_cache, const Tensor* sin_cache, Tensor* output) { +Status RunRotaryEmbedding(onnxruntime::webgpu::ComputeContext& context, const WebgpuAttentionParameters& params, const Tensor* input, const Tensor* pos_ids, const Tensor* cos_cache, const Tensor* sin_cache, Tensor* output, bool is_packed_qkv, bool is_query_input) { const auto half_rotary_embedding_dim = gsl::narrow(cos_cache->Shape()[1]); - - const TensorShape global_shape({params.batch_size_, params.sequence_length_, params.num_heads_, params.head_size_ - half_rotary_embedding_dim}); + auto num_heads = is_packed_qkv ? params.num_heads_ + 2 * params.kv_num_heads_ : (is_query_input ? params.num_heads_ : params.kv_num_heads_); + const TensorShape global_shape({params.batch_size_, params.sequence_length_, num_heads, params.head_size_ - half_rotary_embedding_dim}); const auto rank = global_shape.NumDimensions(); std::vector global_dims(rank); std::vector global_strides(rank); @@ -122,23 +122,6 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& scale_, softcap_)); WebgpuAttentionParameters parameters(params); - if (parameters.is_packed_qkv_) { - ORT_NOT_IMPLEMENTED("Packed QKV of shape (B, L, N, 3, H) not implemented for webgpu-ep."); - } - if (do_rotary_) { - Tensor q = context.CreateGPUTensor(query->DataType(), query->Shape()); - Tensor k = context.CreateGPUTensor(key->DataType(), key->Shape()); - TensorShape pos_ids_shape = parameters.is_first_prompt_ ? TensorShape({1}) : TensorShape({parameters.batch_size_ * parameters.sequence_length_}); - Tensor pos_ids = context.CreateGPUTensor(DataTypeImpl::GetType(), pos_ids_shape); - ORT_RETURN_IF_ERROR(GeneratePositionIDs(context, parameters, seqlens_k, &pos_ids)); - - ORT_RETURN_IF_ERROR(RunRotaryEmbedding(context, parameters, query, &pos_ids, cos_cache, sin_cache, &q)); - - ORT_RETURN_IF_ERROR(RunRotaryEmbedding(context, parameters, key, &pos_ids, cos_cache, sin_cache, &k)); - - query = &q; - key = &k; - } TensorShapeVector output_shape(3); output_shape[0] = static_cast(parameters.batch_size_); @@ -155,32 +138,60 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& Tensor* present_value = context.Output(2, present_kv_shape); parameters.past_present_share_buffer_ = present_key != nullptr && present_value != nullptr && past_key != nullptr && past_value != nullptr && past_key->DataRaw() == present_key->DataRaw() && past_value->DataRaw() == present_value->DataRaw(); - TensorShapeVector q_new_dims({parameters.batch_size_, parameters.num_heads_, + TensorShapeVector q_new_dims({parameters.batch_size_, parameters.is_packed_qkv_ ? parameters.num_heads_ + 2 * parameters.kv_num_heads_ : parameters.num_heads_, parameters.sequence_length_, parameters.head_size_}); TensorShape q_new_shape(q_new_dims); - Tensor Q = context.CreateGPUTensor(query->DataType(), q_new_shape); + Tensor qBNSH = context.CreateGPUTensor(query->DataType(), q_new_shape); ORT_RETURN_IF_ERROR(TransferBSDToBNSH( - context, parameters.num_heads_, parameters.sequence_length_, parameters.head_size_, query, nullptr, 0, &Q)); - if (parameters.qkv_format_ == Q_K_V_BSNH_BNSH_BNSH) { // key and value in BNSH format - return ApplyAttention(&Q, key, value, nullptr, past_key, past_value, output, present_key, + context, parameters.num_heads_, parameters.sequence_length_, parameters.head_size_, query, nullptr, 0, &qBNSH)); + if (!parameters.is_packed_qkv_) { + TensorShapeVector k_new_dims({parameters.batch_size_, parameters.kv_num_heads_, + parameters.kv_sequence_length_, parameters.head_size_}); + TensorShape k_new_shape(k_new_dims); + Tensor kBNSH = context.CreateGPUTensor(key->DataType(), k_new_shape); + ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.kv_num_heads_, parameters.kv_sequence_length_, + parameters.head_size_, key, nullptr, 0, &kBNSH)); + + TensorShapeVector v_new_dims({parameters.batch_size_, parameters.kv_num_heads_, + parameters.kv_sequence_length_, parameters.v_head_size_}); + TensorShape v_new_shape(v_new_dims); + Tensor vBNSH = context.CreateGPUTensor(value->DataType(), v_new_shape); + ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.kv_num_heads_, parameters.kv_sequence_length_, + parameters.v_head_size_, value, nullptr, 0, &vBNSH)); + + if (do_rotary_) { + Tensor qRotary = context.CreateGPUTensor(qBNSH.DataType(), qBNSH.Shape()); + Tensor kRotary = context.CreateGPUTensor(kBNSH.DataType(), kBNSH.Shape()); + TensorShape pos_ids_shape = parameters.is_first_prompt_ ? TensorShape({1}) : TensorShape({parameters.batch_size_ * parameters.sequence_length_}); + Tensor pos_ids = context.CreateGPUTensor(DataTypeImpl::GetType(), pos_ids_shape); + ORT_RETURN_IF_ERROR(GeneratePositionIDs(context, parameters, seqlens_k, &pos_ids)); + + ORT_RETURN_IF_ERROR(RunRotaryEmbedding(context, parameters, &qBNSH, &pos_ids, cos_cache, sin_cache, &qRotary, /* is_packed_qkv = */ false, /* is_query_input = */ true)); + + ORT_RETURN_IF_ERROR(RunRotaryEmbedding(context, parameters, &kBNSH, &pos_ids, cos_cache, sin_cache, &kRotary, /* is_packed_qkv = */ false, /* is_query_input = */ false)); + return ApplyAttention(&qRotary, &kRotary, &vBNSH, nullptr, past_key, past_value, output, present_key, + present_value, parameters, context, seqlens_k); + } else { + return ApplyAttention(&qBNSH, &kBNSH, &vBNSH, nullptr, past_key, past_value, output, present_key, + present_value, parameters, context, seqlens_k); + } + } else { + // Q, K and V are packed. Both key and value are nullptr + if (parameters.do_rotary_) { + Tensor qRotary = context.CreateGPUTensor(qBNSH.DataType(), qBNSH.Shape()); + TensorShape pos_ids_shape = parameters.is_first_prompt_ ? TensorShape({1}) : TensorShape({parameters.batch_size_ * parameters.sequence_length_}); + Tensor pos_ids = context.CreateGPUTensor(DataTypeImpl::GetType(), pos_ids_shape); + ORT_RETURN_IF_ERROR(GeneratePositionIDs(context, parameters, seqlens_k, &pos_ids)); + + ORT_RETURN_IF_ERROR(RunRotaryEmbedding(context, parameters, &qBNSH, &pos_ids, cos_cache, sin_cache, &qRotary, /* is_packed_qkv = */ true, /* is_query_input = */ true)); + + return ApplyAttention(&qRotary, nullptr, nullptr, nullptr, past_key, past_value, output, present_key, present_value, parameters, context, seqlens_k); + } else { + return ApplyAttention(&qBNSH, nullptr, nullptr, nullptr, past_key, past_value, output, present_key, + present_value, parameters, context, seqlens_k); + } } - - TensorShapeVector k_new_dims({parameters.batch_size_, parameters.kv_num_heads_, - parameters.kv_sequence_length_, parameters.head_size_}); - TensorShape k_new_shape(k_new_dims); - Tensor K = context.CreateGPUTensor(key->DataType(), k_new_shape); - ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.kv_num_heads_, parameters.kv_sequence_length_, - parameters.head_size_, key, nullptr, 0, &K)); - - TensorShapeVector v_new_dims({parameters.batch_size_, parameters.kv_num_heads_, - parameters.kv_sequence_length_, parameters.v_head_size_}); - TensorShape v_new_shape(v_new_dims); - Tensor V = context.CreateGPUTensor(value->DataType(), v_new_shape); - ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.kv_num_heads_, parameters.kv_sequence_length_, - parameters.v_head_size_, value, nullptr, 0, &V)); - return ApplyAttention(&Q, &K, &V, nullptr, past_key, past_value, output, present_key, - present_value, parameters, context, seqlens_k); } } // namespace webgpu From ae87526406cf2cab0cbf27784832f94a32390ed0 Mon Sep 17 00:00:00 2001 From: Satya Kumar Jandhyala Date: Thu, 16 Jan 2025 14:14:08 -0800 Subject: [PATCH 03/23] Fix lint errors. --- onnxruntime/contrib_ops/webgpu/bert/attention.cc | 4 ++-- .../contrib_ops/webgpu/bert/group_query_attention.cc | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.cc b/onnxruntime/contrib_ops/webgpu/bert/attention.cc index c0ec6dce4a1f7..7ad3a736fdc18 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.cc @@ -140,8 +140,8 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const { << " }\n"; } else { shader.MainFunctionBody() << " if (n + local_id.y < uniforms.kv_sequence_length) {\n" - " tileK[idx] = " << (is_packed_qkv_ ? "q" : "key") << "[kOffset + (n + local_id.y) * uniforms.K + w + local_id.x];\n" - " }\n"; + << " tileK[idx] = " << (is_packed_qkv_ ? "q" : "key") << "[kOffset + (n + local_id.y) * uniforms.K + w + local_id.x];\n" + << " }\n"; } if (has_present_key_) { diff --git a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc index 819371d60bd04..03f75f387b5db 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc @@ -170,10 +170,10 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& ORT_RETURN_IF_ERROR(RunRotaryEmbedding(context, parameters, &kBNSH, &pos_ids, cos_cache, sin_cache, &kRotary, /* is_packed_qkv = */ false, /* is_query_input = */ false)); return ApplyAttention(&qRotary, &kRotary, &vBNSH, nullptr, past_key, past_value, output, present_key, - present_value, parameters, context, seqlens_k); + present_value, parameters, context, seqlens_k); } else { return ApplyAttention(&qBNSH, &kBNSH, &vBNSH, nullptr, past_key, past_value, output, present_key, - present_value, parameters, context, seqlens_k); + present_value, parameters, context, seqlens_k); } } else { // Q, K and V are packed. Both key and value are nullptr @@ -186,10 +186,10 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& ORT_RETURN_IF_ERROR(RunRotaryEmbedding(context, parameters, &qBNSH, &pos_ids, cos_cache, sin_cache, &qRotary, /* is_packed_qkv = */ true, /* is_query_input = */ true)); return ApplyAttention(&qRotary, nullptr, nullptr, nullptr, past_key, past_value, output, present_key, - present_value, parameters, context, seqlens_k); + present_value, parameters, context, seqlens_k); } else { return ApplyAttention(&qBNSH, nullptr, nullptr, nullptr, past_key, past_value, output, present_key, - present_value, parameters, context, seqlens_k); + present_value, parameters, context, seqlens_k); } } } From df90ffafe78768bb4514d8141b0f2b23d45e39fe Mon Sep 17 00:00:00 2001 From: Satya Kumar Jandhyala Date: Thu, 16 Jan 2025 14:14:55 -0800 Subject: [PATCH 04/23] Fixed shader code compilation errors. --- onnxruntime/contrib_ops/webgpu/bert/attention.cc | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.cc b/onnxruntime/contrib_ops/webgpu/bert/attention.cc index 7ad3a736fdc18..8f27c9d6f6f2c 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.cc @@ -98,7 +98,7 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const { } shader.AdditionalImplementation() << "var tileQ: array;\n" - << "var tileK: array;\n" + << "var tileK: array<" << (is_packed_qkv_ ? "q_value_t" : "key_value_t") << ", " << tile_size_ * tile_size_ << ">;\n" << "alias f32_val_t = " << (components_ == 4 ? "vec4" : (components_ == 2 ? "vec2" : "f32")) << ";\n"; shader.MainFunctionBody() << "// x holds the N and y holds the M\n" << "let m = workgroup_id.y * TILE_SIZE;\n" @@ -108,10 +108,11 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const { << "var total_sequence_length = uniforms.N;\n"; if (is_packed_qkv_) { shader.MainFunctionBody() << "let head_idx = workgroup_id.z % uniforms.num_heads;\n" - << "let kv_num_heads = uniforms.kv_num_heads /" << n_reps_ << ";\n" + << "let kv_num_heads = uniforms.num_heads /" << n_reps_ << ";\n" << "let packed_batch_stride = (uniforms.num_heads + 2 * kv_num_heads) * uniforms.M * uniforms.K;\n" << "let qOffset = batch_idx * packed_batch_stride + head_idx * uniforms.M * uniforms.K;\n" - << "let kOffset = batchIdx * packed_batch_stride + (uniforms.num_heads + kvHeadIdx) * uniforms.kv_sequence_length * uniforms.K;\n"; + << "let kvHeadIdx = head_idx % kv_num_heads;\n" + << "let kOffset = batch_idx * packed_batch_stride + (uniforms.num_heads + kvHeadIdx) * uniforms.kv_sequence_length * uniforms.K;\n"; } else { shader.MainFunctionBody() << "let qOffset = workgroup_id.z * uniforms.M * uniforms.K + m * uniforms.K;\n" << "let kOffset = (workgroup_id.z / " << n_reps_ << ") * uniforms.kv_sequence_length * uniforms.K;\n"; @@ -346,6 +347,7 @@ Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const { if (is_packed_qkv_) { shader.MainFunctionBody() << "let kv_num_heads = uniforms.num_heads / " << n_reps_ << ";\n" << "let packed_batch_stride = (uniforms.num_heads + 2 * kv_num_heads) * uniforms.M * uniforms.K;\n" + << "let kvHeadIdx = head_idx % kv_num_heads;\n" << "let vOffset = batch_idx * packed_batch_stride + (uniforms.num_heads + kv_num_heads + kvHeadIdx) * uniforms.N * uniforms.kv_sequence_length + n;\n"; } else { shader.MainFunctionBody() << "let vOffset = (workgroup_id.z / " << n_reps_ << ") * uniforms.N * uniforms.kv_sequence_length + n;\n"; From 07044628e281280a697302b286bbc80f04dafcef Mon Sep 17 00:00:00 2001 From: Satya Kumar Jandhyala Date: Thu, 16 Jan 2025 14:20:19 -0800 Subject: [PATCH 05/23] more lint stuff --- onnxruntime/contrib_ops/webgpu/bert/attention.cc | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.cc b/onnxruntime/contrib_ops/webgpu/bert/attention.cc index 8f27c9d6f6f2c..0abfa5746b387 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.cc @@ -125,12 +125,12 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const { } shader.MainFunctionBody() << "var value = f32_val_t(0);\n" - "for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) {\n" - " if (global_id.y < uniforms.M && w + local_id.x < uniforms.K) {\n" - " tileQ[TILE_SIZE * local_id.y + local_id.x] = q[qOffset + local_id.y * uniforms.K + w + local_id.x];\n" - " }\n" - " if (n + local_id.y < uniforms.N && w + local_id.x < uniforms.K) {\n" - " var idx = TILE_SIZE * local_id.y + local_id.x;\n"; + << "for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) {\n" + << " if (global_id.y < uniforms.M && w + local_id.x < uniforms.K) {\n" + << " tileQ[TILE_SIZE * local_id.y + local_id.x] = q[qOffset + local_id.y * uniforms.K + w + local_id.x];\n" + << " }\n" + << " if (n + local_id.y < uniforms.N && w + local_id.x < uniforms.K) {\n" + << " var idx = TILE_SIZE * local_id.y + local_id.x;\n"; if ((feed_past_key_ && has_present_key_) || (past_present_share_buffer_ && !is_first_prompt_)) { shader.MainFunctionBody() << " if (n + local_id.y < past_sequence_length) {\n" From 177f535bb2c05db6faf6e767e1d8cc5686f74421 Mon Sep 17 00:00:00 2001 From: Satya Kumar Jandhyala Date: Thu, 16 Jan 2025 17:07:57 -0800 Subject: [PATCH 06/23] Fixed shader code issues. --- .../webgpu/bert/group_query_attention.cc | 18 ++++++++++-------- .../webgpu/bert/group_query_attention.h | 2 +- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc index 03f75f387b5db..78e9e1538eea6 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc @@ -36,8 +36,8 @@ Status GeneratePositionIDsProgram::GenerateShaderCode(ShaderHelper& sh) const { sh.MainFunctionBody() << "let batch_idx = global_idx / uniforms.sequence_length;\n" << "let sequence_idx = i32(global_idx % uniforms.sequence_length);\n" << "var pos_id: u32 = 0u;\n" - << "if (is_first_prompt == 0) {\n" - << " let total_seqlen = ${seqLensInputHelper.getByOffset('batch_idx')} + 1;\n" + << "if (uniforms.is_first_prompt == 0) {\n" + << " let total_seqlen = seqlens[batch_idx] + 1;\n" << " let past_seqlen = total_seqlen - i32(uniforms.sequence_length);\n" << " if (past_seqlen + sequence_idx < total_seqlen) {\n" << " pos_id = u32(past_seqlen + sequence_idx);\n" @@ -45,12 +45,13 @@ Status GeneratePositionIDsProgram::GenerateShaderCode(ShaderHelper& sh) const { << " pos_id = 1u;\n" << " }\n" << "}\n" - << "output[global_idx] = pos_id;\n"; + << "output[global_idx] = vec2(pos_id, u32(select(0, 0xFFFFFFF, pos_id < 0)));\n"; return Status::OK(); } Status GeneratePositionIDs(onnxruntime::webgpu::ComputeContext& context, const WebgpuAttentionParameters& params, const Tensor* seqlens, Tensor* output_tensor) { GeneratePositionIDsProgram program(params); + auto output_size = params.is_first_prompt_ ? 1 : params.batch_size_ * params.sequence_length_; program.AddInput(seqlens) .AddOutput(output_tensor) .AddUniformVariables({{static_cast(params.batch_size_)}, @@ -60,14 +61,15 @@ Status GeneratePositionIDs(onnxruntime::webgpu::ComputeContext& context, const W {static_cast(params.rotary_dim_)}, {static_cast(params.rotary_interleaved_)}, {static_cast(params.is_first_prompt_ ? 0 : 1)}, - {static_cast(params.total_sequence_length_)}}); + {static_cast(params.total_sequence_length_)}}) + .SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE); return context.RunProgram(program); } Status RunRotaryEmbedding(onnxruntime::webgpu::ComputeContext& context, const WebgpuAttentionParameters& params, const Tensor* input, const Tensor* pos_ids, const Tensor* cos_cache, const Tensor* sin_cache, Tensor* output, bool is_packed_qkv, bool is_query_input) { const auto half_rotary_embedding_dim = gsl::narrow(cos_cache->Shape()[1]); auto num_heads = is_packed_qkv ? params.num_heads_ + 2 * params.kv_num_heads_ : (is_query_input ? params.num_heads_ : params.kv_num_heads_); - const TensorShape global_shape({params.batch_size_, params.sequence_length_, num_heads, params.head_size_ - half_rotary_embedding_dim}); + const TensorShape global_shape({params.batch_size_, params.sequence_length_, num_heads, static_cast(params.head_size_) - static_cast(half_rotary_embedding_dim)}); const auto rank = global_shape.NumDimensions(); std::vector global_dims(rank); std::vector global_strides(rank); @@ -162,7 +164,7 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& if (do_rotary_) { Tensor qRotary = context.CreateGPUTensor(qBNSH.DataType(), qBNSH.Shape()); Tensor kRotary = context.CreateGPUTensor(kBNSH.DataType(), kBNSH.Shape()); - TensorShape pos_ids_shape = parameters.is_first_prompt_ ? TensorShape({1}) : TensorShape({parameters.batch_size_ * parameters.sequence_length_}); + TensorShape pos_ids_shape = parameters.is_first_prompt_ ? TensorShape({1}) : TensorShape({static_cast(parameters.batch_size_) * static_cast(parameters.sequence_length_)}); Tensor pos_ids = context.CreateGPUTensor(DataTypeImpl::GetType(), pos_ids_shape); ORT_RETURN_IF_ERROR(GeneratePositionIDs(context, parameters, seqlens_k, &pos_ids)); @@ -177,9 +179,9 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& } } else { // Q, K and V are packed. Both key and value are nullptr - if (parameters.do_rotary_) { + if (do_rotary_) { Tensor qRotary = context.CreateGPUTensor(qBNSH.DataType(), qBNSH.Shape()); - TensorShape pos_ids_shape = parameters.is_first_prompt_ ? TensorShape({1}) : TensorShape({parameters.batch_size_ * parameters.sequence_length_}); + TensorShape pos_ids_shape = parameters.is_first_prompt_ ? TensorShape({1}) : TensorShape({static_cast(parameters.batch_size_) * static_cast(parameters.sequence_length_)}); Tensor pos_ids = context.CreateGPUTensor(DataTypeImpl::GetType(), pos_ids_shape); ORT_RETURN_IF_ERROR(GeneratePositionIDs(context, parameters, seqlens_k, &pos_ids)); diff --git a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.h b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.h index a06bc1946d044..1eb87048dfa09 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.h @@ -28,7 +28,7 @@ class GeneratePositionIDsProgram final : public Program Date: Tue, 21 Jan 2025 14:24:13 -0800 Subject: [PATCH 07/23] Added split functionality to unpack packed-QKV. --- .../webgpu/bert/group_query_attention.cc | 36 +++++++++++++++++++ .../webgpu/bert/group_query_attention.h | 10 ++++++ 2 files changed, 46 insertions(+) diff --git a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc index 78e9e1538eea6..1632dac648ed6 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc @@ -30,6 +30,42 @@ ONNX_OPERATOR_KERNEL_EX( .InputMemoryType(OrtMemTypeCPUInput, 6), GroupQueryAttention); +Status SplitPackedQKVProgram::GenerateShaderCode(ShaderHelper& sh) const { + const auto& packed_qkv = sh.AddInput("packed_qkv", ShaderUsage::UseOffsetToIndices | ShaderUsage::UseUniform); + const auto& query = sh.AddOutput("query", ShaderUsage::UseSetByIndices | ShaderUsage::UseUniform); + const auto& key = sh.AddOutput("key", ShaderUsage::UseSetByIndices | ShaderUsage::UseUniform); + const auto& value = sh.AddOutput("val", ShaderUsage::UseSetByIndices | ShaderUsage::UseUniform); + sh.MainFunctionBody() << " let packed_qkv_indices = " << packed_qkv.OffsetToIndices("global_idx") << ";\n" + << " let input_data = " << packed_qkv.GetByOffset("global_idx") << ";\n" + << " let index = " << packed_qkv.IndicesGet("packed_qkv_indices", "2") << ";\n" + << " if (index < uniforms.hidden_size) {\n" + << " " << query.SetByIndices("packed_qkv_indices", "input_data") << ";\n" + << " } else if (index < (uniforms.hidden_size + uniforms.kv_hidden_size)) {\n" + << " var key_indices = packed_qkv_indices;\n" + << " " << key.IndicesSet("key_indices", "2", "u32(index - uniforms.hidden_size)") << ";\n" + << " " << key.SetByIndices("key_indices", "input_data") << ";\n" + << " } else {\n" + << " var val_indices = packed_qkv_indices;\n" + << " " << value.IndicesSet("val_indices", "2", "u32(index - uniforms.hidden_size - uniforms.kv_hidden_size)") << ";\n" + << " " << value.SetByIndices("val_indices", "input_data") << ";\n" + << " }"; + return Status::OK(); +} + +Status SplitPackedQKV(onnxruntime::webgpu::ComputeContext& context, const WebgpuAttentionParameters& params, const Tensor* packedQKV, Tensor* query, Tensor* key, Tensor* val) { + SplitPackedQKVProgram program(params); + auto input_size = packedQKV->Shape().Size(); + program + .AddInput({packedQKV, ProgramTensorMetadataDependency::Rank}) + .AddOutputs({{query, ProgramTensorMetadataDependency::Rank}, {key, ProgramTensorMetadataDependency::Rank}, {val, ProgramTensorMetadataDependency::Rank}}) + .AddUniformVariables({ + {static_cast(params.hidden_size_)}, + {static_cast(params.kv_hidden_size_)}, + }) + .SetDispatchGroupSize((input_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE); + return context.RunProgram(program); +} + Status GeneratePositionIDsProgram::GenerateShaderCode(ShaderHelper& sh) const { sh.AddInput("seqlens", ShaderUsage::UseUniform); sh.AddOutput("output", ShaderUsage::UseUniform); diff --git a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.h b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.h index 1eb87048dfa09..4c0b750e4526a 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.h @@ -30,6 +30,16 @@ class GeneratePositionIDsProgram final : public Program { + public: + SplitPackedQKVProgram(const WebgpuAttentionParameters& params) : Program{"SplitPackedQKV"}, params_(params) {} + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"hidden_size", ProgramUniformVariableDataType::Uint32}, + {"kv_hidden_size", ProgramUniformVariableDataType::Uint32}); private: const WebgpuAttentionParameters& params_; From f0d238a3e843950be7a2325aa94f4f0689c22ec0 Mon Sep 17 00:00:00 2001 From: Satya Kumar Jandhyala Date: Tue, 21 Jan 2025 14:29:39 -0800 Subject: [PATCH 08/23] Removed unnecessary uniforms in GeneratePositionIdsProgram --- .../webgpu/bert/group_query_attention.cc | 18 ++++++------------ .../webgpu/bert/group_query_attention.h | 11 ++--------- 2 files changed, 8 insertions(+), 21 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc index 1632dac648ed6..53de4270ff8f2 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc @@ -86,18 +86,12 @@ Status GeneratePositionIDsProgram::GenerateShaderCode(ShaderHelper& sh) const { } Status GeneratePositionIDs(onnxruntime::webgpu::ComputeContext& context, const WebgpuAttentionParameters& params, const Tensor* seqlens, Tensor* output_tensor) { - GeneratePositionIDsProgram program(params); - auto output_size = params.is_first_prompt_ ? 1 : params.batch_size_ * params.sequence_length_; - program.AddInput(seqlens) - .AddOutput(output_tensor) - .AddUniformVariables({{static_cast(params.batch_size_)}, - {static_cast(params.sequence_length_)}, - {static_cast(params.num_heads_)}, - {static_cast(params.head_size_)}, - {static_cast(params.rotary_dim_)}, - {static_cast(params.rotary_interleaved_)}, - {static_cast(params.is_first_prompt_ ? 0 : 1)}, - {static_cast(params.total_sequence_length_)}}) + GeneratePositionIDsProgram program; + auto output_size = params.batch_size_ * params.sequence_length_; + program.CacheHint(params.batch_size_, params.sequence_length_) + .AddInput({seqlens, ProgramTensorMetadataDependency::Rank}) + .AddOutput({output_tensor, ProgramTensorMetadataDependency::Rank}) + .AddUniformVariables({{static_cast(params.sequence_length_)}}) .SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE); return context.RunProgram(program); } diff --git a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.h b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.h index 4c0b750e4526a..55127e5017ad7 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.h @@ -18,18 +18,11 @@ Status GeneratePositionIDs(onnxruntime::webgpu::ComputeContext& context, bool is class GeneratePositionIDsProgram final : public Program { public: - GeneratePositionIDsProgram(const WebgpuAttentionParameters& params) : Program{"GeneratePositionIDs"}, params_(params) {} + GeneratePositionIDsProgram() : Program{"GeneratePositionIDs"} {} Status GenerateShaderCode(ShaderHelper& sh) const override; - WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"batch_size", ProgramUniformVariableDataType::Uint32}, - {"sequence_length", ProgramUniformVariableDataType::Uint32}, - {"num_heads", ProgramUniformVariableDataType::Uint32}, - {"head_size", ProgramUniformVariableDataType::Uint32}, - {"rotary_embedding_dim", ProgramUniformVariableDataType::Uint32}, - {"interleaved", ProgramUniformVariableDataType::Uint32}, - {"is_first_prompt", ProgramUniformVariableDataType::Uint32}, - {"total_seqlen", ProgramUniformVariableDataType::Uint32}); + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"sequence_length", ProgramUniformVariableDataType::Uint32}); }; class SplitPackedQKVProgram final : public Program { From 1009fc9a63bfc9eac80e531dc5174e715667a4e6 Mon Sep 17 00:00:00 2001 From: Satya Kumar Jandhyala Date: Tue, 21 Jan 2025 14:31:53 -0800 Subject: [PATCH 09/23] Apply split and rotrary embedding before converting input ro BSD to BNSH --- .../webgpu/bert/group_query_attention.cc | 126 ++++++++---------- 1 file changed, 59 insertions(+), 67 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc index 53de4270ff8f2..1bf3a50c67ce0 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc @@ -67,21 +67,17 @@ Status SplitPackedQKV(onnxruntime::webgpu::ComputeContext& context, const Webgpu } Status GeneratePositionIDsProgram::GenerateShaderCode(ShaderHelper& sh) const { - sh.AddInput("seqlens", ShaderUsage::UseUniform); - sh.AddOutput("output", ShaderUsage::UseUniform); - sh.MainFunctionBody() << "let batch_idx = global_idx / uniforms.sequence_length;\n" - << "let sequence_idx = i32(global_idx % uniforms.sequence_length);\n" - << "var pos_id: u32 = 0u;\n" - << "if (uniforms.is_first_prompt == 0) {\n" - << " let total_seqlen = seqlens[batch_idx] + 1;\n" + const auto& output = sh.AddOutput("output", ShaderUsage::UseUniform); + const auto& seqlens = sh.AddInput("seqlens", ShaderUsage::UseUniform); + sh.MainFunctionBody() << " let batch_idx = global_idx / uniforms.sequence_length;\n" + << " let sequence_idx = i32(global_idx % uniforms.sequence_length);\n" + << " var pos_id: u32 = 1u;\n" + << " let total_seqlen = " << seqlens.GetByOffset("batch_idx") << " + 1;\n" << " let past_seqlen = total_seqlen - i32(uniforms.sequence_length);\n" << " if (past_seqlen + sequence_idx < total_seqlen) {\n" << " pos_id = u32(past_seqlen + sequence_idx);\n" - << " } else {\n" - << " pos_id = 1u;\n" - << " }\n" - << "}\n" - << "output[global_idx] = vec2(pos_id, u32(select(0, 0xFFFFFFF, pos_id < 0)));\n"; + << " }\n"; + sh.MainFunctionBody() << " " << output.SetByOffset("global_idx", "pos_id"); return Status::OK(); } @@ -96,10 +92,10 @@ Status GeneratePositionIDs(onnxruntime::webgpu::ComputeContext& context, const W return context.RunProgram(program); } -Status RunRotaryEmbedding(onnxruntime::webgpu::ComputeContext& context, const WebgpuAttentionParameters& params, const Tensor* input, const Tensor* pos_ids, const Tensor* cos_cache, const Tensor* sin_cache, Tensor* output, bool is_packed_qkv, bool is_query_input) { +Status RunRotaryEmbedding(onnxruntime::webgpu::ComputeContext& context, const WebgpuAttentionParameters& params, const Tensor* input, const Tensor* pos_ids, const Tensor* cos_cache, const Tensor* sin_cache, Tensor* output, bool is_query_input) { const auto half_rotary_embedding_dim = gsl::narrow(cos_cache->Shape()[1]); - auto num_heads = is_packed_qkv ? params.num_heads_ + 2 * params.kv_num_heads_ : (is_query_input ? params.num_heads_ : params.kv_num_heads_); - const TensorShape global_shape({params.batch_size_, params.sequence_length_, num_heads, static_cast(params.head_size_) - static_cast(half_rotary_embedding_dim)}); + const auto head_size = params.rotary_dim_ == 0 ? half_rotary_embedding_dim * 2 : params.head_size_; + const TensorShape global_shape({params.batch_size_, params.sequence_length_, (is_query_input ? params.hidden_size_ : params.kv_hidden_size_) / head_size, static_cast(head_size) - static_cast(half_rotary_embedding_dim)}); const auto rank = global_shape.NumDimensions(); std::vector global_dims(rank); std::vector global_strides(rank); @@ -107,7 +103,7 @@ Status RunRotaryEmbedding(onnxruntime::webgpu::ComputeContext& context, const We global_dims[j] = gsl::narrow(global_shape[j]); global_strides[j] = gsl::narrow(global_shape.SizeFromDimension(j + 1)); } - const auto input_output_strides = std::vector({gsl::narrow(params.batch_size_), gsl::narrow(params.hidden_size_), gsl::narrow(params.head_size_), 1}); + const auto input_output_strides = std::vector({gsl::narrow(input->Shape().SizeFromDimension(1)), gsl::narrow(params.sequence_length_) * gsl::narrow(head_size), head_size, 1}); const auto output_size = gsl::narrow(global_shape.Size()); RotaryEmbeddingProgram program(params.rotary_interleaved_); @@ -138,7 +134,7 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& const Tensor* cos_cache = context.Input(7); const Tensor* sin_cache = context.Input(8); - GroupQueryAttentionParameters params; + GroupQueryAttentionParameters params = {}; ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckInputs(query, key, value, @@ -170,60 +166,56 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& Tensor* present_value = context.Output(2, present_kv_shape); parameters.past_present_share_buffer_ = present_key != nullptr && present_value != nullptr && past_key != nullptr && past_value != nullptr && past_key->DataRaw() == present_key->DataRaw() && past_value->DataRaw() == present_value->DataRaw(); - TensorShapeVector q_new_dims({parameters.batch_size_, parameters.is_packed_qkv_ ? parameters.num_heads_ + 2 * parameters.kv_num_heads_ : parameters.num_heads_, - parameters.sequence_length_, parameters.head_size_}); - TensorShape q_new_shape(q_new_dims); - Tensor qBNSH = context.CreateGPUTensor(query->DataType(), q_new_shape); - ORT_RETURN_IF_ERROR(TransferBSDToBNSH( - context, parameters.num_heads_, parameters.sequence_length_, parameters.head_size_, query, nullptr, 0, &qBNSH)); - if (!parameters.is_packed_qkv_) { - TensorShapeVector k_new_dims({parameters.batch_size_, parameters.kv_num_heads_, - parameters.kv_sequence_length_, parameters.head_size_}); - TensorShape k_new_shape(k_new_dims); - Tensor kBNSH = context.CreateGPUTensor(key->DataType(), k_new_shape); - ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.kv_num_heads_, parameters.kv_sequence_length_, - parameters.head_size_, key, nullptr, 0, &kBNSH)); - - TensorShapeVector v_new_dims({parameters.batch_size_, parameters.kv_num_heads_, - parameters.kv_sequence_length_, parameters.v_head_size_}); - TensorShape v_new_shape(v_new_dims); - Tensor vBNSH = context.CreateGPUTensor(value->DataType(), v_new_shape); - ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.kv_num_heads_, parameters.kv_sequence_length_, - parameters.v_head_size_, value, nullptr, 0, &vBNSH)); - - if (do_rotary_) { - Tensor qRotary = context.CreateGPUTensor(qBNSH.DataType(), qBNSH.Shape()); - Tensor kRotary = context.CreateGPUTensor(kBNSH.DataType(), kBNSH.Shape()); - TensorShape pos_ids_shape = parameters.is_first_prompt_ ? TensorShape({1}) : TensorShape({static_cast(parameters.batch_size_) * static_cast(parameters.sequence_length_)}); - Tensor pos_ids = context.CreateGPUTensor(DataTypeImpl::GetType(), pos_ids_shape); - ORT_RETURN_IF_ERROR(GeneratePositionIDs(context, parameters, seqlens_k, &pos_ids)); - - ORT_RETURN_IF_ERROR(RunRotaryEmbedding(context, parameters, &qBNSH, &pos_ids, cos_cache, sin_cache, &qRotary, /* is_packed_qkv = */ false, /* is_query_input = */ true)); + Tensor qSplit; + Tensor kSplit; + Tensor vSplit; + if (parameters.is_packed_qkv_) { + qSplit = context.CreateGPUTensor(query->DataType(), TensorShape({parameters.batch_size_, parameters.sequence_length_, parameters.hidden_size_})); + kSplit = context.CreateGPUTensor(query->DataType(), TensorShape({parameters.batch_size_, parameters.sequence_length_, parameters.kv_hidden_size_})); + vSplit = context.CreateGPUTensor(query->DataType(), TensorShape({parameters.batch_size_, parameters.sequence_length_, parameters.kv_hidden_size_})); + ORT_RETURN_IF_ERROR(SplitPackedQKV(context, parameters, query, &qSplit, &kSplit, &vSplit)); + parameters.is_packed_qkv_ = false; + query = &qSplit; + key = &kSplit; + value = &vSplit; + } - ORT_RETURN_IF_ERROR(RunRotaryEmbedding(context, parameters, &kBNSH, &pos_ids, cos_cache, sin_cache, &kRotary, /* is_packed_qkv = */ false, /* is_query_input = */ false)); - return ApplyAttention(&qRotary, &kRotary, &vBNSH, nullptr, past_key, past_value, output, present_key, - present_value, parameters, context, seqlens_k); - } else { - return ApplyAttention(&qBNSH, &kBNSH, &vBNSH, nullptr, past_key, past_value, output, present_key, - present_value, parameters, context, seqlens_k); - } - } else { - // Q, K and V are packed. Both key and value are nullptr - if (do_rotary_) { - Tensor qRotary = context.CreateGPUTensor(qBNSH.DataType(), qBNSH.Shape()); - TensorShape pos_ids_shape = parameters.is_first_prompt_ ? TensorShape({1}) : TensorShape({static_cast(parameters.batch_size_) * static_cast(parameters.sequence_length_)}); - Tensor pos_ids = context.CreateGPUTensor(DataTypeImpl::GetType(), pos_ids_shape); + Tensor qRotary; + Tensor kRotary; + if (do_rotary_) { + qRotary = context.CreateGPUTensor(query->DataType(), query->Shape()); + kRotary = context.CreateGPUTensor(key->DataType(), key->Shape()); + auto pos_ids_shape = parameters.is_first_prompt_ ? TensorShape({1, 1}) : TensorShape({parameters.batch_size_, parameters.sequence_length_}); + Tensor pos_ids = context.CreateGPUTensor(DataTypeImpl::GetType(), pos_ids_shape); + if (!parameters.is_first_prompt_) { ORT_RETURN_IF_ERROR(GeneratePositionIDs(context, parameters, seqlens_k, &pos_ids)); - - ORT_RETURN_IF_ERROR(RunRotaryEmbedding(context, parameters, &qBNSH, &pos_ids, cos_cache, sin_cache, &qRotary, /* is_packed_qkv = */ true, /* is_query_input = */ true)); - - return ApplyAttention(&qRotary, nullptr, nullptr, nullptr, past_key, past_value, output, present_key, - present_value, parameters, context, seqlens_k); - } else { - return ApplyAttention(&qBNSH, nullptr, nullptr, nullptr, past_key, past_value, output, present_key, - present_value, parameters, context, seqlens_k); } + ORT_RETURN_IF_ERROR(RunRotaryEmbedding(context, parameters, query, &pos_ids, cos_cache, sin_cache, &qRotary, /* is_query_input = */ true)); + + ORT_RETURN_IF_ERROR(RunRotaryEmbedding(context, parameters, key, &pos_ids, cos_cache, sin_cache, &kRotary, /* is_query_input = */ false)); + query = &qRotary; + key = &kRotary; } + + TensorShapeVector q_new_dims({parameters.batch_size_, parameters.num_heads_, parameters.sequence_length_, parameters.head_size_}); + TensorShape q_new_shape(q_new_dims); + Tensor qBNSH = context.CreateGPUTensor(query->DataType(), q_new_shape); + ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.num_heads_, parameters.sequence_length_, parameters.head_size_, query, nullptr, 0, &qBNSH)); + TensorShapeVector k_new_dims({parameters.batch_size_, parameters.kv_num_heads_, + parameters.kv_sequence_length_, parameters.head_size_}); + TensorShape k_new_shape(k_new_dims); + Tensor kBNSH = context.CreateGPUTensor(key->DataType(), k_new_shape); + ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.kv_num_heads_, parameters.kv_sequence_length_, + parameters.head_size_, key, nullptr, 0, &kBNSH)); + + TensorShapeVector v_new_dims({parameters.batch_size_, parameters.kv_num_heads_, + parameters.kv_sequence_length_, parameters.v_head_size_}); + TensorShape v_new_shape(v_new_dims); + Tensor vBNSH = context.CreateGPUTensor(value->DataType(), v_new_shape); + ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.kv_num_heads_, parameters.kv_sequence_length_, + parameters.v_head_size_, value, nullptr, 0, &vBNSH)); + return ApplyAttention(&qBNSH, &kBNSH, &vBNSH, nullptr, past_key, past_value, output, present_key, + present_value, parameters, context, seqlens_k); } } // namespace webgpu From a4d848277ce5695a61826c4ea44c447b24082bd7 Mon Sep 17 00:00:00 2001 From: Satya Kumar Jandhyala Date: Tue, 21 Jan 2025 14:32:49 -0800 Subject: [PATCH 10/23] Fix the input_output_stride for 4-dim input. --- onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc b/onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc index bc8b7493fc916..5e0703f87f496 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc @@ -95,7 +95,7 @@ Status RotaryEmbedding::ComputeInternal(onnxruntime::webgpu::ComputeContext& con input_shape.NumDimensions() == 3 ? std::vector({batch_stride, hidden_size, head_size, 1}) : (input_shape.NumDimensions() == 4 - ? std::vector({batch_stride, head_size, sequence_length * head_size, 1}) + ? std::vector({batch_stride, sequence_length * head_size, head_size, 1}) : std::vector({})); program From e406b817f4a02e2aec979af5e2ff26a49bfc9e44 Mon Sep 17 00:00:00 2001 From: Satya Kumar Jandhyala Date: Mon, 3 Feb 2025 09:54:05 -0800 Subject: [PATCH 11/23] Allocate position_ids tensor size/shape even for the first prompt --- onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc index 1bf3a50c67ce0..1e69c9463ab3c 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc @@ -169,7 +169,7 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& Tensor qSplit; Tensor kSplit; Tensor vSplit; - if (parameters.is_packed_qkv_) { + if (parameters.is_packed_qkv_ && do_rotary_) { qSplit = context.CreateGPUTensor(query->DataType(), TensorShape({parameters.batch_size_, parameters.sequence_length_, parameters.hidden_size_})); kSplit = context.CreateGPUTensor(query->DataType(), TensorShape({parameters.batch_size_, parameters.sequence_length_, parameters.kv_hidden_size_})); vSplit = context.CreateGPUTensor(query->DataType(), TensorShape({parameters.batch_size_, parameters.sequence_length_, parameters.kv_hidden_size_})); @@ -185,7 +185,7 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& if (do_rotary_) { qRotary = context.CreateGPUTensor(query->DataType(), query->Shape()); kRotary = context.CreateGPUTensor(key->DataType(), key->Shape()); - auto pos_ids_shape = parameters.is_first_prompt_ ? TensorShape({1, 1}) : TensorShape({parameters.batch_size_, parameters.sequence_length_}); + auto pos_ids_shape = TensorShape({parameters.batch_size_, parameters.sequence_length_}); Tensor pos_ids = context.CreateGPUTensor(DataTypeImpl::GetType(), pos_ids_shape); if (!parameters.is_first_prompt_) { ORT_RETURN_IF_ERROR(GeneratePositionIDs(context, parameters, seqlens_k, &pos_ids)); From 0b08117e8f24d40d084393ef0c366da79320dce8 Mon Sep 17 00:00:00 2001 From: SatyaKumarJ Date: Tue, 18 Feb 2025 10:19:02 -0800 Subject: [PATCH 12/23] Fixed the input_output_strides --- onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc index 1e69c9463ab3c..ec32de55ecce6 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc @@ -95,7 +95,8 @@ Status GeneratePositionIDs(onnxruntime::webgpu::ComputeContext& context, const W Status RunRotaryEmbedding(onnxruntime::webgpu::ComputeContext& context, const WebgpuAttentionParameters& params, const Tensor* input, const Tensor* pos_ids, const Tensor* cos_cache, const Tensor* sin_cache, Tensor* output, bool is_query_input) { const auto half_rotary_embedding_dim = gsl::narrow(cos_cache->Shape()[1]); const auto head_size = params.rotary_dim_ == 0 ? half_rotary_embedding_dim * 2 : params.head_size_; - const TensorShape global_shape({params.batch_size_, params.sequence_length_, (is_query_input ? params.hidden_size_ : params.kv_hidden_size_) / head_size, static_cast(head_size) - static_cast(half_rotary_embedding_dim)}); + const auto hidden_size = is_query_input ? params.hidden_size_ : params.kv_hidden_size_; + const TensorShape global_shape({params.batch_size_, params.sequence_length_, hidden_size / head_size, static_cast(head_size) - static_cast(half_rotary_embedding_dim)}); const auto rank = global_shape.NumDimensions(); std::vector global_dims(rank); std::vector global_strides(rank); @@ -103,7 +104,7 @@ Status RunRotaryEmbedding(onnxruntime::webgpu::ComputeContext& context, const We global_dims[j] = gsl::narrow(global_shape[j]); global_strides[j] = gsl::narrow(global_shape.SizeFromDimension(j + 1)); } - const auto input_output_strides = std::vector({gsl::narrow(input->Shape().SizeFromDimension(1)), gsl::narrow(params.sequence_length_) * gsl::narrow(head_size), head_size, 1}); + const auto input_output_strides = std::vector({gsl::narrow(input->Shape().SizeFromDimension(1)), gsl::narrow(hidden_size), gsl::narrow(head_size), 1}); const auto output_size = gsl::narrow(global_shape.Size()); RotaryEmbeddingProgram program(params.rotary_interleaved_); From a7328f577e971221d2221748fefc345fcd9256e6 Mon Sep 17 00:00:00 2001 From: SatyaKumarJ Date: Wed, 19 Feb 2025 12:05:04 -0800 Subject: [PATCH 13/23] Added is_first_first prompt to the shader that generates position ids to initialize pos_ids to zeros for the first prompt. --- .../webgpu/bert/group_query_attention.cc | 26 +++++++++---------- .../webgpu/bert/group_query_attention.h | 2 +- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc index ec32de55ecce6..8dddc80b62927 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc @@ -69,14 +69,18 @@ Status SplitPackedQKV(onnxruntime::webgpu::ComputeContext& context, const Webgpu Status GeneratePositionIDsProgram::GenerateShaderCode(ShaderHelper& sh) const { const auto& output = sh.AddOutput("output", ShaderUsage::UseUniform); const auto& seqlens = sh.AddInput("seqlens", ShaderUsage::UseUniform); - sh.MainFunctionBody() << " let batch_idx = global_idx / uniforms.sequence_length;\n" + sh.MainFunctionBody() << "var pos_id: i32 = 0;\n" + << "if (uniforms.is_first_prompt == 0) {\n" + << " let batch_idx = global_idx / uniforms.sequence_length;\n" << " let sequence_idx = i32(global_idx % uniforms.sequence_length);\n" - << " var pos_id: u32 = 1u;\n" << " let total_seqlen = " << seqlens.GetByOffset("batch_idx") << " + 1;\n" << " let past_seqlen = total_seqlen - i32(uniforms.sequence_length);\n" << " if (past_seqlen + sequence_idx < total_seqlen) {\n" - << " pos_id = u32(past_seqlen + sequence_idx);\n" - << " }\n"; + << " pos_id = past_seqlen + sequence_idx;\n" + << " } else {\n" + << " pos_id = 1;\n" + << " }\n" + << "}\n"; sh.MainFunctionBody() << " " << output.SetByOffset("global_idx", "pos_id"); return Status::OK(); } @@ -84,19 +88,18 @@ Status GeneratePositionIDsProgram::GenerateShaderCode(ShaderHelper& sh) const { Status GeneratePositionIDs(onnxruntime::webgpu::ComputeContext& context, const WebgpuAttentionParameters& params, const Tensor* seqlens, Tensor* output_tensor) { GeneratePositionIDsProgram program; auto output_size = params.batch_size_ * params.sequence_length_; - program.CacheHint(params.batch_size_, params.sequence_length_) - .AddInput({seqlens, ProgramTensorMetadataDependency::Rank}) + program.AddInput({seqlens, ProgramTensorMetadataDependency::Rank}) .AddOutput({output_tensor, ProgramTensorMetadataDependency::Rank}) - .AddUniformVariables({{static_cast(params.sequence_length_)}}) + .AddUniformVariables({{static_cast(params.sequence_length_)}, {static_cast(params.is_first_prompt_ ? 1 : 0)}}) .SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE); return context.RunProgram(program); } Status RunRotaryEmbedding(onnxruntime::webgpu::ComputeContext& context, const WebgpuAttentionParameters& params, const Tensor* input, const Tensor* pos_ids, const Tensor* cos_cache, const Tensor* sin_cache, Tensor* output, bool is_query_input) { const auto half_rotary_embedding_dim = gsl::narrow(cos_cache->Shape()[1]); - const auto head_size = params.rotary_dim_ == 0 ? half_rotary_embedding_dim * 2 : params.head_size_; + const auto head_size = params.head_size_; const auto hidden_size = is_query_input ? params.hidden_size_ : params.kv_hidden_size_; - const TensorShape global_shape({params.batch_size_, params.sequence_length_, hidden_size / head_size, static_cast(head_size) - static_cast(half_rotary_embedding_dim)}); + const TensorShape global_shape({params.batch_size_, params.sequence_length_, hidden_size / head_size, static_cast(head_size - half_rotary_embedding_dim)}); const auto rank = global_shape.NumDimensions(); std::vector global_dims(rank); std::vector global_strides(rank); @@ -188,11 +191,8 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& kRotary = context.CreateGPUTensor(key->DataType(), key->Shape()); auto pos_ids_shape = TensorShape({parameters.batch_size_, parameters.sequence_length_}); Tensor pos_ids = context.CreateGPUTensor(DataTypeImpl::GetType(), pos_ids_shape); - if (!parameters.is_first_prompt_) { - ORT_RETURN_IF_ERROR(GeneratePositionIDs(context, parameters, seqlens_k, &pos_ids)); - } + ORT_RETURN_IF_ERROR(GeneratePositionIDs(context, parameters, seqlens_k, &pos_ids)); ORT_RETURN_IF_ERROR(RunRotaryEmbedding(context, parameters, query, &pos_ids, cos_cache, sin_cache, &qRotary, /* is_query_input = */ true)); - ORT_RETURN_IF_ERROR(RunRotaryEmbedding(context, parameters, key, &pos_ids, cos_cache, sin_cache, &kRotary, /* is_query_input = */ false)); query = &qRotary; key = &kRotary; diff --git a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.h b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.h index 55127e5017ad7..161e278de2b4e 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.h @@ -22,7 +22,7 @@ class GeneratePositionIDsProgram final : public Program { From 531c6e314c870b2e5f02107c107b80540580a89a Mon Sep 17 00:00:00 2001 From: SatyaKumarJ Date: Wed, 19 Feb 2025 17:44:33 -0800 Subject: [PATCH 14/23] Fixed position_ids generation code. --- .../webgpu/bert/group_query_attention.cc | 68 +++++++++++-------- .../webgpu/bert/group_query_attention.h | 4 +- 2 files changed, 42 insertions(+), 30 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc index 8dddc80b62927..e9fcfc2cd24c3 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc @@ -69,19 +69,29 @@ Status SplitPackedQKV(onnxruntime::webgpu::ComputeContext& context, const Webgpu Status GeneratePositionIDsProgram::GenerateShaderCode(ShaderHelper& sh) const { const auto& output = sh.AddOutput("output", ShaderUsage::UseUniform); const auto& seqlens = sh.AddInput("seqlens", ShaderUsage::UseUniform); - sh.MainFunctionBody() << "var pos_id: i32 = 0;\n" - << "if (uniforms.is_first_prompt == 0) {\n" + sh.MainFunctionBody() << " var pos_id: i32 = 0;\n" << " let batch_idx = global_idx / uniforms.sequence_length;\n" << " let sequence_idx = i32(global_idx % uniforms.sequence_length);\n" - << " let total_seqlen = " << seqlens.GetByOffset("batch_idx") << " + 1;\n" - << " let past_seqlen = total_seqlen - i32(uniforms.sequence_length);\n" - << " if (past_seqlen + sequence_idx < total_seqlen) {\n" - << " pos_id = past_seqlen + sequence_idx;\n" - << " } else {\n" - << " pos_id = 1;\n" - << " }\n" - << "}\n"; - sh.MainFunctionBody() << " " << output.SetByOffset("global_idx", "pos_id"); + << " let seqlen = " << seqlens.GetByOffset("batch_idx") << ";\n" + << " let total_seqlen = seqlen + 1;\n" + << " if (uniforms.is_first_prompt > 0) {\n" + << " if (sequence_idx < total_seqlen) {\n" + << " pos_id = sequence_idx;\n" + << " } else {\n" + << " pos_id = 1;\n" + << " }\n" + << " " << output.SetByOffset("global_idx", "pos_id") << "\n" + << " } else if (uniforms.is_subsequent_prompt > 0) {\n" + << " let past_seqlen = total_seqlen - i32(uniforms.sequence_length);\n" + << " if (past_seqlen + sequence_idx < total_seqlen) {\n" + << " pos_id = past_seqlen + sequence_idx;\n" + << " } else {\n" + << " pos_id = 1;\n" + << " }\n" + << " " << output.SetByOffset("global_idx", "pos_id") << "\n" + << " } else if (global_idx < uniforms.batch_size) {\n" + << " " << output.SetByOffset("global_idx", "seqlen") << "\n" + << " }\n"; return Status::OK(); } @@ -90,7 +100,7 @@ Status GeneratePositionIDs(onnxruntime::webgpu::ComputeContext& context, const W auto output_size = params.batch_size_ * params.sequence_length_; program.AddInput({seqlens, ProgramTensorMetadataDependency::Rank}) .AddOutput({output_tensor, ProgramTensorMetadataDependency::Rank}) - .AddUniformVariables({{static_cast(params.sequence_length_)}, {static_cast(params.is_first_prompt_ ? 1 : 0)}}) + .AddUniformVariables({{static_cast(params.batch_size_)}, {static_cast(params.sequence_length_)}, {static_cast(params.is_first_prompt_ ? 1 : 0)}, {static_cast(params.is_subsequent_prompt_ ? 1 : 0)}}) .SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE); return context.RunProgram(program); } @@ -201,21 +211,25 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& TensorShapeVector q_new_dims({parameters.batch_size_, parameters.num_heads_, parameters.sequence_length_, parameters.head_size_}); TensorShape q_new_shape(q_new_dims); Tensor qBNSH = context.CreateGPUTensor(query->DataType(), q_new_shape); - ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.num_heads_, parameters.sequence_length_, parameters.head_size_, query, nullptr, 0, &qBNSH)); - TensorShapeVector k_new_dims({parameters.batch_size_, parameters.kv_num_heads_, - parameters.kv_sequence_length_, parameters.head_size_}); - TensorShape k_new_shape(k_new_dims); - Tensor kBNSH = context.CreateGPUTensor(key->DataType(), k_new_shape); - ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.kv_num_heads_, parameters.kv_sequence_length_, - parameters.head_size_, key, nullptr, 0, &kBNSH)); - - TensorShapeVector v_new_dims({parameters.batch_size_, parameters.kv_num_heads_, - parameters.kv_sequence_length_, parameters.v_head_size_}); - TensorShape v_new_shape(v_new_dims); - Tensor vBNSH = context.CreateGPUTensor(value->DataType(), v_new_shape); - ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.kv_num_heads_, parameters.kv_sequence_length_, - parameters.v_head_size_, value, nullptr, 0, &vBNSH)); - return ApplyAttention(&qBNSH, &kBNSH, &vBNSH, nullptr, past_key, past_value, output, present_key, + ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.is_packed_qkv_ ? parameters.num_heads_ + 2 * parameters.kv_num_heads_ : parameters.num_heads_, parameters.sequence_length_, parameters.head_size_, query, nullptr, 0, &qBNSH)); + query = &qBNSH; + Tensor kBNSH; + Tensor vBNSH; + if (nullptr != key) { + TensorShapeVector k_new_dims({parameters.batch_size_, parameters.kv_num_heads_, parameters.kv_sequence_length_, parameters.head_size_}); + TensorShape k_new_shape(k_new_dims); + kBNSH = context.CreateGPUTensor(key->DataType(), k_new_shape); + ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.kv_num_heads_, parameters.kv_sequence_length_, parameters.head_size_, key, nullptr, 0, &kBNSH)); + key = &kBNSH; + } + if (nullptr != value) { + TensorShapeVector v_new_dims({parameters.batch_size_, parameters.kv_num_heads_, parameters.kv_sequence_length_, parameters.v_head_size_}); + TensorShape v_new_shape(v_new_dims); + vBNSH = context.CreateGPUTensor(value->DataType(), v_new_shape); + ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.kv_num_heads_, parameters.kv_sequence_length_, parameters.v_head_size_, value, nullptr, 0, &vBNSH)); + value = &vBNSH; + } + return ApplyAttention(query, key, value, nullptr, past_key, past_value, output, present_key, present_value, parameters, context, seqlens_k); } diff --git a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.h b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.h index 161e278de2b4e..b13ee220fd96c 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.h @@ -14,15 +14,13 @@ namespace webgpu { using namespace onnxruntime::webgpu; -Status GeneratePositionIDs(onnxruntime::webgpu::ComputeContext& context, bool is_first_prompt, int batch_size, int sequence_length, int num_heads, int head_size, int rotary_embedding_dim, bool interleaved, int total_seqlen, const Tensor* seqlens, Tensor* output_tensor); - class GeneratePositionIDsProgram final : public Program { public: GeneratePositionIDsProgram() : Program{"GeneratePositionIDs"} {} Status GenerateShaderCode(ShaderHelper& sh) const override; - WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"sequence_length", ProgramUniformVariableDataType::Uint32}, {"is_first_prompt", ProgramUniformVariableDataType::Uint32}); + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"batch_size", ProgramUniformVariableDataType::Uint32}, {"sequence_length", ProgramUniformVariableDataType::Uint32}, {"is_first_prompt", ProgramUniformVariableDataType::Uint32}, {"is_subsequent_prompt", ProgramUniformVariableDataType::Uint32}); }; class SplitPackedQKVProgram final : public Program { From 29819ed59b98a7c5aa1ec456d1bbd1e7ada3f10f Mon Sep 17 00:00:00 2001 From: SatyaKumarJ Date: Thu, 20 Feb 2025 11:07:27 -0800 Subject: [PATCH 15/23] Check is_first_prompt and is_subsequence_prompt flags in the c++ code instead of the shader code. --- .../webgpu/bert/group_query_attention.cc | 50 ++++++++++--------- .../webgpu/bert/group_query_attention.h | 7 ++- 2 files changed, 32 insertions(+), 25 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc index e9fcfc2cd24c3..412a3505fabee 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc @@ -72,35 +72,39 @@ Status GeneratePositionIDsProgram::GenerateShaderCode(ShaderHelper& sh) const { sh.MainFunctionBody() << " var pos_id: i32 = 0;\n" << " let batch_idx = global_idx / uniforms.sequence_length;\n" << " let sequence_idx = i32(global_idx % uniforms.sequence_length);\n" - << " let seqlen = " << seqlens.GetByOffset("batch_idx") << ";\n" - << " let total_seqlen = seqlen + 1;\n" - << " if (uniforms.is_first_prompt > 0) {\n" - << " if (sequence_idx < total_seqlen) {\n" - << " pos_id = sequence_idx;\n" - << " } else {\n" - << " pos_id = 1;\n" - << " }\n" - << " " << output.SetByOffset("global_idx", "pos_id") << "\n" - << " } else if (uniforms.is_subsequent_prompt > 0) {\n" - << " let past_seqlen = total_seqlen - i32(uniforms.sequence_length);\n" - << " if (past_seqlen + sequence_idx < total_seqlen) {\n" - << " pos_id = past_seqlen + sequence_idx;\n" - << " } else {\n" - << " pos_id = 1;\n" - << " }\n" - << " " << output.SetByOffset("global_idx", "pos_id") << "\n" - << " } else if (global_idx < uniforms.batch_size) {\n" - << " " << output.SetByOffset("global_idx", "seqlen") << "\n" - << " }\n"; + << " let seqlen = " << seqlens.GetByOffset("batch_idx") << ";\n"; + if (is_first_prompt_) { + sh.MainFunctionBody() << " let total_seqlen = seqlen + 1;\n" + << " if (sequence_idx < total_seqlen) {\n" + << " pos_id = sequence_idx;\n" + << " } else {\n" + << " pos_id = 1;\n" + << " }\n" + << " " << output.SetByOffset("global_idx", "pos_id") << "\n"; + } else if (is_subsequent_prompt_) { + sh.MainFunctionBody() << " let total_seqlen = seqlen + 1;\n" + << " let past_seqlen = total_seqlen - i32(uniforms.sequence_length);\n" + << " if (past_seqlen + sequence_idx < total_seqlen) {\n" + << " pos_id = past_seqlen + sequence_idx;\n" + << " } else {\n" + << " pos_id = 1;\n" + << " }\n" + << " " << output.SetByOffset("global_idx", "pos_id") << "\n"; + } else { + sh.MainFunctionBody() << " if (global_idx < uniforms.batch_size) {\n" + << " " << output.SetByOffset("global_idx", "seqlen") << "\n" + << " }\n"; + } return Status::OK(); } Status GeneratePositionIDs(onnxruntime::webgpu::ComputeContext& context, const WebgpuAttentionParameters& params, const Tensor* seqlens, Tensor* output_tensor) { - GeneratePositionIDsProgram program; + GeneratePositionIDsProgram program(params.is_first_prompt_, params.is_subsequent_prompt_); auto output_size = params.batch_size_ * params.sequence_length_; - program.AddInput({seqlens, ProgramTensorMetadataDependency::Rank}) + program.CacheHint(params.is_first_prompt_, params.is_subsequent_prompt_) + .AddInput({seqlens, ProgramTensorMetadataDependency::Rank}) .AddOutput({output_tensor, ProgramTensorMetadataDependency::Rank}) - .AddUniformVariables({{static_cast(params.batch_size_)}, {static_cast(params.sequence_length_)}, {static_cast(params.is_first_prompt_ ? 1 : 0)}, {static_cast(params.is_subsequent_prompt_ ? 1 : 0)}}) + .AddUniformVariables({{static_cast(params.batch_size_)}, {static_cast(params.sequence_length_)}}) .SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE); return context.RunProgram(program); } diff --git a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.h b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.h index b13ee220fd96c..896c1f10a9b55 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.h @@ -16,11 +16,14 @@ using namespace onnxruntime::webgpu; class GeneratePositionIDsProgram final : public Program { public: - GeneratePositionIDsProgram() : Program{"GeneratePositionIDs"} {} + GeneratePositionIDsProgram(bool is_first_prompt, bool is_subsequent_prompt) : Program{"GeneratePositionIDs"}, is_first_prompt_(is_first_prompt), is_subsequent_prompt_(is_subsequent_prompt){} Status GenerateShaderCode(ShaderHelper& sh) const override; - WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"batch_size", ProgramUniformVariableDataType::Uint32}, {"sequence_length", ProgramUniformVariableDataType::Uint32}, {"is_first_prompt", ProgramUniformVariableDataType::Uint32}, {"is_subsequent_prompt", ProgramUniformVariableDataType::Uint32}); + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"batch_size", ProgramUniformVariableDataType::Uint32}, {"sequence_length", ProgramUniformVariableDataType::Uint32}); + private: + bool is_first_prompt_; + bool is_subsequent_prompt_; }; class SplitPackedQKVProgram final : public Program { From 6bbef62609d4eaa751f9df1ce440eec0cdd135dc Mon Sep 17 00:00:00 2001 From: SatyaKumarJ Date: Thu, 20 Feb 2025 16:10:38 -0800 Subject: [PATCH 16/23] lint --- onnxruntime/contrib_ops/webgpu/bert/group_query_attention.h | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.h b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.h index 896c1f10a9b55..28d21bf589c18 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.h @@ -16,12 +16,13 @@ using namespace onnxruntime::webgpu; class GeneratePositionIDsProgram final : public Program { public: - GeneratePositionIDsProgram(bool is_first_prompt, bool is_subsequent_prompt) : Program{"GeneratePositionIDs"}, is_first_prompt_(is_first_prompt), is_subsequent_prompt_(is_subsequent_prompt){} + GeneratePositionIDsProgram(bool is_first_prompt, bool is_subsequent_prompt) : Program{"GeneratePositionIDs"}, is_first_prompt_(is_first_prompt), is_subsequent_prompt_(is_subsequent_prompt) {} Status GenerateShaderCode(ShaderHelper& sh) const override; WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"batch_size", ProgramUniformVariableDataType::Uint32}, {"sequence_length", ProgramUniformVariableDataType::Uint32}); - private: + + private: bool is_first_prompt_; bool is_subsequent_prompt_; }; From ff84b7b3293595c3211a3777b3d1800a445b6c56 Mon Sep 17 00:00:00 2001 From: SatyaKumarJ Date: Thu, 20 Feb 2025 17:08:26 -0800 Subject: [PATCH 17/23] Removed unused variable. --- onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc | 2 +- onnxruntime/contrib_ops/webgpu/bert/group_query_attention.h | 5 +---- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc index 412a3505fabee..3dab7262bb956 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc @@ -53,7 +53,7 @@ Status SplitPackedQKVProgram::GenerateShaderCode(ShaderHelper& sh) const { } Status SplitPackedQKV(onnxruntime::webgpu::ComputeContext& context, const WebgpuAttentionParameters& params, const Tensor* packedQKV, Tensor* query, Tensor* key, Tensor* val) { - SplitPackedQKVProgram program(params); + SplitPackedQKVProgram program; auto input_size = packedQKV->Shape().Size(); program .AddInput({packedQKV, ProgramTensorMetadataDependency::Rank}) diff --git a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.h b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.h index 28d21bf589c18..1fb1e1ffc91fd 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.h @@ -29,15 +29,12 @@ class GeneratePositionIDsProgram final : public Program { public: - SplitPackedQKVProgram(const WebgpuAttentionParameters& params) : Program{"SplitPackedQKV"}, params_(params) {} + SplitPackedQKVProgram() : Program{"SplitPackedQKV"} {} Status GenerateShaderCode(ShaderHelper& sh) const override; WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"hidden_size", ProgramUniformVariableDataType::Uint32}, {"kv_hidden_size", ProgramUniformVariableDataType::Uint32}); - - private: - const WebgpuAttentionParameters& params_; }; class GroupQueryAttention final : public WebGpuKernel { From e468128e04472aca0f3b64ef419d02618015e175 Mon Sep 17 00:00:00 2001 From: SatyaKumarJ Date: Wed, 26 Feb 2025 11:52:26 -0800 Subject: [PATCH 18/23] Added condition to check do_rotary before call fa2 --- onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc index 59423ad9315ca..32ddfe839fef6 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc @@ -184,7 +184,7 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& Tensor* present_key = context.Output(1, present_kv_shape); Tensor* present_value = context.Output(2, present_kv_shape); parameters.past_present_share_buffer_ = present_key != nullptr && present_value != nullptr && past_key != nullptr && past_value != nullptr && past_key->DataRaw() == present_key->DataRaw() && past_value->DataRaw() == present_value->DataRaw(); - if (CanApplyFlashAttention(nullptr /* bias */, present_key, present_value, parameters, context)) { + if (!do_rotary && CanApplyFlashAttention(nullptr /* bias */, present_key, present_value, parameters, context)) { return ApplyFlashAttention(query, key, value, nullptr /* attention_bias */, output, past_key, present_key, past_value, present_value, parameters, context); } From 9f2782c518f3e92e9d0c382ec4c86ad737732e8c Mon Sep 17 00:00:00 2001 From: SatyaKumarJ Date: Wed, 26 Feb 2025 11:53:47 -0800 Subject: [PATCH 19/23] typo --- onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc index 32ddfe839fef6..9bad32372a13c 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc @@ -184,7 +184,7 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& Tensor* present_key = context.Output(1, present_kv_shape); Tensor* present_value = context.Output(2, present_kv_shape); parameters.past_present_share_buffer_ = present_key != nullptr && present_value != nullptr && past_key != nullptr && past_value != nullptr && past_key->DataRaw() == present_key->DataRaw() && past_value->DataRaw() == present_value->DataRaw(); - if (!do_rotary && CanApplyFlashAttention(nullptr /* bias */, present_key, present_value, parameters, context)) { + if (!do_rotary_ && CanApplyFlashAttention(nullptr /* bias */, present_key, present_value, parameters, context)) { return ApplyFlashAttention(query, key, value, nullptr /* attention_bias */, output, past_key, present_key, past_value, present_value, parameters, context); } From b4a6546cbb894d096312363b4eb35a6c8834dc8b Mon Sep 17 00:00:00 2001 From: SatyaKumarJ Date: Wed, 26 Feb 2025 14:13:43 -0800 Subject: [PATCH 20/23] Revert changes to rotary embedding code. --- onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc b/onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc index 5e0703f87f496..bc8b7493fc916 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc @@ -95,7 +95,7 @@ Status RotaryEmbedding::ComputeInternal(onnxruntime::webgpu::ComputeContext& con input_shape.NumDimensions() == 3 ? std::vector({batch_stride, hidden_size, head_size, 1}) : (input_shape.NumDimensions() == 4 - ? std::vector({batch_stride, sequence_length * head_size, head_size, 1}) + ? std::vector({batch_stride, head_size, sequence_length * head_size, 1}) : std::vector({})); program From a80e9f91a40b4055e1d468ae947aafbb889cdaa7 Mon Sep 17 00:00:00 2001 From: SatyaKumarJ Date: Wed, 5 Mar 2025 19:55:26 -0800 Subject: [PATCH 21/23] Removed packed QKV support in attention. --- .../contrib_ops/webgpu/bert/attention.cc | 62 +++++++------------ .../contrib_ops/webgpu/bert/attention.h | 10 ++- .../webgpu/bert/group_query_attention.cc | 53 ++++++++-------- 3 files changed, 53 insertions(+), 72 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.cc b/onnxruntime/contrib_ops/webgpu/bert/attention.cc index 5d82557be2969..0d4afc8c13f4b 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.cc @@ -80,9 +80,7 @@ void InitVarStub(std::ostringstream& ss, const Tensor* seqlen_k) { Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const { shader.AddInput("q", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); - if (!is_packed_qkv_) { - shader.AddInput("key", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); - } + shader.AddInput("key", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); if (feed_past_key_) { shader.AddInput("past_key", ShaderUsage::UseUniform); } @@ -98,51 +96,42 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const { } shader.AdditionalImplementation() << "var tileQ: array;\n" - << "var tileK: array<" << (is_packed_qkv_ ? "q_value_t" : "key_value_t") << ", " << tile_size_ * tile_size_ << ">;\n" + << "var tileK: array;\n" << "alias f32_val_t = " << (components_ == 4 ? "vec4" : (components_ == 2 ? "vec2" : "f32")) << ";\n"; shader.MainFunctionBody() << "// x holds the N and y holds the M\n" << "let m = workgroup_id.y * TILE_SIZE;\n" << "let n = workgroup_id.x * TILE_SIZE;\n" << "let batch_idx = workgroup_id.z / uniforms.num_heads;\n" + << "let qOffset = workgroup_id.z * uniforms.M * uniforms.K + m * uniforms.K;\n" << "let sequence_length = uniforms.M;\n" << "var total_sequence_length = uniforms.N;\n"; - if (is_packed_qkv_) { - shader.MainFunctionBody() << "let head_idx = workgroup_id.z % uniforms.num_heads;\n" - << "let kv_num_heads = uniforms.num_heads /" << n_reps_ << ";\n" - << "let packed_batch_stride = (uniforms.num_heads + 2 * kv_num_heads) * uniforms.M * uniforms.K;\n" - << "let qOffset = batch_idx * packed_batch_stride + head_idx * uniforms.M * uniforms.K;\n" - << "let kvHeadIdx = head_idx % kv_num_heads;\n" - << "let kOffset = batch_idx * packed_batch_stride + (uniforms.num_heads + kvHeadIdx) * uniforms.kv_sequence_length * uniforms.K;\n"; - } else { - shader.MainFunctionBody() << "let qOffset = workgroup_id.z * uniforms.M * uniforms.K + m * uniforms.K;\n" - << "let kOffset = (workgroup_id.z / " << n_reps_ << ") * uniforms.kv_sequence_length * uniforms.K;\n"; - } std::ostringstream oss; InitVarStub(oss, seqlen_k_); shader.MainFunctionBody() << oss.str(); + shader.MainFunctionBody() << "let kOffset = (workgroup_id.z / " << n_reps_ << ") * uniforms.kv_sequence_length * uniforms.K;\n"; if (has_present_key_) { shader.MainFunctionBody() << "let presentKeyOffset = (workgroup_id.z / " << n_reps_ << ") * uniforms.present_sequence_length * uniforms.K;\n"; } shader.MainFunctionBody() << "var value = f32_val_t(0);\n" - << "for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) {\n" - << " if (global_id.y < uniforms.M && w + local_id.x < uniforms.K) {\n" - << " tileQ[TILE_SIZE * local_id.y + local_id.x] = q[qOffset + local_id.y * uniforms.K + w + local_id.x];\n" - << " }\n" - << " if (n + local_id.y < uniforms.N && w + local_id.x < uniforms.K) {\n" - << " var idx = TILE_SIZE * local_id.y + local_id.x;\n"; + "for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) {\n" + " if (global_id.y < uniforms.M && w + local_id.x < uniforms.K) {\n" + " tileQ[TILE_SIZE * local_id.y + local_id.x] = q[qOffset + local_id.y * uniforms.K + w + local_id.x];\n" + " }\n" + " if (n + local_id.y < uniforms.N && w + local_id.x < uniforms.K) {\n" + " var idx = TILE_SIZE * local_id.y + local_id.x;\n"; if ((feed_past_key_ && has_present_key_) || (past_present_share_buffer_ && !is_first_prompt_)) { shader.MainFunctionBody() << " if (n + local_id.y < past_sequence_length) {\n" << " let pastKeyOffset = (workgroup_id.z / " << n_reps_ << ") * uniforms.past_sequence_length * uniforms.K;\n" << " tileK[idx] = " << (past_present_share_buffer_ ? "present_key" : "past_key") << "[pastKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x];\n" << " } else if (n + local_id.y - past_sequence_length < uniforms.kv_sequence_length) {\n" - << " tileK[idx] = " << (is_packed_qkv_ ? "q" : "key") << "[kOffset + (n + local_id.y - past_sequence_length) * uniforms.K + w + local_id.x];\n" + << " tileK[idx] = key[kOffset + (n + local_id.y - past_sequence_length) * uniforms.K + w + local_id.x];\n" << " }\n"; } else { shader.MainFunctionBody() << " if (n + local_id.y < uniforms.kv_sequence_length) {\n" - << " tileK[idx] = " << (is_packed_qkv_ ? "q" : "key") << "[kOffset + (n + local_id.y) * uniforms.K + w + local_id.x];\n" - << " }\n"; + " tileK[idx] = key[kOffset + (n + local_id.y) * uniforms.K + w + local_id.x];\n" + " }\n"; } if (has_present_key_) { @@ -192,11 +181,9 @@ Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int o const int components = parameters.head_size_ % 4 == 0 ? 4 : (parameters.head_size_ % 2 == 0 ? 2 : 1); AttentionProbsProgram program{"AttentionProbs", feed_past_key, has_present_key, has_attention_bias, tile_size, - components, parameters.is_first_prompt_, parameters.is_packed_qkv_, parameters.n_reps, seqlen_k, parameters.past_present_share_buffer_}; - program.AddInput({Q, ProgramTensorMetadataDependency::TypeAndRank, components}); - if (K != nullptr) { - program.AddInput({K, ProgramTensorMetadataDependency::TypeAndRank, components}); - } + components, parameters.is_first_prompt_, parameters.n_reps, seqlen_k, parameters.past_present_share_buffer_}; + program.AddInputs({{Q, ProgramTensorMetadataDependency::TypeAndRank, components}, + {K, ProgramTensorMetadataDependency::TypeAndRank, components}}); if (feed_past_key) { program.AddInput({past_key, ProgramTensorMetadataDependency::TypeAndRank, components}); } @@ -216,7 +203,7 @@ Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int o (parameters.sequence_length_ + tile_size - 1) / tile_size, parameters.batch_size_ * parameters.num_heads_) .SetWorkgroupSize(tile_size, tile_size) - .CacheHint(std::to_string(tile_size), parameters.past_present_share_buffer_, feed_past_key, has_present_key, has_attention_bias, seqlen_k != nullptr, components, parameters.is_first_prompt_, parameters.is_packed_qkv_) + .CacheHint(std::to_string(tile_size), parameters.past_present_share_buffer_, feed_past_key, has_present_key, has_attention_bias, seqlen_k != nullptr, components, parameters.is_first_prompt_) .AddUniformVariables({{static_cast(parameters.sequence_length_)}, {static_cast(vectorized_head_size)}, {static_cast(total_sequence_length)}, @@ -344,14 +331,7 @@ Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const { std::ostringstream oss; InitVarStub(oss, seqlen_k_); shader.MainFunctionBody() << oss.str(); - if (is_packed_qkv_) { - shader.MainFunctionBody() << "let kv_num_heads = uniforms.num_heads / " << n_reps_ << ";\n" - << "let packed_batch_stride = (uniforms.num_heads + 2 * kv_num_heads) * uniforms.M * uniforms.K;\n" - << "let kvHeadIdx = head_idx % kv_num_heads;\n" - << "let vOffset = batch_idx * packed_batch_stride + (uniforms.num_heads + kv_num_heads + kvHeadIdx) * uniforms.N * uniforms.kv_sequence_length + n;\n"; - } else { - shader.MainFunctionBody() << "let vOffset = (workgroup_id.z / " << n_reps_ << ") * uniforms.N * uniforms.kv_sequence_length + n;\n"; - } + shader.MainFunctionBody() << "let vOffset = (workgroup_id.z / " << n_reps_ << ") * uniforms.N * uniforms.kv_sequence_length + n;\n"; if (has_present_value_) { shader.MainFunctionBody() << "let presentValueOffset = (workgroup_id.z / " << n_reps_ << ") * uniforms.N * uniforms.present_sequence_length + n;\n"; } @@ -420,7 +400,7 @@ Status ComputeVxAttentionScore(onnxruntime::webgpu::ComputeContext& context, int const int components = parameters.v_head_size_ % 4 == 0 ? 4 : (parameters.v_head_size_ % 2 == 0 ? 2 : 1); constexpr int tile_size = 12; int tile_n_size = tile_size * components; - VxAttentionScoreProgram program{"VxAttentionScore", feed_past_value, has_present_value, tile_size, parameters.is_first_prompt_, parameters.is_packed_qkv_, parameters.n_reps, seqlen_k, parameters.past_present_share_buffer_}; + VxAttentionScoreProgram program{"VxAttentionScore", feed_past_value, has_present_value, tile_size, parameters.is_first_prompt_, parameters.n_reps, seqlen_k, parameters.past_present_share_buffer_}; program.AddInputs({{probs, ProgramTensorMetadataDependency::TypeAndRank}, {V, ProgramTensorMetadataDependency::TypeAndRank, components}}); if (feed_past_value) { @@ -437,7 +417,7 @@ Status ComputeVxAttentionScore(onnxruntime::webgpu::ComputeContext& context, int program.SetDispatchGroupSize((parameters.v_head_size_ + tile_n_size - 1) / tile_n_size, (parameters.sequence_length_ + tile_size - 1) / tile_size, parameters.batch_size_ * parameters.num_heads_) - .CacheHint(std::to_string(tile_size), parameters.past_present_share_buffer_, feed_past_value, has_present_value, seqlen_k != nullptr, parameters.is_first_prompt_, parameters.is_packed_qkv_) + .CacheHint(std::to_string(tile_size), parameters.past_present_share_buffer_, feed_past_value, has_present_value, seqlen_k != nullptr, parameters.is_first_prompt_) .SetWorkgroupSize(tile_size, tile_size) .AddUniformVariables({{static_cast(parameters.sequence_length_)}, {static_cast(total_sequence_length)}, @@ -472,7 +452,7 @@ Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const T ORT_RETURN_IF_ERROR(ComputeInPlaceSoftmax(context, &probs, parameters.batch_size_, parameters.num_heads_, parameters.past_sequence_length_, parameters.sequence_length_, total_sequence_length, seqlen_k, parameters.is_first_prompt_)); - ORT_RETURN_IF_ERROR(ComputeVxAttentionScore(context, output_count, &probs, parameters.is_packed_qkv_ ? Q : V, past_value, output, present_value, + ORT_RETURN_IF_ERROR(ComputeVxAttentionScore(context, output_count, &probs, V, past_value, output, present_value, parameters, past_sequence_length, total_sequence_length, seqlen_k)); return Status::OK(); diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.h b/onnxruntime/contrib_ops/webgpu/bert/attention.h index 3dcb339bd896d..164ea72b07d9d 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.h @@ -34,8 +34,8 @@ class TransferBSDToBNSHProgram final : public Program class AttentionProbsProgram final : public Program { public: AttentionProbsProgram(const std::string& kernel_name, bool feed_past_key, bool has_present_key, - bool has_attention_bias, int tile_size, int components, bool is_first_prompt, bool is_packed_qkv, int n_reps = 1, const Tensor* seqlen_k = nullptr, bool past_present_share_buffer = false) - : Program{kernel_name}, feed_past_key_(feed_past_key), has_present_key_(has_present_key), has_attention_bias_(has_attention_bias), tile_size_(tile_size), components_(components), n_reps_(n_reps), seqlen_k_(seqlen_k), past_present_share_buffer_(past_present_share_buffer), is_first_prompt_(is_first_prompt), is_packed_qkv_(is_packed_qkv) { + bool has_attention_bias, int tile_size, int components, bool is_first_prompt, int n_reps = 1, const Tensor* seqlen_k = nullptr, bool past_present_share_buffer = false) + : Program{kernel_name}, feed_past_key_(feed_past_key), has_present_key_(has_present_key), has_attention_bias_(has_attention_bias), tile_size_(tile_size), components_(components), n_reps_(n_reps), seqlen_k_(seqlen_k), past_present_share_buffer_(past_present_share_buffer), is_first_prompt_(is_first_prompt) { } Status GenerateShaderCode(ShaderHelper& sh) const override; @@ -64,7 +64,6 @@ class AttentionProbsProgram final : public Program { const Tensor* seqlen_k_; bool past_present_share_buffer_; bool is_first_prompt_; - bool is_packed_qkv_; }; class InPlaceSoftmaxProgram final : public Program { @@ -91,8 +90,8 @@ class InPlaceSoftmaxProgram final : public Program { class VxAttentionScoreProgram final : public Program { public: - VxAttentionScoreProgram(const std::string& kernel_name, bool feed_past_value, bool has_present_value, int tile_size, bool is_first_prompt, bool is_packed_qkv, int n_reps = 1, const Tensor* seqlen_k = nullptr, bool past_present_share_buffer = false) - : Program{kernel_name}, feed_past_value_(feed_past_value), has_present_value_(has_present_value), tile_size_(tile_size), n_reps_(n_reps), seqlen_k_(seqlen_k), past_present_share_buffer_(past_present_share_buffer), is_first_prompt_(is_first_prompt), is_packed_qkv_(is_packed_qkv) { + VxAttentionScoreProgram(const std::string& kernel_name, bool feed_past_value, bool has_present_value, int tile_size, bool is_first_prompt, int n_reps = 1, const Tensor* seqlen_k = nullptr, bool past_present_share_buffer = false) + : Program{kernel_name}, feed_past_value_(feed_past_value), has_present_value_(has_present_value), tile_size_(tile_size), n_reps_(n_reps), seqlen_k_(seqlen_k), past_present_share_buffer_(past_present_share_buffer), is_first_prompt_(is_first_prompt) { } Status GenerateShaderCode(ShaderHelper& sh) const override; @@ -119,7 +118,6 @@ class VxAttentionScoreProgram final : public Program { const Tensor* seqlen_k_; bool past_present_share_buffer_; bool is_first_prompt_; - bool is_packed_qkv_; }; } // namespace webgpu diff --git a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc index 9bad32372a13c..2a45cc21d450f 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc @@ -148,7 +148,7 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& const Tensor* value = context.Input(2); const Tensor* past_key = context.Input(3); const Tensor* past_value = context.Input(4); - const Tensor* seqlens_k = context.Input(5); + const Tensor* seqlen_k = context.Input(5); const Tensor* total_seqlen_tensor = context.Input(6); const Tensor* cos_cache = context.Input(7); const Tensor* sin_cache = context.Input(8); @@ -164,12 +164,11 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& ¶ms, num_heads_, kv_num_heads_, - seqlens_k, + seqlen_k, total_seqlen_tensor, scale_, softcap_)); WebgpuAttentionParameters parameters(params); - TensorShapeVector output_shape(3); output_shape[0] = static_cast(parameters.batch_size_); output_shape[1] = static_cast(parameters.sequence_length_); @@ -184,6 +183,7 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& Tensor* present_key = context.Output(1, present_kv_shape); Tensor* present_value = context.Output(2, present_kv_shape); parameters.past_present_share_buffer_ = present_key != nullptr && present_value != nullptr && past_key != nullptr && past_value != nullptr && past_key->DataRaw() == present_key->DataRaw() && past_value->DataRaw() == present_value->DataRaw(); + if (!do_rotary_ && CanApplyFlashAttention(nullptr /* bias */, present_key, present_value, parameters, context)) { return ApplyFlashAttention(query, key, value, nullptr /* attention_bias */, output, past_key, present_key, past_value, present_value, parameters, context); @@ -192,7 +192,7 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& Tensor qSplit; Tensor kSplit; Tensor vSplit; - if (parameters.is_packed_qkv_ && do_rotary_) { + if (parameters.is_packed_qkv_) { qSplit = context.CreateGPUTensor(query->DataType(), TensorShape({parameters.batch_size_, parameters.sequence_length_, parameters.hidden_size_})); kSplit = context.CreateGPUTensor(query->DataType(), TensorShape({parameters.batch_size_, parameters.sequence_length_, parameters.kv_hidden_size_})); vSplit = context.CreateGPUTensor(query->DataType(), TensorShape({parameters.batch_size_, parameters.sequence_length_, parameters.kv_hidden_size_})); @@ -210,36 +210,39 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& kRotary = context.CreateGPUTensor(key->DataType(), key->Shape()); auto pos_ids_shape = TensorShape({parameters.batch_size_, parameters.sequence_length_}); Tensor pos_ids = context.CreateGPUTensor(DataTypeImpl::GetType(), pos_ids_shape); - ORT_RETURN_IF_ERROR(GeneratePositionIDs(context, parameters, seqlens_k, &pos_ids)); + ORT_RETURN_IF_ERROR(GeneratePositionIDs(context, parameters, seqlen_k, &pos_ids)); ORT_RETURN_IF_ERROR(RunRotaryEmbedding(context, parameters, query, &pos_ids, cos_cache, sin_cache, &qRotary, /* is_query_input = */ true)); ORT_RETURN_IF_ERROR(RunRotaryEmbedding(context, parameters, key, &pos_ids, cos_cache, sin_cache, &kRotary, /* is_query_input = */ false)); query = &qRotary; key = &kRotary; } - TensorShapeVector q_new_dims({parameters.batch_size_, parameters.num_heads_, parameters.sequence_length_, parameters.head_size_}); + TensorShapeVector q_new_dims({parameters.batch_size_, parameters.num_heads_, + parameters.sequence_length_, parameters.head_size_}); TensorShape q_new_shape(q_new_dims); - Tensor qBNSH = context.CreateGPUTensor(query->DataType(), q_new_shape); - ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.is_packed_qkv_ ? parameters.num_heads_ + 2 * parameters.kv_num_heads_ : parameters.num_heads_, parameters.sequence_length_, parameters.head_size_, query, nullptr, 0, &qBNSH)); - query = &qBNSH; - Tensor kBNSH; - Tensor vBNSH; - if (nullptr != key) { - TensorShapeVector k_new_dims({parameters.batch_size_, parameters.kv_num_heads_, parameters.kv_sequence_length_, parameters.head_size_}); - TensorShape k_new_shape(k_new_dims); - kBNSH = context.CreateGPUTensor(key->DataType(), k_new_shape); - ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.kv_num_heads_, parameters.kv_sequence_length_, parameters.head_size_, key, nullptr, 0, &kBNSH)); - key = &kBNSH; + Tensor Q = context.CreateGPUTensor(query->DataType(), q_new_shape); + ORT_RETURN_IF_ERROR(TransferBSDToBNSH( + context, parameters.num_heads_, parameters.sequence_length_, parameters.head_size_, query, nullptr, 0, &Q)); + if (parameters.qkv_format_ == Q_K_V_BSNH_BNSH_BNSH) { // key and value in BNSH format + return ApplyAttention(&Q, key, value, nullptr, past_key, past_value, output, present_key, + present_value, parameters, context, seqlen_k); } - if (nullptr != value) { - TensorShapeVector v_new_dims({parameters.batch_size_, parameters.kv_num_heads_, parameters.kv_sequence_length_, parameters.v_head_size_}); + + TensorShapeVector k_new_dims({parameters.batch_size_, parameters.kv_num_heads_, + parameters.kv_sequence_length_, parameters.head_size_}); + TensorShape k_new_shape(k_new_dims); + Tensor K = context.CreateGPUTensor(key->DataType(), k_new_shape); + ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.kv_num_heads_, parameters.kv_sequence_length_, + parameters.head_size_, key, nullptr, 0, &K)); + + TensorShapeVector v_new_dims({parameters.batch_size_, parameters.kv_num_heads_, + parameters.kv_sequence_length_, parameters.v_head_size_}); TensorShape v_new_shape(v_new_dims); - vBNSH = context.CreateGPUTensor(value->DataType(), v_new_shape); - ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.kv_num_heads_, parameters.kv_sequence_length_, parameters.v_head_size_, value, nullptr, 0, &vBNSH)); - value = &vBNSH; - } - return ApplyAttention(query, key, value, nullptr, past_key, past_value, output, present_key, - present_value, parameters, context, seqlens_k); + Tensor V = context.CreateGPUTensor(value->DataType(), v_new_shape); + ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.kv_num_heads_, parameters.kv_sequence_length_, + parameters.v_head_size_, value, nullptr, 0, &V)); + return ApplyAttention(&Q, &K, &V, nullptr, past_key, past_value, output, present_key, + present_value, parameters, context, seqlen_k); } } // namespace webgpu From 19c48d41c8365e3141c02546c28f09ad36d20893 Mon Sep 17 00:00:00 2001 From: SatyaKumarJ Date: Wed, 5 Mar 2025 22:06:56 -0800 Subject: [PATCH 22/23] lint --- .../contrib_ops/webgpu/bert/group_query_attention.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc index 2a45cc21d450f..04fcb7600cf0a 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc @@ -218,11 +218,11 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& } TensorShapeVector q_new_dims({parameters.batch_size_, parameters.num_heads_, - parameters.sequence_length_, parameters.head_size_}); + parameters.sequence_length_, parameters.head_size_}); TensorShape q_new_shape(q_new_dims); Tensor Q = context.CreateGPUTensor(query->DataType(), q_new_shape); ORT_RETURN_IF_ERROR(TransferBSDToBNSH( - context, parameters.num_heads_, parameters.sequence_length_, parameters.head_size_, query, nullptr, 0, &Q)); + context, parameters.num_heads_, parameters.sequence_length_, parameters.head_size_, query, nullptr, 0, &Q)); if (parameters.qkv_format_ == Q_K_V_BSNH_BNSH_BNSH) { // key and value in BNSH format return ApplyAttention(&Q, key, value, nullptr, past_key, past_value, output, present_key, present_value, parameters, context, seqlen_k); @@ -230,14 +230,14 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& TensorShapeVector k_new_dims({parameters.batch_size_, parameters.kv_num_heads_, parameters.kv_sequence_length_, parameters.head_size_}); - TensorShape k_new_shape(k_new_dims); + TensorShape k_new_shape(k_new_dims); Tensor K = context.CreateGPUTensor(key->DataType(), k_new_shape); ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.kv_num_heads_, parameters.kv_sequence_length_, parameters.head_size_, key, nullptr, 0, &K)); TensorShapeVector v_new_dims({parameters.batch_size_, parameters.kv_num_heads_, parameters.kv_sequence_length_, parameters.v_head_size_}); - TensorShape v_new_shape(v_new_dims); + TensorShape v_new_shape(v_new_dims); Tensor V = context.CreateGPUTensor(value->DataType(), v_new_shape); ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.kv_num_heads_, parameters.kv_sequence_length_, parameters.v_head_size_, value, nullptr, 0, &V)); From ce3d60b7b6c5d8ccf62c547e9981f46bd4ed4a90 Mon Sep 17 00:00:00 2001 From: SatyaKumarJ Date: Thu, 13 Mar 2025 13:39:27 -0700 Subject: [PATCH 23/23] Replaced gsl::naroow with gsl::narrow_cast --- .../contrib_ops/webgpu/bert/group_query_attention.cc | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc index 04fcb7600cf0a..f002db108035f 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc @@ -111,7 +111,7 @@ Status GeneratePositionIDs(onnxruntime::webgpu::ComputeContext& context, const W } Status RunRotaryEmbedding(onnxruntime::webgpu::ComputeContext& context, const WebgpuAttentionParameters& params, const Tensor* input, const Tensor* pos_ids, const Tensor* cos_cache, const Tensor* sin_cache, Tensor* output, bool is_query_input) { - const auto half_rotary_embedding_dim = gsl::narrow(cos_cache->Shape()[1]); + const auto half_rotary_embedding_dim = gsl::narrow_cast(cos_cache->Shape()[1]); const auto head_size = params.head_size_; const auto hidden_size = is_query_input ? params.hidden_size_ : params.kv_hidden_size_; const TensorShape global_shape({params.batch_size_, params.sequence_length_, hidden_size / head_size, static_cast(head_size - half_rotary_embedding_dim)}); @@ -119,11 +119,11 @@ Status RunRotaryEmbedding(onnxruntime::webgpu::ComputeContext& context, const We std::vector global_dims(rank); std::vector global_strides(rank); for (size_t j = 0; j < rank; ++j) { - global_dims[j] = gsl::narrow(global_shape[j]); - global_strides[j] = gsl::narrow(global_shape.SizeFromDimension(j + 1)); + global_dims[j] = gsl::narrow_cast(global_shape[j]); + global_strides[j] = gsl::narrow_cast(global_shape.SizeFromDimension(j + 1)); } - const auto input_output_strides = std::vector({gsl::narrow(input->Shape().SizeFromDimension(1)), gsl::narrow(hidden_size), gsl::narrow(head_size), 1}); - const auto output_size = gsl::narrow(global_shape.Size()); + const auto input_output_strides = std::vector({gsl::narrow_cast(input->Shape().SizeFromDimension(1)), gsl::narrow_cast(hidden_size), gsl::narrow_cast(head_size), 1}); + const auto output_size = gsl::narrow_cast(global_shape.Size()); RotaryEmbeddingProgram program(params.rotary_interleaved_); program