diff --git a/include/triton/Conversion/TritonGPUToLLVM/Utility.h b/include/triton/Conversion/TritonGPUToLLVM/Utility.h index 3948751981b5..ed12924042ab 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/Utility.h +++ b/include/triton/Conversion/TritonGPUToLLVM/Utility.h @@ -613,10 +613,12 @@ struct LocalSharedMemoryAddress { // Compute per-element shared-memory addresses for a local atomic/ldst update by // replacing `coords[*][axis]` with `idxValues[*]` and mapping the resulting // logical coordinates back to shared-memory offsets and target CTAs. -SmallVector computeLocalAddrs( - Location loc, triton::gpu::MemDescType memDescTy, - SharedMemoryObject smemObj, Type llvmElemTy, ArrayRef idxValues, - ArrayRef> coords, unsigned axis, RewriterBase &rewriter); +SmallVector +computeLocalAddrs(Location loc, triton::gpu::MemDescType memDescTy, + SharedMemoryObject smemObj, Type llvmElemTy, + ArrayRef idxValues, + ArrayRef> coords, unsigned axis, + RewriterBase &rewriter, ArrayRef offsets = {}); // Backend-agnostic preparation for lowering LocalAtomicScatterRMWOp. struct LocalAtomicScatterRMWInfo { diff --git a/include/triton/Dialect/TritonInstrument/IR/TritonInstrumentOps.td b/include/triton/Dialect/TritonInstrument/IR/TritonInstrumentOps.td index 1a48850a3525..8f1782e3846d 100644 --- a/include/triton/Dialect/TritonInstrument/IR/TritonInstrumentOps.td +++ b/include/triton/Dialect/TritonInstrument/IR/TritonInstrumentOps.td @@ -15,6 +15,7 @@ include "triton/Dialect/TritonInstrument/IR/TritonInstrumentAttrDefs.td" // Interfaces // def GlobalMemory : Resource<"::mlir::triton::GlobalMemory">; +def SharedMemory : Resource<"::mlir::triton::gpu::SharedMemory">; // // Ops @@ -78,6 +79,29 @@ def TTI_ExperimentalClusterCTAIdOp let assemblyFormat = "attr-dict `:` type($result)"; } +def TTI_ExperimentalLocalGatherOp + : TTI_Op<"experimental_local_gather"> { + let summary = "Gather elements from shared memory with logical base offsets"; + let description = [{ + Gather elements from a shared memory descriptor using an index tensor along + one axis, after shifting the logical source coordinates by rank-sized scalar + offsets. This is intentionally private to instrumentation passes. + }]; + let arguments = (ins + Arg]>:$src, + TT_IntTensor:$indices, + Variadic:$offsets, + I32Attr:$axis + ); + let results = (outs TT_Tensor:$result); + + let assemblyFormat = [{ + $src `[` $indices `]` `offsets` `=` `[` $offsets `]` + attr-dict `:` qualified(type($src)) `,` type($indices) `->` type($result) + }]; + let hasVerifier = 1; +} + def TTI_ExperimentalGSanInitOp : TTI_Op<"experimental_gsan_init"> { let summary = "Initialize GSan thread"; diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.cpp b/lib/Conversion/TritonGPUToLLVM/Utility.cpp index f6044cc3223a..a8bb36b74232 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.cpp +++ b/lib/Conversion/TritonGPUToLLVM/Utility.cpp @@ -545,7 +545,7 @@ computeLocalAddrs(Location loc, triton::gpu::MemDescType memDescTy, SharedMemoryObject smemObj, Type llvmElemTy, ArrayRef idxValues, ArrayRef> coords, unsigned axis, - RewriterBase &rewriter) { + RewriterBase &rewriter, ArrayRef offsets) { MLIRContext *ctx = memDescTy.getContext(); auto b = TritonLLVMOpBuilder(loc, rewriter); @@ -580,9 +580,12 @@ computeLocalAddrs(Location loc, triton::gpu::MemDescType memDescTy, idx = b.zext(i32_ty, idx); } - // Copy coordinates and replace the axis coordinate with the index value + // Copy coordinates, replace the axis coordinate with the index value, and + // then shift all logical coordinates by the optional base offsets. SmallVector indices(coords[i]); indices[axis] = idx; + for (auto [dim, offset] : llvm::enumerate(offsets)) + indices[dim] = b.add(indices[dim], offset); // Apply inverted shared layout to compute offset SmallVector> inputs; diff --git a/lib/Conversion/TritonInstrumentToLLVM/InstrumentationToLLVM.cpp b/lib/Conversion/TritonInstrumentToLLVM/InstrumentationToLLVM.cpp index ba5303472036..47491469193b 100644 --- a/lib/Conversion/TritonInstrumentToLLVM/InstrumentationToLLVM.cpp +++ b/lib/Conversion/TritonInstrumentToLLVM/InstrumentationToLLVM.cpp @@ -347,6 +347,51 @@ struct ClusterCTAIdOpConversion const TargetInfoBase &targetInfo; }; +struct LocalGatherOpConversion + : public ConvertOpToLLVMPattern { + LocalGatherOpConversion(const LLVMTypeConverter &converter, + const TargetInfoBase &targetInfo, + PatternBenefit benefit = 1) + : ConvertOpToLLVMPattern(converter, + benefit), + targetInfo(targetInfo) {} + + LogicalResult + matchAndRewrite(tti::ExperimentalLocalGatherOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto memDescTy = cast(op.getSrc().getType()); + auto regTy = cast(op.getType()); + auto typeConverter = getTypeConverter(); + + Type llvmElemTy = typeConverter->convertType(memDescTy.getElementType()); + auto smemObj = LLVM::getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(), + llvmElemTy, rewriter); + auto idxValues = unpackLLElements(loc, adaptor.getIndices(), rewriter); + auto dstIndices = + emitIndices(loc, rewriter, targetInfo, regTy.getEncoding(), regTy, + /*withCTAOffset=*/true); + SmallVector offsets(adaptor.getOffsets()); + + auto addrs = + computeLocalAddrs(loc, memDescTy, smemObj, llvmElemTy, idxValues, + dstIndices, op.getAxis(), rewriter, offsets); + auto b = TritonLLVMOpBuilder(loc, rewriter); + SmallVector results = + llvm::map_to_vector(addrs, [&](const LocalSharedMemoryAddress &addr) { + return targetInfo.loadDShared(rewriter, loc, addr.ptr, addr.ctaId, + llvmElemTy, b.true_val()); + }); + Value result = packLLElements(loc, typeConverter, results, rewriter, regTy); + + rewriter.replaceOp(op, result); + return success(); + } + +private: + const TargetInfoBase &targetInfo; +}; + } // namespace void mlir::triton::populateInstrumentationToLLVMPatterns( @@ -358,4 +403,5 @@ void mlir::triton::populateInstrumentationToLLVMPatterns( patterns.add(typeConverter, targetInfo); patterns.add(typeConverter); patterns.add(typeConverter, targetInfo); + patterns.add(typeConverter, targetInfo); } diff --git a/lib/Dialect/TritonInstrument/IR/Ops.cpp b/lib/Dialect/TritonInstrument/IR/Ops.cpp index 9b8ec082eaff..468e34444986 100644 --- a/lib/Dialect/TritonInstrument/IR/Ops.cpp +++ b/lib/Dialect/TritonInstrument/IR/Ops.cpp @@ -40,6 +40,40 @@ LogicalResult DotI8Op::verify() { bEnc); } +LogicalResult ExperimentalLocalGatherOp::verify() { + auto srcTy = getSrc().getType(); + auto indicesTy = cast(getIndices().getType()); + auto dstTy = cast(getType()); + unsigned axis = getAxis(); + + if (!isa(srcTy.getEncoding())) + return emitError("source must have shared memory encoding"); + + if (!indicesTy.getElementType().isInteger()) + return emitError("indices must have integer element type"); + + if (dstTy.getShape() != indicesTy.getShape()) + return emitError("result shape must match indices shape"); + + if (srcTy.getRank() != indicesTy.getRank()) + return emitError("source and indices must have the same rank"); + + if (axis >= srcTy.getRank()) + return emitError("axis ") + << axis << " is out of bounds for source rank " << srcTy.getRank(); + + if (srcTy.getElementType() != dstTy.getElementType()) + return emitError("result element type must match source element type"); + + if (indicesTy.getEncoding() != dstTy.getEncoding()) + return emitError("indices and result must have the same layout"); + + if (static_cast(getOffsets().size()) != srcTy.getRank()) + return emitError("offset count must match source rank"); + + return success(); +} + template struct PushFPSanThroughViewPattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; diff --git a/lib/Dialect/TritonInstrument/Transforms/FpSanitizer.cpp b/lib/Dialect/TritonInstrument/Transforms/FpSanitizer.cpp index 37c8dae6f580..c768609f5dd6 100644 --- a/lib/Dialect/TritonInstrument/Transforms/FpSanitizer.cpp +++ b/lib/Dialect/TritonInstrument/Transforms/FpSanitizer.cpp @@ -68,19 +68,26 @@ bool canUseI8MmaTile(int64_t m, int64_t n, int numWarps) { return m >= 16 && n >= 8 && (m / 16) * (n / 8) >= numWarps; } -std::pair getMmaEmulationTileShape(PatternRewriter &rewriter, - int64_t m, int64_t n, - int64_t k, - IntegerType accElem) { +std::pair +getMmaEmulationTileShape(PatternRewriter &rewriter, int64_t m, int64_t n, + int64_t k, IntegerType accElem, + bool directShared = false) { + std::pair tile = {std::min(kTileM, m), + std::min(kTileN, n)}; + int64_t numWarps = + ttg::lookupNumWarps(rewriter.getInsertionBlock()->getParent()); if (supportsI8DotDecomposition(rewriter, accElem) && (k % kI8MmaK) == 0) { - int64_t numWarps = - ttg::lookupNumWarps(rewriter.getInsertionBlock()->getParent()); int64_t tileM = std::min(16 * numWarps, m); int64_t tileN = std::min(8 * numWarps, n); if (canUseI8MmaTile(tileM, tileN, numWarps)) - return {tileM, tileN}; + tile = {tileM, tileN}; } - return {std::min(kTileM, m), std::min(kTileN, n)}; + if (directShared) { + int64_t widerN = std::min(16 * numWarps, n); + if (widerN > tile.second && canUseI8MmaTile(tile.first, widerN, numWarps)) + tile.second = widerN; + } + return tile; } Operation *createGlobalScratchBarrier(PatternRewriter &rewriter, Location loc, @@ -152,6 +159,16 @@ struct ScratchInfo { RankedTensorType tensorType; }; +struct MmaOperandSource { + Value scratchPtr; + Value sharedMemdesc; + RankedTensorType tileType; + int64_t rowStride; + int64_t stride; + + bool isShared() const { return static_cast(sharedMemdesc); } +}; + struct ScratchState { std::optional canonical; DenseMap byScope; @@ -885,31 +902,18 @@ Value fpsanVariadicExternTagged(PatternRewriter &rewriter, Location loc, } std::optional -createOperandScratch(PatternRewriter &rewriter, Location loc, - TmemScratchManager &scratch, Value memdesc, - ttg::MemDescType memTy, bool isTmem, Region *scope) { +createTmemOperandScratch(PatternRewriter &rewriter, Location loc, + TmemScratchManager &scratch, Value memdesc, + ttg::MemDescType memTy, Region *scope) { auto layout = scratch.getScratchEncoding(rewriter, memdesc, memTy); auto tensorTy = RankedTensorType::get(memTy.getShape(), memTy.getElementType(), layout); - Value fullVal; - if (isTmem) { - auto info = scratch.getOrCreate(memdesc, rewriter, scope); - if (!info) - return std::nullopt; - fullVal = loadFpSanScratchMemory(rewriter, loc, info->ptr, tensorTy); - if (!fullVal) - return std::nullopt; - } else { - if (scratch.usesSharedClusterState()) { - // A two-CTA TMA barrier is waited on by the lead CTA. FPSAN executes the - // emulated MMA in both CTAs, so rendezvous before the partner snapshots - // the shared-memory operand. - ttng::ClusterBarrierOp::create(rewriter, loc); - } - fullVal = - ttg::LocalLoadOp::create(rewriter, loc, tensorTy, memdesc, Value()) - .getResult(); - } + auto info = scratch.getOrCreate(memdesc, rewriter, scope); + if (!info) + return std::nullopt; + Value fullVal = loadFpSanScratchMemory(rewriter, loc, info->ptr, tensorTy); + if (!fullVal) + return std::nullopt; int64_t elSize = memTy.getElementType().getIntOrFloatBitWidth() / 8; int64_t alignment = std::max(elSize, 16); int64_t sizeInBytes = product(memTy.getShape()) * elSize; @@ -923,6 +927,19 @@ createOperandScratch(PatternRewriter &rewriter, Location loc, return ScratchInfo{ptr, tensorTy}; } +std::optional createMmaOperandSource( + PatternRewriter &rewriter, Location loc, TmemScratchManager &scratch, + Value memdesc, ttg::MemDescType memTy, bool isTmem, RankedTensorType tileTy, + Region *scope, int64_t rowStride, int64_t stride) { + if (!isTmem) + return MmaOperandSource{Value(), memdesc, tileTy, rowStride, stride}; + auto info = + createTmemOperandScratch(rewriter, loc, scratch, memdesc, memTy, scope); + if (!info) + return std::nullopt; + return MmaOperandSource{info->ptr, Value(), tileTy, rowStride, stride}; +} + std::optional createWGMMAScratch(PatternRewriter &rewriter, Location loc, Value operand) { if (auto memTy = dyn_cast(operand.getType())) { @@ -1025,6 +1042,75 @@ Value loadScratchStrided2D(PatternRewriter &rewriter, Location loc, Value base, stride1); } +Value loadMmaOperand(PatternRewriter &rewriter, Location loc, + const MmaOperandSource &source, RankedTensorType resultTy, + bool isLhs, Value tileOffset, Value kOffset) { + if (!source.isShared()) { + Value rowOffset = isLhs ? tileOffset : kOffset; + Value colOffset = isLhs ? kOffset : tileOffset; + Value rowStride = arith::ConstantOp::create( + rewriter, loc, rewriter.getI32IntegerAttr(source.rowStride)); + Value stride = arith::ConstantOp::create( + rewriter, loc, rewriter.getI32IntegerAttr(source.stride)); + Value row = arith::MulIOp::create(rewriter, loc, rowOffset, rowStride); + Value col = arith::MulIOp::create(rewriter, loc, colOffset, stride); + Value offset = arith::AddIOp::create(rewriter, loc, row, col); + Value ptr = tt::AddPtrOp::create(rewriter, loc, source.scratchPtr.getType(), + source.scratchPtr, offset); + return loadScratchStrided2D(rewriter, loc, ptr, resultTy, source.rowStride, + source.stride); + } + + Value shared = source.sharedMemdesc; + auto sharedTy = cast(shared.getType()); + unsigned tileAxis = isLhs ? 0 : 1; + unsigned kAxis = 1 - tileAxis; + auto cgaLayout = ttg::getCGALayout(sharedTy.getEncoding()); + auto replicatedCGA = ttg::CGAEncodingAttr::fromSplitParams( + rewriter.getContext(), cgaLayout.getCTAsPerCGA(), + SmallVector(2, 1), cgaLayout.getCTAOrder()); + SmallVector loadShape(resultTy.getShape()); + int numWarps = ttg::lookupNumWarps(rewriter.getInsertionBlock()->getParent()); + int threadsPerWarp = ttg::lookupThreadsPerWarp(rewriter); + SmallVector threadsPerWarpShape(2, 1); + threadsPerWarpShape[tileAxis] = + std::min(loadShape[tileAxis], threadsPerWarp); + threadsPerWarpShape[kAxis] = threadsPerWarp / threadsPerWarpShape[tileAxis]; + SmallVector warpsPerCTA(2, 1); + warpsPerCTA[kAxis] = numWarps; + SmallVector sizePerThread(2, 1); + sizePerThread[tileAxis] = loadShape[tileAxis] / threadsPerWarpShape[tileAxis]; + auto loadLayout = ttg::BlockedEncodingAttr::get( + rewriter.getContext(), sizePerThread, threadsPerWarpShape, warpsPerCTA, + SmallVector{tileAxis, kAxis}, replicatedCGA); + auto loadTy = + RankedTensorType::get(loadShape, resultTy.getElementType(), loadLayout); + auto indicesTy = + RankedTensorType::get(loadShape, rewriter.getI32Type(), loadLayout); + auto kEncoding = getSingleDimSliceEncoding(loadLayout, kAxis); + auto kTy = RankedTensorType::get({loadShape[kAxis]}, rewriter.getI32Type(), + kEncoding); + Value indices = + tt::MakeRangeOp::create(rewriter, loc, kTy, 0, loadShape[kAxis]); + Value kSplat = tt::SplatOp::create(rewriter, loc, kTy, kOffset); + indices = arith::AddIOp::create(rewriter, loc, indices, kSplat); + indices = expandAllSlicedDims(rewriter, loc, indices); + if (cast(indices.getType()).getShape() != + ArrayRef(loadShape)) + indices = tt::BroadcastOp::create(rewriter, loc, indicesTy, indices); + Value zero = + arith::ConstantOp::create(rewriter, loc, rewriter.getI32IntegerAttr(0)); + SmallVector offsets(2, zero); + offsets[tileAxis] = tileOffset; + Value loaded = ExperimentalLocalGatherOp::create( + rewriter, loc, loadTy, shared, indices, offsets, + rewriter.getI32IntegerAttr(kAxis)) + .getResult(); + if (loaded.getType() != resultTy) + loaded = ttg::ConvertLayoutOp::create(rewriter, loc, resultTy, loaded); + return loaded; +} + Operation *storeScratchStrided2D(PatternRewriter &rewriter, Location loc, Value base, Value tensor, RankedTensorType tensorTy, int64_t stride0, @@ -1266,17 +1352,17 @@ Value unpackPackedFp4Tensor(PatternRewriter &rewriter, Location loc, } Value loadOperandK32(PatternRewriter &rewriter, Location loc, bool isLhs, - Value tilePtr, RankedTensorType fullTileTy, Value kI32, - int64_t rowStride, int64_t stride, + const MmaOperandSource &source, Value tileIdx, Value kI32, int64_t packFactor = 1) { - SmallVector logicalShape{isLhs ? fullTileTy.getShape()[0] : kI8MmaK, - isLhs ? kI8MmaK : fullTileTy.getShape()[1]}; + SmallVector logicalShape{ + isLhs ? source.tileType.getShape()[0] : kI8MmaK, + isLhs ? kI8MmaK : source.tileType.getShape()[1]}; SmallVector rawShape = logicalShape; rawShape[isLhs ? 1 : 0] /= packFactor; - auto rawLayout = getOptimizedBlockedEncoding(rewriter, rawShape, - fullTileTy.getElementType()); - auto rawTy = - RankedTensorType::get(rawShape, fullTileTy.getElementType(), rawLayout); + auto rawLayout = getOptimizedBlockedEncoding( + rewriter, rawShape, source.tileType.getElementType()); + auto rawTy = RankedTensorType::get(rawShape, source.tileType.getElementType(), + rawLayout); Value packedKIdx = kI32; if (packFactor != 1) { @@ -1284,13 +1370,8 @@ Value loadOperandK32(PatternRewriter &rewriter, Location loc, bool isLhs, rewriter, loc, rewriter.getI32IntegerAttr(packFactor)); packedKIdx = arith::DivUIOp::create(rewriter, loc, kI32, factor); } - Value kStride = arith::ConstantOp::create( - rewriter, loc, rewriter.getI32IntegerAttr(isLhs ? stride : rowStride)); - Value offset = arith::MulIOp::create(rewriter, loc, packedKIdx, kStride); - Value chunkPtr = - tt::AddPtrOp::create(rewriter, loc, tilePtr.getType(), tilePtr, offset); Value chunk = - loadScratchStrided2D(rewriter, loc, chunkPtr, rawTy, rowStride, stride); + loadMmaOperand(rewriter, loc, source, rawTy, isLhs, tileIdx, packedKIdx); if (packFactor == 2) { auto logicalLayout = getOptimizedBlockedEncoding(rewriter, logicalShape, @@ -1365,13 +1446,13 @@ Value loadScaledScaleK32(PatternRewriter &rewriter, Location loc, bool isLhs, } Value loadScaledOperandK32(PatternRewriter &rewriter, Location loc, bool isLhs, - Value tilePtr, RankedTensorType fullTileTy, + const MmaOperandSource &source, const DotScaleConfig &scale, Value tileIdx, - Value kI32, int64_t rowStride, int64_t stride) { + Value kI32) { int64_t packFactor = isLhs ? scale.aKPackFactor : scale.bKPackFactor; tt::ScaleDotElemType elemType = isLhs ? scale.aElemType : scale.bElemType; - Value chunk = loadOperandK32(rewriter, loc, isLhs, tilePtr, fullTileTy, kI32, - rowStride, stride, packFactor); + Value chunk = + loadOperandK32(rewriter, loc, isLhs, source, tileIdx, kI32, packFactor); Value payload = castDotScaledOperandToComputePayload( rewriter, loc, chunk, elemType, scale.computeElem); @@ -1488,17 +1569,17 @@ Value tryEmitI8DotDecomposition(PatternRewriter &rewriter, Location loc, } std::optional emitMmaEmulationLoops( - PatternRewriter &rewriter, Location loc, Value aPtr, Value bPtr, Value dPtr, - int64_t m, int64_t n, int64_t k, int64_t tileM, int64_t tileN, - RankedTensorType aTileTy, RankedTensorType bTileTy, - RankedTensorType accTileTy, ttg::DistributedEncodingTrait accLayout, - IntegerType accElem, Value useDInt, Value predInt, int64_t aStride, - int64_t bStride, int64_t dStride, const DotScaleConfig &scale = {}, - int64_t aRowStride = 1, int64_t bRowStride = 1, int64_t dRowStride = 1) { + PatternRewriter &rewriter, Location loc, const MmaOperandSource &aSource, + const MmaOperandSource &bSource, Value dPtr, int64_t m, int64_t n, + int64_t k, int64_t tileM, int64_t tileN, RankedTensorType accTileTy, + ttg::DistributedEncodingTrait accLayout, IntegerType accElem, Value useDInt, + Value predInt, int64_t dStride, const DotScaleConfig &scale = {}, + int64_t dRowStride = 1) { if ((m % tileM) != 0 || (n % tileN) != 0) return std::nullopt; OpBuilder::InsertionGuard guard(rewriter); + auto i32Ty = rewriter.getI32Type(); Value zero = arith::ConstantOp::create(rewriter, loc, rewriter.getI32IntegerAttr(0)); Value mUpper = @@ -1509,27 +1590,19 @@ std::optional emitMmaEmulationLoops( rewriter.getI32IntegerAttr(tileM)); Value nStep = arith::ConstantOp::create(rewriter, loc, rewriter.getI32IntegerAttr(tileN)); - auto mLoop = scf::ForOp::create(rewriter, loc, zero, mUpper, mStep); rewriter.setInsertionPointToStart(mLoop.getBody()); - Value mIdx = mLoop.getInductionVar(); auto nLoop = scf::ForOp::create(rewriter, loc, zero, nUpper, nStep); rewriter.setInsertionPointToStart(nLoop.getBody()); - Value nIdx = nLoop.getInductionVar(); + Value mIdxI32 = + arith::IndexCastOp::create(rewriter, loc, i32Ty, mLoop.getInductionVar()); + Value nIdxI32 = + arith::IndexCastOp::create(rewriter, loc, i32Ty, nLoop.getInductionVar()); - auto i32Ty = rewriter.getI32Type(); - Value mIdxI32 = arith::IndexCastOp::create(rewriter, loc, i32Ty, mIdx); - Value nIdxI32 = arith::IndexCastOp::create(rewriter, loc, i32Ty, nIdx); Value dRowStrideConst = arith::ConstantOp::create( rewriter, loc, rewriter.getI32IntegerAttr(dRowStride)); Value dStrideConst = arith::ConstantOp::create( rewriter, loc, rewriter.getI32IntegerAttr(dStride)); - Value aRowStrideConst = arith::ConstantOp::create( - rewriter, loc, rewriter.getI32IntegerAttr(aRowStride)); - Value bStrideConst = arith::ConstantOp::create( - rewriter, loc, rewriter.getI32IntegerAttr(bStride)); - Value bRowStrideConst = arith::ConstantOp::create( - rewriter, loc, rewriter.getI32IntegerAttr(bRowStride)); Value mDOffset = arith::MulIOp::create(rewriter, loc, mIdxI32, dRowStrideConst); @@ -1541,15 +1614,8 @@ std::optional emitMmaEmulationLoops( dRowStride, dStride); Value accTileI = embedToInt(rewriter, loc, accTile); - Value aMOffset = - arith::MulIOp::create(rewriter, loc, mIdxI32, aRowStrideConst); - Value aTilePtr = - tt::AddPtrOp::create(rewriter, loc, aPtr.getType(), aPtr, aMOffset); - Value bOffset = arith::MulIOp::create(rewriter, loc, nIdxI32, bStrideConst); - Value bTilePtr = - tt::AddPtrOp::create(rewriter, loc, bPtr.getType(), bPtr, bOffset); - Value sum; + bool hasSharedOperand = aSource.isShared() || bSource.isShared(); int numWarps = ttg::lookupNumWarps(rewriter.getInsertionBlock()->getParent()); auto isScaleK32Aligned = [](Value scalePtr, int64_t scaleFactor) { return !scalePtr || (scaleFactor > 0 && ((kI8MmaK % scaleFactor) == 0 || @@ -1561,11 +1627,11 @@ std::optional emitMmaEmulationLoops( isScaleK32Aligned(scale.bScalePtr, scale.bScaleFactor) && canUseI8MmaTile(tileM, tileN, numWarps); if (canUseI8Decomposition) { - if (!scale.computeElem && accElem.getWidth() <= 32) { - Value aTile = loadScratchStrided2D(rewriter, loc, aTilePtr, aTileTy, - aRowStride, aStride); - Value bTile = loadScratchStrided2D(rewriter, loc, bTilePtr, bTileTy, - bRowStride, bStride); + if (!hasSharedOperand && !scale.computeElem && accElem.getWidth() <= 32) { + Value aTile = loadMmaOperand(rewriter, loc, aSource, aSource.tileType, + /*isLhs=*/true, mIdxI32, zero); + Value bTile = loadMmaOperand(rewriter, loc, bSource, bSource.tileType, + /*isLhs=*/false, nIdxI32, zero); sum = tryEmitI8DotDecomposition( rewriter, loc, embedToInt(rewriter, loc, aTile), embedToInt(rewriter, loc, bTile), accLayout, accElem, numWarps); @@ -1585,17 +1651,15 @@ std::optional emitMmaEmulationLoops( Value aChunk; Value bChunk; if (scale.computeElem) { - aChunk = loadScaledOperandK32(rewriter, loc, /*isLhs=*/true, aTilePtr, - aTileTy, scale, mIdxI32, kI32, aRowStride, - aStride); - bChunk = loadScaledOperandK32(rewriter, loc, /*isLhs=*/false, bTilePtr, - bTileTy, scale, nIdxI32, kI32, bRowStride, - bStride); + aChunk = loadScaledOperandK32(rewriter, loc, /*isLhs=*/true, aSource, + scale, mIdxI32, kI32); + bChunk = loadScaledOperandK32(rewriter, loc, /*isLhs=*/false, bSource, + scale, nIdxI32, kI32); } else { - aChunk = loadOperandK32(rewriter, loc, /*isLhs=*/true, aTilePtr, - aTileTy, kI32, aRowStride, aStride); - bChunk = loadOperandK32(rewriter, loc, /*isLhs=*/false, bTilePtr, - bTileTy, kI32, bRowStride, bStride); + aChunk = loadOperandK32(rewriter, loc, /*isLhs=*/true, aSource, mIdxI32, + kI32); + bChunk = loadOperandK32(rewriter, loc, /*isLhs=*/false, bSource, + nIdxI32, kI32); } Value partial = tryEmitI8DotDecomposition( rewriter, loc, embedToInt(rewriter, loc, aChunk), @@ -1612,13 +1676,10 @@ std::optional emitMmaEmulationLoops( } if (!sum) { - auto aSliceTy = - RankedTensorType::get({tileM, 1}, aTileTy.getElementType(), accLayout); - auto bSliceTy = - RankedTensorType::get({1, tileN}, bTileTy.getElementType(), accLayout); - Value aStrideVal = arith::ConstantOp::create( - rewriter, loc, rewriter.getI32IntegerAttr(aStride)); - + auto aSliceTy = RankedTensorType::get( + {tileM, 1}, aSource.tileType.getElementType(), accLayout); + auto bSliceTy = RankedTensorType::get( + {1, tileN}, bSource.tileType.getElementType(), accLayout); Value zeroSum = getIntConstantLike(rewriter, loc, accTileI.getType(), 0); Value kUpper = arith::ConstantOp::create(rewriter, loc, rewriter.getI32IntegerAttr(k)); @@ -1641,18 +1702,10 @@ std::optional emitMmaEmulationLoops( rewriter, loc, rewriter.getI32IntegerAttr(scale.bKPackFactor)); bKIdx = arith::DivUIOp::create(rewriter, loc, kI32, bPackFactor); } - Value aOffset = - arith::MulIOp::create(rewriter, loc, i32Ty, aKIdx, aStrideVal); - Value aSlicePtr = - tt::AddPtrOp::create(rewriter, loc, aPtr.getType(), aTilePtr, aOffset); - Value aSlice = loadScratchStrided2D(rewriter, loc, aSlicePtr, aSliceTy, - aRowStride, aStride); - Value bKOffset = - arith::MulIOp::create(rewriter, loc, bKIdx, bRowStrideConst); - Value bSlicePtr = - tt::AddPtrOp::create(rewriter, loc, bPtr.getType(), bTilePtr, bKOffset); - Value bSlice = loadScratchStrided2D(rewriter, loc, bSlicePtr, bSliceTy, - bRowStride, bStride); + Value aSlice = loadMmaOperand(rewriter, loc, aSource, aSliceTy, + /*isLhs=*/true, mIdxI32, aKIdx); + Value bSlice = loadMmaOperand(rewriter, loc, bSource, bSliceTy, + /*isLhs=*/false, nIdxI32, bKIdx); if (scale.aKPackFactor == 2) aSlice = unpackPackedFp4Slice(rewriter, loc, aSlice, kI32); if (scale.bKPackFactor == 2) @@ -1695,10 +1748,24 @@ std::optional emitMmaEmulationLoops( createGlobalScratchBarrier(rewriter, loc); storeScratchStrided2D(rewriter, loc, dTilePtr, out, accTileTy, dRowStride, dStride); - return mLoop; } +std::optional emitMmaEmulationLoops( + PatternRewriter &rewriter, Location loc, Value aPtr, Value bPtr, Value dPtr, + int64_t m, int64_t n, int64_t k, int64_t tileM, int64_t tileN, + RankedTensorType aTileTy, RankedTensorType bTileTy, + RankedTensorType accTileTy, ttg::DistributedEncodingTrait accLayout, + IntegerType accElem, Value useDInt, Value predInt, int64_t aStride, + int64_t bStride, int64_t dStride, const DotScaleConfig &scale = {}, + int64_t aRowStride = 1, int64_t bRowStride = 1, int64_t dRowStride = 1) { + MmaOperandSource aSource{aPtr, Value(), aTileTy, aRowStride, aStride}; + MmaOperandSource bSource{bPtr, Value(), bTileTy, bRowStride, bStride}; + return emitMmaEmulationLoops(rewriter, loc, aSource, bSource, dPtr, m, n, k, + tileM, tileN, accTileTy, accLayout, accElem, + useDInt, predInt, dStride, scale, dRowStride); +} + //---------------------------------------- // Patterns //---------------------------------------- @@ -2505,42 +2572,43 @@ struct TCGen5MMAPattern : public OpRewritePattern { arith::ExtUIOp::create(rewriter, loc, accElem, op.getPred()); rewriter.setInsertionPoint(op); - auto aScratch = createOperandScratch(rewriter, loc, *scratch, op.getA(), - aMemTy, aIsTmem, scope); - if (!aScratch) - return emitFpSanCodegenError(op.getOperation()); - auto bScratch = createOperandScratch(rewriter, loc, *scratch, op.getB(), - bMemTy, bIsTmem, scope); - if (!bScratch) - return emitFpSanCodegenError(op.getOperation()); - - // Each warp may only populate a subset of the operand scratch tiles, so - // synchronize before the emulation loops start reading them. - createGlobalScratchBarrier(rewriter, loc, - scratch->usesSharedClusterState()); - - auto [tileM, tileN] = getMmaEmulationTileShape(rewriter, m, n, k, accElem); + auto [tileM, tileN] = getMmaEmulationTileShape( + rewriter, m, n, k, accElem, /*directShared=*/!aIsTmem || !bIsTmem); auto accTileLayout = getOptimizedBlockedEncoding(rewriter, {tileM, tileN}, accElem); auto accTileTy = RankedTensorType::get({tileM, tileN}, accElem, accTileLayout); - auto aTileLayout = getOptimizedBlockedEncoding( - rewriter, {tileM, k}, - getScratchStorageElementType(aMemTy.getElementType())); - auto aTileTy = RankedTensorType::get( - {tileM, k}, getScratchStorageElementType(aMemTy.getElementType()), - aTileLayout); - auto bTileLayout = getOptimizedBlockedEncoding( - rewriter, {k, tileN}, - getScratchStorageElementType(bMemTy.getElementType())); - auto bTileTy = RankedTensorType::get( - {k, tileN}, getScratchStorageElementType(bMemTy.getElementType()), - bTileLayout); + Type aTileElem = aIsTmem + ? getScratchStorageElementType(aMemTy.getElementType()) + : aMemTy.getElementType(); + auto aTileLayout = + getOptimizedBlockedEncoding(rewriter, {tileM, k}, aTileElem); + auto aTileTy = RankedTensorType::get({tileM, k}, aTileElem, aTileLayout); + Type bTileElem = bIsTmem + ? getScratchStorageElementType(bMemTy.getElementType()) + : bMemTy.getElementType(); + auto bTileLayout = + getOptimizedBlockedEncoding(rewriter, {k, tileN}, bTileElem); + auto bTileTy = RankedTensorType::get({k, tileN}, bTileElem, bTileLayout); + + auto aSource = createMmaOperandSource(rewriter, loc, *scratch, op.getA(), + aMemTy, aIsTmem, aTileTy, scope, + /*rowStride=*/1, /*stride=*/m); + auto bSource = createMmaOperandSource(rewriter, loc, *scratch, op.getB(), + bMemTy, bIsTmem, bTileTy, scope, + /*rowStride=*/1, /*stride=*/k); + if (!aSource || !bSource) + return emitFpSanCodegenError(op.getOperation()); + + // TMEM and D scratch are written cooperatively. In two-CTA mode, the + // cluster barrier also makes both CTAs' shared operands visible before the + // first direct load. + createGlobalScratchBarrier(rewriter, loc, + scratch->usesSharedClusterState()); auto mLoop = emitMmaEmulationLoops( - rewriter, loc, aScratch->ptr, bScratch->ptr, dInfo->ptr, m, n, k, tileM, - tileN, aTileTy, bTileTy, accTileTy, accTileLayout, accElem, useDInt, - predInt, /*aStride=*/m, /*bStride=*/k, /*dStride=*/m); + rewriter, loc, *aSource, *bSource, dInfo->ptr, m, n, k, tileM, tileN, + accTileTy, accTileLayout, accElem, useDInt, predInt, /*dStride=*/m); if (!mLoop) return emitFpSanUnsupported(op.getOperation()); rewriter.setInsertionPointAfter(*mLoop); @@ -2661,24 +2729,17 @@ struct TCGen5MMAScaledPattern arith::ExtUIOp::create(rewriter, loc, accElem, op.getPred()); rewriter.setInsertionPoint(op); - auto aScratch = createOperandScratch(rewriter, loc, *scratch, op.getA(), - aMemTy, aIsTmem, scope); - if (!aScratch) - return emitFpSanCodegenError(op.getOperation()); - auto bScratch = createOperandScratch(rewriter, loc, *scratch, op.getB(), - bMemTy, bIsTmem, scope); - if (!bScratch) - return emitFpSanCodegenError(op.getOperation()); - auto aScaleScratch = createOperandScratch( - rewriter, loc, *scratch, op.getAScale(), aScaleMemTy, true, scope); + auto aScaleScratch = createTmemOperandScratch( + rewriter, loc, *scratch, op.getAScale(), aScaleMemTy, scope); if (!aScaleScratch) return emitFpSanCodegenError(op.getOperation()); - auto bScaleScratch = createOperandScratch( - rewriter, loc, *scratch, op.getBScale(), bScaleMemTy, true, scope); + auto bScaleScratch = createTmemOperandScratch( + rewriter, loc, *scratch, op.getBScale(), bScaleMemTy, scope); if (!bScaleScratch) return emitFpSanCodegenError(op.getOperation()); - auto [tileM, tileN] = getMmaEmulationTileShape(rewriter, m, n, k, accElem); + auto [tileM, tileN] = getMmaEmulationTileShape( + rewriter, m, n, k, accElem, /*directShared=*/!aIsTmem || !bIsTmem); auto accTileLayout = getOptimizedBlockedEncoding(rewriter, {tileM, tileN}, dMemTy.getElementType()); @@ -2693,6 +2754,15 @@ struct TCGen5MMAScaledPattern auto bTileTy = RankedTensorType::get({bPackedK, tileN}, bMemTy.getElementType(), bTileLayout); + auto aSource = createMmaOperandSource(rewriter, loc, *scratch, op.getA(), + aMemTy, aIsTmem, aTileTy, scope, + /*rowStride=*/1, /*stride=*/m); + auto bSource = createMmaOperandSource(rewriter, loc, *scratch, op.getB(), + bMemTy, bIsTmem, bTileTy, scope, + /*rowStride=*/1, /*stride=*/bPackedK); + if (!aSource || !bSource) + return emitFpSanCodegenError(op.getOperation()); + DotScaleConfig scale; scale.aElemType = op.getAType(); scale.bElemType = op.getBType(); @@ -2711,15 +2781,15 @@ struct TCGen5MMAScaledPattern scale.aScaleFactor = *aScaleFactor; scale.bScaleFactor = *bScaleFactor; - // The operand and scale scratch buffers are written cooperatively, so all - // warps must finish those stores before the emulation loop reads them. + // TMEM scales and D scratch are written cooperatively. In two-CTA mode, + // rendezvous before either CTA directly reads the shared operands. createGlobalScratchBarrier(rewriter, loc, scratch->usesSharedClusterState()); - auto mLoop = emitMmaEmulationLoops( - rewriter, loc, aScratch->ptr, bScratch->ptr, dInfo->ptr, m, n, k, tileM, - tileN, aTileTy, bTileTy, accTileTy, accTileLayout, accElem, useDInt, - predInt, /*aStride=*/m, /*bStride=*/bPackedK, /*dStride=*/m, scale); + auto mLoop = + emitMmaEmulationLoops(rewriter, loc, *aSource, *bSource, dInfo->ptr, m, + n, k, tileM, tileN, accTileTy, accTileLayout, + accElem, useDInt, predInt, /*dStride=*/m, scale); if (!mLoop) return emitFpSanUnsupported(op.getOperation()); rewriter.setInsertionPointAfter(*mLoop); diff --git a/test/Conversion/tritoninstrument_to_llvm.mlir b/test/Conversion/tritoninstrument_to_llvm.mlir index 6749f655c06f..9b266195d902 100644 --- a/test/Conversion/tritoninstrument_to_llvm.mlir +++ b/test/Conversion/tritoninstrument_to_llvm.mlir @@ -137,3 +137,26 @@ tt.func private @experimental_fpsan_unembed(%arg0: i32) -> f32 { tt.return %0 : f32 } } + +// ----- + +#local_gather_blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0], CGALayout = [[0, 1]]}> +#local_gather_shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], CGALayout = [[0, 1]]}> + +module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32, ttg.target = "cuda:90"} { +// CHECK-LABEL: @experimental_local_gather +// CHECK: nvvm.mapa +// CHECK: llvm.load {{.*}} : !llvm.ptr<7> -> i32 +tt.func private @experimental_local_gather(%out: !tt.ptr) { + %src = ttg.local_alloc {allocation.offset = [0 : i32, 256 : i32]} : () -> !ttg.memdesc<2x32xi32, #local_gather_shared, #ttg.shared_memory, mutable> + %idx = arith.constant dense<0> : tensor<2x32xi32, #local_gather_blocked> + %c0 = arith.constant 0 : i32 + %c1 = arith.constant 1 : i32 + %g = tti.experimental_local_gather %src[%idx] offsets = [%c1, %c0] {axis = 0 : i32} : !ttg.memdesc<2x32xi32, #local_gather_shared, #ttg.shared_memory, mutable>, tensor<2x32xi32, #local_gather_blocked> -> tensor<2x32xi32, #local_gather_blocked> + %ptrs = tt.splat %out : !tt.ptr -> tensor<2x32x!tt.ptr, #local_gather_blocked> + %offs = arith.constant dense<0> : tensor<2x32xi32, #local_gather_blocked> + %out_ptrs = tt.addptr %ptrs, %offs : tensor<2x32x!tt.ptr, #local_gather_blocked>, tensor<2x32xi32, #local_gather_blocked> + tt.store %out_ptrs, %g : tensor<2x32x!tt.ptr, #local_gather_blocked> + tt.return +} +} diff --git a/test/TritonGPU/nvidia-fpsan.mlir b/test/TritonGPU/nvidia-fpsan.mlir index 94a56218a400..fde1b2c22fb4 100644 --- a/test/TritonGPU/nvidia-fpsan.mlir +++ b/test/TritonGPU/nvidia-fpsan.mlir @@ -76,6 +76,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar // CHECK: ttg.global_scratch_alloc // CHECK: ttg.barrier global_read|global_write // CHECK-NEXT: scf.for + // CHECK: tti.experimental_local_gather + // CHECK: tti.experimental_local_gather // CHECK: tti.dot_i8 {{.*}} aSigned = true, bSigned = true // CHECK: tti.dot_i8 {{.*}} aSigned = false, bSigned = true // CHECK: tti.dot_i8 {{.*}} aSigned = true, bSigned = false @@ -97,6 +99,32 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar // ----- +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}> +#shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}> +#smem = #ttg.shared_memory +#tmem_a = #ttng.tensor_memory_encoding +#tmem_d = #ttng.tensor_memory_encoding +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 65544 : i32, ttg.tensor_memory_size = 0 : i32, "ttg.total-num-warps" = 1 : i32} { + // CHECK-LABEL: @tcgen05_mma_tmem_a_shared_b + tt.func public @tcgen05_mma_tmem_a_shared_b() { + // CHECK: ttg.global_scratch_alloc + // CHECK: tt.store + // CHECK: ttg.barrier global_read|global_write + // CHECK: tti.experimental_local_gather + // CHECK: tti.dot_i8 + // CHECK-NOT: ttng.tc_gen5_mma + %true = arith.constant true + %a = ttng.tmem_alloc {tensor_memory_col_offset = 0 : i32, tensor_memory_row_offset = 0 : i32} : () -> !ttg.memdesc<128x128xf16, #tmem_a, #ttng.tensor_memory, mutable> + %b = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<128x128xf16, #shared, #smem, mutable> + %d = ttng.tmem_alloc {tensor_memory_col_offset = 128 : i32, tensor_memory_row_offset = 0 : i32} : () -> !ttg.memdesc<128x128xf32, #tmem_d, #ttng.tensor_memory, mutable> + %bar = ttg.local_alloc {allocation.offset = 8192 : i32} : () -> !ttg.memdesc<1xi64, #shared1, #smem, mutable> + ttng.tc_gen5_mma %a, %b, %d, %true, %true, %bar[%true] {is_async} : !ttg.memdesc<128x128xf16, #tmem_a, #ttng.tensor_memory, mutable>, !ttg.memdesc<128x128xf16, #shared, #smem, mutable>, !ttg.memdesc<128x128xf32, #tmem_d, #ttng.tensor_memory, mutable>, !ttg.memdesc<1xi64, #shared1, #smem, mutable> + tt.return + } +} + +// ----- + #shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}> #shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}> #smem = #ttg.shared_memory @@ -135,6 +163,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar // CHECK: ttg.global_scratch_alloc // CHECK: ttg.barrier global_read|global_write // CHECK-NEXT: scf.for + // CHECK: tti.experimental_local_gather + // CHECK: tti.experimental_local_gather // CHECK: tti.dot_i8 {{.*}} aSigned = true, bSigned = true // CHECK: tti.dot_i8 {{.*}} aSigned = false, bSigned = true // CHECK: tti.dot_i8 {{.*}} aSigned = true, bSigned = false @@ -261,22 +291,19 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar #shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CGALayout = [[1]]}> #smem = #ttg.shared_memory #tmem = #ttng.tensor_memory_encoding -// CHECK: #[[$SHARED_A_SCRATCH:[A-Za-z0-9_]+]] = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0], CGALayout = {{\[\[1, 0\]\]}}}> // CHECK: module attributes {{.*}}"ttng.two-ctas" = true module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 65544 : i32, ttg.tensor_memory_size = 0 : i32, "ttg.total-num-warps" = 1 : i32} { // CHECK-LABEL: @tcgen05_mma_two_ctas tt.func public @tcgen05_mma_two_ctas() { // CHECK: ttg.global_scratch_alloc {{.*}}shared_cluster_state // CHECK: ttng.cluster_barrier - // CHECK-NEXT: {{.*}} = ttg.local_load {{.*}} -> tensor<256x128xf16, #[[$SHARED_A_SCRATCH]]> - // CHECK: tt.store - // CHECK: ttg.barrier global_read|global_write - // CHECK-NEXT: ttng.cluster_barrier - // CHECK: scf.for + // CHECK-NEXT: scf.for + // CHECK: tti.experimental_local_gather + // CHECK: tti.experimental_local_gather // CHECK: tt.store // CHECK: ttg.barrier global_read|global_write // CHECK-NEXT: ttng.cluster_barrier - // CHECK: ttng.arrive_barrier + // CHECK-NEXT: ttng.arrive_barrier // CHECK-NOT: ttng.tc_gen5_mma %true = arith.constant true %a = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<256x128xf16, #shared_a, #smem, mutable> @@ -298,7 +325,6 @@ module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32, ttg.shar #blocked_multibuffer = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1], CGALayout = [[1, 0]]}> #tmem = #ttng.tensor_memory_encoding #tmem_scales = #ttng.tensor_memory_scales_encoding -// CHECK: #[[$TMEM_VIEW_SCRATCH:[A-Za-z0-9_]+]] = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1], CGALayout = {{\[\[1, 0\]\]}}}> module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 65544 : i32, ttg.tensor_memory_size = 0 : i32, "ttg.total-num-warps" = 1 : i32} { tt.func public @enable_two_ctas() { %true = arith.constant true @@ -312,7 +338,7 @@ module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32, ttg.shar // CHECK-LABEL: @tmem_multibuffer_two_ctas tt.func public @tmem_multibuffer_two_ctas(%idx: i32) { // CHECK: ttg.global_scratch_alloc {{.*}}shared_cluster_state - // CHECK: tt.store {{.*}} {ignore_cta} : tensor<256x128x!tt.ptr, #[[$TMEM_VIEW_SCRATCH]]> + // CHECK: tt.store {{.*}} {ignore_cta} : tensor<256x128x!tt.ptr // CHECK-NOT: ttng.tmem_load // CHECK-NOT: ttng.tmem_store // CHECK-NOT: ttg.memdesc_index @@ -361,10 +387,13 @@ module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32, ttg.shar tt.func public @tcgen05_mma_scaled_two_ctas() { // CHECK: ttg.global_scratch_alloc {{.*}}shared_cluster_state // CHECK: ttng.cluster_barrier - // CHECK-NEXT: {{.*}} = ttg.local_load + // CHECK-NEXT: scf.for + // CHECK: tti.experimental_local_gather + // CHECK: tti.experimental_local_gather + // CHECK: tt.store // CHECK: ttg.barrier global_read|global_write // CHECK-NEXT: ttng.cluster_barrier - // CHECK: ttng.arrive_barrier + // CHECK-NEXT: ttng.arrive_barrier // CHECK-NOT: ttng.tc_gen5_mma_scaled %true = arith.constant true %a = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<256x256xi8, #shared_a, #smem, mutable>