@@ -290,6 +290,7 @@ void mlir::torch::onnx_c::populateComMicrosoftDomain(
290290 // `do_rotary` attribute. If it's false, then the input operands can be
291291 // 7 but if it's true then the operands has to be 9 including cos_cache
292292 // and sin_cache for rotary_embedding.
293+ // TODO: Add support for packed_qkv.
293294 if (!((operands.size () == 9 ) || (!doRotary && operands.size () == 7 )))
294295 return rewriter.notifyMatchFailure (
295296 binder.op , " Unimplemented: excepted input operands to be either "
@@ -322,8 +323,6 @@ void mlir::torch::onnx_c::populateComMicrosoftDomain(
322323 binder.op , " Unimplemented: softcap attribute is not supported, "
323324 " hence it should have default value equal to 0.0" );
324325
325- // TODO: Add support for packed_qkv.
326-
327326 Location loc = binder.getLoc ();
328327 MLIRContext *context = binder.op ->getContext ();
329328 Value query = operands[0 ];
@@ -562,19 +561,224 @@ void mlir::torch::onnx_c::populateComMicrosoftDomain(
562561 /* scale=*/ cstFloatOne);
563562 }
564563
565- // Do attention.
566- Value cstEnableGQA = Torch::ConstantBoolOp::create (rewriter, loc, true );
564+ // Build present_key/present_value by padding past with zeros, then
565+ // scattering current K/V into the correct per-batch position.
566+ //
567+ // Why pad instead of cat? With cat(past, current), the current token
568+ // ends up at position past_seq_len. With variable seqlens_k, ORT places
569+ // the current token at seqlens_k[b] and leaves position past_seq_len as
570+ // zero/uninitialized. Using cat+scatter leaves a stale copy of the
571+ // current token at past_seq_len which doesn't match ORT's output.
572+ // Padding with zeros then scattering avoids this.
573+ //
574+ // constant_pad_nd pads innermost dims first: [0, 0, 0, seq_len]
575+ // dim 3 (head_size): [0, 0] -- no padding
576+ // dim 2 (seq): [0, seq_len] -- extend by seq_len on the right
567577 Value cstFloatZero = Torch::ConstantFloatOp::create (
568578 rewriter, loc, rewriter.getType <Torch::FloatType>(),
569579 rewriter.getF64FloatAttr (0.0 ));
580+ Value padList = Torch::PrimListConstructOp::create (
581+ rewriter, loc, intListType,
582+ SmallVector<Value>{cstIntZero, cstIntZero, cstIntZero,
583+ cstSequenceLength});
584+ Value presentKey = Torch::AtenConstantPadNdOp::create (
585+ rewriter, loc, resultTypes[1 ], pastKey, padList, cstFloatZero);
586+ Value presentValue = Torch::AtenConstantPadNdOp::create (
587+ rewriter, loc, resultTypes[2 ], pastValue, padList, cstFloatZero);
588+
589+ // Scatter current K/V into the padded buffer at position pastLen[b]+q.
590+ // pastLen = seqlens_k + 1 - seq_len
591+ Value totalSeqForScatter = Torch::AtenAddScalarOp::create (
592+ rewriter, loc, seqlensK.getType (), seqlensK, cstIntOne, cstIntOne);
593+ Value pastLen = Torch::AtenSubScalarOp::create (
594+ rewriter, loc, totalSeqForScatter.getType (), totalSeqForScatter,
595+ cstSequenceLength, cstIntOne);
596+
597+ // qRange: [0, 1, ..., seqLen-1] — shared by scatter and mask.
598+ Torch::ValueTensorType qRangeType = Torch::ValueTensorType::get (
599+ context, {sequenceLength},
600+ rewriter.getIntegerType (64 , /* isSigned=*/ true ));
601+ Value qRange = Torch::AtenArangeOp::create (
602+ rewriter, loc, qRangeType, cstSequenceLength, cstInt64Dtype,
603+ /* layout=*/ cstNone, /* device=*/ cstNone, /* pin_memory=*/ cstNone);
604+
605+ // pastLen -> [B, 1, 1, 1] for 4D scatter index broadcasting
606+ Value scatterViewList = Torch::PrimListConstructOp::create (
607+ rewriter, loc, intListType,
608+ SmallVector<Value>{cstIntMinusOne, cstIntOne, cstIntOne,
609+ cstIntOne});
610+ SmallVector<int64_t > pastLenView4dSizes{batchSize, 1 , 1 , 1 };
611+ Torch::ValueTensorType pastLenView4dType = Torch::ValueTensorType::get (
612+ context, pastLenView4dSizes,
613+ rewriter.getIntegerType (64 , /* isSigned=*/ true ));
614+ Value pastLenView4d = Torch::AtenViewOp::create (
615+ rewriter, loc, pastLenView4dType, pastLen, scatterViewList);
616+
617+ // qRange -> [1, 1, seq, 1] for scatter
618+ Value scatterQViewList = Torch::PrimListConstructOp::create (
619+ rewriter, loc, intListType,
620+ SmallVector<Value>{cstIntOne, cstIntOne, cstIntMinusOne,
621+ cstIntOne});
622+ SmallVector<int64_t > scatterQViewSizes{1 , 1 , sequenceLength, 1 };
623+ Torch::ValueTensorType scatterQViewType = Torch::ValueTensorType::get (
624+ context, scatterQViewSizes,
625+ rewriter.getIntegerType (64 , /* isSigned=*/ true ));
626+ Value scatterQRangeView = Torch::AtenViewOp::create (
627+ rewriter, loc, scatterQViewType, qRange, scatterQViewList);
628+
629+ // scatterIdxBase = pastLen[B,1,1,1] + qRange[1,1,seq,1]
630+ // -> [B, 1, seq, 1]
631+ SmallVector<int64_t > scatterIdxBaseSizes{batchSize, 1 , sequenceLength,
632+ 1 };
633+ Torch::ValueTensorType scatterIdxBaseType = Torch::ValueTensorType::get (
634+ context, scatterIdxBaseSizes,
635+ rewriter.getIntegerType (64 , /* isSigned=*/ true ));
636+ Value scatterIdxBase = Torch::AtenAddTensorOp::create (
637+ rewriter, loc, scatterIdxBaseType, pastLenView4d, scatterQRangeView,
638+ cstIntOne);
639+
640+ // Expand to [B, kv_heads, seq, head_size] to match current K/V shape
641+ SmallVector<int64_t > scatterExpandSizes{batchSize, kvNumHeads,
642+ sequenceLength, headSize};
643+ Torch::ValueTensorType scatterIdxType = Torch::ValueTensorType::get (
644+ context, scatterExpandSizes,
645+ rewriter.getIntegerType (64 , /* isSigned=*/ true ));
646+ Value scatterExpandSizeList = Torch::PrimListConstructOp::create (
647+ rewriter, loc, intListType,
648+ SmallVector<Value>{cstBatchSize, cstKVNumHeads, cstSequenceLength,
649+ cstHeadSize});
650+ Value scatterIdx = Torch::AtenExpandOp::create (
651+ rewriter, loc, scatterIdxType, scatterIdxBase,
652+ scatterExpandSizeList, /* implicit=*/ cstFalse);
653+
654+ // Scatter current K/V into buffer at position pastLen[b] + q
655+ presentKey = Torch::AtenScatterSrcOp::create (
656+ rewriter, loc, resultTypes[1 ], presentKey, cstDim2, scatterIdx,
657+ kRotary );
658+ presentValue = Torch::AtenScatterSrcOp::create (
659+ rewriter, loc, resultTypes[2 ], presentValue, cstDim2, scatterIdx,
660+ vInput);
661+
662+ // Generate causal attention mask.
663+ // With scatter, KV layout matches ORT: current at pastLen[b].
664+ // Simple boolean mask: k <= pastLen[b] + q
665+ // Mask shape: [batch, 1, seqLen, kvSeqLen] (i1).
666+ Value attnMask = cstNone;
667+
668+ // Get the KV sequence length from presentKey shape
669+ Torch::ValueTensorType presentKeyType =
670+ cast<Torch::ValueTensorType>(presentKey.getType ());
671+ if (presentKeyType.hasSizes () &&
672+ presentKeyType.getSizes ().size () == 4 ) {
673+ int64_t kvSeqLen = presentKeyType.getSizes ()[2 ];
674+
675+ // Only generate mask if KV sequence length is dynamic or > 0
676+ // For dynamic shapes or non-trivial sequences, we need to mask
677+ if (kvSeqLen == Torch::kUnknownSize || kvSeqLen > 0 ) {
678+ // Get KV sequence dimension size
679+ Value kvSeqLenVal = Torch::AtenSizeIntOp::create (
680+ rewriter, loc, rewriter.getType <Torch::IntType>(), presentKey,
681+ cstDim2);
682+
683+ // kRange: [0, 1, 2, ..., kvSeqLen-1] shape [kvSeqLen]
684+ Torch::ValueTensorType kRangeType = Torch::ValueTensorType::get (
685+ context, {kvSeqLen},
686+ rewriter.getIntegerType (64 , /* isSigned=*/ true ));
687+ Value kRange = Torch::AtenArangeOp::create (
688+ rewriter, loc, kRangeType , kvSeqLenVal, cstInt64Dtype,
689+ /* layout=*/ cstNone, /* device=*/ cstNone, /* pin_memory=*/ cstNone);
690+
691+ // Reshape for broadcasting:
692+ // pastLen: [batch] -> [batch, 1, 1]
693+ // qRange: [seqLen] -> [1, seqLen, 1] (reuses qRange from above)
694+ // kRange: [kvSeqLen] -> [1, 1, kvSeqLen]
695+
696+ // pastLen -> [batch, 1, 1]
697+ Value seqlensViewList = Torch::PrimListConstructOp::create (
698+ rewriter, loc, intListType,
699+ SmallVector<Value>{cstIntMinusOne, cstIntOne, cstIntOne});
700+ SmallVector<int64_t > seqlensViewSizes{batchSize, 1 , 1 };
701+ Torch::ValueTensorType seqlensViewType =
702+ Torch::ValueTensorType::get (
703+ context, seqlensViewSizes,
704+ rewriter.getIntegerType (64 , /* isSigned=*/ true ));
705+ Value pastLenView = Torch::AtenViewOp::create (
706+ rewriter, loc, seqlensViewType, pastLen, seqlensViewList);
707+
708+ // qRange -> [1, seqLen, 1]
709+ Value qViewList = Torch::PrimListConstructOp::create (
710+ rewriter, loc, intListType,
711+ SmallVector<Value>{cstIntOne, cstIntMinusOne, cstIntOne});
712+ SmallVector<int64_t > qViewSizes{1 , sequenceLength, 1 };
713+ Torch::ValueTensorType qViewType = Torch::ValueTensorType::get (
714+ context, qViewSizes,
715+ rewriter.getIntegerType (64 , /* isSigned=*/ true ));
716+ Value qRangeView = Torch::AtenViewOp::create (
717+ rewriter, loc, qViewType, qRange, qViewList);
718+
719+ // kRange -> [1, 1, kvSeqLen]
720+ Value kViewList = Torch::PrimListConstructOp::create (
721+ rewriter, loc, intListType,
722+ SmallVector<Value>{cstIntOne, cstIntOne, cstIntMinusOne});
723+ SmallVector<int64_t > kViewSizes {1 , 1 , kvSeqLen};
724+ Torch::ValueTensorType kViewType = Torch::ValueTensorType::get (
725+ context, kViewSizes ,
726+ rewriter.getIntegerType (64 , /* isSigned=*/ true ));
727+ Value kRangeView = Torch::AtenViewOp::create (
728+ rewriter, loc, kViewType , kRange , kViewList );
729+
730+ // Causal mask: k <= pastLen + q
731+ // pastLenView[batch,1,1] + qRangeView[1,seqLen,1]
732+ // -> [batch, seqLen, 1]
733+ SmallVector<int64_t > pastLenPlusQSizes{batchSize, sequenceLength,
734+ 1 };
735+ Torch::ValueTensorType pastLenPlusQType =
736+ Torch::ValueTensorType::get (
737+ context, pastLenPlusQSizes,
738+ rewriter.getIntegerType (64 , /* isSigned=*/ true ));
739+ Value pastLenPlusQ = Torch::AtenAddTensorOp::create (
740+ rewriter, loc, pastLenPlusQType, pastLenView, qRangeView,
741+ cstIntOne);
742+
743+ // kRangeView[1,1,kvSeqLen] <= pastLenPlusQ[batch,seqLen,1]
744+ // -> [batch, seqLen, kvSeqLen]
745+ SmallVector<int64_t > maskBoolSizes{batchSize, sequenceLength,
746+ kvSeqLen};
747+ Torch::ValueTensorType maskBoolType = Torch::ValueTensorType::get (
748+ context, maskBoolSizes, rewriter.getI1Type ());
749+ Value causalMask = Torch::AtenLeTensorOp::create (
750+ rewriter, loc, maskBoolType, kRangeView , pastLenPlusQ);
751+
752+ // Reshape to [batch, 1, seqLen, kvSeqLen] for SDPA.
753+ // Pass the boolean mask directly — downstream backends (e.g.
754+ // IREE's iree_linalg_ext.attention) handle bool-to-float
755+ // conversion internally.
756+ Value maskReshapeSizeList = Torch::PrimListConstructOp::create (
757+ rewriter, loc, intListType,
758+ SmallVector<Value>{cstBatchSize, cstIntOne, cstSequenceLength,
759+ kvSeqLenVal});
760+ SmallVector<int64_t > attnMaskSizes{batchSize, 1 , sequenceLength,
761+ kvSeqLen};
762+ Torch::ValueTensorType attnMaskType = Torch::ValueTensorType::get (
763+ context, attnMaskSizes, rewriter.getI1Type ());
764+ attnMask = Torch::AtenReshapeOp::create (
765+ rewriter, loc, attnMaskType, causalMask, maskReshapeSizeList);
766+ }
767+ }
768+
769+ // Do attention with full KV cache (past + current) and mask.
770+ Value cstEnableGQA = Torch::ConstantBoolOp::create (rewriter, loc, true );
570771 Value cstScale = cstNone;
571772 if (scale != 0 .0f )
572773 cstScale = Torch::ConstantFloatOp::create (
573774 rewriter, loc, rewriter.getType <Torch::FloatType>(),
574775 rewriter.getF64FloatAttr (scale));
776+
777+ // Use presentKey/presentValue (full KV cache) for attention, not just
778+ // the current token's K/V. This is essential for proper KV caching.
575779 Value attention = Torch::AtenScaledDotProductAttentionOp::create (
576- rewriter, loc, qRotary.getType (), qRotary, kRotary , vInput ,
577- /* attn_mask=*/ cstNone ,
780+ rewriter, loc, qRotary.getType (), qRotary, presentKey, presentValue ,
781+ /* attn_mask=*/ attnMask ,
578782 /* dropout_p=*/ cstFloatZero, /* is_causal=*/ cstFalse, cstScale,
579783 cstEnableGQA);
580784
@@ -607,40 +811,6 @@ void mlir::torch::onnx_c::populateComMicrosoftDomain(
607811 attnTransposed,
608812 attentionResultSizesList);
609813
610- // Compute 2nd and 3rd result: present_key, present_value.
611- // present_key = torch.cat([past_key, key], dim=2) or past_key
612- // present_value = torch.cat([past_value, value], dim=2) or past_value
613- Value presentKey = pastKey, presentValue = pastValue;
614- if (!llvm::equal (
615- cast<Torch::ValueTensorType>(pastKey.getType ()).getSizes (),
616- cast<Torch::ValueTensorType>(resultTypes[1 ]).getSizes ())) {
617- Value cstConcatDim = Torch::ConstantIntOp::create (
618- rewriter, loc, rewriter.getI64IntegerAttr (2 ));
619- Type kvListElemType = keyType.getWithSizesAndDtype (
620- /* optionalSizes=*/ std::nullopt ,
621- /* optionalDtype=*/ nullptr );
622- Type kvListType = Torch::ListType::get (kvListElemType);
623- Value keyList = Torch::PrimListConstructOp::create (
624- rewriter, loc, kvListType, SmallVector<Value>{pastKey, kRotary });
625- presentKey = Torch::AtenCatOp::create (rewriter, loc, resultTypes[1 ],
626- keyList, cstConcatDim);
627- }
628-
629- if (!llvm::equal (
630- cast<Torch::ValueTensorType>(pastValue.getType ()).getSizes (),
631- cast<Torch::ValueTensorType>(resultTypes[2 ]).getSizes ())) {
632- Value cstConcatDim = Torch::ConstantIntOp::create (
633- rewriter, loc, rewriter.getI64IntegerAttr (2 ));
634- Type kvListElemType = keyType.getWithSizesAndDtype (
635- /* optionalSizes=*/ std::nullopt ,
636- /* optionalDtype=*/ nullptr );
637- Type kvListType = Torch::ListType::get (kvListElemType);
638- Value valueList = Torch::PrimListConstructOp::create (
639- rewriter, loc, kvListType, SmallVector<Value>{pastValue, vInput});
640- presentValue = Torch::AtenCatOp::create (rewriter, loc, resultTypes[2 ],
641- valueList, cstConcatDim);
642- }
643-
644814 rewriter.replaceOp (binder.op , {attention, presentKey, presentValue});
645815 return success ();
646816 });
0 commit comments