@@ -71,6 +71,7 @@ export interface AttentionParameters {
7171 rotaryInterLeaved ?: number ;
7272 sommoothSoftmax ?: number ;
7373 localWindowsSize ?: number ;
74+ packedQKV ?: boolean ;
7475}
7576
7677export interface AttentionAttrs {
@@ -442,13 +443,14 @@ const createInPlaceSoftmaxProgramInfo = (
442443const createAttentionProbsProgramInfo = (
443444 outputCount : number ,
444445 q : TensorView ,
445- key : TensorView ,
446+ key : TensorView | undefined ,
446447 pastKey : TensorView | undefined ,
447448 attentionBias : TensorView | undefined ,
448449 parameters : AttentionParameters ,
449450 pastSequenceLength : number ,
450451 seqLens : TensorView | undefined ,
451452 totalSequenceLengthInput : TensorView | undefined ,
453+ packedQKV : boolean ,
452454) => {
453455 const totalSequenceLength = pastSequenceLength + parameters . kvSequenceLength ;
454456 const probsShape = [ parameters . batchSize , parameters . numHeads , parameters . sequenceLength , totalSequenceLength ] ;
@@ -474,15 +476,17 @@ const createAttentionProbsProgramInfo = (
474476 { type : DataType . uint32 , data : vectorizedHeadSize } ,
475477 { type : DataType . uint32 , data : totalSequenceLength } ,
476478 { type : DataType . uint32 , data : parameters . numHeads } ,
477- { type : DataType . uint32 , data : parameters . headSize } ,
478479 { type : DataType . float , data : alpha } ,
479480 { type : DataType . uint32 , data : pastSequenceLength } ,
480481 { type : DataType . uint32 , data : parameters . kvSequenceLength } ,
481482 { type : DataType . uint32 , data : nReps } ,
482483 ] ;
483484 // Feed pastKey to the shader-code only if it is non-zero and presentKey is being produced
484485 const feedPastKey = presentKey && pastKey && ShapeUtil . size ( pastKey . dims ) > 0 ;
485- const inputDependencies : ProgramInputTensorInfoDependency [ ] = [ 'type' , 'type' ] ;
486+ const inputDependencies : ProgramInputTensorInfoDependency [ ] = [ 'type' ] ;
487+ if ( key ) {
488+ inputDependencies . push ( 'type' ) ;
489+ }
486490 if ( feedPastKey ) {
487491 inputDependencies . push ( 'type' ) ;
488492 }
@@ -501,8 +505,11 @@ const createAttentionProbsProgramInfo = (
501505 }
502506 const getShaderSource = ( shaderHelper : ShaderHelper ) => {
503507 const qInput = inputVariable ( 'q' , q . dataType , q . dims , components ) ;
504- const kInput = inputVariable ( 'key' , key . dataType , key . dims , components ) ;
505- const inputVars = [ qInput , kInput ] ;
508+ const inputVars = [ qInput ] ;
509+ if ( key ) {
510+ const kInput = inputVariable ( 'key' , key . dataType , key . dims , components ) ;
511+ inputVars . push ( kInput ) ;
512+ }
506513 if ( feedPastKey ) {
507514 const pastKeyInput = inputVariable ( 'past_key' , pastKey . dataType , pastKey . dims , components ) ;
508515 inputVars . push ( pastKeyInput ) ;
@@ -532,7 +539,6 @@ const createAttentionProbsProgramInfo = (
532539 { name : 'K' , type : 'u32' } ,
533540 { name : 'N' , type : 'u32' } ,
534541 { name : 'num_heads' , type : 'u32' } ,
535- { name : 'head_size' , type : 'u32' } ,
536542 { name : 'alpha' , type : 'f32' as UniformDataElementType } ,
537543 { name : 'past_sequence_length' , type : 'u32' } ,
538544 { name : 'kv_sequence_length' , type : 'u32' } ,
@@ -555,10 +561,11 @@ const createAttentionProbsProgramInfo = (
555561 let sequence_length = uniforms.M;
556562 var total_sequence_length = uniforms.N;
557563 ${ initVarStub ( seqLensInputVariable , totalSequenceLengthInputVariable , true ) }
564+ let packed_batch_stride = (uniforms.num_heads + 2 * kv_num_heads) * uniforms.M * uniforms.K;
558565 let absKvHeadIdx = batchIdx * kv_num_heads + kvHeadIdx;
559- let qOffset = workgroup_id.z * uniforms.M * uniforms.K + m * uniforms.K;
566+ let qOffset = ${ packedQKV ? 'batchIdx * packed_batch_stride + headIdx * uniforms.M * uniforms.K' : ' workgroup_id.z * uniforms.M * uniforms.K' } + m * uniforms.K;
560567 ${ feedPastKey && presentKey ? 'let pastKeyOffset = absKvHeadIdx * uniforms.past_sequence_length * uniforms.K;' : '' } ;
561- let kOffset = absKvHeadIdx * uniforms.kv_sequence_length * uniforms.K;
568+ let kOffset = ${ packedQKV ? 'batchIdx * packed_batch_stride + (uniforms.num_heads + kvHeadIdx) * uniforms.kv_sequence_length * uniforms.K' : ' absKvHeadIdx * uniforms.kv_sequence_length * uniforms.K' } ;
562569 ${ presentKey ? 'let presentKeyOffset = absKvHeadIdx * uniforms.N * uniforms.K;' : '' }
563570 var value = ${ f32Type } (0);
564571 for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) {
@@ -573,12 +580,12 @@ const createAttentionProbsProgramInfo = (
573580 if (n + local_id.y < past_sequence_length) {
574581 tileK[idx] = past_key[pastKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x];
575582 } else if (n + local_id.y - past_sequence_length < uniforms.kv_sequence_length) {
576- tileK[idx] = key[kOffset + (n + local_id.y - past_sequence_length) * uniforms.K + w + local_id.x];
583+ tileK[idx] = ${ packedQKV ? 'q' : ' key' } [kOffset + (n + local_id.y - past_sequence_length) * uniforms.K + w + local_id.x];
577584 }` ;
578585 } else {
579586 return `
580587 if (n + local_id.y < uniforms.kv_sequence_length) {
581- tileK[idx] = key[kOffset + (n + local_id.y) * uniforms.K + w + local_id.x];
588+ tileK[idx] = ${ packedQKV ? 'q' : ' key' } [kOffset + (n + local_id.y) * uniforms.K + w + local_id.x];
582589 }` ;
583590 }
584591 } ) ( ) }
@@ -640,6 +647,7 @@ const createVxAttentionScoreProgramInfo = (
640647 pastSequenceLength : number ,
641648 seqLens : TensorView | undefined = undefined ,
642649 totalSequenceLengthInput : TensorView | undefined = undefined ,
650+ packedQKV : boolean ,
643651) => {
644652 const totalSequenceLength = pastSequenceLength + params . kvSequenceLength ;
645653 const nReps = params . nReps ? params . nReps : 1 ;
@@ -662,7 +670,6 @@ const createVxAttentionScoreProgramInfo = (
662670 { type : DataType . uint32 , data : totalSequenceLength } ,
663671 { type : DataType . uint32 , data : params . vHeadSize } ,
664672 { type : DataType . uint32 , data : params . numHeads } ,
665- { type : DataType . uint32 , data : params . headSize } ,
666673 { type : DataType . uint32 , data : repeatedVHiddenSize } ,
667674 { type : DataType . uint32 , data : pastSequenceLength } ,
668675 { type : DataType . uint32 , data : params . kvSequenceLength } ,
@@ -711,7 +718,6 @@ const createVxAttentionScoreProgramInfo = (
711718 { name : 'K' , type : 'u32' } ,
712719 { name : 'N' , type : 'u32' } ,
713720 { name : 'num_heads' , type : 'u32' } ,
714- { name : 'head_size' , type : 'u32' } ,
715721 { name : 'v_hidden_size' , type : 'u32' } ,
716722 { name : 'past_sequence_length' , type : 'u32' } ,
717723 { name : 'kv_sequence_length' , type : 'u32' } ,
@@ -732,10 +738,11 @@ const createVxAttentionScoreProgramInfo = (
732738 let sequence_length = uniforms.M;
733739 var total_sequence_length = uniforms.K;
734740 ${ initVarStub ( seqLensInputVariable , totalSequenceLengthInputVariable , true ) }
741+ let packed_batch_stride = (uniforms.num_heads + 2 * kv_num_heads) * uniforms.M * uniforms.K;
735742 let offsetA = workgroup_id.z * uniforms.M * uniforms.K + m * uniforms.K;
736743 let absKvHeadIdx = batchIdx * kv_num_heads + kvHeadIdx; // kvHeadIdx is relative to the batch
737744 ${ feedPastValue && presentValue ? 'let pastValueOffset = absKvHeadIdx * uniforms.N * uniforms.past_sequence_length + n;' : '' } ;
738- let vOffset = absKvHeadIdx * uniforms.N * uniforms.kv_sequence_length + n;
745+ let vOffset = ${ packedQKV ? 'batchIdx * packed_batch_stride + (uniforms.num_heads + kv_num_heads + kvHeadIdx) * uniforms.N * uniforms.kv_sequence_length' : ' absKvHeadIdx * uniforms.N * uniforms.kv_sequence_length' } + n;
739746 ${ presentValue ? 'let presentValueOffset = absKvHeadIdx * uniforms.N * uniforms.K + n;' : '' }
740747 var value = ${ probsHelper . type . storage } (0);
741748 for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) {
@@ -796,8 +803,8 @@ const createVxAttentionScoreProgramInfo = (
796803export const applyAttention = (
797804 context : ComputeContext ,
798805 q : TensorView ,
799- k : TensorView ,
800- v : TensorView ,
806+ k : TensorView | undefined ,
807+ v : TensorView | undefined ,
801808 _maskIndex : TensorView | undefined ,
802809 _past : TensorView | undefined ,
803810 pastKey : TensorView | undefined ,
@@ -814,7 +821,10 @@ export const applyAttention = (
814821 const attentionBias =
815822 attentionBiasInput && ShapeUtil . size ( attentionBiasInput . dims ) > 0 ? attentionBiasInput : undefined ;
816823
817- const inputsK = [ q , k ] ;
824+ const inputsK = [ q ] ;
825+ if ( k ) {
826+ inputsK . push ( k ) ;
827+ }
818828 if ( outputCount > 1 && pastKey && ShapeUtil . size ( pastKey . dims ) > 0 ) {
819829 inputsK . push ( pastKey ) ;
820830 }
@@ -839,6 +849,7 @@ export const applyAttention = (
839849 pastSequenceLength ,
840850 seqLens ,
841851 totalSequenceLengthInput ,
852+ parameters . packedQKV === true ,
842853 ) ,
843854 { inputs : inputsK , outputs : outputCount > 1 ? [ - 1 , 1 ] : [ - 1 ] } ,
844855 ) [ 0 ] ;
@@ -859,7 +870,7 @@ export const applyAttention = (
859870 ) ;
860871
861872 // Run AttentionScore
862- const inputsV = [ probs , v ] ;
873+ const inputsV = [ probs , parameters . packedQKV ? q : v ! ] ;
863874 if ( outputCount > 1 && pastValue && ShapeUtil . size ( pastValue . dims ) > 0 ) {
864875 inputsV . push ( pastValue ) ;
865876 }
@@ -873,12 +884,13 @@ export const applyAttention = (
873884 createVxAttentionScoreProgramInfo (
874885 outputCount ,
875886 probs ,
876- v ,
887+ parameters . packedQKV ? q : v ! ,
877888 pastValue ,
878889 parameters ,
879890 pastSequenceLength ,
880891 seqLens ,
881892 totalSequenceLengthInput ,
893+ parameters . packedQKV === true ,
882894 ) ,
883895 {
884896 inputs : inputsV ,
0 commit comments