88#include " contrib_ops/webgpu/bert/multihead_attention.h"
99#include " contrib_ops/webgpu/webgpu_contrib_kernels.h"
1010#include " core/providers/webgpu/webgpu_supported_types.h"
11+ #include " core/providers/webgpu/webgpu_utils.h"
12+ #include " core/providers/webgpu/math/matmul.h"
1113using namespace onnxruntime ::webgpu;
1214using namespace ::onnxruntime::common;
1315using namespace ONNX_NAMESPACE ;
@@ -70,6 +72,50 @@ Status TransferBSDToBNSH(onnxruntime::webgpu::ComputeContext& context, int num_h
7072 return context.RunProgram (program);
7173};
7274
75+ Status SplitPackedQKVProgram::GenerateShaderCode (ShaderHelper& sh) const {
76+ // Inputs: packed_qkv [B, S, D], outputs: Q, K, V [B, S, D]
77+ const auto & packed_qkv = sh.AddInput (" packed_qkv" , ShaderUsage::UseOffsetToIndices | ShaderUsage::UseUniform);
78+ const auto & query = sh.AddOutput (" query" , ShaderUsage::UseSetByIndices | ShaderUsage::UseUniform);
79+ const auto & key = sh.AddOutput (" key" , ShaderUsage::UseSetByIndices | ShaderUsage::UseUniform);
80+ const auto & value = sh.AddOutput (" val" , ShaderUsage::UseSetByIndices | ShaderUsage::UseUniform);
81+ sh.MainFunctionBody ()
82+ << sh.GuardAgainstOutOfBoundsWorkgroupSizes (" uniforms.input_size" )
83+ << " let packed_qkv_indices = " << packed_qkv.OffsetToIndices (" global_idx" ) << " ;\n "
84+ << " let batch = packed_qkv_indices[0];\n "
85+ << " let seq = packed_qkv_indices[1];\n "
86+ << " let d = packed_qkv_indices[2];\n "
87+ << " let input_data = " << packed_qkv.GetByOffset (" global_idx" ) << " ;\n "
88+ << " if (d < uniforms.hidden_size) {\n "
89+ << " " << query.SetByIndices (" vec3<u32>(batch, seq, d)" , " input_data" ) << " ;\n "
90+ << " } else if (d < (uniforms.hidden_size + uniforms.kv_hidden_size)) {\n "
91+ << " let kd = d - uniforms.hidden_size;\n "
92+ << " " << key.SetByIndices (" vec3<u32>(batch, seq, kd)" , " input_data" ) << " ;\n "
93+ << " } else {\n "
94+ << " let vd = d - uniforms.hidden_size - uniforms.kv_hidden_size;\n "
95+ << " " << value.SetByIndices (" vec3<u32>(batch, seq, vd)" , " input_data" ) << " ;\n "
96+ << " }\n " ;
97+ return Status::OK ();
98+ }
99+
100+ Status SplitPackedQKV (onnxruntime::webgpu::ComputeContext& context, const WebgpuAttentionParameters& params,
101+ const Tensor* packedQKV, Tensor* query, Tensor* key, Tensor* val, int kv_hidden_size) {
102+ // Output Q, K, V in BSD format
103+ const int components = std::min ({GetMaxComponents (params.hidden_size_ ), GetMaxComponents (kv_hidden_size), GetMaxComponents (params.v_hidden_size_ )});
104+ SplitPackedQKVProgram program;
105+ auto input_size = packedQKV->Shape ().Size ();
106+ const uint32_t vectorized_input_size = static_cast <uint32_t >(input_size / components);
107+ program
108+ .AddInput ({packedQKV, ProgramTensorMetadataDependency::TypeAndRank, components})
109+ .AddOutputs ({{query, ProgramTensorMetadataDependency::TypeAndRank, components}, {key, ProgramTensorMetadataDependency::TypeAndRank, components}, {val, ProgramTensorMetadataDependency::TypeAndRank, components}})
110+ .AddUniformVariables ({
111+ {vectorized_input_size},
112+ {static_cast <uint32_t >(params.hidden_size_ / components)},
113+ {static_cast <uint32_t >(kv_hidden_size / components)},
114+ })
115+ .SetDispatchGroupSize ((vectorized_input_size + WORKGROUP_SIZE - 1 ) / WORKGROUP_SIZE);
116+ return context.RunProgram (program);
117+ }
118+
73119void InitVarStub (std::ostringstream& ss, bool has_seqlen_k) {
74120 if (has_seqlen_k) {
75121 ss << " total_sequence_length = u32(seqlen_k[batch_idx]) + 1;\n " ;
@@ -594,113 +640,26 @@ Attention::Attention(const OpKernelInfo& info)
594640 onnxruntime::contrib::AttentionBase(info, false ) {
595641}
596642
597- // QKV preparation program - computes Q, K, V from input, weights, and bias
598- class AttentionPrepareProgram final : public Program<AttentionPrepareProgram> {
599- public:
600- AttentionPrepareProgram () : Program{" AttentionPrepare" } {}
601-
602- Status GenerateShaderCode (ShaderHelper& shader) const override {
603- shader.AddInput (" input" , ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
604- shader.AddInput (" weight" , ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
605- shader.AddInput (" bias" , ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
606- shader.AddOutput (" output_q" , ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
607- shader.AddOutput (" output_k" , ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
608- shader.AddOutput (" output_v" , ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
609-
610- constexpr int TILE_SIZE = 12 ;
611-
612- shader.AdditionalImplementation () << " const TILE_SIZE = " << TILE_SIZE << " u;\n "
613- << " var<workgroup> tileInput: array<input_value_t, " << TILE_SIZE * TILE_SIZE << " >;\n "
614- << " var<workgroup> tileWeightQ: array<input_value_t, " << TILE_SIZE * TILE_SIZE << " >;\n "
615- << " var<workgroup> tileWeightK: array<input_value_t, " << TILE_SIZE * TILE_SIZE << " >;\n "
616- << " var<workgroup> tileWeightV: array<input_value_t, " << TILE_SIZE * TILE_SIZE << " >;\n " ;
617-
618- shader.MainFunctionBody () // << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.M * uniforms.N")
619- << " let batchIndex = workgroup_id.z / uniforms.num_heads;\n "
620- << " let headNumber = workgroup_id.z % uniforms.num_heads;\n "
621- << " let m = global_id.y;\n "
622- << " let n = global_id.x;\n "
623- << " let inputOffset = batchIndex * (uniforms.M * uniforms.K) + m * uniforms.K;\n "
624- << " let biasOffsetQ = headNumber * uniforms.head_size;\n "
625- << " let biasOffsetK = uniforms.hidden_size + biasOffsetQ;\n "
626- << " let biasOffsetV = uniforms.hidden_size + biasOffsetK;\n "
627- << " var valueQ = input_value_t(0);\n "
628- << " var valueK = input_value_t(0);\n "
629- << " var valueV = input_value_t(0);\n "
630- << " for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) {\n "
631- << " if (m < uniforms.M && w + local_id.x < uniforms.K) {\n "
632- << " tileInput[TILE_SIZE * local_id.y + local_id.x] = input[inputOffset + w + local_id.x];\n "
633- << " }\n "
634- << " if (n < uniforms.N && w + local_id.y < uniforms.K) {\n "
635- << " let offset = n + (w + local_id.y) * uniforms.ldb;\n "
636- << " tileWeightQ[TILE_SIZE * local_id.y + local_id.x] = weight[biasOffsetQ + offset];\n "
637- << " tileWeightK[TILE_SIZE * local_id.y + local_id.x] = weight[biasOffsetK + offset];\n "
638- << " tileWeightV[TILE_SIZE * local_id.y + local_id.x] = weight[biasOffsetV + offset];\n "
639- << " }\n "
640- << " workgroupBarrier();\n "
641- << " for (var k: u32 = 0u; k<TILE_SIZE && w+k < uniforms.K; k++) {\n "
642- << " let inputTileOffset = TILE_SIZE * local_id.y + k;\n "
643- << " let weightTileOffset = TILE_SIZE * k + local_id.x;\n "
644- << " valueQ += tileInput[inputTileOffset] * tileWeightQ[weightTileOffset];\n "
645- << " valueK += tileInput[inputTileOffset] * tileWeightK[weightTileOffset];\n "
646- << " valueV += tileInput[inputTileOffset] * tileWeightV[weightTileOffset];\n "
647- << " }\n "
648- << " workgroupBarrier();\n "
649- << " }\n "
650- << " let headOffset = (m * uniforms.N + n) % uniforms.head_size;\n "
651- << " valueQ += bias[headOffset + biasOffsetQ];\n "
652- << " valueK += bias[headOffset + biasOffsetK];\n "
653- << " valueV += bias[headOffset + biasOffsetV];\n "
654- << " let offset = workgroup_id.z * uniforms.M * uniforms.N;\n "
655- << " if (m < uniforms.M && n < uniforms.N) {\n "
656- << " let outputIdx = offset + m * uniforms.N + n;\n "
657- << " output_q[outputIdx] = valueQ;\n "
658- << " output_k[outputIdx] = valueK;\n "
659- << " output_v[outputIdx] = valueV;\n "
660- << " }\n " ;
661-
662- return Status::OK ();
663- }
664-
665- WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES ({" M" , ProgramUniformVariableDataType::Uint32},
666- {" K" , ProgramUniformVariableDataType::Uint32},
667- {" N" , ProgramUniformVariableDataType::Uint32},
668- {" num_heads" , ProgramUniformVariableDataType::Uint32},
669- {" head_size" , ProgramUniformVariableDataType::Uint32},
670- {" hidden_size" , ProgramUniformVariableDataType::Uint32},
671- {" ldb" , ProgramUniformVariableDataType::Uint32});
672- };
673-
674643Status PrepareQKV (onnxruntime::webgpu::ComputeContext& context, const WebgpuAttentionParameters& parameters,
675644 const Tensor* input, const Tensor* weights, const Tensor* bias,
676645 Tensor* q, Tensor* k, Tensor* v) {
677- constexpr int TILE_SIZE = 12 ;
678- const int M = parameters.sequence_length_ ;
679- const int K = parameters.input_hidden_size_ ;
680- const int N = parameters.head_size_ ;
681-
682- const uint32_t dispatch_x = (parameters.head_size_ + TILE_SIZE - 1 ) / TILE_SIZE;
683- const uint32_t dispatch_y = (parameters.sequence_length_ + TILE_SIZE - 1 ) / TILE_SIZE;
684- const uint32_t dispatch_z = parameters.batch_size_ * parameters.num_heads_ ;
685-
686- AttentionPrepareProgram program{};
687- program.AddInputs ({{input, ProgramTensorMetadataDependency::TypeAndRank},
688- {weights, ProgramTensorMetadataDependency::TypeAndRank},
689- {bias, ProgramTensorMetadataDependency::TypeAndRank}})
690- .AddOutputs ({{q, ProgramTensorMetadataDependency::TypeAndRank},
691- {k, ProgramTensorMetadataDependency::TypeAndRank},
692- {v, ProgramTensorMetadataDependency::TypeAndRank}})
693- .SetDispatchGroupSize (dispatch_x, dispatch_y, dispatch_z)
694- .SetWorkgroupSize (TILE_SIZE, TILE_SIZE)
695- .AddUniformVariables ({{static_cast <uint32_t >(M)},
696- {static_cast <uint32_t >(K)},
697- {static_cast <uint32_t >(N)},
698- {static_cast <uint32_t >(parameters.num_heads_ )},
699- {static_cast <uint32_t >(parameters.head_size_ )},
700- {static_cast <uint32_t >(parameters.hidden_size_ )},
701- {static_cast <uint32_t >(parameters.hidden_size_ + parameters.hidden_size_ + parameters.v_hidden_size_ )}});
646+ // Use MatMul to compute packed QKV output: input * weights + bias
647+ // Then use SplitPackedQKV to split into Q, K, V in BSD format
648+ // Returns Q, K, V in BSD format
702649
703- return context.RunProgram (program);
650+ // Create packed QKV tensor with shape [batch_size, sequence_length, hidden_size + hidden_size + v_hidden_size]
651+ const int64_t packed_qkv_size = parameters.hidden_size_ + parameters.hidden_size_ + parameters.v_hidden_size_ ;
652+ TensorShapeVector packed_qkv_shape ({parameters.batch_size_ , parameters.sequence_length_ , packed_qkv_size});
653+ Tensor packed_qkv = context.CreateGPUTensor (input->DataType (), TensorShape (packed_qkv_shape));
654+
655+ // Prepare inputs for MatMul
656+ std::vector<const Tensor*> matmul_inputs = {input, weights, bias};
657+
658+ // Call MatMul: packed_qkv = input * weights + bias
659+ ORT_RETURN_IF_ERROR (onnxruntime::webgpu::ComputeMatMul (&context, Activation (), matmul_inputs, &packed_qkv, true ));
660+
661+ // Output Q, K, V in BSD format
662+ return SplitPackedQKV (context, parameters, &packed_qkv, q, k, v, parameters.hidden_size_ );
704663}
705664
706665Status Attention::ComputeInternal (onnxruntime::webgpu::ComputeContext& context) const {
@@ -755,15 +714,16 @@ Status Attention::ComputeInternal(onnxruntime::webgpu::ComputeContext& context)
755714 ORT_NOT_IMPLEMENTED (" present tensor not implemented for webgpu Attention" );
756715 }
757716
758- // Create Q, K, V tensors by computing input * weights + bias
759- TensorShapeVector qkv_shape ({parameters.batch_size_ , parameters.num_heads_ ,
760- parameters.sequence_length_ , parameters.head_size_ });
761- Tensor Q = context.CreateGPUTensor (input->DataType (), qkv_shape );
762- Tensor K = context.CreateGPUTensor (input->DataType (), qkv_shape );
763- Tensor V = context.CreateGPUTensor (input->DataType (), qkv_shape );
717+ // Create Q, K, V tensors in BSD format from input * weights + bias
718+ TensorShapeVector qkv_bsd_shape ({parameters.batch_size_ , parameters.sequence_length_ , parameters. hidden_size_ });
719+ TensorShapeVector v_bsd_shape ({parameters. batch_size_ , parameters.sequence_length_ , parameters.v_hidden_size_ });
720+ Tensor Q_bsd = context.CreateGPUTensor (input->DataType (), TensorShape (qkv_bsd_shape) );
721+ Tensor K_bsd = context.CreateGPUTensor (input->DataType (), TensorShape (qkv_bsd_shape) );
722+ Tensor V_bsd = context.CreateGPUTensor (input->DataType (), TensorShape (v_bsd_shape) );
764723
765- // Compute Q, K, V from input, weights, and bias
766- ORT_RETURN_IF_ERROR (PrepareQKV (context, parameters, input, weights, bias, &Q, &K, &V));
724+ // Compute Q, K, V from input, weights, and bias (returns BSD format)
725+ ORT_RETURN_IF_ERROR (PrepareQKV (context, parameters, input, weights, bias, &Q_bsd, &K_bsd, &V_bsd));
726+ parameters.qkv_format_ = Q_K_V_BSNH;
767727
768728 // Check if we can use flash attention
769729 // For Attention operator, we need to create present_key and present_value tensors for flash attention
@@ -774,10 +734,25 @@ Status Attention::ComputeInternal(onnxruntime::webgpu::ComputeContext& context)
774734 Tensor present_value = context.CreateGPUTensor (input->DataType (), present_kv_shape);
775735
776736 if (CanApplyFlashAttention (nullptr , &present_key, &present_value, parameters, context)) {
777- return ApplyFlashAttention (&Q, &K, &V, attention_bias, output, nullptr , &present_key, nullptr , &present_value,
737+ // FlashAttention supports Q_K_V_BSNH format directly
738+ return ApplyFlashAttention (&Q_bsd, &K_bsd, &V_bsd, attention_bias, output, nullptr , &present_key, nullptr , &present_value,
778739 parameters, context, nullptr );
779740 }
780741
742+ // For non-flash attention path, convert BSD to BNSH format
743+ TensorShapeVector qkv_bnsh_shape ({parameters.batch_size_ , parameters.num_heads_ , parameters.sequence_length_ , parameters.head_size_ });
744+ TensorShapeVector v_bnsh_shape ({parameters.batch_size_ , parameters.num_heads_ , parameters.sequence_length_ , parameters.v_head_size_ });
745+ Tensor Q = context.CreateGPUTensor (input->DataType (), TensorShape (qkv_bnsh_shape));
746+ Tensor K = context.CreateGPUTensor (input->DataType (), TensorShape (qkv_bnsh_shape));
747+ Tensor V = context.CreateGPUTensor (input->DataType (), TensorShape (v_bnsh_shape));
748+
749+ ORT_RETURN_IF_ERROR (TransferBSDToBNSH (context, parameters.num_heads_ , parameters.sequence_length_ ,
750+ parameters.head_size_ , &Q_bsd, nullptr , 0 , &Q));
751+ ORT_RETURN_IF_ERROR (TransferBSDToBNSH (context, parameters.num_heads_ , parameters.sequence_length_ ,
752+ parameters.head_size_ , &K_bsd, nullptr , 0 , &K));
753+ ORT_RETURN_IF_ERROR (TransferBSDToBNSH (context, parameters.num_heads_ , parameters.sequence_length_ ,
754+ parameters.v_head_size_ , &V_bsd, nullptr , 0 , &V));
755+
781756 // Apply the actual attention computation
782757 return ApplyAttention (&Q, &K, &V, attention_bias, nullptr , nullptr , output, /* present_key */ nullptr ,
783758 /* present_value */ nullptr , /* output_qk */ nullptr , parameters, context, nullptr , nullptr , -1 );
0 commit comments