Skip to content

Commit 75b8c66

Browse files
committed
[AMD][TDM] Clarify activeWarps=0 sentinel in comments
activeWarps=0 exclusively means "warp_bases absent, all warps active." When warp_bases is present, activeWarps is at least 1 (2^0 for all-zero rows). This distinction matters for understanding the conditional logic in fillTDMDescriptor and emitTDMLoadStore.
1 parent 8f34e7f commit 75b8c66

2 files changed

Lines changed: 24 additions & 23 deletions

File tree

third_party/amd/lib/TritonAMDGPUToLLVM/TDMUtility.cpp

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -534,7 +534,8 @@ swapOutDimSemantics(const triton::LinearLayout &layout, StringAttr dimA,
534534
// Fill TDM descriptor for regular load/store operations (1D-5D tensors).
535535
// activeWarps: number of warps that actually issue TDM copies (power of two,
536536
// <= numWarps). Warps with warpId >= activeWarps get pred=0 (hardware no-op).
537-
// A value of 0 means all warps are active (no partial TDM copy).
537+
// 0 is the sentinel for "warp_bases absent, all warps active"; when warp_bases
538+
// is present, activeWarps is at least 1.
538539
void fillTDMDescriptor(
539540
RewriterBase &rewriter, Location loc,
540541
const LLVMTypeConverter *typeConverter, Type elementType,
@@ -601,9 +602,8 @@ void fillTDMDescriptor(
601602
if (numDims >= 2) {
602603
// tile_dim1: lower 16 bits of group1[4]
603604
group1[4] = b.and_(group1[4], b.i32_val(0xFFFF0000));
604-
group1[4] = b.or_(group1[4],
605-
b.and_(decodedBlockShape[numDims - 2],
606-
b.i32_val(0xFFFF)));
605+
group1[4] = b.or_(
606+
group1[4], b.and_(decodedBlockShape[numDims - 2], b.i32_val(0xFFFF)));
607607
}
608608
if (numDims >= 3) {
609609
// tile_dim2: upper 16 bits of group1[4]
@@ -1000,17 +1000,15 @@ static int64_t computePerPartitionSliceStride(
10001000

10011001
// Emit a single TDM intrinsic (load or store) for the given block shape.
10021002
// This handles both the 2D (d2 intrinsic) and >2D (full intrinsic) cases.
1003-
static void
1004-
emitTDMIntrinsic(RewriterBase &rewriter, Location loc,
1005-
const LLVMTypeConverter *typeConverter, ArrayRef<Value> desc,
1006-
size_t numDims, Type elementType,
1007-
SmallVector<int64_t> effectiveBlockShape, int numWarps,
1008-
unsigned padInterval, unsigned padAmount,
1009-
SmallVector<Value> globalOffset, ArrayRef<Value> instrDstPtrs,
1010-
Value pred, Value multicastMask, Value barrier,
1011-
const triton::LinearLayout &instrSharedLayout, Value ctaId,
1012-
bool isLoad, bool isRowMajor, ArrayRef<unsigned> warpsPerCTA,
1013-
int activeWarps = 0) {
1003+
static void emitTDMIntrinsic(
1004+
RewriterBase &rewriter, Location loc,
1005+
const LLVMTypeConverter *typeConverter, ArrayRef<Value> desc,
1006+
size_t numDims, Type elementType, SmallVector<int64_t> effectiveBlockShape,
1007+
int numWarps, unsigned padInterval, unsigned padAmount,
1008+
SmallVector<Value> globalOffset, ArrayRef<Value> instrDstPtrs, Value pred,
1009+
Value multicastMask, Value barrier,
1010+
const triton::LinearLayout &instrSharedLayout, Value ctaId, bool isLoad,
1011+
bool isRowMajor, ArrayRef<unsigned> warpsPerCTA, int activeWarps = 0) {
10141012
auto b = TritonLLVMOpBuilder(loc, rewriter);
10151013
auto v8i32Ty = VectorType::get(8, rewriter.getI32Type());
10161014
Value group4Zero = LLVM::ZeroOp::create(rewriter, loc, v8i32Ty);
@@ -1042,12 +1040,12 @@ emitTDMIntrinsic(RewriterBase &rewriter, Location loc,
10421040
auto group0Vec = SmallVector<Value>(desc.begin(), desc.begin() + 4);
10431041
auto group1Vec = SmallVector<Value>(desc.begin() + 4, desc.end());
10441042

1045-
fillTDMDescriptor(
1046-
rewriter, loc, typeConverter, elementType, effectiveBlockShape,
1047-
numWarps, padInterval, padAmount, group0Vec, group1Vec, std::nullopt,
1048-
std::nullopt, globalOffset, instrDstPtrs, pred, multicastMask, barrier,
1049-
instrSharedLayout, ctaId, !isLoad, isRowMajor, warpsPerCTA,
1050-
activeWarps);
1043+
fillTDMDescriptor(rewriter, loc, typeConverter, elementType,
1044+
effectiveBlockShape, numWarps, padInterval, padAmount,
1045+
group0Vec, group1Vec, std::nullopt, std::nullopt,
1046+
globalOffset, instrDstPtrs, pred, multicastMask, barrier,
1047+
instrSharedLayout, ctaId, !isLoad, isRowMajor,
1048+
warpsPerCTA, activeWarps);
10511049

10521050
auto group0 = packLLVector(loc, group0Vec, rewriter);
10531051
auto group1 = packLLVector(loc, group1Vec, rewriter);
@@ -1086,7 +1084,9 @@ void emitTDMLoadStore(RewriterBase &rewriter, Location loc,
10861084
assert(numDims <= 5);
10871085

10881086
// Determine activeWarps from warp_bases.
1089-
// The non-zero prefix length gives log2(activeWarps).
1087+
// activeWarps = 2^(number of non-zero bit-rows), so it is at least 1 when
1088+
// warp_bases is present (even all-zero rows yield activeWarps=1).
1089+
// activeWarps=0 is the sentinel for "attribute absent, all warps active."
10901090
int activeWarps = 0;
10911091
if (!warpBases.empty()) {
10921092
int numBits = warpBases.size() / numDims;

third_party/amd/lib/TritonAMDGPUToLLVM/TDMUtility.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,8 @@ TDMDescriptor createTDMDescriptor(RewriterBase &rewriter, Location loc,
4444
// For partitioned shared memory, dstPtrs contains multiple base pointers and
4545
// the correct one is selected based on sharedLayout's partition dimension.
4646
// activeWarps: number of warps that actually issue TDM copies (power of two,
47-
// <= numWarps). 0 means all warps are active (no partial TDM copy).
47+
// <= numWarps). 0 is the sentinel for "warp_bases absent, all warps active";
48+
// when warp_bases is present, activeWarps is at least 1.
4849
void fillTDMDescriptor(
4950
RewriterBase &rewriter, Location loc,
5051
const LLVMTypeConverter *typeConverter, Type elementType,

0 commit comments

Comments
 (0)