@@ -534,7 +534,8 @@ swapOutDimSemantics(const triton::LinearLayout &layout, StringAttr dimA,
534534// Fill TDM descriptor for regular load/store operations (1D-5D tensors).
535535// activeWarps: number of warps that actually issue TDM copies (power of two,
536536// <= numWarps). Warps with warpId >= activeWarps get pred=0 (hardware no-op).
537- // A value of 0 means all warps are active (no partial TDM copy).
537+ // 0 is the sentinel for "warp_bases absent, all warps active"; when warp_bases
538+ // is present, activeWarps is at least 1.
538539void fillTDMDescriptor (
539540 RewriterBase &rewriter, Location loc,
540541 const LLVMTypeConverter *typeConverter, Type elementType,
@@ -601,9 +602,8 @@ void fillTDMDescriptor(
601602 if (numDims >= 2 ) {
602603 // tile_dim1: lower 16 bits of group1[4]
603604 group1[4 ] = b.and_ (group1[4 ], b.i32_val (0xFFFF0000 ));
604- group1[4 ] = b.or_ (group1[4 ],
605- b.and_ (decodedBlockShape[numDims - 2 ],
606- b.i32_val (0xFFFF )));
605+ group1[4 ] = b.or_ (
606+ group1[4 ], b.and_ (decodedBlockShape[numDims - 2 ], b.i32_val (0xFFFF )));
607607 }
608608 if (numDims >= 3 ) {
609609 // tile_dim2: upper 16 bits of group1[4]
@@ -1000,17 +1000,15 @@ static int64_t computePerPartitionSliceStride(
10001000
10011001// Emit a single TDM intrinsic (load or store) for the given block shape.
10021002// This handles both the 2D (d2 intrinsic) and >2D (full intrinsic) cases.
1003- static void
1004- emitTDMIntrinsic (RewriterBase &rewriter, Location loc,
1005- const LLVMTypeConverter *typeConverter, ArrayRef<Value> desc,
1006- size_t numDims, Type elementType,
1007- SmallVector<int64_t > effectiveBlockShape, int numWarps,
1008- unsigned padInterval, unsigned padAmount,
1009- SmallVector<Value> globalOffset, ArrayRef<Value> instrDstPtrs,
1010- Value pred, Value multicastMask, Value barrier,
1011- const triton::LinearLayout &instrSharedLayout, Value ctaId,
1012- bool isLoad, bool isRowMajor, ArrayRef<unsigned > warpsPerCTA,
1013- int activeWarps = 0 ) {
1003+ static void emitTDMIntrinsic (
1004+ RewriterBase &rewriter, Location loc,
1005+ const LLVMTypeConverter *typeConverter, ArrayRef<Value> desc,
1006+ size_t numDims, Type elementType, SmallVector<int64_t > effectiveBlockShape,
1007+ int numWarps, unsigned padInterval, unsigned padAmount,
1008+ SmallVector<Value> globalOffset, ArrayRef<Value> instrDstPtrs, Value pred,
1009+ Value multicastMask, Value barrier,
1010+ const triton::LinearLayout &instrSharedLayout, Value ctaId, bool isLoad,
1011+ bool isRowMajor, ArrayRef<unsigned > warpsPerCTA, int activeWarps = 0 ) {
10141012 auto b = TritonLLVMOpBuilder (loc, rewriter);
10151013 auto v8i32Ty = VectorType::get (8 , rewriter.getI32Type ());
10161014 Value group4Zero = LLVM::ZeroOp::create (rewriter, loc, v8i32Ty);
@@ -1042,12 +1040,12 @@ emitTDMIntrinsic(RewriterBase &rewriter, Location loc,
10421040 auto group0Vec = SmallVector<Value>(desc.begin (), desc.begin () + 4 );
10431041 auto group1Vec = SmallVector<Value>(desc.begin () + 4 , desc.end ());
10441042
1045- fillTDMDescriptor (
1046- rewriter, loc, typeConverter, elementType, effectiveBlockShape ,
1047- numWarps, padInterval, padAmount, group0Vec, group1Vec, std::nullopt ,
1048- std:: nullopt , globalOffset, instrDstPtrs, pred, multicastMask, barrier,
1049- instrSharedLayout, ctaId, !isLoad, isRowMajor, warpsPerCTA ,
1050- activeWarps);
1043+ fillTDMDescriptor (rewriter, loc, typeConverter, elementType,
1044+ effectiveBlockShape, numWarps, padInterval, padAmount ,
1045+ group0Vec, group1Vec, std:: nullopt , std::nullopt ,
1046+ globalOffset, instrDstPtrs, pred, multicastMask, barrier,
1047+ instrSharedLayout, ctaId, !isLoad, isRowMajor,
1048+ warpsPerCTA, activeWarps);
10511049
10521050 auto group0 = packLLVector (loc, group0Vec, rewriter);
10531051 auto group1 = packLLVector (loc, group1Vec, rewriter);
@@ -1086,7 +1084,9 @@ void emitTDMLoadStore(RewriterBase &rewriter, Location loc,
10861084 assert (numDims <= 5 );
10871085
10881086 // Determine activeWarps from warp_bases.
1089- // The non-zero prefix length gives log2(activeWarps).
1087+ // activeWarps = 2^(number of non-zero bit-rows), so it is at least 1 when
1088+ // warp_bases is present (even all-zero rows yield activeWarps=1).
1089+ // activeWarps=0 is the sentinel for "attribute absent, all warps active."
10901090 int activeWarps = 0 ;
10911091 if (!warpBases.empty ()) {
10921092 int numBits = warpBases.size () / numDims;
0 commit comments