Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 32 additions & 20 deletions js/web/lib/wasm/jsep/webgpu/ops/attention.ts
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ export interface AttentionParameters {
rotaryInterLeaved?: number;
sommoothSoftmax?: number;
localWindowsSize?: number;
packedQKV?: boolean;
}

export interface AttentionAttrs {
Expand Down Expand Up @@ -442,13 +443,14 @@ const createInPlaceSoftmaxProgramInfo = (
const createAttentionProbsProgramInfo = (
outputCount: number,
q: TensorView,
key: TensorView,
key: TensorView | undefined,
pastKey: TensorView | undefined,
attentionBias: TensorView | undefined,
parameters: AttentionParameters,
pastSequenceLength: number,
seqLens: TensorView | undefined,
totalSequenceLengthInput: TensorView | undefined,
packedQKV: boolean,
) => {
const totalSequenceLength = pastSequenceLength + parameters.kvSequenceLength;
const probsShape = [parameters.batchSize, parameters.numHeads, parameters.sequenceLength, totalSequenceLength];
Expand All @@ -474,15 +476,17 @@ const createAttentionProbsProgramInfo = (
{ type: DataType.uint32, data: vectorizedHeadSize },
{ type: DataType.uint32, data: totalSequenceLength },
{ type: DataType.uint32, data: parameters.numHeads },
{ type: DataType.uint32, data: parameters.headSize },
{ type: DataType.float, data: alpha },
{ type: DataType.uint32, data: pastSequenceLength },
{ type: DataType.uint32, data: parameters.kvSequenceLength },
{ type: DataType.uint32, data: nReps },
];
// Feed pastKey to the shader-code only if it is non-zero and presentKey is being produced
const feedPastKey = presentKey && pastKey && ShapeUtil.size(pastKey.dims) > 0;
const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type'];
const inputDependencies: ProgramInputTensorInfoDependency[] = ['type'];
if (key) {
inputDependencies.push('type');
}
if (feedPastKey) {
inputDependencies.push('type');
}
Expand All @@ -501,8 +505,11 @@ const createAttentionProbsProgramInfo = (
}
const getShaderSource = (shaderHelper: ShaderHelper) => {
const qInput = inputVariable('q', q.dataType, q.dims, components);
const kInput = inputVariable('key', key.dataType, key.dims, components);
const inputVars = [qInput, kInput];
const inputVars = [qInput];
if (key) {
const kInput = inputVariable('key', key.dataType, key.dims, components);
inputVars.push(kInput);
}
if (feedPastKey) {
const pastKeyInput = inputVariable('past_key', pastKey.dataType, pastKey.dims, components);
inputVars.push(pastKeyInput);
Expand Down Expand Up @@ -532,7 +539,6 @@ const createAttentionProbsProgramInfo = (
{ name: 'K', type: 'u32' },
{ name: 'N', type: 'u32' },
{ name: 'num_heads', type: 'u32' },
{ name: 'head_size', type: 'u32' },
{ name: 'alpha', type: 'f32' as UniformDataElementType },
{ name: 'past_sequence_length', type: 'u32' },
{ name: 'kv_sequence_length', type: 'u32' },
Expand All @@ -555,10 +561,11 @@ const createAttentionProbsProgramInfo = (
let sequence_length = uniforms.M;
var total_sequence_length = uniforms.N;
${initVarStub(seqLensInputVariable, totalSequenceLengthInputVariable, true)}
let packed_batch_stride = (uniforms.num_heads + 2 * kv_num_heads) * uniforms.M * uniforms.K;
let absKvHeadIdx = batchIdx * kv_num_heads + kvHeadIdx;
let qOffset = workgroup_id.z * uniforms.M * uniforms.K + m * uniforms.K;
let qOffset = ${packedQKV ? 'batchIdx * packed_batch_stride + headIdx * uniforms.M * uniforms.K' : 'workgroup_id.z * uniforms.M * uniforms.K'} + m * uniforms.K;
${feedPastKey && presentKey ? 'let pastKeyOffset = absKvHeadIdx * uniforms.past_sequence_length * uniforms.K;' : ''};
let kOffset = absKvHeadIdx * uniforms.kv_sequence_length * uniforms.K;
let kOffset = ${packedQKV ? 'batchIdx * packed_batch_stride + (uniforms.num_heads + kvHeadIdx) * uniforms.kv_sequence_length * uniforms.K' : 'absKvHeadIdx * uniforms.kv_sequence_length * uniforms.K'};
${presentKey ? 'let presentKeyOffset = absKvHeadIdx * uniforms.N * uniforms.K;' : ''}
var value = ${f32Type}(0);
for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) {
Expand All @@ -573,12 +580,12 @@ const createAttentionProbsProgramInfo = (
if (n + local_id.y < past_sequence_length) {
tileK[idx] = past_key[pastKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x];
} else if (n + local_id.y - past_sequence_length < uniforms.kv_sequence_length) {
tileK[idx] = key[kOffset + (n + local_id.y - past_sequence_length) * uniforms.K + w + local_id.x];
tileK[idx] = ${packedQKV ? 'q' : 'key'}[kOffset + (n + local_id.y - past_sequence_length) * uniforms.K + w + local_id.x];
}`;
} else {
return `
if (n + local_id.y < uniforms.kv_sequence_length) {
tileK[idx] = key[kOffset + (n + local_id.y) * uniforms.K + w + local_id.x];
tileK[idx] = ${packedQKV ? 'q' : 'key'}[kOffset + (n + local_id.y) * uniforms.K + w + local_id.x];
}`;
}
})()}
Expand Down Expand Up @@ -623,7 +630,7 @@ const createAttentionProbsProgramInfo = (
return {
name: 'AttentionProbs',
shaderCache: {
hint: `${components};${attentionBias !== undefined};${pastKey !== undefined};${outputCount}`,
hint: `${components};${attentionBias !== undefined};${pastKey !== undefined};${outputCount};${packedQKV}`,
inputDependencies,
},
getRunData: () => ({ outputs, dispatchGroup: dispatch, programUniforms }),
Expand All @@ -640,6 +647,7 @@ const createVxAttentionScoreProgramInfo = (
pastSequenceLength: number,
seqLens: TensorView | undefined = undefined,
totalSequenceLengthInput: TensorView | undefined = undefined,
packedQKV: boolean,
) => {
const totalSequenceLength = pastSequenceLength + params.kvSequenceLength;
const nReps = params.nReps ? params.nReps : 1;
Expand All @@ -662,7 +670,6 @@ const createVxAttentionScoreProgramInfo = (
{ type: DataType.uint32, data: totalSequenceLength },
{ type: DataType.uint32, data: params.vHeadSize },
{ type: DataType.uint32, data: params.numHeads },
{ type: DataType.uint32, data: params.headSize },
{ type: DataType.uint32, data: repeatedVHiddenSize },
{ type: DataType.uint32, data: pastSequenceLength },
{ type: DataType.uint32, data: params.kvSequenceLength },
Expand Down Expand Up @@ -711,7 +718,6 @@ const createVxAttentionScoreProgramInfo = (
{ name: 'K', type: 'u32' },
{ name: 'N', type: 'u32' },
{ name: 'num_heads', type: 'u32' },
{ name: 'head_size', type: 'u32' },
{ name: 'v_hidden_size', type: 'u32' },
{ name: 'past_sequence_length', type: 'u32' },
{ name: 'kv_sequence_length', type: 'u32' },
Expand All @@ -732,10 +738,11 @@ const createVxAttentionScoreProgramInfo = (
let sequence_length = uniforms.M;
var total_sequence_length = uniforms.K;
${initVarStub(seqLensInputVariable, totalSequenceLengthInputVariable, true)}
let packed_batch_stride = (uniforms.num_heads + 2 * kv_num_heads) * uniforms.M * uniforms.K;
let offsetA = workgroup_id.z * uniforms.M * uniforms.K + m * uniforms.K;
let absKvHeadIdx = batchIdx * kv_num_heads + kvHeadIdx; // kvHeadIdx is relative to the batch
${feedPastValue && presentValue ? 'let pastValueOffset = absKvHeadIdx * uniforms.N * uniforms.past_sequence_length + n;' : ''};
let vOffset = absKvHeadIdx * uniforms.N * uniforms.kv_sequence_length + n;
let vOffset = ${packedQKV ? 'batchIdx * packed_batch_stride + (uniforms.num_heads + kv_num_heads + kvHeadIdx) * uniforms.N * uniforms.kv_sequence_length' : 'absKvHeadIdx * uniforms.N * uniforms.kv_sequence_length'} + n;
${presentValue ? 'let presentValueOffset = absKvHeadIdx * uniforms.N * uniforms.K + n;' : ''}
var value = ${probsHelper.type.storage}(0);
for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) {
Expand Down Expand Up @@ -787,7 +794,7 @@ const createVxAttentionScoreProgramInfo = (

return {
name: 'AttentionScore',
shaderCache: { hint: `${pastValue !== undefined};${outputCount}`, inputDependencies },
shaderCache: { hint: `${pastValue !== undefined};${outputCount};${packedQKV}`, inputDependencies },
getRunData: () => ({ outputs, dispatchGroup: dispatch, programUniforms }),
getShaderSource,
};
Expand All @@ -796,8 +803,8 @@ const createVxAttentionScoreProgramInfo = (
export const applyAttention = (
context: ComputeContext,
q: TensorView,
k: TensorView,
v: TensorView,
k: TensorView | undefined,
v: TensorView | undefined,
_maskIndex: TensorView | undefined,
_past: TensorView | undefined,
pastKey: TensorView | undefined,
Expand All @@ -814,7 +821,10 @@ export const applyAttention = (
const attentionBias =
attentionBiasInput && ShapeUtil.size(attentionBiasInput.dims) > 0 ? attentionBiasInput : undefined;

const inputsK = [q, k];
const inputsK = [q];
if (k) {
inputsK.push(k);
}
if (outputCount > 1 && pastKey && ShapeUtil.size(pastKey.dims) > 0) {
inputsK.push(pastKey);
}
Expand All @@ -839,6 +849,7 @@ export const applyAttention = (
pastSequenceLength,
seqLens,
totalSequenceLengthInput,
parameters.packedQKV === true,
),
{ inputs: inputsK, outputs: outputCount > 1 ? [-1, 1] : [-1] },
)[0];
Expand All @@ -859,7 +870,7 @@ export const applyAttention = (
);

// Run AttentionScore
const inputsV = [probs, v];
const inputsV = [probs, parameters.packedQKV ? q : v!];
if (outputCount > 1 && pastValue && ShapeUtil.size(pastValue.dims) > 0) {
inputsV.push(pastValue);
}
Expand All @@ -873,12 +884,13 @@ export const applyAttention = (
createVxAttentionScoreProgramInfo(
outputCount,
probs,
v,
parameters.packedQKV ? q : v!,
pastValue,
parameters,
pastSequenceLength,
seqLens,
totalSequenceLengthInput,
parameters.packedQKV === true,
),
{
inputs: inputsV,
Expand Down
34 changes: 7 additions & 27 deletions js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import { ComputeContext } from '../types';

import { applyAttention, AttentionMaskType, AttentionParameters, AttentionQkvFormat } from './attention';
import { maybeTransposeToBNSHAndAddBias } from './multihead-attention';
import { createSplitProgramInfo, SplitAttributes } from './split';
import { createTransposeProgramInfo, TransposeAttributes } from './transpose';
export interface GroupQueryAttentionAttributes {
numHeads: number;
Expand Down Expand Up @@ -219,6 +218,7 @@ export const validateInputs = (
broadcastResPosBias,
passPastInKv,
qkvFormat,
packedQKV,
};
};

Expand All @@ -240,39 +240,19 @@ const maybeTransposeToBNSH = (context: ComputeContext, input: TensorView, params

export const groupQueryAttention = (context: ComputeContext, attributes: GroupQueryAttentionAttributes): void => {
const params = validateInputs(context.inputs, attributes);
if (context.inputs[0].dims.length === 5) {
throw new Error('Packed QKV is not implemented');
}

if (context.inputs[1]?.dims.length === 5) {
throw new Error('Packed KV is not implemented');
}

const q = context.inputs[0];
const k = context.inputs[1] && context.inputs[1].dims.length > 0 ? context.inputs[1] : undefined;
const v = context.inputs[2] && context.inputs[2].dims.length > 0 ? context.inputs[2] : undefined;
const query = context.inputs[0];
const key = context.inputs[1] && context.inputs[1].dims.length > 0 ? context.inputs[1] : undefined;
const value = context.inputs[2] && context.inputs[2].dims.length > 0 ? context.inputs[2] : undefined;
const pastKey = context.inputs[3] && context.inputs[3].dims.length !== 0 ? context.inputs[3] : undefined;
const pastValue = context.inputs[4] && context.inputs[4].dims.length !== 0 ? context.inputs[4] : undefined;
const seqLens = context.inputs.length > 4 ? context.inputs[5] : undefined;
const totalSequenceLengthInput = context.inputs.length > 5 ? context.inputs[6] : undefined;
const kvNumHeads = params.kvNumHeads ? params.kvNumHeads : params.numHeads;

// TODO Remove explicit split operation and use indexing in Attention implementation to avoid overhead.

const splitAttributes: SplitAttributes = createAttributeWithCacheKey({
axis: 2,
numOutputs: 3,
splitSizes: [params.numHeads * params.headSize, kvNumHeads * params.headSize, kvNumHeads * params.headSize],
});
const [query, key, value] =
!k && !v
? context.compute(createSplitProgramInfo([q], splitAttributes), { inputs: [q], outputs: [-1, -1, -1] })
: [q, k!, v!];

const Q = maybeTransposeToBNSHAndAddBias(
context,
params.batchSize,
params.numHeads,
params.packedQKV ? params.numHeads + 2 * attributes.kvNumHeads : params.numHeads,
params.sequenceLength,
params.headSize,
query,
Expand All @@ -282,8 +262,8 @@ export const groupQueryAttention = (context: ComputeContext, attributes: GroupQu
applyAttention(
context,
Q,
maybeTransposeToBNSH(context, key, params),
maybeTransposeToBNSH(context, value, params),
key ? maybeTransposeToBNSH(context, key, params) : undefined,
value ? maybeTransposeToBNSH(context, value, params) : undefined,
undefined,
undefined,
pastKey,
Expand Down
2 changes: 1 addition & 1 deletion js/web/lib/wasm/jsep/webgpu/ops/split.ts
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ const writeBufferDataImpl = (outputs: readonly IndicesHelper[]) => {
}`;
};

export const createSplitProgramInfo = (inputs: readonly TensorView[], attributes: SplitAttributes): ProgramInfo => {
const createSplitProgramInfo = (inputs: readonly TensorView[], attributes: SplitAttributes): ProgramInfo => {
const inputShape = inputs[0].dims;
const inputSize = ShapeUtil.size(inputShape);
const dataType = inputs[0].dataType;
Expand Down
Loading