@@ -605,9 +605,10 @@ void fillTDMDescriptor(
605605 }
606606}
607607
608- // Fill TDM descriptor for scatter operation (2D only).
609- // Scatter writes data from LDS to non-contiguous rows in global memory.
610- void fillTDMDescriptorForScatter (
608+ // Fill TDM descriptor for gather/scatter operations (2D only).
609+ // Gather reads from non-contiguous rows in global memory to LDS.
610+ // Scatter writes from LDS to non-contiguous rows in global memory.
611+ void fillTDMDescriptorForGatherScatter (
611612 RewriterBase &rewriter, Location loc,
612613 const LLVMTypeConverter *typeConverter, Type elementType,
613614 SmallVector<int64_t > blockShape, SmallVector<Value> &group0,
@@ -616,7 +617,7 @@ void fillTDMDescriptorForScatter(
616617 Value ldsPtr, Value pred, Value barrierPtr,
617618 const triton::LinearLayout &cgaLayout, Value ctaId,
618619 ArrayRef<Value> rowIndices, bool use32BitIndices) {
619- assert (!rowIndices.empty () && " Scatter requires row indices." );
620+ assert (!rowIndices.empty () && " Gather/scatter requires row indices." );
620621
621622 auto ctx = rewriter.getContext ();
622623 auto b = TritonLLVMOpBuilder (loc, rewriter);
@@ -649,17 +650,17 @@ void fillTDMDescriptorForScatter(
649650 Value ldsOffset = b.mul (ldsRowOffset, b.i32_val (blockShape[1 ]));
650651 ldsPtr = b.gep (sharedPtrTy, elementType, ldsPtr, ldsOffset);
651652
652- // Update group0 with addresses and enable scatter
653+ // Update group0 with addresses and enable gather/ scatter mode
653654 Value globalAddr = b.ptrtoint (i64_ty, globalPtr);
654655 Value ldsAddr = b.ptrtoint (i32_ty, ldsPtr);
655656
656- // Set scatter bits: bit 31 = enable, bit 30 = 32-bit indices
657- Value predWithScatter = b.or_ (pred, b.i32_val (1 << 31 ));
657+ // Set gather/ scatter bits: bit 31 = enable, bit 30 = 32-bit indices
658+ Value predWithGatherScatter = b.or_ (pred, b.i32_val (1 << 31 ));
658659 if (use32BitIndices) {
659- predWithScatter = b.or_ (predWithScatter , b.i32_val (1 << 30 ));
660+ predWithGatherScatter = b.or_ (predWithGatherScatter , b.i32_val (1 << 30 ));
660661 }
661662
662- group0[0 ] = predWithScatter ;
663+ group0[0 ] = predWithGatherScatter ;
663664 group0[1 ] = ldsAddr;
664665 group0[2 ] = b.trunc (i32_ty, globalAddr);
665666
@@ -784,28 +785,29 @@ void emitTDMLoadStore(RewriterBase &rewriter, Location loc,
784785 }
785786}
786787
787- // Emit a TDM scatter operation to write non-contiguous rows from LDS to global.
788- void emitTDMScatter (RewriterBase &rewriter, Location loc,
789- const LLVMTypeConverter *typeConverter,
790- ArrayRef<Value> desc, ArrayRef<int64_t > blockShape,
791- Value srcPtr, Value pred, Type elementType,
792- Value barrierPtr, const triton::LinearLayout &cgaLayout,
793- Value ctaId, ArrayRef<Value> rowIndices, Value colOffset,
794- bool use32BitIndices) {
788+ // Emit a TDM gather or scatter operation for non-contiguous row access.
789+ void emitTDMGatherScatter (RewriterBase &rewriter, Location loc,
790+ const LLVMTypeConverter *typeConverter,
791+ ArrayRef<Value> desc, ArrayRef<int64_t > blockShape,
792+ Value ldsPtr, Value pred, Type elementType,
793+ Value barrierPtr,
794+ const triton::LinearLayout &cgaLayout, Value ctaId,
795+ ArrayRef<Value> rowIndices, Value colOffset,
796+ bool use32BitIndices, bool isGather) {
795797 auto b = TritonLLVMOpBuilder (loc, rewriter);
796798
797- assert (!rowIndices.empty () && " Scatter requires row indices" );
798- assert (colOffset && " Scatter requires column offset" );
799+ assert (!rowIndices.empty () && " Gather/scatter requires row indices" );
800+ assert (colOffset && " Gather/scatter requires column offset" );
799801
800802 // Determine max indices per instruction based on index size
801803 size_t maxIndicesPerInstr = use32BitIndices ? 8 : 16 ;
802804 size_t numIndices = rowIndices.size ();
803805
804- // Get the descriptor groups (scatter uses 2D format: 12 dwords)
806+ // Get the descriptor groups (gather/ scatter uses 2D format: 12 dwords)
805807 auto group0Vec = SmallVector<Value>(desc.begin (), desc.begin () + 4 );
806808 auto group1Vec = SmallVector<Value>(desc.begin () + 4 , desc.end ());
807809
808- // For TDM scatter, we need group2 and group3 for indices
810+ // For TDM gather/ scatter, we need group2 and group3 for indices
809811 SmallVector<Value> group2Vec (4 , b.i32_val (0 ));
810812 SmallVector<Value> group3Vec (4 , b.i32_val (0 ));
811813
@@ -824,12 +826,12 @@ void emitTDMScatter(RewriterBase &rewriter, Location loc,
824826 auto g2 = group2Vec;
825827 auto g3 = group3Vec;
826828
827- // Fill the descriptor for scatter:
829+ // Fill the descriptor for gather/ scatter:
828830 // - ldsRowOffset: row offset within shared memory for this batch
829831 // - colOffset: starting column in global memory
830- fillTDMDescriptorForScatter (
832+ fillTDMDescriptorForGatherScatter (
831833 rewriter, loc, typeConverter, elementType, to_vector (blockShape), g0,
832- g1, g2, g3, b.i32_val (startIdx), colOffset, srcPtr , pred, barrierPtr,
834+ g1, g2, g3, b.i32_val (startIdx), colOffset, ldsPtr , pred, barrierPtr,
833835 cgaLayout, ctaId, batchIndices, use32BitIndices);
834836
835837 // Pack and emit the instruction
@@ -838,10 +840,13 @@ void emitTDMScatter(RewriterBase &rewriter, Location loc,
838840 auto group2 = packLLVector (loc, g2, rewriter);
839841 auto group3 = packLLVector (loc, g3, rewriter);
840842
841- // Scatter uses tensor.store.from.lds (not the d2 variant) because it
842- // needs group2/group3 for indices
843+ // Gather/scatter uses full 4-group format (not the d2 variant) for indices
844+ // Gather: tensor.load.to.lds (global -> LDS)
845+ // Scatter: tensor.store.from.lds (LDS -> global)
846+ const char *intrinsicName = isGather ? " llvm.amdgcn.tensor.load.to.lds"
847+ : " llvm.amdgcn.tensor.store.from.lds" ;
843848 LLVM::createLLVMIntrinsicCallOp (
844- rewriter, loc, " llvm.amdgcn.tensor.store.from.lds " , {},
849+ rewriter, loc, intrinsicName , {},
845850 {group0, group1, group2, group3, b.i32_val (0 )});
846851 }
847852}
0 commit comments