Skip to content

Commit 72d0d90

Browse files
authored
[AMD][gluon][gfx1250] Add tensor async gather support using TDM (#9313)
Implements tensor async_gather using TDM in a similar fashion to #9299 on Gluon.
1 parent 395752c commit 72d0d90

8 files changed

Lines changed: 505 additions & 54 deletions

File tree

python/src/gluon_ir.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -981,6 +981,12 @@ void init_gluon_ir(py::module &&m) {
981981
self.create<ttag::AsyncTDMScatterOp>(descPtr, dstRowIndices,
982982
dstColOffset, src, barrier);
983983
})
984+
.def("create_async_tdm_gather",
985+
[](GluonOpBuilder &self, Value descPtr, Value srcRowIndices,
986+
Value srcColOffset, Value dst, Value barrier) {
987+
self.create<ttag::AsyncTDMGatherOp>(descPtr, srcRowIndices,
988+
srcColOffset, dst, barrier);
989+
})
984990
.def("create_tdm_prefetch",
985991
[](GluonOpBuilder &self, Value descPtr, std::vector<Value> &indices,
986992
Value pred, bool speculative, bool returnOffsets) -> Value {

python/triton/experimental/gluon/language/amd/gfx1250/tdm.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,44 @@ def async_scatter(desc: tensor_descriptor, dst_row_indices: ttgl.tensor, dst_col
215215
mbarrier_handle)
216216

217217

218+
@builtin
219+
def async_gather(desc: tensor_descriptor, src_row_indices: ttgl.tensor, src_col_offset, dst: shared_memory_descriptor,
220+
mbarrier: shared_memory_descriptor = None, _semantic=None) -> None:
221+
"""Gather data from non-contiguous rows in global memory to shared memory asynchronously.
222+
223+
This operation uses TDM gather mode to read data from non-contiguous rows in global memory.
224+
Unlike async_load which reads from contiguous rows, gather allows reading from arbitrary
225+
rows specified by the src_row_indices tensor.
226+
227+
The dtype of src_row_indices determines the index size:
228+
- int16: up to 16 rows can be gathered per TDM instruction
229+
- int32: up to 8 rows can be gathered per TDM instruction
230+
If more rows are needed, multiple TDM instructions will be automatically issued.
231+
232+
Args:
233+
desc (tensor_descriptor): the source tensor descriptor. Must be 2D.
234+
src_row_indices (tensor): 1D tensor of row indices (int16 or int32) in the source tensor.
235+
src_col_offset (int or tensor): the starting column offset in the source tensor
236+
for all gathered rows.
237+
dst (shared_memory_descriptor): the shared memory destination to store gathered data. Must be 2D.
238+
mbarrier (shared_memory_descriptor, optional): The barrier object to signal "arrive" on.
239+
"""
240+
ndim = len(desc.block_shape)
241+
assert ndim == 2, f"TDM gather only supports 2D tensors, got {ndim}D"
242+
243+
dst_ndim = len(dst.shape)
244+
assert dst_ndim == 2, f"TDM gather dst must be 2D, got {dst_ndim}D"
245+
246+
# Convert src_col_offset to i32
247+
src_col_offset_handle = _semantic._convert_to_ir_values([src_col_offset], require_i64=False)[0]
248+
249+
mbarrier = _unwrap_if_constexpr(mbarrier)
250+
mbarrier_handle = mbarrier.handle if mbarrier is not None else ttgl.ir.value()
251+
252+
_semantic.builder.create_async_tdm_gather(desc.handle, src_row_indices.handle, src_col_offset_handle, dst.handle,
253+
mbarrier_handle)
254+
255+
218256
@builtin
219257
def prefetch(src: tensor_descriptor, offsets: List[ttgl.constexpr | ttgl.tensor], pred: bool = True,
220258
speculative: bool = False, _semantic=None) -> None:

third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -870,6 +870,45 @@ def AsyncTDMScatterOp : TT_AMDGPU_Op<"async_tdm_scatter"> {
870870
let hasVerifier = 1;
871871
}
872872

873+
//===----------------------------------------------------------------------===//
874+
// AsyncTDMGatherOp
875+
//===----------------------------------------------------------------------===//
876+
877+
def AsyncTDMGatherOp : TT_AMDGPU_Op<"async_tdm_gather"> {
878+
let summary = "Gather data from non-contiguous global memory rows to local memory asynchronously";
879+
880+
let description = [{
881+
This operation gathers data from non-contiguous rows in global memory to local
882+
memory using TDM gather mode.
883+
Unlike the regular async_tdm_copy_global_to_local which reads from contiguous memory,
884+
this operation uses src_row_indices to specify which rows in global memory to read from.
885+
886+
The descriptor must be 2D. The src_row_indices specify which rows in global memory
887+
to read from. The element type of src_row_indices determines the index size:
888+
- I16: 16-bit indices, up to 16 rows per instruction
889+
- I32: 32-bit indices, up to 8 rows per instruction
890+
If more rows are needed, multiple TDM instructions will be issued.
891+
892+
The src_col_offset specifies the starting column in the source tensor for
893+
all gathered rows.
894+
}];
895+
896+
let arguments = (ins
897+
Arg<TT_TensorDescType, "", [MemRead<GlobalMemory>]>:$desc,
898+
TensorOf<[I16, I32]>:$src_row_indices,
899+
I32:$src_col_offset,
900+
Arg<TTG_MemDescType, "", [MemWrite<SharedMemory>]>:$dst,
901+
Optional<TTG_MemDescType>:$barrier
902+
);
903+
904+
let assemblyFormat = [{
905+
$desc `[` $src_row_indices `,` $src_col_offset `]` `to` $dst (`,` `barrier` `=` $barrier^)?
906+
attr-dict `:` qualified(type($src_row_indices)) `,` qualified(type($dst)) (`,` qualified(type($barrier))^)? `->` qualified(type($desc))
907+
}];
908+
909+
let hasVerifier = 1;
910+
}
911+
873912
//===----------------------------------------------------------------------===//
874913
// AsyncTDMWait
875914
//===----------------------------------------------------------------------===//

third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -812,6 +812,49 @@ LogicalResult AsyncTDMScatterOp::verify() {
812812
return success();
813813
}
814814

815+
LogicalResult AsyncTDMGatherOp::verify() {
816+
auto tensorDescTy = getDesc().getType();
817+
auto smemTy = getDst().getType();
818+
819+
// TDM gather mode only supports 2D tensors
820+
auto blockShape = tensorDescTy.getBlockType().getShape();
821+
if (blockShape.size() != 2)
822+
return emitOpError("TDM gather only supports 2D tensors, got ")
823+
<< blockShape.size() << "D";
824+
825+
// Check that every dimension of the block shape is <= 2^16
826+
auto verifyResult = verifyTDMBlockSize(getOperation(), blockShape);
827+
if (failed(verifyResult))
828+
return verifyResult;
829+
830+
auto srcRowIndicesType = cast<RankedTensorType>(getSrcRowIndices().getType());
831+
if (srcRowIndicesType.getRank() != 1)
832+
return emitOpError("src_row_indices must be a 1D tensor");
833+
834+
// Element type (i16 or i32) is already verified by ODS constraint
835+
// TensorOf<[I16, I32]>
836+
837+
int64_t numIndices = srcRowIndicesType.getShape()[0];
838+
if (!llvm::isPowerOf2_64(numIndices))
839+
return emitOpError("src_row_indices size must be a power of 2, got ")
840+
<< numIndices;
841+
842+
auto swizzledEnc =
843+
llvm::dyn_cast<gpu::SwizzledSharedEncodingAttr>(smemTy.getEncoding());
844+
if (swizzledEnc && swizzledEnc.getMaxPhase() != 1)
845+
return emitOpError("TDM does not support swizzling");
846+
847+
auto paddedEnc =
848+
llvm::dyn_cast<gpu::PaddedSharedEncodingAttr>(smemTy.getEncoding());
849+
if (paddedEnc)
850+
return emitOpError("TDM gather does not support padding");
851+
852+
if (!paddedEnc && !swizzledEnc)
853+
return emitOpError("Invalid shared memory layout for TDM");
854+
855+
return success();
856+
}
857+
815858
// -- InitBarrierOp --
816859
LogicalResult InitBarrierOp::verify() {
817860
if (failed(verifyBarrierType(*this, getAlloc().getType())))

third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 91 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1384,10 +1384,89 @@ struct AsyncTDMScatterOpConversion
13841384

13851385
// Predicate must be i32 (not i1) to match other elements in group0
13861386
Value pred = arith::ConstantIntOp::create(rewriter, loc, 1, 32);
1387-
mlir::LLVM::AMD::emitTDMScatter(rewriter, loc, getTypeConverter(), desc,
1388-
shapePerCTA, srcPtr, pred, elementType,
1389-
barrierPtr, cgaLayout, ctaId, dstRowIndices,
1390-
dstColOffset, use32BitIndices);
1387+
mlir::LLVM::AMD::emitTDMGatherScatter(
1388+
rewriter, loc, getTypeConverter(), desc, shapePerCTA, srcPtr, pred,
1389+
elementType, barrierPtr, cgaLayout, ctaId, dstRowIndices, dstColOffset,
1390+
use32BitIndices, /*isGather=*/false);
1391+
1392+
rewriter.eraseOp(op);
1393+
return success();
1394+
}
1395+
};
1396+
1397+
struct AsyncTDMGatherOpConversion
1398+
: public ConvertOpToLLVMPattern<triton::amdgpu::AsyncTDMGatherOp>,
1399+
public LoadStoreConversionBase {
1400+
AsyncTDMGatherOpConversion(LLVMTypeConverter &converter,
1401+
const AMD::TargetInfo &targetInfo,
1402+
ModuleAxisInfoAnalysis &axisAnalysisPass,
1403+
PatternBenefit benefit)
1404+
: ConvertOpToLLVMPattern(converter, benefit),
1405+
LoadStoreConversionBase(targetInfo, axisAnalysisPass) {}
1406+
1407+
LogicalResult
1408+
matchAndRewrite(triton::amdgpu::AsyncTDMGatherOp op, OpAdaptor adaptor,
1409+
ConversionPatternRewriter &rewriter) const override {
1410+
auto loc = op.getLoc();
1411+
auto b = TritonLLVMOpBuilder(loc, rewriter);
1412+
1413+
auto tensorDescTy = op.getDesc().getType();
1414+
auto smemTy = op.getDst().getType();
1415+
Type elementType = getTypeConverter()->convertType(smemTy.getElementType());
1416+
1417+
SmallVector<Value> desc =
1418+
unpackLLElements(loc, adaptor.getDesc(), rewriter);
1419+
1420+
SmallVector<int64_t> blockShape =
1421+
llvm::to_vector(tensorDescTy.getBlockType().getShape());
1422+
1423+
// Gather only supports 2D tensors
1424+
assert(blockShape.size() == 2 &&
1425+
"TDM gather mode only supports 2D tensors");
1426+
1427+
auto dstMemObj = LLVM::getSharedMemoryObjectFromStruct(
1428+
loc, adaptor.getDst(), elementType, rewriter);
1429+
Value dstPtr = dstMemObj.getBase();
1430+
int numWarps = triton::gpu::lookupNumWarps(op);
1431+
1432+
Value barrierPtr = nullptr;
1433+
if (op.getBarrier()) {
1434+
auto smemObj = LLVM::getSharedMemoryObjectFromStruct(
1435+
loc, adaptor.getBarrier(),
1436+
typeConverter->convertType(
1437+
op.getBarrier().getType().getElementType()),
1438+
rewriter);
1439+
barrierPtr = smemObj.getBase();
1440+
}
1441+
1442+
// Get the source row indices for gather
1443+
SmallVector<Value> srcRowIndices =
1444+
unpackLLElements(loc, adaptor.getSrcRowIndices(), rewriter);
1445+
1446+
auto shapePerCTA = triton::gpu::getShapePerCTA(smemTy);
1447+
1448+
// Get the source column offset
1449+
Value srcColOffset = adaptor.getSrcColOffset();
1450+
1451+
// Determine index size from the element type of src_row_indices
1452+
auto srcRowIndicesType =
1453+
cast<RankedTensorType>(op.getSrcRowIndices().getType());
1454+
bool use32BitIndices =
1455+
srcRowIndicesType.getElementType().getIntOrFloatBitWidth() == 32;
1456+
1457+
// Create the CGA layout
1458+
auto sharedLayout = triton::gpu::toLinearLayout(smemTy);
1459+
auto kBlock = rewriter.getStringAttr("block");
1460+
auto cgaLayout = sharedLayout.sublayout(
1461+
{kBlock}, to_vector(sharedLayout.getOutDimNames()));
1462+
auto ctaId = targetInfo.getClusterCTAId(rewriter, loc);
1463+
1464+
// Predicate must be i32 (not i1) to match other elements in group0
1465+
Value pred = arith::ConstantIntOp::create(rewriter, loc, 1, 32);
1466+
mlir::LLVM::AMD::emitTDMGatherScatter(
1467+
rewriter, loc, getTypeConverter(), desc, shapePerCTA, dstPtr, pred,
1468+
elementType, barrierPtr, cgaLayout, ctaId, srcRowIndices, srcColOffset,
1469+
use32BitIndices, /*isGather=*/true);
13911470

13921471
rewriter.eraseOp(op);
13931472
return success();
@@ -2320,13 +2399,14 @@ void populateLoadStoreOpToLLVMPatterns(LLVMTypeConverter &typeConverter,
23202399
RewritePatternSet &patterns,
23212400
ModuleAxisInfoAnalysis &axisInfoAnalysis,
23222401
PatternBenefit benefit) {
2323-
patterns.add<
2324-
AtomicCASOpConversion, AtomicRMWOpConversion, LoadOpConversion,
2325-
StoreOpConversion, BufferLoadOpConversion, BufferLoadToLocalOpConversion,
2326-
BufferStoreOpConversion, BufferAtomicRMWOpConversion,
2327-
AsyncCopyGlobalToLocalOpConversion, AsyncCopyLocalToGlobalOpConversion,
2328-
BufferAtomicCASOpConversion, AsyncTDMCopyGlobalToLocalOpConversion,
2329-
AsyncTDMCopyLocalToGlobalOpConversion, AsyncTDMScatterOpConversion>(
2402+
patterns.add<AtomicCASOpConversion, AtomicRMWOpConversion, LoadOpConversion,
2403+
StoreOpConversion, BufferLoadOpConversion,
2404+
BufferLoadToLocalOpConversion, BufferStoreOpConversion,
2405+
BufferAtomicRMWOpConversion, AsyncCopyGlobalToLocalOpConversion,
2406+
AsyncCopyLocalToGlobalOpConversion, BufferAtomicCASOpConversion,
2407+
AsyncTDMCopyGlobalToLocalOpConversion,
2408+
AsyncTDMCopyLocalToGlobalOpConversion,
2409+
AsyncTDMScatterOpConversion, AsyncTDMGatherOpConversion>(
23302410
typeConverter, targetInfo, axisInfoAnalysis, benefit);
23312411
patterns.add<AsyncWaitOpConversion>(typeConverter, targetInfo, benefit);
23322412
patterns.add<TDMPrefetchConversion>(typeConverter, targetInfo, benefit);

third_party/amd/lib/TritonAMDGPUToLLVM/TDMUtility.cpp

Lines changed: 32 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -605,9 +605,10 @@ void fillTDMDescriptor(
605605
}
606606
}
607607

608-
// Fill TDM descriptor for scatter operation (2D only).
609-
// Scatter writes data from LDS to non-contiguous rows in global memory.
610-
void fillTDMDescriptorForScatter(
608+
// Fill TDM descriptor for gather/scatter operations (2D only).
609+
// Gather reads from non-contiguous rows in global memory to LDS.
610+
// Scatter writes from LDS to non-contiguous rows in global memory.
611+
void fillTDMDescriptorForGatherScatter(
611612
RewriterBase &rewriter, Location loc,
612613
const LLVMTypeConverter *typeConverter, Type elementType,
613614
SmallVector<int64_t> blockShape, SmallVector<Value> &group0,
@@ -616,7 +617,7 @@ void fillTDMDescriptorForScatter(
616617
Value ldsPtr, Value pred, Value barrierPtr,
617618
const triton::LinearLayout &cgaLayout, Value ctaId,
618619
ArrayRef<Value> rowIndices, bool use32BitIndices) {
619-
assert(!rowIndices.empty() && "Scatter requires row indices.");
620+
assert(!rowIndices.empty() && "Gather/scatter requires row indices.");
620621

621622
auto ctx = rewriter.getContext();
622623
auto b = TritonLLVMOpBuilder(loc, rewriter);
@@ -649,17 +650,17 @@ void fillTDMDescriptorForScatter(
649650
Value ldsOffset = b.mul(ldsRowOffset, b.i32_val(blockShape[1]));
650651
ldsPtr = b.gep(sharedPtrTy, elementType, ldsPtr, ldsOffset);
651652

652-
// Update group0 with addresses and enable scatter
653+
// Update group0 with addresses and enable gather/scatter mode
653654
Value globalAddr = b.ptrtoint(i64_ty, globalPtr);
654655
Value ldsAddr = b.ptrtoint(i32_ty, ldsPtr);
655656

656-
// Set scatter bits: bit 31 = enable, bit 30 = 32-bit indices
657-
Value predWithScatter = b.or_(pred, b.i32_val(1 << 31));
657+
// Set gather/scatter bits: bit 31 = enable, bit 30 = 32-bit indices
658+
Value predWithGatherScatter = b.or_(pred, b.i32_val(1 << 31));
658659
if (use32BitIndices) {
659-
predWithScatter = b.or_(predWithScatter, b.i32_val(1 << 30));
660+
predWithGatherScatter = b.or_(predWithGatherScatter, b.i32_val(1 << 30));
660661
}
661662

662-
group0[0] = predWithScatter;
663+
group0[0] = predWithGatherScatter;
663664
group0[1] = ldsAddr;
664665
group0[2] = b.trunc(i32_ty, globalAddr);
665666

@@ -784,28 +785,29 @@ void emitTDMLoadStore(RewriterBase &rewriter, Location loc,
784785
}
785786
}
786787

787-
// Emit a TDM scatter operation to write non-contiguous rows from LDS to global.
788-
void emitTDMScatter(RewriterBase &rewriter, Location loc,
789-
const LLVMTypeConverter *typeConverter,
790-
ArrayRef<Value> desc, ArrayRef<int64_t> blockShape,
791-
Value srcPtr, Value pred, Type elementType,
792-
Value barrierPtr, const triton::LinearLayout &cgaLayout,
793-
Value ctaId, ArrayRef<Value> rowIndices, Value colOffset,
794-
bool use32BitIndices) {
788+
// Emit a TDM gather or scatter operation for non-contiguous row access.
789+
void emitTDMGatherScatter(RewriterBase &rewriter, Location loc,
790+
const LLVMTypeConverter *typeConverter,
791+
ArrayRef<Value> desc, ArrayRef<int64_t> blockShape,
792+
Value ldsPtr, Value pred, Type elementType,
793+
Value barrierPtr,
794+
const triton::LinearLayout &cgaLayout, Value ctaId,
795+
ArrayRef<Value> rowIndices, Value colOffset,
796+
bool use32BitIndices, bool isGather) {
795797
auto b = TritonLLVMOpBuilder(loc, rewriter);
796798

797-
assert(!rowIndices.empty() && "Scatter requires row indices");
798-
assert(colOffset && "Scatter requires column offset");
799+
assert(!rowIndices.empty() && "Gather/scatter requires row indices");
800+
assert(colOffset && "Gather/scatter requires column offset");
799801

800802
// Determine max indices per instruction based on index size
801803
size_t maxIndicesPerInstr = use32BitIndices ? 8 : 16;
802804
size_t numIndices = rowIndices.size();
803805

804-
// Get the descriptor groups (scatter uses 2D format: 12 dwords)
806+
// Get the descriptor groups (gather/scatter uses 2D format: 12 dwords)
805807
auto group0Vec = SmallVector<Value>(desc.begin(), desc.begin() + 4);
806808
auto group1Vec = SmallVector<Value>(desc.begin() + 4, desc.end());
807809

808-
// For TDM scatter, we need group2 and group3 for indices
810+
// For TDM gather/scatter, we need group2 and group3 for indices
809811
SmallVector<Value> group2Vec(4, b.i32_val(0));
810812
SmallVector<Value> group3Vec(4, b.i32_val(0));
811813

@@ -824,12 +826,12 @@ void emitTDMScatter(RewriterBase &rewriter, Location loc,
824826
auto g2 = group2Vec;
825827
auto g3 = group3Vec;
826828

827-
// Fill the descriptor for scatter:
829+
// Fill the descriptor for gather/scatter:
828830
// - ldsRowOffset: row offset within shared memory for this batch
829831
// - colOffset: starting column in global memory
830-
fillTDMDescriptorForScatter(
832+
fillTDMDescriptorForGatherScatter(
831833
rewriter, loc, typeConverter, elementType, to_vector(blockShape), g0,
832-
g1, g2, g3, b.i32_val(startIdx), colOffset, srcPtr, pred, barrierPtr,
834+
g1, g2, g3, b.i32_val(startIdx), colOffset, ldsPtr, pred, barrierPtr,
833835
cgaLayout, ctaId, batchIndices, use32BitIndices);
834836

835837
// Pack and emit the instruction
@@ -838,10 +840,13 @@ void emitTDMScatter(RewriterBase &rewriter, Location loc,
838840
auto group2 = packLLVector(loc, g2, rewriter);
839841
auto group3 = packLLVector(loc, g3, rewriter);
840842

841-
// Scatter uses tensor.store.from.lds (not the d2 variant) because it
842-
// needs group2/group3 for indices
843+
// Gather/scatter uses full 4-group format (not the d2 variant) for indices
844+
// Gather: tensor.load.to.lds (global -> LDS)
845+
// Scatter: tensor.store.from.lds (LDS -> global)
846+
const char *intrinsicName = isGather ? "llvm.amdgcn.tensor.load.to.lds"
847+
: "llvm.amdgcn.tensor.store.from.lds";
843848
LLVM::createLLVMIntrinsicCallOp(
844-
rewriter, loc, "llvm.amdgcn.tensor.store.from.lds", {},
849+
rewriter, loc, intrinsicName, {},
845850
{group0, group1, group2, group3, b.i32_val(0)});
846851
}
847852
}

0 commit comments

Comments
 (0)