Skip to content

Commit 17ae730

Browse files
authored
[TorchOnnxToTorch] Add pad+scatter KV cache and mask to GQA (#4471)
Replace cat-based KV cache with pad+scatter for correct shared-buffer semantics. With cat(past, current), the current token always lands at past_seq_len, but ORT's GQA places it at `seqlens_k[b]` per batch element. Pad with zeros then scatter to match. Add causal attention mask: `k <= pastLen[b] + q`, reshaped to `[batch, 1, seq, kv_seq]` and passed as a boolean mask to SDPA.
1 parent b083368 commit 17ae730

File tree

3 files changed

+408
-78
lines changed

3 files changed

+408
-78
lines changed

lib/Conversion/TorchOnnxToTorch/ComMicrosoftDomain.cpp

Lines changed: 210 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,7 @@ void mlir::torch::onnx_c::populateComMicrosoftDomain(
290290
// `do_rotary` attribute. If it's false, then the input operands can be
291291
// 7 but if it's true then the operands has to be 9 including cos_cache
292292
// and sin_cache for rotary_embedding.
293+
// TODO: Add support for packed_qkv.
293294
if (!((operands.size() == 9) || (!doRotary && operands.size() == 7)))
294295
return rewriter.notifyMatchFailure(
295296
binder.op, "Unimplemented: excepted input operands to be either "
@@ -322,8 +323,6 @@ void mlir::torch::onnx_c::populateComMicrosoftDomain(
322323
binder.op, "Unimplemented: softcap attribute is not supported, "
323324
"hence it should have default value equal to 0.0");
324325

325-
// TODO: Add support for packed_qkv.
326-
327326
Location loc = binder.getLoc();
328327
MLIRContext *context = binder.op->getContext();
329328
Value query = operands[0];
@@ -562,19 +561,224 @@ void mlir::torch::onnx_c::populateComMicrosoftDomain(
562561
/*scale=*/cstFloatOne);
563562
}
564563

565-
// Do attention.
566-
Value cstEnableGQA = Torch::ConstantBoolOp::create(rewriter, loc, true);
564+
// Build present_key/present_value by padding past with zeros, then
565+
// scattering current K/V into the correct per-batch position.
566+
//
567+
// Why pad instead of cat? With cat(past, current), the current token
568+
// ends up at position past_seq_len. With variable seqlens_k, ORT places
569+
// the current token at seqlens_k[b] and leaves position past_seq_len as
570+
// zero/uninitialized. Using cat+scatter leaves a stale copy of the
571+
// current token at past_seq_len which doesn't match ORT's output.
572+
// Padding with zeros then scattering avoids this.
573+
//
574+
// constant_pad_nd pads innermost dims first: [0, 0, 0, seq_len]
575+
// dim 3 (head_size): [0, 0] -- no padding
576+
// dim 2 (seq): [0, seq_len] -- extend by seq_len on the right
567577
Value cstFloatZero = Torch::ConstantFloatOp::create(
568578
rewriter, loc, rewriter.getType<Torch::FloatType>(),
569579
rewriter.getF64FloatAttr(0.0));
580+
Value padList = Torch::PrimListConstructOp::create(
581+
rewriter, loc, intListType,
582+
SmallVector<Value>{cstIntZero, cstIntZero, cstIntZero,
583+
cstSequenceLength});
584+
Value presentKey = Torch::AtenConstantPadNdOp::create(
585+
rewriter, loc, resultTypes[1], pastKey, padList, cstFloatZero);
586+
Value presentValue = Torch::AtenConstantPadNdOp::create(
587+
rewriter, loc, resultTypes[2], pastValue, padList, cstFloatZero);
588+
589+
// Scatter current K/V into the padded buffer at position pastLen[b]+q.
590+
// pastLen = seqlens_k + 1 - seq_len
591+
Value totalSeqForScatter = Torch::AtenAddScalarOp::create(
592+
rewriter, loc, seqlensK.getType(), seqlensK, cstIntOne, cstIntOne);
593+
Value pastLen = Torch::AtenSubScalarOp::create(
594+
rewriter, loc, totalSeqForScatter.getType(), totalSeqForScatter,
595+
cstSequenceLength, cstIntOne);
596+
597+
// qRange: [0, 1, ..., seqLen-1] — shared by scatter and mask.
598+
Torch::ValueTensorType qRangeType = Torch::ValueTensorType::get(
599+
context, {sequenceLength},
600+
rewriter.getIntegerType(64, /*isSigned=*/true));
601+
Value qRange = Torch::AtenArangeOp::create(
602+
rewriter, loc, qRangeType, cstSequenceLength, cstInt64Dtype,
603+
/*layout=*/cstNone, /*device=*/cstNone, /*pin_memory=*/cstNone);
604+
605+
// pastLen -> [B, 1, 1, 1] for 4D scatter index broadcasting
606+
Value scatterViewList = Torch::PrimListConstructOp::create(
607+
rewriter, loc, intListType,
608+
SmallVector<Value>{cstIntMinusOne, cstIntOne, cstIntOne,
609+
cstIntOne});
610+
SmallVector<int64_t> pastLenView4dSizes{batchSize, 1, 1, 1};
611+
Torch::ValueTensorType pastLenView4dType = Torch::ValueTensorType::get(
612+
context, pastLenView4dSizes,
613+
rewriter.getIntegerType(64, /*isSigned=*/true));
614+
Value pastLenView4d = Torch::AtenViewOp::create(
615+
rewriter, loc, pastLenView4dType, pastLen, scatterViewList);
616+
617+
// qRange -> [1, 1, seq, 1] for scatter
618+
Value scatterQViewList = Torch::PrimListConstructOp::create(
619+
rewriter, loc, intListType,
620+
SmallVector<Value>{cstIntOne, cstIntOne, cstIntMinusOne,
621+
cstIntOne});
622+
SmallVector<int64_t> scatterQViewSizes{1, 1, sequenceLength, 1};
623+
Torch::ValueTensorType scatterQViewType = Torch::ValueTensorType::get(
624+
context, scatterQViewSizes,
625+
rewriter.getIntegerType(64, /*isSigned=*/true));
626+
Value scatterQRangeView = Torch::AtenViewOp::create(
627+
rewriter, loc, scatterQViewType, qRange, scatterQViewList);
628+
629+
// scatterIdxBase = pastLen[B,1,1,1] + qRange[1,1,seq,1]
630+
// -> [B, 1, seq, 1]
631+
SmallVector<int64_t> scatterIdxBaseSizes{batchSize, 1, sequenceLength,
632+
1};
633+
Torch::ValueTensorType scatterIdxBaseType = Torch::ValueTensorType::get(
634+
context, scatterIdxBaseSizes,
635+
rewriter.getIntegerType(64, /*isSigned=*/true));
636+
Value scatterIdxBase = Torch::AtenAddTensorOp::create(
637+
rewriter, loc, scatterIdxBaseType, pastLenView4d, scatterQRangeView,
638+
cstIntOne);
639+
640+
// Expand to [B, kv_heads, seq, head_size] to match current K/V shape
641+
SmallVector<int64_t> scatterExpandSizes{batchSize, kvNumHeads,
642+
sequenceLength, headSize};
643+
Torch::ValueTensorType scatterIdxType = Torch::ValueTensorType::get(
644+
context, scatterExpandSizes,
645+
rewriter.getIntegerType(64, /*isSigned=*/true));
646+
Value scatterExpandSizeList = Torch::PrimListConstructOp::create(
647+
rewriter, loc, intListType,
648+
SmallVector<Value>{cstBatchSize, cstKVNumHeads, cstSequenceLength,
649+
cstHeadSize});
650+
Value scatterIdx = Torch::AtenExpandOp::create(
651+
rewriter, loc, scatterIdxType, scatterIdxBase,
652+
scatterExpandSizeList, /*implicit=*/cstFalse);
653+
654+
// Scatter current K/V into buffer at position pastLen[b] + q
655+
presentKey = Torch::AtenScatterSrcOp::create(
656+
rewriter, loc, resultTypes[1], presentKey, cstDim2, scatterIdx,
657+
kRotary);
658+
presentValue = Torch::AtenScatterSrcOp::create(
659+
rewriter, loc, resultTypes[2], presentValue, cstDim2, scatterIdx,
660+
vInput);
661+
662+
// Generate causal attention mask.
663+
// With scatter, KV layout matches ORT: current at pastLen[b].
664+
// Simple boolean mask: k <= pastLen[b] + q
665+
// Mask shape: [batch, 1, seqLen, kvSeqLen] (i1).
666+
Value attnMask = cstNone;
667+
668+
// Get the KV sequence length from presentKey shape
669+
Torch::ValueTensorType presentKeyType =
670+
cast<Torch::ValueTensorType>(presentKey.getType());
671+
if (presentKeyType.hasSizes() &&
672+
presentKeyType.getSizes().size() == 4) {
673+
int64_t kvSeqLen = presentKeyType.getSizes()[2];
674+
675+
// Only generate mask if KV sequence length is dynamic or > 0
676+
// For dynamic shapes or non-trivial sequences, we need to mask
677+
if (kvSeqLen == Torch::kUnknownSize || kvSeqLen > 0) {
678+
// Get KV sequence dimension size
679+
Value kvSeqLenVal = Torch::AtenSizeIntOp::create(
680+
rewriter, loc, rewriter.getType<Torch::IntType>(), presentKey,
681+
cstDim2);
682+
683+
// kRange: [0, 1, 2, ..., kvSeqLen-1] shape [kvSeqLen]
684+
Torch::ValueTensorType kRangeType = Torch::ValueTensorType::get(
685+
context, {kvSeqLen},
686+
rewriter.getIntegerType(64, /*isSigned=*/true));
687+
Value kRange = Torch::AtenArangeOp::create(
688+
rewriter, loc, kRangeType, kvSeqLenVal, cstInt64Dtype,
689+
/*layout=*/cstNone, /*device=*/cstNone, /*pin_memory=*/cstNone);
690+
691+
// Reshape for broadcasting:
692+
// pastLen: [batch] -> [batch, 1, 1]
693+
// qRange: [seqLen] -> [1, seqLen, 1] (reuses qRange from above)
694+
// kRange: [kvSeqLen] -> [1, 1, kvSeqLen]
695+
696+
// pastLen -> [batch, 1, 1]
697+
Value seqlensViewList = Torch::PrimListConstructOp::create(
698+
rewriter, loc, intListType,
699+
SmallVector<Value>{cstIntMinusOne, cstIntOne, cstIntOne});
700+
SmallVector<int64_t> seqlensViewSizes{batchSize, 1, 1};
701+
Torch::ValueTensorType seqlensViewType =
702+
Torch::ValueTensorType::get(
703+
context, seqlensViewSizes,
704+
rewriter.getIntegerType(64, /*isSigned=*/true));
705+
Value pastLenView = Torch::AtenViewOp::create(
706+
rewriter, loc, seqlensViewType, pastLen, seqlensViewList);
707+
708+
// qRange -> [1, seqLen, 1]
709+
Value qViewList = Torch::PrimListConstructOp::create(
710+
rewriter, loc, intListType,
711+
SmallVector<Value>{cstIntOne, cstIntMinusOne, cstIntOne});
712+
SmallVector<int64_t> qViewSizes{1, sequenceLength, 1};
713+
Torch::ValueTensorType qViewType = Torch::ValueTensorType::get(
714+
context, qViewSizes,
715+
rewriter.getIntegerType(64, /*isSigned=*/true));
716+
Value qRangeView = Torch::AtenViewOp::create(
717+
rewriter, loc, qViewType, qRange, qViewList);
718+
719+
// kRange -> [1, 1, kvSeqLen]
720+
Value kViewList = Torch::PrimListConstructOp::create(
721+
rewriter, loc, intListType,
722+
SmallVector<Value>{cstIntOne, cstIntOne, cstIntMinusOne});
723+
SmallVector<int64_t> kViewSizes{1, 1, kvSeqLen};
724+
Torch::ValueTensorType kViewType = Torch::ValueTensorType::get(
725+
context, kViewSizes,
726+
rewriter.getIntegerType(64, /*isSigned=*/true));
727+
Value kRangeView = Torch::AtenViewOp::create(
728+
rewriter, loc, kViewType, kRange, kViewList);
729+
730+
// Causal mask: k <= pastLen + q
731+
// pastLenView[batch,1,1] + qRangeView[1,seqLen,1]
732+
// -> [batch, seqLen, 1]
733+
SmallVector<int64_t> pastLenPlusQSizes{batchSize, sequenceLength,
734+
1};
735+
Torch::ValueTensorType pastLenPlusQType =
736+
Torch::ValueTensorType::get(
737+
context, pastLenPlusQSizes,
738+
rewriter.getIntegerType(64, /*isSigned=*/true));
739+
Value pastLenPlusQ = Torch::AtenAddTensorOp::create(
740+
rewriter, loc, pastLenPlusQType, pastLenView, qRangeView,
741+
cstIntOne);
742+
743+
// kRangeView[1,1,kvSeqLen] <= pastLenPlusQ[batch,seqLen,1]
744+
// -> [batch, seqLen, kvSeqLen]
745+
SmallVector<int64_t> maskBoolSizes{batchSize, sequenceLength,
746+
kvSeqLen};
747+
Torch::ValueTensorType maskBoolType = Torch::ValueTensorType::get(
748+
context, maskBoolSizes, rewriter.getI1Type());
749+
Value causalMask = Torch::AtenLeTensorOp::create(
750+
rewriter, loc, maskBoolType, kRangeView, pastLenPlusQ);
751+
752+
// Reshape to [batch, 1, seqLen, kvSeqLen] for SDPA.
753+
// Pass the boolean mask directly — downstream backends (e.g.
754+
// IREE's iree_linalg_ext.attention) handle bool-to-float
755+
// conversion internally.
756+
Value maskReshapeSizeList = Torch::PrimListConstructOp::create(
757+
rewriter, loc, intListType,
758+
SmallVector<Value>{cstBatchSize, cstIntOne, cstSequenceLength,
759+
kvSeqLenVal});
760+
SmallVector<int64_t> attnMaskSizes{batchSize, 1, sequenceLength,
761+
kvSeqLen};
762+
Torch::ValueTensorType attnMaskType = Torch::ValueTensorType::get(
763+
context, attnMaskSizes, rewriter.getI1Type());
764+
attnMask = Torch::AtenReshapeOp::create(
765+
rewriter, loc, attnMaskType, causalMask, maskReshapeSizeList);
766+
}
767+
}
768+
769+
// Do attention with full KV cache (past + current) and mask.
770+
Value cstEnableGQA = Torch::ConstantBoolOp::create(rewriter, loc, true);
570771
Value cstScale = cstNone;
571772
if (scale != 0.0f)
572773
cstScale = Torch::ConstantFloatOp::create(
573774
rewriter, loc, rewriter.getType<Torch::FloatType>(),
574775
rewriter.getF64FloatAttr(scale));
776+
777+
// Use presentKey/presentValue (full KV cache) for attention, not just
778+
// the current token's K/V. This is essential for proper KV caching.
575779
Value attention = Torch::AtenScaledDotProductAttentionOp::create(
576-
rewriter, loc, qRotary.getType(), qRotary, kRotary, vInput,
577-
/*attn_mask=*/cstNone,
780+
rewriter, loc, qRotary.getType(), qRotary, presentKey, presentValue,
781+
/*attn_mask=*/attnMask,
578782
/*dropout_p=*/cstFloatZero, /*is_causal=*/cstFalse, cstScale,
579783
cstEnableGQA);
580784

@@ -607,40 +811,6 @@ void mlir::torch::onnx_c::populateComMicrosoftDomain(
607811
attnTransposed,
608812
attentionResultSizesList);
609813

610-
// Compute 2nd and 3rd result: present_key, present_value.
611-
// present_key = torch.cat([past_key, key], dim=2) or past_key
612-
// present_value = torch.cat([past_value, value], dim=2) or past_value
613-
Value presentKey = pastKey, presentValue = pastValue;
614-
if (!llvm::equal(
615-
cast<Torch::ValueTensorType>(pastKey.getType()).getSizes(),
616-
cast<Torch::ValueTensorType>(resultTypes[1]).getSizes())) {
617-
Value cstConcatDim = Torch::ConstantIntOp::create(
618-
rewriter, loc, rewriter.getI64IntegerAttr(2));
619-
Type kvListElemType = keyType.getWithSizesAndDtype(
620-
/*optionalSizes=*/std::nullopt,
621-
/*optionalDtype=*/nullptr);
622-
Type kvListType = Torch::ListType::get(kvListElemType);
623-
Value keyList = Torch::PrimListConstructOp::create(
624-
rewriter, loc, kvListType, SmallVector<Value>{pastKey, kRotary});
625-
presentKey = Torch::AtenCatOp::create(rewriter, loc, resultTypes[1],
626-
keyList, cstConcatDim);
627-
}
628-
629-
if (!llvm::equal(
630-
cast<Torch::ValueTensorType>(pastValue.getType()).getSizes(),
631-
cast<Torch::ValueTensorType>(resultTypes[2]).getSizes())) {
632-
Value cstConcatDim = Torch::ConstantIntOp::create(
633-
rewriter, loc, rewriter.getI64IntegerAttr(2));
634-
Type kvListElemType = keyType.getWithSizesAndDtype(
635-
/*optionalSizes=*/std::nullopt,
636-
/*optionalDtype=*/nullptr);
637-
Type kvListType = Torch::ListType::get(kvListElemType);
638-
Value valueList = Torch::PrimListConstructOp::create(
639-
rewriter, loc, kvListType, SmallVector<Value>{pastValue, vInput});
640-
presentValue = Torch::AtenCatOp::create(rewriter, loc, resultTypes[2],
641-
valueList, cstConcatDim);
642-
}
643-
644814
rewriter.replaceOp(binder.op, {attention, presentKey, presentValue});
645815
return success();
646816
});

0 commit comments

Comments
 (0)