Skip to content

Commit 25199fc

Browse files
Remove explicit split operator in GQA when QKV packed.
1 parent 7ad7873 commit 25199fc

File tree

3 files changed

+38
-46
lines changed

3 files changed

+38
-46
lines changed

js/web/lib/wasm/jsep/webgpu/ops/attention.ts

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ export interface AttentionParameters {
7171
rotaryInterLeaved?: number;
7272
sommoothSoftmax?: number;
7373
localWindowsSize?: number;
74+
packedQKV?: boolean;
7475
}
7576

7677
export interface AttentionAttrs {
@@ -442,13 +443,14 @@ const createInPlaceSoftmaxProgramInfo = (
442443
const createAttentionProbsProgramInfo = (
443444
outputCount: number,
444445
q: TensorView,
445-
key: TensorView,
446+
key: TensorView | undefined,
446447
pastKey: TensorView | undefined,
447448
attentionBias: TensorView | undefined,
448449
parameters: AttentionParameters,
449450
pastSequenceLength: number,
450451
seqLens: TensorView | undefined,
451452
totalSequenceLengthInput: TensorView | undefined,
453+
packedQKV: boolean,
452454
) => {
453455
const totalSequenceLength = pastSequenceLength + parameters.kvSequenceLength;
454456
const probsShape = [parameters.batchSize, parameters.numHeads, parameters.sequenceLength, totalSequenceLength];
@@ -474,15 +476,17 @@ const createAttentionProbsProgramInfo = (
474476
{ type: DataType.uint32, data: vectorizedHeadSize },
475477
{ type: DataType.uint32, data: totalSequenceLength },
476478
{ type: DataType.uint32, data: parameters.numHeads },
477-
{ type: DataType.uint32, data: parameters.headSize },
478479
{ type: DataType.float, data: alpha },
479480
{ type: DataType.uint32, data: pastSequenceLength },
480481
{ type: DataType.uint32, data: parameters.kvSequenceLength },
481482
{ type: DataType.uint32, data: nReps },
482483
];
483484
// Feed pastKey to the shader-code only if it is non-zero and presentKey is being produced
484485
const feedPastKey = presentKey && pastKey && ShapeUtil.size(pastKey.dims) > 0;
485-
const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type'];
486+
const inputDependencies: ProgramInputTensorInfoDependency[] = ['type'];
487+
if (key) {
488+
inputDependencies.push('type');
489+
}
486490
if (feedPastKey) {
487491
inputDependencies.push('type');
488492
}
@@ -501,8 +505,11 @@ const createAttentionProbsProgramInfo = (
501505
}
502506
const getShaderSource = (shaderHelper: ShaderHelper) => {
503507
const qInput = inputVariable('q', q.dataType, q.dims, components);
504-
const kInput = inputVariable('key', key.dataType, key.dims, components);
505-
const inputVars = [qInput, kInput];
508+
const inputVars = [qInput];
509+
if (key) {
510+
const kInput = inputVariable('key', key.dataType, key.dims, components);
511+
inputVars.push(kInput);
512+
}
506513
if (feedPastKey) {
507514
const pastKeyInput = inputVariable('past_key', pastKey.dataType, pastKey.dims, components);
508515
inputVars.push(pastKeyInput);
@@ -532,7 +539,6 @@ const createAttentionProbsProgramInfo = (
532539
{ name: 'K', type: 'u32' },
533540
{ name: 'N', type: 'u32' },
534541
{ name: 'num_heads', type: 'u32' },
535-
{ name: 'head_size', type: 'u32' },
536542
{ name: 'alpha', type: 'f32' as UniformDataElementType },
537543
{ name: 'past_sequence_length', type: 'u32' },
538544
{ name: 'kv_sequence_length', type: 'u32' },
@@ -555,10 +561,11 @@ const createAttentionProbsProgramInfo = (
555561
let sequence_length = uniforms.M;
556562
var total_sequence_length = uniforms.N;
557563
${initVarStub(seqLensInputVariable, totalSequenceLengthInputVariable, true)}
564+
let packed_batch_stride = (uniforms.num_heads + 2 * kv_num_heads) * uniforms.M * uniforms.K;
558565
let absKvHeadIdx = batchIdx * kv_num_heads + kvHeadIdx;
559-
let qOffset = workgroup_id.z * uniforms.M * uniforms.K + m * uniforms.K;
566+
let qOffset = ${packedQKV ? 'batchIdx * packed_batch_stride + headIdx * uniforms.M * uniforms.K' : 'workgroup_id.z * uniforms.M * uniforms.K'} + m * uniforms.K;
560567
${feedPastKey && presentKey ? 'let pastKeyOffset = absKvHeadIdx * uniforms.past_sequence_length * uniforms.K;' : ''};
561-
let kOffset = absKvHeadIdx * uniforms.kv_sequence_length * uniforms.K;
568+
let kOffset = ${packedQKV ? 'batchIdx * packed_batch_stride + (uniforms.num_heads + kvHeadIdx) * uniforms.kv_sequence_length * uniforms.K' : 'absKvHeadIdx * uniforms.kv_sequence_length * uniforms.K'};
562569
${presentKey ? 'let presentKeyOffset = absKvHeadIdx * uniforms.N * uniforms.K;' : ''}
563570
var value = ${f32Type}(0);
564571
for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) {
@@ -573,12 +580,12 @@ const createAttentionProbsProgramInfo = (
573580
if (n + local_id.y < past_sequence_length) {
574581
tileK[idx] = past_key[pastKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x];
575582
} else if (n + local_id.y - past_sequence_length < uniforms.kv_sequence_length) {
576-
tileK[idx] = key[kOffset + (n + local_id.y - past_sequence_length) * uniforms.K + w + local_id.x];
583+
tileK[idx] = ${packedQKV ? 'q' : 'key'}[kOffset + (n + local_id.y - past_sequence_length) * uniforms.K + w + local_id.x];
577584
}`;
578585
} else {
579586
return `
580587
if (n + local_id.y < uniforms.kv_sequence_length) {
581-
tileK[idx] = key[kOffset + (n + local_id.y) * uniforms.K + w + local_id.x];
588+
tileK[idx] = ${packedQKV ? 'q' : 'key'}[kOffset + (n + local_id.y) * uniforms.K + w + local_id.x];
582589
}`;
583590
}
584591
})()}
@@ -640,6 +647,7 @@ const createVxAttentionScoreProgramInfo = (
640647
pastSequenceLength: number,
641648
seqLens: TensorView | undefined = undefined,
642649
totalSequenceLengthInput: TensorView | undefined = undefined,
650+
packedQKV: boolean,
643651
) => {
644652
const totalSequenceLength = pastSequenceLength + params.kvSequenceLength;
645653
const nReps = params.nReps ? params.nReps : 1;
@@ -662,7 +670,6 @@ const createVxAttentionScoreProgramInfo = (
662670
{ type: DataType.uint32, data: totalSequenceLength },
663671
{ type: DataType.uint32, data: params.vHeadSize },
664672
{ type: DataType.uint32, data: params.numHeads },
665-
{ type: DataType.uint32, data: params.headSize },
666673
{ type: DataType.uint32, data: repeatedVHiddenSize },
667674
{ type: DataType.uint32, data: pastSequenceLength },
668675
{ type: DataType.uint32, data: params.kvSequenceLength },
@@ -711,7 +718,6 @@ const createVxAttentionScoreProgramInfo = (
711718
{ name: 'K', type: 'u32' },
712719
{ name: 'N', type: 'u32' },
713720
{ name: 'num_heads', type: 'u32' },
714-
{ name: 'head_size', type: 'u32' },
715721
{ name: 'v_hidden_size', type: 'u32' },
716722
{ name: 'past_sequence_length', type: 'u32' },
717723
{ name: 'kv_sequence_length', type: 'u32' },
@@ -732,10 +738,11 @@ const createVxAttentionScoreProgramInfo = (
732738
let sequence_length = uniforms.M;
733739
var total_sequence_length = uniforms.K;
734740
${initVarStub(seqLensInputVariable, totalSequenceLengthInputVariable, true)}
741+
let packed_batch_stride = (uniforms.num_heads + 2 * kv_num_heads) * uniforms.M * uniforms.K;
735742
let offsetA = workgroup_id.z * uniforms.M * uniforms.K + m * uniforms.K;
736743
let absKvHeadIdx = batchIdx * kv_num_heads + kvHeadIdx; // kvHeadIdx is relative to the batch
737744
${feedPastValue && presentValue ? 'let pastValueOffset = absKvHeadIdx * uniforms.N * uniforms.past_sequence_length + n;' : ''};
738-
let vOffset = absKvHeadIdx * uniforms.N * uniforms.kv_sequence_length + n;
745+
let vOffset = ${packedQKV ? 'batchIdx * packed_batch_stride + (uniforms.num_heads + kv_num_heads + kvHeadIdx) * uniforms.N * uniforms.kv_sequence_length' : 'absKvHeadIdx * uniforms.N * uniforms.kv_sequence_length'} + n;
739746
${presentValue ? 'let presentValueOffset = absKvHeadIdx * uniforms.N * uniforms.K + n;' : ''}
740747
var value = ${probsHelper.type.storage}(0);
741748
for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) {
@@ -796,8 +803,8 @@ const createVxAttentionScoreProgramInfo = (
796803
export const applyAttention = (
797804
context: ComputeContext,
798805
q: TensorView,
799-
k: TensorView,
800-
v: TensorView,
806+
k: TensorView | undefined,
807+
v: TensorView | undefined,
801808
_maskIndex: TensorView | undefined,
802809
_past: TensorView | undefined,
803810
pastKey: TensorView | undefined,
@@ -814,7 +821,10 @@ export const applyAttention = (
814821
const attentionBias =
815822
attentionBiasInput && ShapeUtil.size(attentionBiasInput.dims) > 0 ? attentionBiasInput : undefined;
816823

817-
const inputsK = [q, k];
824+
const inputsK = [q];
825+
if (k) {
826+
inputsK.push(k);
827+
}
818828
if (outputCount > 1 && pastKey && ShapeUtil.size(pastKey.dims) > 0) {
819829
inputsK.push(pastKey);
820830
}
@@ -839,6 +849,7 @@ export const applyAttention = (
839849
pastSequenceLength,
840850
seqLens,
841851
totalSequenceLengthInput,
852+
parameters.packedQKV === true,
842853
),
843854
{ inputs: inputsK, outputs: outputCount > 1 ? [-1, 1] : [-1] },
844855
)[0];
@@ -859,7 +870,7 @@ export const applyAttention = (
859870
);
860871

861872
// Run AttentionScore
862-
const inputsV = [probs, v];
873+
const inputsV = [probs, parameters.packedQKV ? q : v!];
863874
if (outputCount > 1 && pastValue && ShapeUtil.size(pastValue.dims) > 0) {
864875
inputsV.push(pastValue);
865876
}
@@ -873,12 +884,13 @@ export const applyAttention = (
873884
createVxAttentionScoreProgramInfo(
874885
outputCount,
875886
probs,
876-
v,
887+
parameters.packedQKV ? q : v!,
877888
pastValue,
878889
parameters,
879890
pastSequenceLength,
880891
seqLens,
881892
totalSequenceLengthInput,
893+
parameters.packedQKV === true,
882894
),
883895
{
884896
inputs: inputsV,

js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts

Lines changed: 7 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ import { ComputeContext } from '../types';
77

88
import { applyAttention, AttentionMaskType, AttentionParameters, AttentionQkvFormat } from './attention';
99
import { maybeTransposeToBNSHAndAddBias } from './multihead-attention';
10-
import { createSplitProgramInfo, SplitAttributes } from './split';
1110
import { createTransposeProgramInfo, TransposeAttributes } from './transpose';
1211
export interface GroupQueryAttentionAttributes {
1312
numHeads: number;
@@ -216,6 +215,7 @@ export const validateInputs = (
216215
broadcastResPosBias,
217216
passPastInKv,
218217
qkvFormat,
218+
packedQKV,
219219
};
220220
};
221221

@@ -237,39 +237,19 @@ const maybeTransposeToBNSH = (context: ComputeContext, input: TensorView, params
237237

238238
export const groupQueryAttention = (context: ComputeContext, attributes: GroupQueryAttentionAttributes): void => {
239239
const params = validateInputs(context.inputs, attributes);
240-
if (context.inputs[0].dims.length === 5) {
241-
throw new Error('Packed QKV is not implemented');
242-
}
243-
244-
if (context.inputs[1]?.dims.length === 5) {
245-
throw new Error('Packed KV is not implemented');
246-
}
247240

248-
const q = context.inputs[0];
249-
const k = context.inputs[1] && context.inputs[1].dims.length > 0 ? context.inputs[1] : undefined;
250-
const v = context.inputs[2] && context.inputs[2].dims.length > 0 ? context.inputs[2] : undefined;
241+
const query = context.inputs[0];
242+
const key = context.inputs[1] && context.inputs[1].dims.length > 0 ? context.inputs[1] : undefined;
243+
const value = context.inputs[2] && context.inputs[2].dims.length > 0 ? context.inputs[2] : undefined;
251244
const pastKey = context.inputs[3] && context.inputs[3].dims.length !== 0 ? context.inputs[3] : undefined;
252245
const pastValue = context.inputs[4] && context.inputs[4].dims.length !== 0 ? context.inputs[4] : undefined;
253246
const seqLens = context.inputs.length > 4 ? context.inputs[5] : undefined;
254247
const totalSequenceLengthInput = context.inputs.length > 5 ? context.inputs[6] : undefined;
255-
const kvNumHeads = params.kvNumHeads ? params.kvNumHeads : params.numHeads;
256-
257-
// TODO Remove explicit split operation and use indexing in Attention implementation to avoid overhead.
258-
259-
const splitAttributes: SplitAttributes = createAttributeWithCacheKey({
260-
axis: 2,
261-
numOutputs: 3,
262-
splitSizes: [params.numHeads * params.headSize, kvNumHeads * params.headSize, kvNumHeads * params.headSize],
263-
});
264-
const [query, key, value] =
265-
!k && !v
266-
? context.compute(createSplitProgramInfo([q], splitAttributes), { inputs: [q], outputs: [-1, -1, -1] })
267-
: [q, k!, v!];
268248

269249
const Q = maybeTransposeToBNSHAndAddBias(
270250
context,
271251
params.batchSize,
272-
params.numHeads,
252+
params.packedQKV ? params.numHeads + 2 * attributes.kvNumHeads : params.numHeads,
273253
params.sequenceLength,
274254
params.headSize,
275255
query,
@@ -279,8 +259,8 @@ export const groupQueryAttention = (context: ComputeContext, attributes: GroupQu
279259
applyAttention(
280260
context,
281261
Q,
282-
maybeTransposeToBNSH(context, key, params),
283-
maybeTransposeToBNSH(context, value, params),
262+
key ? maybeTransposeToBNSH(context, key, params) : undefined,
263+
value ? maybeTransposeToBNSH(context, value, params) : undefined,
284264
undefined,
285265
undefined,
286266
pastKey,

js/web/lib/wasm/jsep/webgpu/ops/split.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ const writeBufferDataImpl = (outputs: readonly IndicesHelper[]) => {
7171
}`;
7272
};
7373

74-
export const createSplitProgramInfo = (inputs: readonly TensorView[], attributes: SplitAttributes): ProgramInfo => {
74+
const createSplitProgramInfo = (inputs: readonly TensorView[], attributes: SplitAttributes): ProgramInfo => {
7575
const inputShape = inputs[0].dims;
7676
const inputSize = ShapeUtil.size(inputShape);
7777
const dataType = inputs[0].dataType;

0 commit comments

Comments
 (0)