From 45e2e4b91f36ffbb1e9c90c013b415f10c2ad39f Mon Sep 17 00:00:00 2001 From: Jeff Niu Date: Wed, 3 Jun 2026 23:54:56 +0000 Subject: [PATCH 01/22] Accelerate FPSan MMA emulation with i8 decomposition --- include/triton/Analysis/Utility.h | 2 + .../IR/TritonInstrumentOps.td | 27 + lib/Analysis/Utility.cpp | 6 + lib/Dialect/TritonGPU/IR/Dialect.cpp | 4 +- .../Transforms/RemoveLayoutConversions.cpp | 6 +- lib/Dialect/TritonInstrument/IR/Ops.cpp | 30 + .../Transforms/FpSanitizer.cpp | 555 ++++++++++++++---- test/Conversion/tritongpu_to_llvm.mlir | 45 ++ test/TritonGPU/fpsan.mlir | 89 +++ test/TritonGPU/invalid.mlir | 40 ++ test/TritonGPU/nvidia-fpsan.mlir | 12 +- .../lib/TritonNVIDIAGPUToLLVM/CMakeLists.txt | 1 + .../lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM.cpp | 35 ++ .../DotOpToLLVM/MMAv2.cpp | 76 ++- 14 files changed, 790 insertions(+), 138 deletions(-) diff --git a/include/triton/Analysis/Utility.h b/include/triton/Analysis/Utility.h index 83acf8ed44bd..aa49a18d6e73 100644 --- a/include/triton/Analysis/Utility.h +++ b/include/triton/Analysis/Utility.h @@ -252,6 +252,8 @@ bool supportWMMA(triton::DotOp op); bool supportMMA(triton::DotOp op, int version); +bool supportMMA(triton::DotOpInterface op, int version); + bool supportMMA(Value value, int version); // Conversion from `srcTy` to `dstTy` involving the minimum amount of data diff --git a/include/triton/Dialect/TritonInstrument/IR/TritonInstrumentOps.td b/include/triton/Dialect/TritonInstrument/IR/TritonInstrumentOps.td index 3bff95749f0d..1a48850a3525 100644 --- a/include/triton/Dialect/TritonInstrument/IR/TritonInstrumentOps.td +++ b/include/triton/Dialect/TritonInstrument/IR/TritonInstrumentOps.td @@ -5,6 +5,7 @@ include "triton/Dialect/TritonInstrument/IR/TritonInstrumentDialect.td" include "triton/Dialect/TritonGPU/IR/TritonGPUTypes.td" include "triton/Dialect/Triton/IR/TritonTypes.td" include "triton/Dialect/Triton/IR/TritonInterfaces.td" +include "triton/Dialect/Triton/IR/TritonOpInterfaces.td" include "triton/Dialect/Triton/IR/TritonAttrDefs.td" include "mlir/IR/OpBase.td" include "mlir/Interfaces/SideEffectInterfaces.td" @@ -210,6 +211,32 @@ def TTI_ExperimentalLockReleaseOp : TTI_Op<"experimental_lock_release", [MemoryE // ===== FPSan ops ===== +def TTI_DotI8Op : TTI_Op<"dot_i8", [ + Pure, + DeclareOpInterfaceMethods, + TypesMatchWith<"result's type matches accumulator's type", + "d", "c", "$_self"> +]> { + let summary = "non-saturating NVIDIA MMAv2 i8 dot"; + let description = [{ + Performs a wrapping i8 matrix multiplication into an i32 accumulator using + NVIDIA MMAv2. The A and B operands have independent signedness. + }]; + let arguments = (ins + RankedTensorOf<[I8]>:$a, + RankedTensorOf<[I8]>:$b, + RankedTensorOf<[I32]>:$c, + BoolAttr:$aSigned, + BoolAttr:$bSigned + ); + let results = (outs RankedTensorOf<[I32]>:$d); + let assemblyFormat = [{ + $a `,` $b `,` $c `,` `aSigned` `=` $aSigned `,` `bSigned` `=` $bSigned + attr-dict `:` type($a) `*` type($b) `->` type($d) + }]; + let hasVerifier = 1; +} + def TTI_ExperimentalFPSanEmbedOp : TTI_Op<"experimental_fpsan_embed", [ Pure, Elementwise, diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index dd8a914703c0..b2381ca74f59 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -1242,6 +1242,12 @@ bool supportMMA(triton::DotOp op, int version) { return supportMMA(op.getA(), version) && supportMMA(op.getB(), version); } +bool supportMMA(triton::DotOpInterface op, int version) { + if (auto dotOp = dyn_cast(op.getOperation())) + return supportMMA(dotOp, version); + return supportMMA(op.getA(), version) && supportMMA(op.getB(), version); +} + bool supportMMA(Value value, int version) { // Tell whether a DotOp support MMA by the operand type(either $a or $b). // We cannot get both the operand types(in TypeConverter), here we assume the diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 5df13a344f00..ca91c231e015 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -3174,8 +3174,8 @@ struct TritonGPUInferLayoutInterface dyn_cast_or_null(aEncoding.getParent()); auto mmaBEncoding = dyn_cast_or_null(bEncoding.getParent()); - auto dotOp = cast(op); - auto resEnc = dotOp.getResult().getType().getEncoding(); + auto dotOp = cast(op); + auto resEnc = cast(dotOp.getD().getType()).getEncoding(); auto mmaResEncoding = dyn_cast(resEnc); if (mmaAEncoding || mmaBEncoding || mmaResEncoding) { // Check that they are all set and have the same version. diff --git a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp index 9e0e6677e0bd..9dcc65de8c87 100644 --- a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp @@ -202,8 +202,8 @@ bool isLayoutAnchor(Operation *op) { return true; if (isa(op)) return isExpensiveLoadOrStore(op); - if (isa(op)) + if (isa(op)) return true; if (auto gatherOp = dyn_cast(op)) return gatherOp.getEfficientLayout(); @@ -653,7 +653,7 @@ void LayoutPropagation::rewriteOp(Operation *op) { bool canBeRemat(Operation *op) { if (isa(op)) return !isExpensiveLoadOrStore(op); - if (isa(op)) + if (isa(op)) return false; if (auto gather = dyn_cast(op)) return !gather.getEfficientLayout(); diff --git a/lib/Dialect/TritonInstrument/IR/Ops.cpp b/lib/Dialect/TritonInstrument/IR/Ops.cpp index 793bcd784307..5526f30a8624 100644 --- a/lib/Dialect/TritonInstrument/IR/Ops.cpp +++ b/lib/Dialect/TritonInstrument/IR/Ops.cpp @@ -16,6 +16,36 @@ namespace instrument { namespace tt = mlir::triton; namespace ttg = mlir::triton::gpu; +bool DotI8Op::verifyDims() { + auto aShape = getA().getType().getShape(); + auto bShape = getB().getType().getShape(); + return aShape.back() == bShape[bShape.size() - 2]; +} + +LogicalResult DotI8Op::verify() { + auto aEnc = + dyn_cast(getA().getType().getEncoding()); + auto bEnc = + dyn_cast(getB().getType().getEncoding()); + if (!aEnc || !bEnc) + return emitError("requires dot operand encodings for A and B"); + + auto aMma = dyn_cast(aEnc.getParent()); + auto bMma = dyn_cast(bEnc.getParent()); + auto dMma = + dyn_cast(getD().getType().getEncoding()); + if (!aMma || !bMma || !dMma || aMma.getVersionMajor() != 2 || + bMma.getVersionMajor() != 2 || dMma.getVersionMajor() != 2) + return emitError("requires NVIDIA MMAv2 operand and result layouts"); + if (aMma != bMma || aMma != dMma) + return emitError("requires matching NVIDIA MMAv2 layouts"); + + auto layoutInterface = + cast(&dMma.getDialect()); + return layoutInterface->verifyDotOpEncodingCompatibility(getOperation(), aEnc, + bEnc); +} + template struct PushFPSanThroughViewPattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; diff --git a/lib/Dialect/TritonInstrument/Transforms/FpSanitizer.cpp b/lib/Dialect/TritonInstrument/Transforms/FpSanitizer.cpp index 55afa949216e..092788530ee2 100644 --- a/lib/Dialect/TritonInstrument/Transforms/FpSanitizer.cpp +++ b/lib/Dialect/TritonInstrument/Transforms/FpSanitizer.cpp @@ -53,6 +53,56 @@ static bool isValueAvailableInScope(Value value, Region *scope) { constexpr int64_t kTileM = 8; constexpr int64_t kTileN = 8; +constexpr int64_t kI8MmaK = 32; + +bool supportsI8DotDecomposition(PatternRewriter &rewriter, + IntegerType accElem) { + auto module = + rewriter.getInsertionBlock()->getParentOp()->getParentOfType(); + if (getAMDArch(module)) + return false; + return llvm::is_contained({16, 32, 64}, accElem.getWidth()); +} + +std::optional> getI8MmaWarpsPerCTA(int64_t m, int64_t n, + int numWarps) { + if (m < 16 || n < 8 || (m % 16) != 0 || (n % 8) != 0 || numWarps <= 0) + return std::nullopt; + + SmallVector warpsPerCTA{1, 1}; + SmallVector reps{m / 16, n / 8}; + while (warpsPerCTA[0] * warpsPerCTA[1] < numWarps) { + unsigned axis = reps[0] >= reps[1] ? 0 : 1; + if (reps[axis] <= 1 || (reps[axis] % 2) != 0) + axis = 1 - axis; + if (reps[axis] <= 1 || (reps[axis] % 2) != 0) + return std::nullopt; + warpsPerCTA[axis] *= 2; + reps[axis] /= 2; + } + if (warpsPerCTA[0] * warpsPerCTA[1] != numWarps) + return std::nullopt; + return warpsPerCTA; +} + +std::pair getMmaEmulationTileShape(PatternRewriter &rewriter, + int64_t m, int64_t n, + int64_t k, + IntegerType accElem) { + if (supportsI8DotDecomposition(rewriter, accElem) && (k % kI8MmaK) == 0) { + int64_t numWarps = + ttg::lookupNumWarps(rewriter.getInsertionBlock()->getParent()); + int64_t tileM = std::min(16 * numWarps, m); + while (tileM > 0 && (m % tileM) != 0) + tileM /= 2; + int64_t tileN = std::min(8 * numWarps, n); + while (tileN > 0 && (n % tileN) != 0) + tileN /= 2; + if (getI8MmaWarpsPerCTA(tileM, tileN, numWarps)) + return {tileM, tileN}; + } + return {std::min(kTileM, m), std::min(kTileN, n)}; +} Operation *createGlobalScratchBarrier(PatternRewriter &rewriter, Location loc, bool sharedClusterState = false) { @@ -1206,6 +1256,258 @@ Value emulateDotStep(PatternRewriter &rewriter, Location loc, Value aSlice, return arith::MulIOp::create(rewriter, loc, aFull, bFull); } +Value unpackPackedFp4Tensor(PatternRewriter &rewriter, Location loc, + Value packed, int64_t axis, + RankedTensorType logicalTy) { + Value packedI = embedToInt(rewriter, loc, packed); + auto packedTy = cast(packedI.getType()); + auto packedI8Ty = packedTy.clone(rewriter.getI8Type()); + packedI = castSignedIntValueToType(rewriter, loc, packedI, packedI8Ty); + + Value mask = getIntConstantLike(rewriter, loc, packedI8Ty, 0x0F); + Value four = getIntConstantLike(rewriter, loc, packedI8Ty, 4); + Value lo = arith::AndIOp::create(rewriter, loc, packedI, mask); + Value hi = arith::ShRUIOp::create(rewriter, loc, packedI, four); + auto halfTy = packedTy.clone(logicalTy.getElementType()); + lo = castSignedIntValueToType(rewriter, loc, lo, halfTy); + hi = castSignedIntValueToType(rewriter, loc, hi, halfTy); + Value joined = tt::JoinOp::create(rewriter, loc, lo, hi); + + int64_t rank = packedTy.getRank(); + auto order = llvm::to_vector(llvm::seq(axis + 1)); + order.push_back(rank); + llvm::append_range(order, llvm::seq(axis + 1, rank)); + Value transposed = tt::TransOp::create(rewriter, loc, joined, order); + + Value logical = + tt::ReshapeOp::create(rewriter, loc, logicalTy.getShape(), transposed); + if (logical.getType() != logicalTy) + logical = ttg::ConvertLayoutOp::create(rewriter, loc, logicalTy, logical); + return logical; +} + +Value loadOperandK32(PatternRewriter &rewriter, Location loc, bool isLhs, + Value tilePtr, RankedTensorType fullTileTy, Value kI32, + int64_t rowStride, int64_t stride, + int64_t packFactor = 1) { + SmallVector logicalShape{isLhs ? fullTileTy.getShape()[0] : kI8MmaK, + isLhs ? kI8MmaK : fullTileTy.getShape()[1]}; + SmallVector rawShape = logicalShape; + rawShape[isLhs ? 1 : 0] /= packFactor; + auto rawLayout = getOptimizedBlockedEncoding(rewriter, rawShape, + fullTileTy.getElementType()); + auto rawTy = + RankedTensorType::get(rawShape, fullTileTy.getElementType(), rawLayout); + + Value packedKIdx = kI32; + if (packFactor != 1) { + Value factor = arith::ConstantOp::create( + rewriter, loc, rewriter.getI32IntegerAttr(packFactor)); + packedKIdx = arith::DivUIOp::create(rewriter, loc, kI32, factor); + } + Value kStride = arith::ConstantOp::create( + rewriter, loc, rewriter.getI32IntegerAttr(isLhs ? stride : rowStride)); + Value offset = arith::MulIOp::create(rewriter, loc, packedKIdx, kStride); + Value chunkPtr = + tt::AddPtrOp::create(rewriter, loc, tilePtr.getType(), tilePtr, offset); + Value chunk = + loadScratchStrided2D(rewriter, loc, chunkPtr, rawTy, rowStride, stride); + + if (packFactor == 2) { + auto logicalLayout = getOptimizedBlockedEncoding(rewriter, logicalShape, + rewriter.getI8Type()); + auto logicalTy = RankedTensorType::get(logicalShape, rewriter.getI8Type(), + logicalLayout); + chunk = unpackPackedFp4Tensor(rewriter, loc, chunk, + /*axis=*/isLhs ? 1 : 0, logicalTy); + } + return chunk; +} + +Value loadScaledScaleK32(PatternRewriter &rewriter, Location loc, bool isLhs, + const DotScaleConfig &scale, Value tileIdx, Value kI32, + RankedTensorType targetTy) { + Value ptr = isLhs ? scale.aScalePtr : scale.bScalePtr; + if (!ptr) + return Value(); + + int64_t scaleFactor = isLhs ? scale.aScaleFactor : scale.bScaleFactor; + int64_t scaleStride = isLhs ? scale.aScaleStride : scale.bScaleStride; + auto scaleTileTy = isLhs ? scale.aScaleTileTy : scale.bScaleTileTy; + int64_t groups = scaleFactor < kI8MmaK ? kI8MmaK / scaleFactor : 1; + int64_t repeat = kI8MmaK / groups; + SmallVector compactShape = + isLhs ? SmallVector{targetTy.getShape()[0], groups} + : SmallVector{groups, targetTy.getShape()[1]}; + SmallVector broadcastShape = + isLhs ? SmallVector{targetTy.getShape()[0], groups, repeat} + : SmallVector{groups, repeat, targetTy.getShape()[1]}; + int64_t expandAxis = isLhs ? 2 : 1; + + auto broadcastLayout = getOptimizedBlockedEncoding( + rewriter, broadcastShape, scaleTileTy.getElementType()); + auto compactSliceLayout = ttg::SliceEncodingAttr::get( + rewriter.getContext(), expandAxis, broadcastLayout); + auto compactLoadLayout = getOptimizedBlockedEncoding( + rewriter, compactShape, scaleTileTy.getElementType()); + auto compactLoadTy = RankedTensorType::get( + compactShape, scaleTileTy.getElementType(), compactLoadLayout); + + Value tilePtr = + tt::AddPtrOp::create(rewriter, loc, ptr.getType(), ptr, tileIdx); + Value factor = arith::ConstantOp::create( + rewriter, loc, rewriter.getI32IntegerAttr(scaleFactor)); + Value groupIdx = arith::DivUIOp::create(rewriter, loc, kI32, factor); + Value stride = arith::ConstantOp::create( + rewriter, loc, rewriter.getI32IntegerAttr(scaleStride)); + Value groupOffset = arith::MulIOp::create(rewriter, loc, groupIdx, stride); + Value groupPtr = + tt::AddPtrOp::create(rewriter, loc, ptr.getType(), tilePtr, groupOffset); + Value compact = loadScratchStrided2D(rewriter, loc, groupPtr, compactLoadTy, + /*stride0=*/isLhs ? 1 : scaleStride, + /*stride1=*/isLhs ? scaleStride : 1); + compact = castDotScaledScaleToComputePayload(rewriter, loc, compact, + scale.computeElem); + auto compactSliceTy = cast(compact.getType()) + .cloneWithEncoding(compactSliceLayout); + compact = + ttg::ConvertLayoutOp::create(rewriter, loc, compactSliceTy, compact); + + Value expanded = tt::ExpandDimsOp::create(rewriter, loc, compact, expandAxis); + auto broadcastTy = + cast(expanded.getType()).clone(broadcastShape); + Value broadcast = + tt::BroadcastOp::create(rewriter, loc, broadcastTy, expanded); + Value reshaped = + tt::ReshapeOp::create(rewriter, loc, targetTy.getShape(), broadcast); + if (reshaped.getType() != targetTy) + reshaped = ttg::ConvertLayoutOp::create(rewriter, loc, targetTy, reshaped); + return reshaped; +} + +Value loadScaledOperandK32(PatternRewriter &rewriter, Location loc, bool isLhs, + Value tilePtr, RankedTensorType fullTileTy, + const DotScaleConfig &scale, Value tileIdx, + Value kI32, int64_t rowStride, int64_t stride) { + int64_t packFactor = isLhs ? scale.aKPackFactor : scale.bKPackFactor; + tt::ScaleDotElemType elemType = isLhs ? scale.aElemType : scale.bElemType; + Value chunk = loadOperandK32(rewriter, loc, isLhs, tilePtr, fullTileTy, kI32, + rowStride, stride, packFactor); + + Value payload = castDotScaledOperandToComputePayload( + rewriter, loc, chunk, elemType, scale.computeElem); + auto payloadTy = cast(payload.getType()); + Value scalePayload = + loadScaledScaleK32(rewriter, loc, isLhs, scale, tileIdx, kI32, payloadTy); + if (scalePayload) + payload = arith::MulIOp::create(rewriter, loc, payload, scalePayload); + return payload; +} + +Value tryEmitI8DotDecomposition(PatternRewriter &rewriter, Location loc, + Value aPayload, Value bPayload, + Attribute accLayout, IntegerType accElem, + int numWarps) { + auto aPayloadTy = cast(aPayload.getType()); + auto bPayloadTy = cast(bPayload.getType()); + auto aShape = aPayloadTy.getShape(); + auto bShape = bPayloadTy.getShape(); + int64_t m = aShape[0]; + int64_t k = aShape[1]; + int64_t n = bShape[1]; + if (bShape[0] != k || (k % kI8MmaK) != 0 || + !supportsI8DotDecomposition(rewriter, accElem)) + return Value(); + auto warpsPerCTA = getI8MmaWarpsPerCTA(m, n, numWarps); + if (!warpsPerCTA) + return Value(); + + auto aElem = cast(aPayloadTy.getElementType()); + auto bElem = cast(bPayloadTy.getElementType()); + assert((aElem.getWidth() % 8) == 0 && (bElem.getWidth() % 8) == 0); + int64_t aLimbs = aElem.getWidth() / 8; + int64_t bLimbs = bElem.getWidth() / 8; + int64_t accLimbs = accElem.getWidth() / 8; + int64_t highestDiagonal = std::min(accLimbs - 1, aLimbs + bLimbs - 2); + + auto *ctx = rewriter.getContext(); + auto i8Ty = rewriter.getI8Type(); + auto i32Ty = rewriter.getI32Type(); + auto mmaLayout = ttg::NvidiaMmaEncodingAttr::get( + ctx, /*versionMajor=*/2, /*versionMinor=*/0, *warpsPerCTA, + ttg::getCGALayout(accLayout), SmallVector{16, 8}); + auto aDotLayout = ttg::DotOperandEncodingAttr::get(ctx, 0, mmaLayout, i8Ty); + auto bDotLayout = ttg::DotOperandEncodingAttr::get(ctx, 1, mmaLayout, i8Ty); + auto accMmaTy = RankedTensorType::get({m, n}, i32Ty, mmaLayout); + + auto extractLimb = [&](Value payload, ttg::DotOperandEncodingAttr layout, + int64_t limb) { + auto payloadTy = cast(payload.getType()); + auto blockedLimbTy = payloadTy.clone(i8Ty); + auto dotLimbTy = blockedLimbTy.cloneWithEncoding(layout); + Value shifted = payload; + if (limb != 0) { + Value shift = getIntConstantLike(rewriter, loc, payloadTy, 8 * limb); + shifted = arith::ShRUIOp::create(rewriter, loc, shifted, shift); + } + Value truncated = + arith::TruncIOp::create(rewriter, loc, blockedLimbTy, shifted); + return ttg::ConvertLayoutOp::create(rewriter, loc, dotLimbTy, truncated); + }; + + auto emitByteDiagonal = [&](Value sum, int64_t diagonal) { + int64_t firstALimb = std::max(0, diagonal - bLimbs + 1); + int64_t lastALimb = std::min(diagonal, aLimbs - 1); + for (int64_t aLimb = firstALimb; aLimb <= lastALimb; ++aLimb) { + int64_t bLimb = diagonal - aLimb; + Value a = extractLimb(aPayload, aDotLayout, aLimb); + Value b = extractLimb(bPayload, bDotLayout, bLimb); + auto dot = DotI8Op::create(rewriter, loc, accMmaTy, a, b, sum, + aLimb == aLimbs - 1, bLimb == bLimbs - 1); + sum = dot.getResult(); + } + return sum; + }; + + if (accElem.getWidth() == 64) { + // Complete each K32 byte diagonal in i32 before widening it. The largest + // diagonal is 8 * kI8MmaK * 255^2 < 2^24, so its i32 result is exact. + auto accTy = RankedTensorType::get({m, n}, accElem, accLayout); + Value product = getIntConstantLike(rewriter, loc, accTy, 0); + for (int64_t diagonal = highestDiagonal; diagonal >= 0; --diagonal) { + if (diagonal != highestDiagonal) { + Value shift = getIntConstantLike(rewriter, loc, accTy, 8); + product = arith::ShLIOp::create(rewriter, loc, product, shift); + } + + Value diagonalSum = getIntConstantLike(rewriter, loc, accMmaTy, 0); + diagonalSum = emitByteDiagonal(diagonalSum, diagonal); + auto diagonalTy = RankedTensorType::get({m, n}, i32Ty, accLayout); + if (diagonalSum.getType() != diagonalTy) { + diagonalSum = ttg::ConvertLayoutOp::create(rewriter, loc, diagonalTy, + diagonalSum); + } + Value diagonalWide = + arith::ExtSIOp::create(rewriter, loc, accTy, diagonalSum); + product = arith::AddIOp::create(rewriter, loc, product, diagonalWide); + } + return product; + } + + Value product = getIntConstantLike(rewriter, loc, accMmaTy, 0); + for (int64_t diagonal = highestDiagonal; diagonal >= 0; --diagonal) { + if (diagonal != highestDiagonal) { + Value shift = getIntConstantLike(rewriter, loc, accMmaTy, 8); + product = arith::ShLIOp::create(rewriter, loc, product, shift); + } + product = emitByteDiagonal(product, diagonal); + } + auto accTy = RankedTensorType::get({m, n}, i32Ty, accLayout); + if (product.getType() == accTy) + return product; + return ttg::ConvertLayoutOp::create(rewriter, loc, accTy, product); +} + std::optional emitMmaEmulationLoops( PatternRewriter &rewriter, Location loc, Value aPtr, Value bPtr, Value dPtr, int64_t m, int64_t n, int64_t k, int64_t tileM, int64_t tileN, @@ -1268,67 +1570,133 @@ std::optional emitMmaEmulationLoops( Value bTilePtr = tt::AddPtrOp::create(rewriter, loc, bPtr.getType(), bPtr, bOffset); - auto aSliceTy = - RankedTensorType::get({tileM, 1}, aTileTy.getElementType(), accLayout); - auto bSliceTy = - RankedTensorType::get({1, tileN}, bTileTy.getElementType(), accLayout); - Value aStrideVal = arith::ConstantOp::create( - rewriter, loc, rewriter.getI32IntegerAttr(aStride)); - - Value zeroSum = getIntConstantLike(rewriter, loc, accTileI.getType(), 0); - Value kUpper = - arith::ConstantOp::create(rewriter, loc, rewriter.getI32IntegerAttr(k)); - Value kStep = - arith::ConstantOp::create(rewriter, loc, rewriter.getI32IntegerAttr(1)); - auto kLoop = scf::ForOp::create(rewriter, loc, zero, kUpper, kStep, zeroSum); - rewriter.setInsertionPointToStart(kLoop.getBody()); - Value kIdx = kLoop.getInductionVar(); - Value kI32 = arith::IndexCastOp::create(rewriter, loc, i32Ty, kIdx); - Value aKIdx = kI32; - Value bKIdx = kI32; - if (scale.aKPackFactor == 2) { - Value aPackFactor = arith::ConstantOp::create( - rewriter, loc, rewriter.getI32IntegerAttr(scale.aKPackFactor)); - aKIdx = arith::DivUIOp::create(rewriter, loc, kI32, aPackFactor); - } - if (scale.bKPackFactor == 2) { - Value bPackFactor = arith::ConstantOp::create( - rewriter, loc, rewriter.getI32IntegerAttr(scale.bKPackFactor)); - bKIdx = arith::DivUIOp::create(rewriter, loc, kI32, bPackFactor); + Value sum; + int numWarps = ttg::lookupNumWarps(rewriter.getInsertionBlock()->getParent()); + auto isScaleK32Aligned = [](Value scalePtr, int64_t scaleFactor) { + return !scalePtr || (scaleFactor > 0 && ((kI8MmaK % scaleFactor) == 0 || + (scaleFactor % kI8MmaK) == 0)); + }; + bool canUseI8Decomposition = + (k % kI8MmaK) == 0 && supportsI8DotDecomposition(rewriter, accElem) && + isScaleK32Aligned(scale.aScalePtr, scale.aScaleFactor) && + isScaleK32Aligned(scale.bScalePtr, scale.bScaleFactor) && + getI8MmaWarpsPerCTA(tileM, tileN, numWarps).has_value(); + if (canUseI8Decomposition) { + if (!scale.computeElem && accElem.getWidth() <= 32) { + Value aTile = loadScratchStrided2D(rewriter, loc, aTilePtr, aTileTy, + aRowStride, aStride); + Value bTile = loadScratchStrided2D(rewriter, loc, bTilePtr, bTileTy, + bRowStride, bStride); + sum = tryEmitI8DotDecomposition( + rewriter, loc, embedToInt(rewriter, loc, aTile), + embedToInt(rewriter, loc, bTile), accLayout, accElem, numWarps); + assert(sum && "i8 decomposition eligibility must match its emitter"); + sum = castSignedIntValueToType(rewriter, loc, sum, accTileI.getType()); + } else { + Value zeroSum = getIntConstantLike(rewriter, loc, accTileI.getType(), 0); + Value kUpper = arith::ConstantOp::create(rewriter, loc, + rewriter.getI32IntegerAttr(k)); + Value kStep = arith::ConstantOp::create( + rewriter, loc, rewriter.getI32IntegerAttr(kI8MmaK)); + auto kLoop = + scf::ForOp::create(rewriter, loc, zero, kUpper, kStep, zeroSum); + rewriter.setInsertionPointToStart(kLoop.getBody()); + Value kIdx = kLoop.getInductionVar(); + Value kI32 = arith::IndexCastOp::create(rewriter, loc, i32Ty, kIdx); + Value aChunk; + Value bChunk; + if (scale.computeElem) { + aChunk = loadScaledOperandK32(rewriter, loc, /*isLhs=*/true, aTilePtr, + aTileTy, scale, mIdxI32, kI32, aRowStride, + aStride); + bChunk = loadScaledOperandK32(rewriter, loc, /*isLhs=*/false, bTilePtr, + bTileTy, scale, nIdxI32, kI32, bRowStride, + bStride); + } else { + aChunk = loadOperandK32(rewriter, loc, /*isLhs=*/true, aTilePtr, + aTileTy, kI32, aRowStride, aStride); + bChunk = loadOperandK32(rewriter, loc, /*isLhs=*/false, bTilePtr, + bTileTy, kI32, bRowStride, bStride); + } + Value partial = tryEmitI8DotDecomposition( + rewriter, loc, embedToInt(rewriter, loc, aChunk), + embedToInt(rewriter, loc, bChunk), accLayout, accElem, numWarps); + assert(partial && "K32 decomposition must be eligible"); + partial = + castSignedIntValueToType(rewriter, loc, partial, accTileI.getType()); + Value next = arith::AddIOp::create(rewriter, loc, + kLoop.getRegionIterArgs()[0], partial); + scf::YieldOp::create(rewriter, loc, next); + rewriter.setInsertionPointAfter(kLoop); + sum = kLoop.getResult(0); + } } - Value aOffset = - arith::MulIOp::create(rewriter, loc, i32Ty, aKIdx, aStrideVal); - Value aSlicePtr = - tt::AddPtrOp::create(rewriter, loc, aPtr.getType(), aTilePtr, aOffset); - Value aSlice = loadScratchStrided2D(rewriter, loc, aSlicePtr, aSliceTy, - aRowStride, aStride); - Value bKOffset = arith::MulIOp::create(rewriter, loc, bKIdx, bRowStrideConst); - Value bSlicePtr = - tt::AddPtrOp::create(rewriter, loc, bPtr.getType(), bTilePtr, bKOffset); - Value bSlice = loadScratchStrided2D(rewriter, loc, bSlicePtr, bSliceTy, - bRowStride, bStride); - Value aScaleSlice; - if (scale.aScalePtr) { + + if (!sum) { + auto aSliceTy = + RankedTensorType::get({tileM, 1}, aTileTy.getElementType(), accLayout); + auto bSliceTy = + RankedTensorType::get({1, tileN}, bTileTy.getElementType(), accLayout); + Value aStrideVal = arith::ConstantOp::create( + rewriter, loc, rewriter.getI32IntegerAttr(aStride)); + + Value zeroSum = getIntConstantLike(rewriter, loc, accTileI.getType(), 0); + Value kUpper = + arith::ConstantOp::create(rewriter, loc, rewriter.getI32IntegerAttr(k)); + Value kStep = + arith::ConstantOp::create(rewriter, loc, rewriter.getI32IntegerAttr(1)); + auto kLoop = + scf::ForOp::create(rewriter, loc, zero, kUpper, kStep, zeroSum); + rewriter.setInsertionPointToStart(kLoop.getBody()); + Value kIdx = kLoop.getInductionVar(); + Value kI32 = arith::IndexCastOp::create(rewriter, loc, i32Ty, kIdx); + Value aKIdx = kI32; + Value bKIdx = kI32; + if (scale.aKPackFactor == 2) { + Value aPackFactor = arith::ConstantOp::create( + rewriter, loc, rewriter.getI32IntegerAttr(scale.aKPackFactor)); + aKIdx = arith::DivUIOp::create(rewriter, loc, kI32, aPackFactor); + } + if (scale.bKPackFactor == 2) { + Value bPackFactor = arith::ConstantOp::create( + rewriter, loc, rewriter.getI32IntegerAttr(scale.bKPackFactor)); + bKIdx = arith::DivUIOp::create(rewriter, loc, kI32, bPackFactor); + } + Value aOffset = + arith::MulIOp::create(rewriter, loc, i32Ty, aKIdx, aStrideVal); + Value aSlicePtr = + tt::AddPtrOp::create(rewriter, loc, aPtr.getType(), aTilePtr, aOffset); + Value aSlice = loadScratchStrided2D(rewriter, loc, aSlicePtr, aSliceTy, + aRowStride, aStride); + Value bKOffset = + arith::MulIOp::create(rewriter, loc, bKIdx, bRowStrideConst); + Value bSlicePtr = + tt::AddPtrOp::create(rewriter, loc, bPtr.getType(), bTilePtr, bKOffset); + Value bSlice = loadScratchStrided2D(rewriter, loc, bSlicePtr, bSliceTy, + bRowStride, bStride); if (scale.aKPackFactor == 2) aSlice = unpackPackedFp4Slice(rewriter, loc, aSlice, kI32); - aScaleSlice = - loadScaleSlice(rewriter, loc, /*isLhs=*/true, scale, mIdxI32, kI32); - } - Value bScaleSlice; - if (scale.bScalePtr) { if (scale.bKPackFactor == 2) bSlice = unpackPackedFp4Slice(rewriter, loc, bSlice, kI32); - bScaleSlice = - loadScaleSlice(rewriter, loc, /*isLhs=*/false, scale, nIdxI32, kI32); + Value aScaleSlice; + if (scale.aScalePtr) { + aScaleSlice = + loadScaleSlice(rewriter, loc, /*isLhs=*/true, scale, mIdxI32, kI32); + } + Value bScaleSlice; + if (scale.bScalePtr) { + bScaleSlice = + loadScaleSlice(rewriter, loc, /*isLhs=*/false, scale, nIdxI32, kI32); + } + Value partial = + emulateDotStep(rewriter, loc, aSlice, bSlice, aScaleSlice, bScaleSlice, + tileM, tileN, accLayout, accElem, scale); + Value acc = kLoop.getRegionIterArgs()[0]; + Value next = arith::AddIOp::create(rewriter, loc, acc, partial); + scf::YieldOp::create(rewriter, loc, next); + rewriter.setInsertionPointAfter(kLoop); + sum = kLoop.getResult(0); } - Value partial = - emulateDotStep(rewriter, loc, aSlice, bSlice, aScaleSlice, bScaleSlice, - tileM, tileN, accLayout, accElem, scale); - Value acc = kLoop.getRegionIterArgs()[0]; - Value next = arith::AddIOp::create(rewriter, loc, acc, partial); - scf::YieldOp::create(rewriter, loc, next); - rewriter.setInsertionPointAfter(kLoop); - Value sum = kLoop.getResult(0); Value useDMask = tt::SplatOp::create(rewriter, loc, accTileI.getType(), useDInt); @@ -1342,7 +1710,9 @@ std::optional emitMmaEmulationLoops( Value outMasked = arith::MulIOp::create(rewriter, loc, outI, predMask); Value accMasked = arith::MulIOp::create(rewriter, loc, accTileI, predInv); Value outSelI = arith::AddIOp::create(rewriter, loc, outMasked, accMasked); - Value out = unembedToFloat(rewriter, loc, outSelI, accTileTy); + Value out = isFloatLike(accTileTy) + ? unembedToFloat(rewriter, loc, outSelI, accTileTy) + : outSelI; createGlobalScratchBarrier(rewriter, loc); storeScratchStrided2D(rewriter, loc, dTilePtr, out, accTileTy, dRowStride, dStride); @@ -1565,30 +1935,10 @@ struct Fp4ToFpPattern : public OpRewritePattern { if (!srcElemTy || srcElemTy.getWidth() != 8) return emitFpSanInvariantError(op.getOperation()); - int64_t axis = op.getAxis(); - int64_t rank = srcTy.getRank(); auto dstIntTy = cast(getIntTypeLike(dstTy)); - auto halfIntTy = srcTy.clone(dstIntTy.getElementType()); - auto loc = op.getLoc(); - auto mask = getIntConstantLike(rewriter, loc, srcTy, 0x0F); - auto four = getIntConstantLike(rewriter, loc, srcTy, 4); - Value lo = arith::AndIOp::create(rewriter, loc, op.getSrc(), mask); - Value hi = arith::ShRUIOp::create(rewriter, loc, op.getSrc(), four); - auto loI = castSignedIntValueToType(rewriter, loc, lo, halfIntTy); - auto hiI = castSignedIntValueToType(rewriter, loc, hi, halfIntTy); - Value joined = tt::JoinOp::create(rewriter, loc, loI, hiI); - - auto order = llvm::to_vector(llvm::seq(axis + 1)); - order.push_back(rank); - llvm::append_range(order, llvm::seq(axis + 1, rank)); - auto transposed = tt::TransOp::create(rewriter, loc, joined, order); - - Value result = - tt::ReshapeOp::create(rewriter, loc, dstTy.getShape(), transposed); - if (result.getType() != dstIntTy) - result = ttg::ConvertLayoutOp::create(rewriter, loc, dstIntTy, result); - + Value result = unpackPackedFp4Tensor(rewriter, loc, op.getSrc(), + op.getAxis(), dstIntTy); rewriter.replaceOp(op, unembedToFloat(rewriter, loc, result, dstTy)); return success(); } @@ -1668,13 +2018,11 @@ struct DotPattern : public OpRewritePattern { Value predInt = arith::ConstantOp::create( rewriter, loc, rewriter.getIntegerAttr(accElem, 1)); - int64_t tileM = std::min(kTileM, m); - int64_t tileN = std::min(kTileN, n); + auto [tileM, tileN] = getMmaEmulationTileShape(rewriter, m, n, k, accElem); // Use optimized blocked layouts for emulation tiles instead of the // original dot encodings. Encodings like AMDWmmaEncodingAttr impose - // minimum shape requirements (e.g. >= 16x16) that the small emulation - // tiles (kTileM x kTileN = 8x8) cannot satisfy. + // minimum shape requirements that FMA fallback tiles cannot satisfy. auto accLayout = getOptimizedBlockedEncoding(rewriter, {tileM, tileN}, cTy.getElementType()); auto aLayout = @@ -1823,8 +2171,7 @@ struct DotScaledPattern : public OpRewritePattern { Value predInt = arith::ConstantOp::create( rewriter, loc, rewriter.getIntegerAttr(accElem, 1)); - int64_t tileM = std::min(kTileM, m); - int64_t tileN = std::min(kTileN, n); + auto [tileM, tileN] = getMmaEmulationTileShape(rewriter, m, n, k, accElem); auto accLayout = getOptimizedBlockedEncoding(rewriter, {tileM, tileN}, cTy.getElementType()); @@ -2093,8 +2440,7 @@ struct WarpGroupDotPattern : public OpRewritePattern { if (!aScratch || !bScratch || !dPtr) return emitFpSanCodegenError(op.getOperation()); - int64_t tileM = std::min(kTileM, m); - int64_t tileN = std::min(kTileN, n); + auto [tileM, tileN] = getMmaEmulationTileShape(rewriter, m, n, k, accElem); auto accTileLayout = getOptimizedBlockedEncoding(rewriter, {tileM, tileN}, cTy.getElementType()); @@ -2189,27 +2535,29 @@ struct TCGen5MMAPattern : public OpRewritePattern { if (!bScratch) return emitFpSanCodegenError(op.getOperation()); - int64_t tileM = std::min(kTileM, m); - int64_t tileN = std::min(kTileN, n); - - auto accTileLayout = getOptimizedBlockedEncoding(rewriter, {tileM, tileN}, - dMemTy.getElementType()); - auto accTileTy = RankedTensorType::get( - {tileM, tileN}, dMemTy.getElementType(), accTileLayout); - auto aTileLayout = getOptimizedBlockedEncoding(rewriter, {tileM, k}, - aMemTy.getElementType()); - auto aTileTy = - RankedTensorType::get({tileM, k}, aMemTy.getElementType(), aTileLayout); - auto bTileLayout = getOptimizedBlockedEncoding(rewriter, {k, tileN}, - bMemTy.getElementType()); - auto bTileTy = - RankedTensorType::get({k, tileN}, bMemTy.getElementType(), bTileLayout); - // Each warp may only populate a subset of the operand scratch tiles, so // synchronize before the emulation loops start reading them. createGlobalScratchBarrier(rewriter, loc, scratch->usesSharedClusterState()); + auto [tileM, tileN] = getMmaEmulationTileShape(rewriter, m, n, k, accElem); + auto accTileLayout = + getOptimizedBlockedEncoding(rewriter, {tileM, tileN}, accElem); + auto accTileTy = + RankedTensorType::get({tileM, tileN}, accElem, accTileLayout); + auto aTileLayout = getOptimizedBlockedEncoding( + rewriter, {tileM, k}, + getScratchStorageElementType(aMemTy.getElementType())); + auto aTileTy = RankedTensorType::get( + {tileM, k}, getScratchStorageElementType(aMemTy.getElementType()), + aTileLayout); + auto bTileLayout = getOptimizedBlockedEncoding( + rewriter, {k, tileN}, + getScratchStorageElementType(bMemTy.getElementType())); + auto bTileTy = RankedTensorType::get( + {k, tileN}, getScratchStorageElementType(bMemTy.getElementType()), + bTileLayout); + auto mLoop = emitMmaEmulationLoops( rewriter, loc, aScratch->ptr, bScratch->ptr, dInfo->ptr, m, n, k, tileM, tileN, aTileTy, bTileTy, accTileTy, accTileLayout, accElem, useDInt, @@ -2351,8 +2699,7 @@ struct TCGen5MMAScaledPattern if (!bScaleScratch) return emitFpSanCodegenError(op.getOperation()); - int64_t tileM = std::min(kTileM, m); - int64_t tileN = std::min(kTileN, n); + auto [tileM, tileN] = getMmaEmulationTileShape(rewriter, m, n, k, accElem); auto accTileLayout = getOptimizedBlockedEncoding(rewriter, {tileM, tileN}, dMemTy.getElementType()); diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index 8947c5570cc1..275a1bb60262 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -1491,6 +1491,51 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // ----- +#mma = #ttg.nvidia_mma<{versionMajor=2, warpsPerCTA=[1, 1], instrShape = [16, 8]}> +#dot_operand_a = #ttg.dot_op<{opIdx=0, parent=#mma, kWidth=4}> +#dot_operand_b = #ttg.dot_op<{opIdx=1, parent=#mma, kWidth=4}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { + // CHECK-LABEL: matmul_signed_i8dot + tt.func @matmul_signed_i8dot(%a: tensor<16x32xi8, #dot_operand_a>, + %b: tensor<32x8xi8, #dot_operand_b>, + %c: tensor<16x8xi32, #mma>) { + // CHECK: llvm.inline_asm + // CHECK-SAME: mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32 + %d = tt.dot %a, %b, %c : tensor<16x32xi8, #dot_operand_a> * tensor<32x8xi8, #dot_operand_b> -> tensor<16x8xi32, #mma> + tt.return + } +} + +// ----- + +#mma = #ttg.nvidia_mma<{versionMajor=2, warpsPerCTA=[1, 1], instrShape = [16, 8]}> +#dot_operand_a = #ttg.dot_op<{opIdx=0, parent=#mma, kWidth=4}> +#dot_operand_b = #ttg.dot_op<{opIdx=1, parent=#mma, kWidth=4}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { + // CHECK-LABEL: matmul_mixed_signed_i8dot + tt.func @matmul_mixed_signed_i8dot(%a: tensor<16x32xi8, #dot_operand_a>, + %b: tensor<32x8xi8, #dot_operand_b>, + %c: tensor<16x8xi32, #mma>) { + // CHECK-NOT: satfinite + // CHECK: llvm.inline_asm + // CHECK-SAME: mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 + // CHECK: llvm.inline_asm + // CHECK-SAME: mma.sync.aligned.m16n8k32.row.col.s32.s8.u8.s32 + // CHECK: llvm.inline_asm + // CHECK-SAME: mma.sync.aligned.m16n8k32.row.col.s32.u8.s8.s32 + // CHECK: llvm.inline_asm + // CHECK-SAME: mma.sync.aligned.m16n8k32.row.col.s32.u8.u8.s32 + // CHECK-NOT: satfinite + %d0 = tti.dot_i8 %a, %b, %c, aSigned = true, bSigned = true : tensor<16x32xi8, #dot_operand_a> * tensor<32x8xi8, #dot_operand_b> -> tensor<16x8xi32, #mma> + %d1 = tti.dot_i8 %a, %b, %d0, aSigned = true, bSigned = false : tensor<16x32xi8, #dot_operand_a> * tensor<32x8xi8, #dot_operand_b> -> tensor<16x8xi32, #mma> + %d2 = tti.dot_i8 %a, %b, %d1, aSigned = false, bSigned = true : tensor<16x32xi8, #dot_operand_a> * tensor<32x8xi8, #dot_operand_b> -> tensor<16x8xi32, #mma> + %d3 = tti.dot_i8 %a, %b, %d2, aSigned = false, bSigned = false : tensor<16x32xi8, #dot_operand_a> * tensor<32x8xi8, #dot_operand_b> -> tensor<16x8xi32, #mma> + tt.return + } +} + +// ----- + #blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.target" = "cuda:80"} { // CHECK-LABEL: atomic_add_f32 diff --git a/test/TritonGPU/fpsan.mlir b/test/TritonGPU/fpsan.mlir index 3a0b3ad00245..4029cd8efd7a 100644 --- a/test/TritonGPU/fpsan.mlir +++ b/test/TritonGPU/fpsan.mlir @@ -25,6 +25,68 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr // ----- +#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> +#dot_operand_a = #ttg.dot_op<{opIdx = 0, parent = #blocked}> +#dot_operand_b = #ttg.dot_op<{opIdx = 1, parent = #blocked}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: @dot_i8_decomposition + tt.func public @dot_i8_decomposition() -> tensor<32x32xf32, #blocked> { + // CHECK: scf.for + // CHECK: tti.dot_i8 {{.*}} aSigned = true, bSigned = true + // CHECK: tti.dot_i8 {{.*}} aSigned = false, bSigned = true + // CHECK: tti.dot_i8 {{.*}} aSigned = true, bSigned = false + // CHECK: tti.dot_i8 {{.*}} aSigned = false, bSigned = false + // CHECK-NOT: tti.dot_i8 + %one = arith.constant 1.000000e+00 : f16 + %zero = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #blocked> + %a = tt.splat %one : f16 -> tensor<32x32xf16, #dot_operand_a> + %b = tt.splat %one : f16 -> tensor<32x32xf16, #dot_operand_b> + %out = tt.dot %a, %b, %zero : tensor<32x32xf16, #dot_operand_a> * tensor<32x32xf16, #dot_operand_b> -> tensor<32x32xf32, #blocked> + tt.return %out : tensor<32x32xf32, #blocked> + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> +#dot_operand_a = #ttg.dot_op<{opIdx = 0, parent = #blocked}> +#dot_operand_b = #ttg.dot_op<{opIdx = 1, parent = #blocked}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: @dot_f64_i8_decomposition + tt.func public @dot_f64_i8_decomposition() -> tensor<32x32xf64, #blocked> { + // CHECK-COUNT-36: tti.dot_i8 + // CHECK-NOT: tti.dot_i8 + // CHECK: arith.extsi {{.*}} : tensor<{{.*}}xi32, #{{.*}}> to tensor<{{.*}}xi64, #{{.*}}> + %one = arith.constant 1.000000e+00 : f64 + %zero = arith.constant dense<0.000000e+00> : tensor<32x32xf64, #blocked> + %a = tt.splat %one : f64 -> tensor<32x32xf64, #dot_operand_a> + %b = tt.splat %one : f64 -> tensor<32x32xf64, #dot_operand_b> + %out = tt.dot %a, %b, %zero : tensor<32x32xf64, #dot_operand_a> * tensor<32x32xf64, #dot_operand_b> -> tensor<32x32xf64, #blocked> + tt.return %out : tensor<32x32xf64, #blocked> + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> +#dot_operand_a = #ttg.dot_op<{opIdx = 0, parent = #blocked}> +#dot_operand_b = #ttg.dot_op<{opIdx = 1, parent = #blocked}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: @dot_f64_fma_fallback + tt.func public @dot_f64_fma_fallback() -> tensor<8x8xf64, #blocked> { + // CHECK: scf.for + // CHECK-NOT: tt.dot + %one = arith.constant 1.000000e+00 : f64 + %zero = arith.constant dense<0.000000e+00> : tensor<8x8xf64, #blocked> + %a = tt.splat %one : f64 -> tensor<8x4xf64, #dot_operand_a> + %b = tt.splat %one : f64 -> tensor<4x8xf64, #dot_operand_b> + %out = tt.dot %a, %b, %zero : tensor<8x4xf64, #dot_operand_a> * tensor<4x8xf64, #dot_operand_b> -> tensor<8x8xf64, #blocked> + tt.return %out : tensor<8x8xf64, #blocked> + } +} + +// ----- + #blocked = #ttg.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [1, 32, 1], warpsPerCTA = [1, 4, 1], order = [2, 1, 0]}> #dot_operand_a = #ttg.dot_op<{opIdx = 0, parent = #blocked}> #dot_operand_b = #ttg.dot_op<{opIdx = 1, parent = #blocked}> @@ -56,6 +118,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr // CHECK: scf.for // CHECK: ttg.barrier global_read|global_write // CHECK-NOT: ttg.dot_scaled + // CHECK-NOT: tt.dot // CHECK-NOT: ttg.convert_layout %cst = arith.constant 1.000000e+00 : f16 %zero = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #blocked> @@ -68,6 +131,30 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr // ----- +#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> +#dot_A = #ttg.dot_op<{opIdx = 0, parent = #blocked}> +#dot_B = #ttg.dot_op<{opIdx = 1, parent = #blocked}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: @dot_scaled_i8_decomposition + tt.func public @dot_scaled_i8_decomposition() -> tensor<32x32xf32, #blocked> { + // CHECK: scf.for + // CHECK: tti.dot_i8 {{.*}} aSigned = true, bSigned = true + // CHECK: tti.dot_i8 {{.*}} aSigned = false, bSigned = true + // CHECK: tti.dot_i8 {{.*}} aSigned = true, bSigned = false + // CHECK: tti.dot_i8 {{.*}} aSigned = false, bSigned = false + // CHECK-NOT: tti.dot_i8 + // CHECK-NOT: ttg.dot_scaled + %one = arith.constant 1.000000e+00 : f16 + %zero = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #blocked> + %a = tt.splat %one : f16 -> tensor<32x64xf16, #dot_A> + %b = tt.splat %one : f16 -> tensor<64x32xf16, #dot_B> + %out = tt.dot_scaled %a, %b, %zero lhs = fp16 rhs = fp16 {fastMath = false} : tensor<32x64xf16, #dot_A> * tensor<64x32xf16, #dot_B> -> tensor<32x32xf32, #blocked> + tt.return %out : tensor<32x32xf32, #blocked> + } +} + +// ----- + #mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 32, 16]}> #shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 32}> #smem = #ttg.shared_memory @@ -79,6 +166,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr // CHECK: tt.store // CHECK: ttg.barrier global_read|global_write // CHECK: scf.for + // CHECK-COUNT-10: tti.dot_i8 + // CHECK-NOT: tti.dot_i8 // CHECK: ttg.barrier global_read|global_write // CHECK: %[[RAW:.*]] = tt.load // CHECK: %[[OUT:.*]] = tti.experimental_fpsan_unembed %[[RAW]] diff --git a/test/TritonGPU/invalid.mlir b/test/TritonGPU/invalid.mlir index e4d556ec098a..f04d96aaff1a 100644 --- a/test/TritonGPU/invalid.mlir +++ b/test/TritonGPU/invalid.mlir @@ -326,6 +326,46 @@ module attributes {"ttg.num-warps" = 1 : i32} { // ----- +#mma0 = #ttg.nvidia_mma<{versionMajor=2, warpsPerCTA=[1,1], instrShape = [16, 8]}> +#dot_operand_a = #ttg.dot_op<{opIdx=0, parent=#mma0, kWidth=4}> +#dot_operand_b = #ttg.dot_op<{opIdx=1, parent=#mma0, kWidth=4}> +module attributes {"ttg.num-warps" = 1 : i32} { + tt.func @dot_i8_invalid_operand_type(%A: tensor<16x32xi16, #dot_operand_a>, %B: tensor<32x8xi8, #dot_operand_b>, %C: tensor<16x8xi32, #mma0>) { + // expected-error@+1 {{operand #0 must be ranked tensor of 8-bit signless integer values}} + %D = "tti.dot_i8"(%A, %B, %C) {aSigned = true, bSigned = true} : (tensor<16x32xi16, #dot_operand_a>, tensor<32x8xi8, #dot_operand_b>, tensor<16x8xi32, #mma0>) -> tensor<16x8xi32, #mma0> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 1], order = [1, 0]}> +#dot_operand_a = #ttg.dot_op<{opIdx=0, parent=#blocked}> +#dot_operand_b = #ttg.dot_op<{opIdx=1, parent=#blocked}> +module attributes {"ttg.num-warps" = 1 : i32} { + tt.func @dot_i8_non_mma_layout(%A: tensor<16x32xi8, #dot_operand_a>, %B: tensor<32x8xi8, #dot_operand_b>, %C: tensor<16x8xi32, #blocked>) { + // expected-error@+1 {{requires NVIDIA MMAv2 operand and result layouts}} + %D = tti.dot_i8 %A, %B, %C, aSigned = true, bSigned = true : tensor<16x32xi8, #dot_operand_a> * tensor<32x8xi8, #dot_operand_b> -> tensor<16x8xi32, #blocked> + tt.return + } +} + +// ----- + +#mma0 = #ttg.nvidia_mma<{versionMajor=2, warpsPerCTA=[1,1], instrShape = [16, 8]}> +#mma1 = #ttg.nvidia_mma<{versionMajor=2, warpsPerCTA=[1,1], instrShape = [32, 8]}> +#dot_operand_a = #ttg.dot_op<{opIdx=0, parent=#mma0, kWidth=4}> +#dot_operand_b = #ttg.dot_op<{opIdx=1, parent=#mma1, kWidth=4}> +module attributes {"ttg.num-warps" = 1 : i32} { + tt.func @dot_i8_mismatched_layout(%A: tensor<16x32xi8, #dot_operand_a>, %B: tensor<32x8xi8, #dot_operand_b>, %C: tensor<16x8xi32, #mma0>) { + // expected-error@+1 {{requires matching NVIDIA MMAv2 layouts}} + %D = tti.dot_i8 %A, %B, %C, aSigned = true, bSigned = true : tensor<16x32xi8, #dot_operand_a> * tensor<32x8xi8, #dot_operand_b> -> tensor<16x8xi32, #mma0> + tt.return + } +} + +// ----- + tt.func @warp_specialize_no_holder() { // expected-error @below {{'ttg.warp_specialize' op expected to find only a `ttg.warp_specialize.partitions` op inside its second region}} "ttg.warp_specialize"() ({ diff --git a/test/TritonGPU/nvidia-fpsan.mlir b/test/TritonGPU/nvidia-fpsan.mlir index 90a0fa72cfe9..94a56218a400 100644 --- a/test/TritonGPU/nvidia-fpsan.mlir +++ b/test/TritonGPU/nvidia-fpsan.mlir @@ -76,7 +76,11 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar // CHECK: ttg.global_scratch_alloc // CHECK: ttg.barrier global_read|global_write // CHECK-NEXT: scf.for - // CHECK: ttg.barrier global_read|global_write + // CHECK: tti.dot_i8 {{.*}} aSigned = true, bSigned = true + // CHECK: tti.dot_i8 {{.*}} aSigned = false, bSigned = true + // CHECK: tti.dot_i8 {{.*}} aSigned = true, bSigned = false + // CHECK: tti.dot_i8 {{.*}} aSigned = false, bSigned = false + // CHECK-NOT: tti.dot_i8 // CHECK: tt.store // CHECK: ttg.barrier global_read|global_write // CHECK: ttng.arrive_barrier @@ -103,7 +107,6 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar // CHECK: ttg.global_scratch_alloc // CHECK: ttg.barrier global_read|global_write // CHECK-NEXT: scf.for - // CHECK: ttg.barrier global_read|global_write // CHECK: tt.store // CHECK: ttg.barrier global_read|global_write // CHECK: ttng.arrive_barrier @@ -132,6 +135,11 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar // CHECK: ttg.global_scratch_alloc // CHECK: ttg.barrier global_read|global_write // CHECK-NEXT: scf.for + // CHECK: tti.dot_i8 {{.*}} aSigned = true, bSigned = true + // CHECK: tti.dot_i8 {{.*}} aSigned = false, bSigned = true + // CHECK: tti.dot_i8 {{.*}} aSigned = true, bSigned = false + // CHECK: tti.dot_i8 {{.*}} aSigned = false, bSigned = false + // CHECK-NOT: tti.dot_i8 // CHECK: ttg.barrier global_read|global_write // CHECK: tt.store // CHECK: ttg.barrier global_read|global_write diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/CMakeLists.txt b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/CMakeLists.txt index c5b355c98fff..9d84b4a3bc62 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/CMakeLists.txt +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/CMakeLists.txt @@ -28,6 +28,7 @@ add_triton_library(TritonNVIDIAGPUToLLVM GluonTransforms TritonAnalysis TritonGPUToLLVM + TritonInstrumentIR TritonInstrumentToLLVM MLIRReconcileUnrealizedCasts NVGPUIR diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM.cpp index 71cc98c11a1a..d378c70023b5 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM.cpp @@ -2,6 +2,7 @@ #include "Utility.h" #include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Dialect/TritonInstrument/IR/Dialect.h" using namespace mlir; using namespace mlir::triton; @@ -15,6 +16,11 @@ LogicalResult convertMMA(triton::DotOp op, triton::DotOp::Adaptor adaptor, const LLVMTypeConverter *typeConverter, ConversionPatternRewriter &rewriter, bool isTuring); +LogicalResult convertMMA(triton::instrument::DotI8Op op, + triton::instrument::DotI8Op::Adaptor adaptor, + const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter, bool isTuring); + LogicalResult convertMMADotScaled(triton::DotScaledOp op, triton::DotScaledOp::Adaptor adaptor, const LLVMTypeConverter *typeConverter, @@ -85,6 +91,34 @@ struct DotOpConversion : public ConvertOpToLLVMPattern { } }; +struct DotI8OpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + triton::instrument::DotI8Op>::ConvertOpToLLVMPattern; + + DotI8OpConversion(LLVMTypeConverter &converter, int, PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, + benefit) {} + + LogicalResult + matchAndRewrite(triton::instrument::DotI8Op op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto dType = op.getD().getType(); + auto dEncoding = dType.getEncoding(); + if (!isPermutationMatrixLayout(toLinearLayout(dType.getShape(), dEncoding))) + return rewriter.notifyMatchFailure( + op, "DotI8Op result encoding must have a permutation-matrix linear " + "layout"); + + auto mmaLayout = dyn_cast(dEncoding); + if (!mmaLayout || mmaLayout.getVersionMajor() != 2) + return rewriter.notifyMatchFailure(op, + "DotI8Op requires an MMAv2 layout"); + return convertMMA(op, adaptor, getTypeConverter(), rewriter, + mmaLayout.isTuring()); + } +}; + struct WarpGroupDotOpConversion : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern< @@ -161,6 +195,7 @@ void mlir::triton::NVIDIA::populateDotOpToLLVMPatterns( LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, int computeCapability, PatternBenefit benefit) { patterns.add(typeConverter, computeCapability, benefit); + patterns.add(typeConverter, computeCapability, benefit); patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); patterns.add(typeConverter, computeCapability, diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp index 478d067fa901..fae80f64f5fa 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp @@ -4,6 +4,7 @@ #include "triton/Conversion/TritonGPUToLLVM/Utility.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Attributes.h" +#include "triton/Dialect/TritonInstrument/IR/Dialect.h" #include "llvm/ADT/SmallVector.h" using namespace mlir; @@ -715,13 +716,11 @@ using EmitMmaCallback = std::function &fc, RankedTensorType dTensorTy, int repK)>; -LogicalResult -convertMMAImpl(DotOpInterface op, Value llvmA, Value llvmB, Value llvmC, - const LLVMTypeConverter *typeConverter, - ConversionPatternRewriter &rewriter, TensorCoreType mmaType, - const NumRegisters &numRegisters, - const std::map &mmaInstructions, - const EmitMmaCallback &emitMma) { +LogicalResult convertMMAImpl( + DotOpInterface op, Value llvmA, Value llvmB, Value llvmC, + const LLVMTypeConverter *typeConverter, ConversionPatternRewriter &rewriter, + TensorCoreType mmaType, const NumRegisters &numRegisters, + const std::string &mmaInstruction, const EmitMmaCallback &emitMma) { auto loc = op.getLoc(); auto aType = cast(op.getA().getType()); auto bType = cast(op.getB().getType()); @@ -787,7 +786,7 @@ convertMMAImpl(DotOpInterface op, Value llvmA, Value llvmB, Value llvmC, elemsPerThread[rank - 2] * elemsPerThread[rank - 1] / numCPackedElem; auto callMma = [&](unsigned b, unsigned m, unsigned n, unsigned k) { PTXBuilder builder; - auto &mma = *builder.create(mmaInstructions.at(mmaType)); + auto &mma = *builder.create(mmaInstruction); // using =r for float32 works but leads to less readable ptx. unsigned colsPerThread = repN * 2; emitMma(builder, b, static_cast(m), static_cast(n), @@ -834,21 +833,10 @@ convertMMAImpl(DotOpInterface op, Value llvmA, Value llvmB, Value llvmC, return success(); } -} // namespace - -LogicalResult convertMMA(triton::DotOp op, triton::DotOp::Adaptor adaptor, - const LLVMTypeConverter *typeConverter, - ConversionPatternRewriter &rewriter, bool isTuring) { - auto aTensorTy = op.getA().getType(); - auto bTensorTy = op.getB().getType(); - auto dTensorTy = op.getD().getType(); - - TensorCoreType mmaType = getMmaTypeDot(op, aTensorTy, bTensorTy, dTensorTy); - const auto &instrMap = isTuring ? mmaInstrPtxTuring : mmaInstrPtxAmpere; - if (instrMap.find(mmaType) == instrMap.end()) - return op.emitError( - "unsupported MMA instruction for the given operand/result types"); - +LogicalResult convertMMAWithInstruction( + DotOpInterface op, Value llvmA, Value llvmB, Value llvmC, + const LLVMTypeConverter *typeConverter, ConversionPatternRewriter &rewriter, + TensorCoreType mmaType, const std::string &mmaInstruction, bool isTuring) { NumRegisters numRegisters = (mmaType == TensorCoreType::FP64_FP64_FP64_FP64) ? NumRegisters{1, 1, 1} : NumRegisters{2, 1, 2}; @@ -885,9 +873,43 @@ LogicalResult convertMMA(triton::DotOp op, triton::DotOp::Adaptor adaptor, } }; - return convertMMAImpl(op, adaptor.getA(), adaptor.getB(), adaptor.getC(), - typeConverter, rewriter, mmaType, numRegisters, - instrMap, emit); + return convertMMAImpl(op, llvmA, llvmB, llvmC, typeConverter, rewriter, + mmaType, numRegisters, mmaInstruction, emit); +} + +} // namespace + +LogicalResult convertMMA(triton::DotOp op, triton::DotOp::Adaptor adaptor, + const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter, bool isTuring) { + auto aTensorTy = op.getA().getType(); + auto bTensorTy = op.getB().getType(); + auto dTensorTy = op.getD().getType(); + + TensorCoreType mmaType = getMmaTypeDot(op, aTensorTy, bTensorTy, dTensorTy); + const auto &instrMap = isTuring ? mmaInstrPtxTuring : mmaInstrPtxAmpere; + if (instrMap.find(mmaType) == instrMap.end()) + return op.emitError( + "unsupported MMA instruction for the given operand/result types"); + + return convertMMAWithInstruction(op, adaptor.getA(), adaptor.getB(), + adaptor.getC(), typeConverter, rewriter, + mmaType, instrMap.at(mmaType), isTuring); +} + +LogicalResult convertMMA(triton::instrument::DotI8Op op, + triton::instrument::DotI8Op::Adaptor adaptor, + const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter, bool isTuring) { + std::string mmaInstruction = isTuring + ? "mma.sync.aligned.m8n8k16.row.col.s32." + : "mma.sync.aligned.m16n8k32.row.col.s32."; + mmaInstruction += op.getASigned() ? "s8." : "u8."; + mmaInstruction += op.getBSigned() ? "s8.s32" : "u8.s32"; + return convertMMAWithInstruction(op, adaptor.getA(), adaptor.getB(), + adaptor.getC(), typeConverter, rewriter, + TensorCoreType::INT32_INT8_INT8_INT32, + mmaInstruction, isTuring); } LogicalResult convertMMADotScaled(triton::DotScaledOp op, @@ -955,5 +977,5 @@ LogicalResult convertMMADotScaled(triton::DotScaledOp op, return convertMMAImpl(op, adaptor.getA(), adaptor.getB(), adaptor.getC(), typeConverter, rewriter, mmaType, numRegisters, - mmaInstrPtxScaled, emit); + mmaInstrPtxScaled.at(mmaType), emit); } From a784db0c4948c487274f798c685f85cbd96e400d Mon Sep 17 00:00:00 2001 From: Jeff Niu Date: Wed, 3 Jun 2026 23:54:56 +0000 Subject: [PATCH 02/22] Test FPSan TCGen MMA in warp partitions --- python/test/gluon/test_fpsan.py | 64 +++++++++++++++++++++++++++++++++ 1 file changed, 64 insertions(+) diff --git a/python/test/gluon/test_fpsan.py b/python/test/gluon/test_fpsan.py index 16528ed39a59..e98b7de1d550 100644 --- a/python/test/gluon/test_fpsan.py +++ b/python/test/gluon/test_fpsan.py @@ -1770,6 +1770,70 @@ def kernel(a_ptr, b_ptr, c_ptr, out_ptr, USE_ACC: gl.constexpr): _assert_payload_equal(out, exp_bits) +@pytest.mark.skipif(not is_blackwell(), reason="Requires Blackwell") +@pytest.mark.parametrize("partition_warps", [4, 2, 1]) +def test_tcgen05_mma_warp_specialize_partition(device, partition_warps, fresh_knobs): + _require_cuda_backend(device) + + M = gl.constexpr(64) + N = gl.constexpr(32) + K = gl.constexpr(32) + PARTITION_WARPS = gl.constexpr(partition_warps) + fresh_knobs.compilation.instrumentation_mode = "fpsan" + + @gluon.jit + def default_partition(): + pass + + @gluon.jit + def mma_partition(smem_a, smem_b, acc_tmem, bar): + tcgen05_mma(smem_a, smem_b.permute((1, 0)), acc_tmem, use_acc=False, pred=True, mbarriers=[bar]) + + @gluon.jit + def kernel(a_ptr, b_ptr, out_ptr): + layout: gl.constexpr = gl.BlockedLayout([1, 1], [32, 1], [gl.num_warps(), 1], [1, 0]) + offs_m = gl.arange(0, M, layout=gl.SliceLayout(1, layout))[:, None] + offs_n = gl.arange(0, N, layout=gl.SliceLayout(1, layout))[:, None] + offs_k = gl.arange(0, K, layout=gl.SliceLayout(0, layout))[None, :] + out_offs_n = gl.arange(0, N, layout=gl.SliceLayout(0, layout))[None, :] + + a = gl.load(a_ptr + offs_m * K + offs_k) + b = gl.load(b_ptr + offs_n * K + offs_k) + smem_a = gl.allocate_shared_memory(gl.float32, [M, K], gl.NVMMASharedLayout.get_default_for([M, K], gl.float32), + a) + smem_b = gl.allocate_shared_memory(gl.float32, [N, K], gl.NVMMASharedLayout.get_default_for([N, K], gl.float32), + b) + acc_tmem = allocate_tensor_memory(gl.float32, [M, N], layout=TensorMemoryLayout((M, N), col_stride=1)) + bar = mbarrier.allocate_mbarrier() + mbarrier.init(bar, count=1) + + gl.warp_specialize([ + (default_partition, ()), + (mma_partition, (smem_a, smem_b, acc_tmem, bar)), + ], [PARTITION_WARPS]) + + mbarrier.wait(bar, phase=0, deps=[smem_a, smem_b]) + mbarrier.invalidate(bar) + out = gl.convert_layout(acc_tmem.load(), layout) + gl.store(out_ptr + offs_m * N + out_offs_n, out) + + rs = np.random.RandomState(0) + a_bits = rs.randint(-(2**31), 2**31 - 1, size=(M.value, K.value), dtype=np.int32) + b_bits = rs.randint(-(2**31), 2**31 - 1, size=(N.value, K.value), dtype=np.int32) + exp_bits = _mm_payload_u32(a_bits, b_bits.T) + a = torch.tensor(a_bits, device=device, dtype=torch.int32) + b = torch.tensor(b_bits, device=device, dtype=torch.int32) + out = torch.empty((M.value, N.value), device=device, dtype=torch.int32) + + kernel[(1, )]( + triton.TensorWrapper(a, dtype=torch.float32), + triton.TensorWrapper(b, dtype=torch.float32), + triton.TensorWrapper(out, dtype=torch.float32), + num_warps=4, + ) + _assert_payload_equal(out, exp_bits) + + @pytest.mark.skipif(not is_blackwell(), reason="Requires Blackwell") @pytest.mark.parametrize("use_acc", [False, True]) def test_tcgen05_mma_two_ctas(device, use_acc, fresh_knobs): From 333348b36a9317bd64f8e3823956b71195cdf567 Mon Sep 17 00:00:00 2001 From: Jeff Niu Date: Thu, 4 Jun 2026 22:06:11 +0000 Subject: [PATCH 03/22] [FPSan] Address i8 decomposition review comments Centralize ordinary dot K verification in DotOpInterface, share MMAv2 warp distribution between matmul acceleration and FPSan, and simplify FPSan tile selection using TTGIR's power-of-two shape invariant. Preserve bounded emulation tiles and existing i8 decomposition eligibility. --- .../Dialect/Triton/IR/TritonOpInterfaces.td | 8 +++- include/triton/Dialect/TritonGPU/IR/Dialect.h | 6 +++ lib/Dialect/Triton/IR/Ops.cpp | 7 ---- lib/Dialect/TritonGPU/IR/Dialect.cpp | 24 ++++++++++++ .../TritonGPU/Transforms/AccelerateMatmul.cpp | 30 +------------- lib/Dialect/TritonInstrument/IR/Ops.cpp | 6 --- .../Transforms/FpSanitizer.cpp | 39 +++++-------------- lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp | 14 ------- 8 files changed, 48 insertions(+), 86 deletions(-) diff --git a/include/triton/Dialect/Triton/IR/TritonOpInterfaces.td b/include/triton/Dialect/Triton/IR/TritonOpInterfaces.td index 434a1c5d62d7..1f379c3c49db 100644 --- a/include/triton/Dialect/Triton/IR/TritonOpInterfaces.td +++ b/include/triton/Dialect/Triton/IR/TritonOpInterfaces.td @@ -58,7 +58,13 @@ def DotOpInterface : OpInterface<"DotOpInterface"> { /*desc=*/"Verify the dimensions of the A and B DotOp operands.", /*retType=*/"bool", /*methodName=*/"verifyDims", - /*args=*/(ins)>, + /*args=*/(ins), + /*methodBody=*/[{}], + /*defaultImpl=*/ [{ + auto aShape = cast($_op.getA().getType()).getShape(); + auto bShape = cast($_op.getB().getType()).getShape(); + return aShape.back() == bShape[bShape.size() - 2]; + }]>, InterfaceMethod< /*desc=*/"Verify the dimensions of the DotOp output.", /*retType=*/"bool", diff --git a/include/triton/Dialect/TritonGPU/IR/Dialect.h b/include/triton/Dialect/TritonGPU/IR/Dialect.h index 7bb1658f604a..cac44099e35b 100644 --- a/include/triton/Dialect/TritonGPU/IR/Dialect.h +++ b/include/triton/Dialect/TritonGPU/IR/Dialect.h @@ -278,6 +278,12 @@ SmallVector getAllocationShapePerCTA(Type type); unsigned getNumCTAs(Attribute layout); +// Returns the MMAv2 warp distribution for a matrix tile. This does not apply +// dot-chain policy and may oversubscribe tiles with fewer instruction +// repetitions than warps. +SmallVector getMmaV2WarpsPerCTA(ArrayRef shape, + int numWarps); + // Return the order that represents that the batch is in row-major or // column-major order for a batch of matrices of shape [*, m, n] with // len(shape) == rank. diff --git a/lib/Dialect/Triton/IR/Ops.cpp b/lib/Dialect/Triton/IR/Ops.cpp index fd67131a51ef..75ee5ed439ae 100644 --- a/lib/Dialect/Triton/IR/Ops.cpp +++ b/lib/Dialect/Triton/IR/Ops.cpp @@ -283,13 +283,6 @@ LogicalResult DotOp::verify() { bEncoding); } -bool DotOp::verifyDims() { - auto aShape = this->getA().getType().getShape(); - auto bShape = this->getB().getType().getShape(); - - return aShape[aShape.size() - 1] == bShape[aShape.size() - 2]; -} - //-- DotScaledOp -- bool DotScaledOp::verifyDims() { auto aShape = this->getA().getType().getShape(); diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index ca91c231e015..33da764d4208 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -472,6 +472,30 @@ SmallVector getAllocationShapePerCTA(Type type) { tensorType.getShape()); } +SmallVector getMmaV2WarpsPerCTA(ArrayRef shape, + int numWarps) { + if (shape.size() == 3) + return {static_cast(numWarps), 1, 1}; + + assert(shape.size() == 2); + SmallVector reps = {ceil(shape[0], int64_t{16}), + ceil(shape[1], int64_t{8})}; + SmallVector warps = {1, 1}; + // Balance repetitions to reduce register pressure, breaking ties toward M + // because the lhs instruction tile uses more registers than the rhs. + while (product(warps) < numWarps) { + if (reps[0] >= reps[1]) { + warps[0] *= 2; + if (reps[0] != 1) + reps[0] /= 2; + } else { + warps[1] *= 2; + reps[1] /= 2; + } + } + return warps; +} + unsigned getNumCTAs(Attribute layout) { return product(getCTAsPerCGA(layout)); } diff --git a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp index 82789d800f9a..61ba23c4c9c8 100644 --- a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp +++ b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp @@ -88,7 +88,7 @@ SmallVector warpsPerTileV2(DotOpInterface dotOp, auto rank = shape.size(); // Early exit for batched matmul if (rank == 3) - return {(unsigned)numWarps, 1, 1}; + return getMmaV2WarpsPerCTA(shape, numWarps); auto filter = [&dotOp](Operation *op) { return op->getParentRegion() == dotOp->getParentRegion() && @@ -117,33 +117,7 @@ SmallVector warpsPerTileV2(DotOpInterface dotOp, } } - assert(rank == 2); - SmallVector shapePerWarp = {16, 8}; - SmallVector warps = {1, 1}; - // Compute repM and repN - SmallVector reps = {ceil(shape[0], shapePerWarp[0]), - ceil(shape[1], shapePerWarp[1])}; - // The formula for the number of registers given the reps is - // repM * 4 * repK + repN * 2 * repK + regsC - // where regsC = repM * repN * 4, which does not depend on the warp shape - // - // As such, to minimize the register pressure, we need to balance - // repM and repN. We then untie towards M, as the lhs tile has 4 elements, - // and the rhs tile has just 2. - while (product(warps) < numWarps) { - if (reps[0] >= reps[1]) { - warps[0] *= 2; - // Too many warps for this mma (repM == repN == 1). - // We allocate the remaining warps to the left (arbitrary choice) - if (reps[0] != 1) { - reps[0] /= 2; - } - } else { - warps[1] *= 2; - reps[1] /= 2; - } - } - return {(unsigned)warps[0], (unsigned)warps[1]}; + return getMmaV2WarpsPerCTA(shape, numWarps); } SmallVector warpsPerTileV3(DotOpInterface dotOp, const ArrayRef shape, diff --git a/lib/Dialect/TritonInstrument/IR/Ops.cpp b/lib/Dialect/TritonInstrument/IR/Ops.cpp index 5526f30a8624..9b8ec082eaff 100644 --- a/lib/Dialect/TritonInstrument/IR/Ops.cpp +++ b/lib/Dialect/TritonInstrument/IR/Ops.cpp @@ -16,12 +16,6 @@ namespace instrument { namespace tt = mlir::triton; namespace ttg = mlir::triton::gpu; -bool DotI8Op::verifyDims() { - auto aShape = getA().getType().getShape(); - auto bShape = getB().getType().getShape(); - return aShape.back() == bShape[bShape.size() - 2]; -} - LogicalResult DotI8Op::verify() { auto aEnc = dyn_cast(getA().getType().getEncoding()); diff --git a/lib/Dialect/TritonInstrument/Transforms/FpSanitizer.cpp b/lib/Dialect/TritonInstrument/Transforms/FpSanitizer.cpp index 092788530ee2..37c8dae6f580 100644 --- a/lib/Dialect/TritonInstrument/Transforms/FpSanitizer.cpp +++ b/lib/Dialect/TritonInstrument/Transforms/FpSanitizer.cpp @@ -57,32 +57,15 @@ constexpr int64_t kI8MmaK = 32; bool supportsI8DotDecomposition(PatternRewriter &rewriter, IntegerType accElem) { - auto module = + auto moduleOp = rewriter.getInsertionBlock()->getParentOp()->getParentOfType(); - if (getAMDArch(module)) + if (getAMDArch(moduleOp)) return false; return llvm::is_contained({16, 32, 64}, accElem.getWidth()); } -std::optional> getI8MmaWarpsPerCTA(int64_t m, int64_t n, - int numWarps) { - if (m < 16 || n < 8 || (m % 16) != 0 || (n % 8) != 0 || numWarps <= 0) - return std::nullopt; - - SmallVector warpsPerCTA{1, 1}; - SmallVector reps{m / 16, n / 8}; - while (warpsPerCTA[0] * warpsPerCTA[1] < numWarps) { - unsigned axis = reps[0] >= reps[1] ? 0 : 1; - if (reps[axis] <= 1 || (reps[axis] % 2) != 0) - axis = 1 - axis; - if (reps[axis] <= 1 || (reps[axis] % 2) != 0) - return std::nullopt; - warpsPerCTA[axis] *= 2; - reps[axis] /= 2; - } - if (warpsPerCTA[0] * warpsPerCTA[1] != numWarps) - return std::nullopt; - return warpsPerCTA; +bool canUseI8MmaTile(int64_t m, int64_t n, int numWarps) { + return m >= 16 && n >= 8 && (m / 16) * (n / 8) >= numWarps; } std::pair getMmaEmulationTileShape(PatternRewriter &rewriter, @@ -93,12 +76,8 @@ std::pair getMmaEmulationTileShape(PatternRewriter &rewriter, int64_t numWarps = ttg::lookupNumWarps(rewriter.getInsertionBlock()->getParent()); int64_t tileM = std::min(16 * numWarps, m); - while (tileM > 0 && (m % tileM) != 0) - tileM /= 2; int64_t tileN = std::min(8 * numWarps, n); - while (tileN > 0 && (n % tileN) != 0) - tileN /= 2; - if (getI8MmaWarpsPerCTA(tileM, tileN, numWarps)) + if (canUseI8MmaTile(tileM, tileN, numWarps)) return {tileM, tileN}; } return {std::min(kTileM, m), std::min(kTileN, n)}; @@ -1418,9 +1397,9 @@ Value tryEmitI8DotDecomposition(PatternRewriter &rewriter, Location loc, if (bShape[0] != k || (k % kI8MmaK) != 0 || !supportsI8DotDecomposition(rewriter, accElem)) return Value(); - auto warpsPerCTA = getI8MmaWarpsPerCTA(m, n, numWarps); - if (!warpsPerCTA) + if (!canUseI8MmaTile(m, n, numWarps)) return Value(); + auto warpsPerCTA = ttg::getMmaV2WarpsPerCTA({m, n}, numWarps); auto aElem = cast(aPayloadTy.getElementType()); auto bElem = cast(bPayloadTy.getElementType()); @@ -1434,7 +1413,7 @@ Value tryEmitI8DotDecomposition(PatternRewriter &rewriter, Location loc, auto i8Ty = rewriter.getI8Type(); auto i32Ty = rewriter.getI32Type(); auto mmaLayout = ttg::NvidiaMmaEncodingAttr::get( - ctx, /*versionMajor=*/2, /*versionMinor=*/0, *warpsPerCTA, + ctx, /*versionMajor=*/2, /*versionMinor=*/0, warpsPerCTA, ttg::getCGALayout(accLayout), SmallVector{16, 8}); auto aDotLayout = ttg::DotOperandEncodingAttr::get(ctx, 0, mmaLayout, i8Ty); auto bDotLayout = ttg::DotOperandEncodingAttr::get(ctx, 1, mmaLayout, i8Ty); @@ -1580,7 +1559,7 @@ std::optional emitMmaEmulationLoops( (k % kI8MmaK) == 0 && supportsI8DotDecomposition(rewriter, accElem) && isScaleK32Aligned(scale.aScalePtr, scale.aScaleFactor) && isScaleK32Aligned(scale.bScalePtr, scale.bScaleFactor) && - getI8MmaWarpsPerCTA(tileM, tileN, numWarps).has_value(); + canUseI8MmaTile(tileM, tileN, numWarps); if (canUseI8Decomposition) { if (!scale.computeElem && accElem.getWidth() <= 32) { Value aTile = loadScratchStrided2D(rewriter, loc, aTilePtr, aTileTy, diff --git a/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp b/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp index d64b3d97ea07..44f5ae9c5ebb 100644 --- a/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp +++ b/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp @@ -149,13 +149,6 @@ bool WarpGroupDotOp::needsPartialAccumulator() { return isFP8 && accFP32 && maxNumImpreciseAcc <= aTensorTy.getShape()[1]; } -bool WarpGroupDotOp::verifyDims() { - auto aShape = this->getA().getType().getShape(); - auto bShape = this->getB().getType().getShape(); - - return aShape[aShape.size() - 1] == bShape[aShape.size() - 2]; -} - // -- WarpGroupDotWaitOp -- LogicalResult WarpGroupDotWaitOp::inferReturnTypes( MLIRContext *context, std::optional location, ValueRange operands, @@ -977,13 +970,6 @@ void TCGen5MMAOp::getEffects( SharedMemory::get()); } -bool TCGen5MMAOp::verifyDims() { - auto aShape = this->getA().getType().getShape(); - auto bShape = this->getB().getType().getShape(); - - return aShape[aShape.size() - 1] == bShape[aShape.size() - 2]; -} - Value TCGen5MMAOp::useAccumulator() { return getUseD(); } void TCGen5MMAOp::setUseAccumulator(Value flag) { From 4dd75e4efabbcbf845b42371147254bdb0306300 Mon Sep 17 00:00:00 2001 From: Jeff Niu Date: Thu, 4 Jun 2026 05:40:27 +0000 Subject: [PATCH 04/22] Support multi-CTA local gather and scatter --- .../Conversion/TritonGPUToLLVM/Utility.h | 16 ++++++- .../TritonGPUToLLVM/MemoryOpToLLVM.cpp | 46 ++++++++++-------- lib/Conversion/TritonGPUToLLVM/Utility.cpp | 47 +++++++++++++------ test/Conversion/tritonnvidiagpu_to_llvm.mlir | 23 +++++++++ 4 files changed, 97 insertions(+), 35 deletions(-) diff --git a/include/triton/Conversion/TritonGPUToLLVM/Utility.h b/include/triton/Conversion/TritonGPUToLLVM/Utility.h index cf33973062aa..417efeb9da68 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/Utility.h +++ b/include/triton/Conversion/TritonGPUToLLVM/Utility.h @@ -605,9 +605,21 @@ emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target, const LinearLayout &layout, RankedTensorType type, bool withCTAOffset); -// Compute per-element shared-memory pointers for a local atomic/ldst update by +struct LocalSharedMemoryAddress { + Value ptr; + std::optional ctaId; +}; + +// Compute per-element shared-memory addresses for a local atomic/ldst update by // replacing `coords[*][axis]` with `idxValues[*]` and mapping the resulting -// logical coordinates back to shared-memory offsets. +// logical coordinates back to shared-memory offsets and target CTAs. +SmallVector +computeLocalAddrs(Location loc, triton::gpu::MemDescType memDescTy, + SharedMemoryObject smemObj, Type llvmElemTy, + ArrayRef idxValues, + ArrayRef> coords, unsigned axis, + RewriterBase &rewriter); + SmallVector computeLocalPtrs(Location loc, triton::gpu::MemDescType memDescTy, SharedMemoryObject smemObj, Type llvmElemTy, diff --git a/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp index 613b6188e1b6..f89f7a838188 100644 --- a/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp @@ -25,19 +25,41 @@ lowerLocalScGt(Location loc, MLIRContext *ctx, MemDescType memDescTy, const TargetInfoBase &targetInfo) { auto b = TritonLLVMOpBuilder(loc, rewriter); bool isScatter = !storeVals.empty(); - SmallVector ptrs = computeLocalPtrs( + SmallVector addrs = computeLocalAddrs( loc, memDescTy, smemObj, llvmElemTy, idxValues, coords, axis, rewriter); + Value currentCtaId; + if (llvm::any_of(addrs, [](const LocalSharedMemoryAddress &addr) { + return addr.ctaId.has_value(); + })) + currentCtaId = targetInfo.getClusterCTAId(rewriter, loc); SmallVector results; if (!isScatter) results.resize(coords.size()); - for (auto [i, ptr] : llvm::enumerate(ptrs)) { + for (auto [i, addr] : llvm::enumerate(addrs)) { if (isScatter) { - targetInfo.storeShared(rewriter, loc, ptr, storeVals[i], b.true_val()); + if (addr.ctaId) { + Value isLocal = b.icmp_eq(*addr.ctaId, currentCtaId); + Value isRemote = b.icmp_ne(*addr.ctaId, currentCtaId); + targetInfo.storeShared(rewriter, loc, addr.ptr, storeVals[i], isLocal); + targetInfo.storeDShared(rewriter, loc, addr.ptr, addr.ctaId, + storeVals[i], isRemote); + } else { + targetInfo.storeShared(rewriter, loc, addr.ptr, storeVals[i], + b.true_val()); + } + } else if (addr.ctaId) { + Value isLocal = b.icmp_eq(*addr.ctaId, currentCtaId); + Value isRemote = b.icmp_ne(*addr.ctaId, currentCtaId); + Value local = + targetInfo.loadShared(rewriter, loc, addr.ptr, llvmElemTy, isLocal); + Value remote = targetInfo.loadDShared(rewriter, loc, addr.ptr, addr.ctaId, + llvmElemTy, isRemote); + results[i] = b.select(isLocal, local, remote); } else { - results[i] = - targetInfo.loadShared(rewriter, loc, ptr, llvmElemTy, b.true_val()); + results[i] = targetInfo.loadShared(rewriter, loc, addr.ptr, llvmElemTy, + b.true_val()); } } @@ -267,13 +289,6 @@ struct LocalGatherOpConversion : public ConvertOpToLLVMPattern { auto loc = op.getLoc(); auto *ctx = op.getContext(); auto memDescTy = cast(op.getSrc().getType()); - // TODO: PartitionedSharedEncoding lowering will be enabled in subsequent - // PRs. - if (isa( - memDescTy.getEncoding())) { - return rewriter.notifyMatchFailure( - op, "PartitionedSharedEncoding not yet supported in lowering"); - } auto regTy = cast(op.getType()); auto typeConverter = getTypeConverter(); @@ -316,13 +331,6 @@ struct LocalScatterOpConversion auto loc = op.getLoc(); auto *ctx = op.getContext(); auto memDescTy = cast(op.getDst().getType()); - // TODO: PartitionedSharedEncoding lowering will be enabled in subsequent - // PRs. - if (isa( - memDescTy.getEncoding())) { - return rewriter.notifyMatchFailure( - op, "PartitionedSharedEncoding not yet supported in lowering"); - } auto valuesTy = cast(op.getValues().getType()); auto typeConverter = getTypeConverter(); diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.cpp b/lib/Conversion/TritonGPUToLLVM/Utility.cpp index 052a970bd9e6..651666f64818 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.cpp +++ b/lib/Conversion/TritonGPUToLLVM/Utility.cpp @@ -540,12 +540,12 @@ emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target, return emitIndices(loc, rewriter, target, ll, type, withCTAOffset); } -SmallVector computeLocalPtrs(Location loc, - triton::gpu::MemDescType memDescTy, - SharedMemoryObject smemObj, Type llvmElemTy, - ArrayRef idxValues, - ArrayRef> coords, - unsigned axis, RewriterBase &rewriter) { +SmallVector +computeLocalAddrs(Location loc, triton::gpu::MemDescType memDescTy, + SharedMemoryObject smemObj, Type llvmElemTy, + ArrayRef idxValues, + ArrayRef> coords, unsigned axis, + RewriterBase &rewriter) { MLIRContext *ctx = memDescTy.getContext(); auto b = TritonLLVMOpBuilder(loc, rewriter); @@ -561,12 +561,15 @@ SmallVector computeLocalPtrs(Location loc, allDims.push_back(str_attr("dim" + Twine(dim))); auto kOffset = str_attr("offset"); + auto kBlock = str_attr("block"); + bool useBlockId = invSharedLayout.hasOutDim(kBlock) && + invSharedLayout.getOutDimSize(kBlock) > 1; // Get the subslice affine offset (non-zero for memdesc subslices) Value affineOffset = smemObj.getShmemOffset(loc, rewriter, memDescTy); auto bitwidth = getIntOrFloatOrPtrBitWidth(llvmElemTy); - SmallVector ptrs; - ptrs.reserve(coords.size()); + SmallVector addrs; + addrs.reserve(coords.size()); for (auto [i, idxVal] : llvm::enumerate(idxValues)) { Value idx = idxVal; @@ -589,15 +592,18 @@ SmallVector computeLocalPtrs(Location loc, auto outputs = applyLinearLayout(loc, rewriter, invSharedLayout, inputs); - // Extract the offset value + // Extract the offset and target CTA. Value offset = nullptr; + Value blockId = nullptr; for (auto [name, value] : outputs) { - if (name == kOffset) { + if (name == kOffset) offset = value; - break; - } + else if (name == kBlock) + blockId = value; } assert(offset && "expected offset output from inverted shared layout"); + assert((!useBlockId || blockId) && + "expected block output from multi-CTA shared layout"); // For subslices, the physical offset is computed as: // physical_offset = L⁻¹(coords) ⊕ L⁻¹(subslice_logical_offset) @@ -626,10 +632,23 @@ SmallVector computeLocalPtrs(Location loc, ptr = b.gep(smemObj.getBase().getType(), llvmElemTy, smemObj.getBase(), offset); } - ptrs.push_back(ptr); + addrs.push_back( + {ptr, useBlockId ? std::optional(blockId) : std::nullopt}); } - return ptrs; + return addrs; +} + +SmallVector computeLocalPtrs(Location loc, + triton::gpu::MemDescType memDescTy, + SharedMemoryObject smemObj, Type llvmElemTy, + ArrayRef idxValues, + ArrayRef> coords, + unsigned axis, RewriterBase &rewriter) { + return llvm::map_to_vector( + computeLocalAddrs(loc, memDescTy, smemObj, llvmElemTy, idxValues, coords, + axis, rewriter), + [](const LocalSharedMemoryAddress &addr) { return addr.ptr; }); } FailureOr prepareLocalAtomicScatterRMW( diff --git a/test/Conversion/tritonnvidiagpu_to_llvm.mlir b/test/Conversion/tritonnvidiagpu_to_llvm.mlir index 9c81ab8b9899..4d269832209e 100644 --- a/test/Conversion/tritonnvidiagpu_to_llvm.mlir +++ b/test/Conversion/tritonnvidiagpu_to_llvm.mlir @@ -802,3 +802,26 @@ module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32, "ttg.tot llvm.return } } + +// ----- + +#local_gather_scatter_blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0], CGALayout = [[0, 1]]}> +#local_gather_scatter_shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], CGALayout = [[0, 1]]}> + +module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: @local_gather_scatter_two_ctas + // CHECK: ld.shared::cluster + // CHECK: nvvm.barrier0 + // CHECK: st.shared::cluster + tt.func @local_gather_scatter_two_ctas(%out: !tt.ptr, %vals: tensor<2x32xi32, #local_gather_scatter_blocked>) { + %src = ttg.local_alloc {allocation.offset = [0 : i32, 256 : i32]} : () -> !ttg.memdesc<2x32xi32, #local_gather_scatter_shared, #ttg.shared_memory, mutable> + %idx = arith.constant dense<0> : tensor<2x32xi32, #local_gather_scatter_blocked> + %g = ttg.local_gather %src[%idx] {axis = 0 : i32} : !ttg.memdesc<2x32xi32, #local_gather_scatter_shared, #ttg.shared_memory, mutable>, tensor<2x32xi32, #local_gather_scatter_blocked> -> tensor<2x32xi32, #local_gather_scatter_blocked> + ttg.local_scatter %src[%idx], %vals {axis = 0 : i32} : !ttg.memdesc<2x32xi32, #local_gather_scatter_shared, #ttg.shared_memory, mutable>, tensor<2x32xi32, #local_gather_scatter_blocked>, tensor<2x32xi32, #local_gather_scatter_blocked> + %ptrs = tt.splat %out : !tt.ptr -> tensor<2x32x!tt.ptr, #local_gather_scatter_blocked> + %offs = arith.constant dense<0> : tensor<2x32xi32, #local_gather_scatter_blocked> + %out_ptrs = tt.addptr %ptrs, %offs : tensor<2x32x!tt.ptr, #local_gather_scatter_blocked>, tensor<2x32xi32, #local_gather_scatter_blocked> + tt.store %out_ptrs, %g : tensor<2x32x!tt.ptr, #local_gather_scatter_blocked> + tt.return + } +} From 80bf7cca955572e1b4bc351ee57cfdb145b82d4c Mon Sep 17 00:00:00 2001 From: Jeff Niu Date: Thu, 4 Jun 2026 06:26:25 +0000 Subject: [PATCH 05/22] Simplify multi-CTA gather and scatter lowering --- .../TritonGPUToLLVM/MemoryOpToLLVM.cpp | 40 ++++--------------- 1 file changed, 7 insertions(+), 33 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp index f89f7a838188..65ad169a4047 100644 --- a/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp @@ -18,7 +18,7 @@ using namespace mlir::triton::gpu; // For gather: storeVals is empty, returns loaded values. // For scatter: storeVals contains values to store, returns empty. SmallVector -lowerLocalScGt(Location loc, MLIRContext *ctx, MemDescType memDescTy, +lowerLocalScGt(Location loc, MemDescType memDescTy, SharedMemoryObject smemObj, Type llvmElemTy, ArrayRef idxValues, ArrayRef> coords, unsigned axis, ArrayRef storeVals, RewriterBase &rewriter, @@ -27,39 +27,15 @@ lowerLocalScGt(Location loc, MLIRContext *ctx, MemDescType memDescTy, bool isScatter = !storeVals.empty(); SmallVector addrs = computeLocalAddrs( loc, memDescTy, smemObj, llvmElemTy, idxValues, coords, axis, rewriter); - Value currentCtaId; - if (llvm::any_of(addrs, [](const LocalSharedMemoryAddress &addr) { - return addr.ctaId.has_value(); - })) - currentCtaId = targetInfo.getClusterCTAId(rewriter, loc); SmallVector results; - if (!isScatter) - results.resize(coords.size()); - for (auto [i, addr] : llvm::enumerate(addrs)) { if (isScatter) { - if (addr.ctaId) { - Value isLocal = b.icmp_eq(*addr.ctaId, currentCtaId); - Value isRemote = b.icmp_ne(*addr.ctaId, currentCtaId); - targetInfo.storeShared(rewriter, loc, addr.ptr, storeVals[i], isLocal); - targetInfo.storeDShared(rewriter, loc, addr.ptr, addr.ctaId, - storeVals[i], isRemote); - } else { - targetInfo.storeShared(rewriter, loc, addr.ptr, storeVals[i], - b.true_val()); - } - } else if (addr.ctaId) { - Value isLocal = b.icmp_eq(*addr.ctaId, currentCtaId); - Value isRemote = b.icmp_ne(*addr.ctaId, currentCtaId); - Value local = - targetInfo.loadShared(rewriter, loc, addr.ptr, llvmElemTy, isLocal); - Value remote = targetInfo.loadDShared(rewriter, loc, addr.ptr, addr.ctaId, - llvmElemTy, isRemote); - results[i] = b.select(isLocal, local, remote); + targetInfo.storeDShared(rewriter, loc, addr.ptr, addr.ctaId, storeVals[i], + b.true_val()); } else { - results[i] = targetInfo.loadShared(rewriter, loc, addr.ptr, llvmElemTy, - b.true_val()); + results.push_back(targetInfo.loadDShared( + rewriter, loc, addr.ptr, addr.ctaId, llvmElemTy, b.true_val())); } } @@ -287,7 +263,6 @@ struct LocalGatherOpConversion : public ConvertOpToLLVMPattern { matchAndRewrite(LocalGatherOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); - auto *ctx = op.getContext(); auto memDescTy = cast(op.getSrc().getType()); auto regTy = cast(op.getType()); auto typeConverter = getTypeConverter(); @@ -302,7 +277,7 @@ struct LocalGatherOpConversion : public ConvertOpToLLVMPattern { emitIndices(loc, rewriter, targetInfo, regTy.getEncoding(), regTy, /*withCTAOffset=*/true); - auto results = lowerLocalScGt(loc, ctx, memDescTy, smemObj, llvmElemTy, + auto results = lowerLocalScGt(loc, memDescTy, smemObj, llvmElemTy, idxValues, dstIndices, op.getAxis(), /*storeVals=*/{}, rewriter, targetInfo); @@ -329,7 +304,6 @@ struct LocalScatterOpConversion matchAndRewrite(LocalScatterOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); - auto *ctx = op.getContext(); auto memDescTy = cast(op.getDst().getType()); auto valuesTy = cast(op.getValues().getType()); auto typeConverter = getTypeConverter(); @@ -346,7 +320,7 @@ struct LocalScatterOpConversion emitIndices(loc, rewriter, targetInfo, valuesTy.getEncoding(), valuesTy, /*withCTAOffset=*/true); - lowerLocalScGt(loc, ctx, memDescTy, smemObj, llvmElemTy, idxValues, + lowerLocalScGt(loc, memDescTy, smemObj, llvmElemTy, idxValues, srcIndices, op.getAxis(), values, rewriter, targetInfo); rewriter.eraseOp(op); From 1ffc0a11ea61bb0b513fea59fdf8f53786d2238a Mon Sep 17 00:00:00 2001 From: Jeff Niu Date: Thu, 4 Jun 2026 06:30:27 +0000 Subject: [PATCH 06/22] Preserve explicit cluster gather codegen --- .../TritonGPUToLLVM/MemoryOpToLLVM.cpp | 38 +++++++++++++++---- 1 file changed, 31 insertions(+), 7 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp index 65ad169a4047..37f8b7008b8c 100644 --- a/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp @@ -18,7 +18,7 @@ using namespace mlir::triton::gpu; // For gather: storeVals is empty, returns loaded values. // For scatter: storeVals contains values to store, returns empty. SmallVector -lowerLocalScGt(Location loc, MemDescType memDescTy, +lowerLocalScGt(Location loc, MLIRContext *ctx, MemDescType memDescTy, SharedMemoryObject smemObj, Type llvmElemTy, ArrayRef idxValues, ArrayRef> coords, unsigned axis, ArrayRef storeVals, RewriterBase &rewriter, @@ -27,15 +27,37 @@ lowerLocalScGt(Location loc, MemDescType memDescTy, bool isScatter = !storeVals.empty(); SmallVector addrs = computeLocalAddrs( loc, memDescTy, smemObj, llvmElemTy, idxValues, coords, axis, rewriter); + Value currentCtaId; + if (!addrs.empty() && addrs.front().ctaId) + currentCtaId = targetInfo.getClusterCTAId(rewriter, loc); SmallVector results; + if (!isScatter) + results.resize(coords.size()); + for (auto [i, addr] : llvm::enumerate(addrs)) { if (isScatter) { - targetInfo.storeDShared(rewriter, loc, addr.ptr, addr.ctaId, storeVals[i], - b.true_val()); + if (addr.ctaId) { + Value isLocal = b.icmp_eq(*addr.ctaId, currentCtaId); + Value isRemote = b.icmp_ne(*addr.ctaId, currentCtaId); + targetInfo.storeShared(rewriter, loc, addr.ptr, storeVals[i], isLocal); + targetInfo.storeDShared(rewriter, loc, addr.ptr, addr.ctaId, + storeVals[i], isRemote); + } else { + targetInfo.storeShared(rewriter, loc, addr.ptr, storeVals[i], + b.true_val()); + } + } else if (addr.ctaId) { + Value isLocal = b.icmp_eq(*addr.ctaId, currentCtaId); + Value isRemote = b.icmp_ne(*addr.ctaId, currentCtaId); + Value local = + targetInfo.loadShared(rewriter, loc, addr.ptr, llvmElemTy, isLocal); + Value remote = targetInfo.loadDShared(rewriter, loc, addr.ptr, addr.ctaId, + llvmElemTy, isRemote); + results[i] = b.select(isLocal, local, remote); } else { - results.push_back(targetInfo.loadDShared( - rewriter, loc, addr.ptr, addr.ctaId, llvmElemTy, b.true_val())); + results[i] = targetInfo.loadShared(rewriter, loc, addr.ptr, llvmElemTy, + b.true_val()); } } @@ -263,6 +285,7 @@ struct LocalGatherOpConversion : public ConvertOpToLLVMPattern { matchAndRewrite(LocalGatherOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); + auto *ctx = op.getContext(); auto memDescTy = cast(op.getSrc().getType()); auto regTy = cast(op.getType()); auto typeConverter = getTypeConverter(); @@ -277,7 +300,7 @@ struct LocalGatherOpConversion : public ConvertOpToLLVMPattern { emitIndices(loc, rewriter, targetInfo, regTy.getEncoding(), regTy, /*withCTAOffset=*/true); - auto results = lowerLocalScGt(loc, memDescTy, smemObj, llvmElemTy, + auto results = lowerLocalScGt(loc, ctx, memDescTy, smemObj, llvmElemTy, idxValues, dstIndices, op.getAxis(), /*storeVals=*/{}, rewriter, targetInfo); @@ -304,6 +327,7 @@ struct LocalScatterOpConversion matchAndRewrite(LocalScatterOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); + auto *ctx = op.getContext(); auto memDescTy = cast(op.getDst().getType()); auto valuesTy = cast(op.getValues().getType()); auto typeConverter = getTypeConverter(); @@ -320,7 +344,7 @@ struct LocalScatterOpConversion emitIndices(loc, rewriter, targetInfo, valuesTy.getEncoding(), valuesTy, /*withCTAOffset=*/true); - lowerLocalScGt(loc, memDescTy, smemObj, llvmElemTy, idxValues, + lowerLocalScGt(loc, ctx, memDescTy, smemObj, llvmElemTy, idxValues, srcIndices, op.getAxis(), values, rewriter, targetInfo); rewriter.eraseOp(op); From 6dc53acef0cd5a1ba762480fd2dd42b3f1e715b8 Mon Sep 17 00:00:00 2001 From: Jeff Niu Date: Thu, 4 Jun 2026 06:59:35 +0000 Subject: [PATCH 07/22] Apply pre-commit formatting --- include/triton/Conversion/TritonGPUToLLVM/Utility.h | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/include/triton/Conversion/TritonGPUToLLVM/Utility.h b/include/triton/Conversion/TritonGPUToLLVM/Utility.h index 417efeb9da68..abd813304502 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/Utility.h +++ b/include/triton/Conversion/TritonGPUToLLVM/Utility.h @@ -613,12 +613,10 @@ struct LocalSharedMemoryAddress { // Compute per-element shared-memory addresses for a local atomic/ldst update by // replacing `coords[*][axis]` with `idxValues[*]` and mapping the resulting // logical coordinates back to shared-memory offsets and target CTAs. -SmallVector -computeLocalAddrs(Location loc, triton::gpu::MemDescType memDescTy, - SharedMemoryObject smemObj, Type llvmElemTy, - ArrayRef idxValues, - ArrayRef> coords, unsigned axis, - RewriterBase &rewriter); +SmallVector computeLocalAddrs( + Location loc, triton::gpu::MemDescType memDescTy, + SharedMemoryObject smemObj, Type llvmElemTy, ArrayRef idxValues, + ArrayRef> coords, unsigned axis, RewriterBase &rewriter); SmallVector computeLocalPtrs(Location loc, triton::gpu::MemDescType memDescTy, From 486963793c4819551cc915cda2ed1b4fca876b2a Mon Sep 17 00:00:00 2001 From: Jeff Niu Date: Thu, 4 Jun 2026 05:49:23 +0000 Subject: [PATCH 08/22] Add instrumentation local gather for FPSan --- .../Conversion/TritonGPUToLLVM/Utility.h | 12 + .../IR/TritonInstrumentOps.td | 24 + .../TritonGPUToLLVM/MemoryOpToLLVM.cpp | 33 +- lib/Conversion/TritonGPUToLLVM/Utility.cpp | 46 +- .../InstrumentationToLLVM.cpp | 41 ++ lib/Dialect/TritonInstrument/IR/Ops.cpp | 34 ++ .../Transforms/FpSanitizer.cpp | 447 +++++++++++------- test/Conversion/tritoninstrument_to_llvm.mlir | 22 + test/TritonGPU/nvidia-fpsan.mlir | 51 +- 9 files changed, 508 insertions(+), 202 deletions(-) diff --git a/include/triton/Conversion/TritonGPUToLLVM/Utility.h b/include/triton/Conversion/TritonGPUToLLVM/Utility.h index abd813304502..e7b14224159b 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/Utility.h +++ b/include/triton/Conversion/TritonGPUToLLVM/Utility.h @@ -618,6 +618,13 @@ SmallVector computeLocalAddrs( SharedMemoryObject smemObj, Type llvmElemTy, ArrayRef idxValues, ArrayRef> coords, unsigned axis, RewriterBase &rewriter); +SmallVector +computeLocalAddrs(Location loc, triton::gpu::MemDescType memDescTy, + SharedMemoryObject smemObj, Type llvmElemTy, + ArrayRef idxValues, + ArrayRef> coords, unsigned axis, + ArrayRef offsets, RewriterBase &rewriter); + SmallVector computeLocalPtrs(Location loc, triton::gpu::MemDescType memDescTy, SharedMemoryObject smemObj, Type llvmElemTy, @@ -625,6 +632,11 @@ SmallVector computeLocalPtrs(Location loc, ArrayRef> coords, unsigned axis, RewriterBase &rewriter); +SmallVector loadLocalAddrs(Location loc, Type llvmElemTy, + ArrayRef addrs, + RewriterBase &rewriter, + const TargetInfoBase &targetInfo); + // Backend-agnostic preparation for lowering LocalAtomicScatterRMWOp. struct LocalAtomicScatterRMWInfo { RankedTensorType valuesTy; diff --git a/include/triton/Dialect/TritonInstrument/IR/TritonInstrumentOps.td b/include/triton/Dialect/TritonInstrument/IR/TritonInstrumentOps.td index 1a48850a3525..8f1782e3846d 100644 --- a/include/triton/Dialect/TritonInstrument/IR/TritonInstrumentOps.td +++ b/include/triton/Dialect/TritonInstrument/IR/TritonInstrumentOps.td @@ -15,6 +15,7 @@ include "triton/Dialect/TritonInstrument/IR/TritonInstrumentAttrDefs.td" // Interfaces // def GlobalMemory : Resource<"::mlir::triton::GlobalMemory">; +def SharedMemory : Resource<"::mlir::triton::gpu::SharedMemory">; // // Ops @@ -78,6 +79,29 @@ def TTI_ExperimentalClusterCTAIdOp let assemblyFormat = "attr-dict `:` type($result)"; } +def TTI_ExperimentalLocalGatherOp + : TTI_Op<"experimental_local_gather"> { + let summary = "Gather elements from shared memory with logical base offsets"; + let description = [{ + Gather elements from a shared memory descriptor using an index tensor along + one axis, after shifting the logical source coordinates by rank-sized scalar + offsets. This is intentionally private to instrumentation passes. + }]; + let arguments = (ins + Arg]>:$src, + TT_IntTensor:$indices, + Variadic:$offsets, + I32Attr:$axis + ); + let results = (outs TT_Tensor:$result); + + let assemblyFormat = [{ + $src `[` $indices `]` `offsets` `=` `[` $offsets `]` + attr-dict `:` qualified(type($src)) `,` type($indices) `->` type($result) + }]; + let hasVerifier = 1; +} + def TTI_ExperimentalGSanInitOp : TTI_Op<"experimental_gsan_init"> { let summary = "Initialize GSan thread"; diff --git a/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp index 37f8b7008b8c..00239413efef 100644 --- a/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp @@ -27,37 +27,24 @@ lowerLocalScGt(Location loc, MLIRContext *ctx, MemDescType memDescTy, bool isScatter = !storeVals.empty(); SmallVector addrs = computeLocalAddrs( loc, memDescTy, smemObj, llvmElemTy, idxValues, coords, axis, rewriter); - Value currentCtaId; - if (!addrs.empty() && addrs.front().ctaId) - currentCtaId = targetInfo.getClusterCTAId(rewriter, loc); SmallVector results; if (!isScatter) - results.resize(coords.size()); + return loadLocalAddrs(loc, llvmElemTy, addrs, rewriter, targetInfo); + Value currentCtaId; + if (!addrs.empty() && addrs.front().ctaId) + currentCtaId = targetInfo.getClusterCTAId(rewriter, loc); for (auto [i, addr] : llvm::enumerate(addrs)) { - if (isScatter) { - if (addr.ctaId) { - Value isLocal = b.icmp_eq(*addr.ctaId, currentCtaId); - Value isRemote = b.icmp_ne(*addr.ctaId, currentCtaId); - targetInfo.storeShared(rewriter, loc, addr.ptr, storeVals[i], isLocal); - targetInfo.storeDShared(rewriter, loc, addr.ptr, addr.ctaId, - storeVals[i], isRemote); - } else { - targetInfo.storeShared(rewriter, loc, addr.ptr, storeVals[i], - b.true_val()); - } - } else if (addr.ctaId) { + if (addr.ctaId) { Value isLocal = b.icmp_eq(*addr.ctaId, currentCtaId); Value isRemote = b.icmp_ne(*addr.ctaId, currentCtaId); - Value local = - targetInfo.loadShared(rewriter, loc, addr.ptr, llvmElemTy, isLocal); - Value remote = targetInfo.loadDShared(rewriter, loc, addr.ptr, addr.ctaId, - llvmElemTy, isRemote); - results[i] = b.select(isLocal, local, remote); + targetInfo.storeShared(rewriter, loc, addr.ptr, storeVals[i], isLocal); + targetInfo.storeDShared(rewriter, loc, addr.ptr, addr.ctaId, storeVals[i], + isRemote); } else { - results[i] = targetInfo.loadShared(rewriter, loc, addr.ptr, llvmElemTy, - b.true_val()); + targetInfo.storeShared(rewriter, loc, addr.ptr, storeVals[i], + b.true_val()); } } diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.cpp b/lib/Conversion/TritonGPUToLLVM/Utility.cpp index 651666f64818..fde30b0b53b5 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.cpp +++ b/lib/Conversion/TritonGPUToLLVM/Utility.cpp @@ -546,6 +546,16 @@ computeLocalAddrs(Location loc, triton::gpu::MemDescType memDescTy, ArrayRef idxValues, ArrayRef> coords, unsigned axis, RewriterBase &rewriter) { + return computeLocalAddrs(loc, memDescTy, smemObj, llvmElemTy, idxValues, + coords, axis, /*offsets=*/{}, rewriter); +} + +SmallVector +computeLocalAddrs(Location loc, triton::gpu::MemDescType memDescTy, + SharedMemoryObject smemObj, Type llvmElemTy, + ArrayRef idxValues, + ArrayRef> coords, unsigned axis, + ArrayRef offsets, RewriterBase &rewriter) { MLIRContext *ctx = memDescTy.getContext(); auto b = TritonLLVMOpBuilder(loc, rewriter); @@ -581,9 +591,12 @@ computeLocalAddrs(Location loc, triton::gpu::MemDescType memDescTy, idx = b.zext(i32_ty, idx); } - // Copy coordinates and replace the axis coordinate with the index value + // Copy coordinates, replace the axis coordinate with the index value, and + // then shift all logical coordinates by the optional base offsets. SmallVector indices(coords[i]); indices[axis] = idx; + for (auto [dim, offset] : llvm::enumerate(offsets)) + indices[dim] = b.add(indices[dim], offset); // Apply inverted shared layout to compute offset SmallVector> inputs; @@ -651,6 +664,37 @@ SmallVector computeLocalPtrs(Location loc, [](const LocalSharedMemoryAddress &addr) { return addr.ptr; }); } +SmallVector loadLocalAddrs(Location loc, Type llvmElemTy, + ArrayRef addrs, + RewriterBase &rewriter, + const TargetInfoBase &targetInfo) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + Value currentCtaId; + if (llvm::any_of(addrs, [](const LocalSharedMemoryAddress &addr) { + return addr.ctaId.has_value(); + })) + currentCtaId = targetInfo.getClusterCTAId(rewriter, loc); + + SmallVector results; + results.reserve(addrs.size()); + for (const LocalSharedMemoryAddress &addr : addrs) { + if (!addr.ctaId) { + results.push_back(targetInfo.loadShared(rewriter, loc, addr.ptr, + llvmElemTy, b.true_val())); + continue; + } + + Value isLocal = b.icmp_eq(*addr.ctaId, currentCtaId); + Value isRemote = b.icmp_ne(*addr.ctaId, currentCtaId); + Value local = + targetInfo.loadShared(rewriter, loc, addr.ptr, llvmElemTy, isLocal); + Value remote = targetInfo.loadDShared(rewriter, loc, addr.ptr, addr.ctaId, + llvmElemTy, isRemote); + results.push_back(b.select(isLocal, local, remote)); + } + return results; +} + FailureOr prepareLocalAtomicScatterRMW( triton::gpu::LocalAtomicScatterRMWOp op, Value dst, Value indices, Value inputValues, Value mask, ConversionPatternRewriter &rewriter, diff --git a/lib/Conversion/TritonInstrumentToLLVM/InstrumentationToLLVM.cpp b/lib/Conversion/TritonInstrumentToLLVM/InstrumentationToLLVM.cpp index ba5303472036..ba0a374db97a 100644 --- a/lib/Conversion/TritonInstrumentToLLVM/InstrumentationToLLVM.cpp +++ b/lib/Conversion/TritonInstrumentToLLVM/InstrumentationToLLVM.cpp @@ -347,6 +347,46 @@ struct ClusterCTAIdOpConversion const TargetInfoBase &targetInfo; }; +struct LocalGatherOpConversion + : public ConvertOpToLLVMPattern { + LocalGatherOpConversion(const LLVMTypeConverter &converter, + const TargetInfoBase &targetInfo, + PatternBenefit benefit = 1) + : ConvertOpToLLVMPattern(converter, + benefit), + targetInfo(targetInfo) {} + + LogicalResult + matchAndRewrite(tti::ExperimentalLocalGatherOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto memDescTy = cast(op.getSrc().getType()); + auto regTy = cast(op.getType()); + auto typeConverter = getTypeConverter(); + + Type llvmElemTy = typeConverter->convertType(memDescTy.getElementType()); + auto smemObj = LLVM::getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(), + llvmElemTy, rewriter); + auto idxValues = unpackLLElements(loc, adaptor.getIndices(), rewriter); + auto dstIndices = + emitIndices(loc, rewriter, targetInfo, regTy.getEncoding(), regTy, + /*withCTAOffset=*/true); + SmallVector offsets(adaptor.getOffsets()); + + auto addrs = + computeLocalAddrs(loc, memDescTy, smemObj, llvmElemTy, idxValues, + dstIndices, op.getAxis(), offsets, rewriter); + auto results = loadLocalAddrs(loc, llvmElemTy, addrs, rewriter, targetInfo); + Value result = packLLElements(loc, typeConverter, results, rewriter, regTy); + + rewriter.replaceOp(op, result); + return success(); + } + +private: + const TargetInfoBase &targetInfo; +}; + } // namespace void mlir::triton::populateInstrumentationToLLVMPatterns( @@ -358,4 +398,5 @@ void mlir::triton::populateInstrumentationToLLVMPatterns( patterns.add(typeConverter, targetInfo); patterns.add(typeConverter); patterns.add(typeConverter, targetInfo); + patterns.add(typeConverter, targetInfo); } diff --git a/lib/Dialect/TritonInstrument/IR/Ops.cpp b/lib/Dialect/TritonInstrument/IR/Ops.cpp index 9b8ec082eaff..468e34444986 100644 --- a/lib/Dialect/TritonInstrument/IR/Ops.cpp +++ b/lib/Dialect/TritonInstrument/IR/Ops.cpp @@ -40,6 +40,40 @@ LogicalResult DotI8Op::verify() { bEnc); } +LogicalResult ExperimentalLocalGatherOp::verify() { + auto srcTy = getSrc().getType(); + auto indicesTy = cast(getIndices().getType()); + auto dstTy = cast(getType()); + unsigned axis = getAxis(); + + if (!isa(srcTy.getEncoding())) + return emitError("source must have shared memory encoding"); + + if (!indicesTy.getElementType().isInteger()) + return emitError("indices must have integer element type"); + + if (dstTy.getShape() != indicesTy.getShape()) + return emitError("result shape must match indices shape"); + + if (srcTy.getRank() != indicesTy.getRank()) + return emitError("source and indices must have the same rank"); + + if (axis >= srcTy.getRank()) + return emitError("axis ") + << axis << " is out of bounds for source rank " << srcTy.getRank(); + + if (srcTy.getElementType() != dstTy.getElementType()) + return emitError("result element type must match source element type"); + + if (indicesTy.getEncoding() != dstTy.getEncoding()) + return emitError("indices and result must have the same layout"); + + if (static_cast(getOffsets().size()) != srcTy.getRank()) + return emitError("offset count must match source rank"); + + return success(); +} + template struct PushFPSanThroughViewPattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; diff --git a/lib/Dialect/TritonInstrument/Transforms/FpSanitizer.cpp b/lib/Dialect/TritonInstrument/Transforms/FpSanitizer.cpp index 37c8dae6f580..17680e0cd514 100644 --- a/lib/Dialect/TritonInstrument/Transforms/FpSanitizer.cpp +++ b/lib/Dialect/TritonInstrument/Transforms/FpSanitizer.cpp @@ -3,6 +3,7 @@ #include "mlir/IR/Types.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/RegionUtils.h" +#include "triton/Analysis/Utility.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/Transforms/Utility.h" @@ -83,6 +84,21 @@ std::pair getMmaEmulationTileShape(PatternRewriter &rewriter, return {std::min(kTileM, m), std::min(kTileN, n)}; } +std::pair +getDirectSharedMmaEmulationTileShape(PatternRewriter &rewriter, int64_t m, + int64_t n, int64_t k, + IntegerType accElem) { + auto [tileM, tileN] = getMmaEmulationTileShape(rewriter, m, n, k, accElem); + int64_t numWarps = + ttg::lookupNumWarps(rewriter.getInsertionBlock()->getParent()); + int64_t widerN = std::min(16 * numWarps, n); + while (widerN > tileN && (n % widerN) != 0) + widerN /= 2; + if (widerN > tileN && getI8MmaWarpsPerCTA(tileM, widerN, numWarps)) + tileN = widerN; + return {tileM, tileN}; +} + Operation *createGlobalScratchBarrier(PatternRewriter &rewriter, Location loc, bool sharedClusterState = false) { Operation *barrier = ttg::BarrierOp::create(rewriter, loc, @@ -152,6 +168,16 @@ struct ScratchInfo { RankedTensorType tensorType; }; +struct MmaOperandSource { + Value scratchPtr; + Value sharedMemdesc; + RankedTensorType tileType; + int64_t rowStride; + int64_t stride; + + bool isShared() const { return static_cast(sharedMemdesc); } +}; + struct ScratchState { std::optional canonical; DenseMap byScope; @@ -885,31 +911,18 @@ Value fpsanVariadicExternTagged(PatternRewriter &rewriter, Location loc, } std::optional -createOperandScratch(PatternRewriter &rewriter, Location loc, - TmemScratchManager &scratch, Value memdesc, - ttg::MemDescType memTy, bool isTmem, Region *scope) { +createTmemOperandScratch(PatternRewriter &rewriter, Location loc, + TmemScratchManager &scratch, Value memdesc, + ttg::MemDescType memTy, Region *scope) { auto layout = scratch.getScratchEncoding(rewriter, memdesc, memTy); auto tensorTy = RankedTensorType::get(memTy.getShape(), memTy.getElementType(), layout); - Value fullVal; - if (isTmem) { - auto info = scratch.getOrCreate(memdesc, rewriter, scope); - if (!info) - return std::nullopt; - fullVal = loadFpSanScratchMemory(rewriter, loc, info->ptr, tensorTy); - if (!fullVal) - return std::nullopt; - } else { - if (scratch.usesSharedClusterState()) { - // A two-CTA TMA barrier is waited on by the lead CTA. FPSAN executes the - // emulated MMA in both CTAs, so rendezvous before the partner snapshots - // the shared-memory operand. - ttng::ClusterBarrierOp::create(rewriter, loc); - } - fullVal = - ttg::LocalLoadOp::create(rewriter, loc, tensorTy, memdesc, Value()) - .getResult(); - } + auto info = scratch.getOrCreate(memdesc, rewriter, scope); + if (!info) + return std::nullopt; + Value fullVal = loadFpSanScratchMemory(rewriter, loc, info->ptr, tensorTy); + if (!fullVal) + return std::nullopt; int64_t elSize = memTy.getElementType().getIntOrFloatBitWidth() / 8; int64_t alignment = std::max(elSize, 16); int64_t sizeInBytes = product(memTy.getShape()) * elSize; @@ -1025,6 +1038,75 @@ Value loadScratchStrided2D(PatternRewriter &rewriter, Location loc, Value base, stride1); } +Value loadMmaOperand(PatternRewriter &rewriter, Location loc, + const MmaOperandSource &source, RankedTensorType resultTy, + bool isLhs, Value tileOffset, Value kOffset) { + if (!source.isShared()) { + Value rowOffset = isLhs ? tileOffset : kOffset; + Value colOffset = isLhs ? kOffset : tileOffset; + Value rowStride = arith::ConstantOp::create( + rewriter, loc, rewriter.getI32IntegerAttr(source.rowStride)); + Value stride = arith::ConstantOp::create( + rewriter, loc, rewriter.getI32IntegerAttr(source.stride)); + Value row = arith::MulIOp::create(rewriter, loc, rowOffset, rowStride); + Value col = arith::MulIOp::create(rewriter, loc, colOffset, stride); + Value offset = arith::AddIOp::create(rewriter, loc, row, col); + Value ptr = tt::AddPtrOp::create(rewriter, loc, source.scratchPtr.getType(), + source.scratchPtr, offset); + return loadScratchStrided2D(rewriter, loc, ptr, resultTy, source.rowStride, + source.stride); + } + + Value shared = source.sharedMemdesc; + auto sharedTy = cast(shared.getType()); + unsigned tileAxis = isLhs ? 0 : 1; + unsigned kAxis = 1 - tileAxis; + auto cgaLayout = ttg::getCGALayout(sharedTy.getEncoding()); + auto replicatedCGA = ttg::CGAEncodingAttr::fromSplitParams( + rewriter.getContext(), cgaLayout.getCTAsPerCGA(), + SmallVector(2, 1), cgaLayout.getCTAOrder()); + SmallVector loadShape(resultTy.getShape()); + int numWarps = ttg::lookupNumWarps(rewriter.getInsertionBlock()->getParent()); + int threadsPerWarp = ttg::lookupThreadsPerWarp(rewriter); + SmallVector threadsPerWarpShape(2, 1); + threadsPerWarpShape[tileAxis] = + std::min(loadShape[tileAxis], threadsPerWarp); + threadsPerWarpShape[kAxis] = threadsPerWarp / threadsPerWarpShape[tileAxis]; + SmallVector warpsPerCTA(2, 1); + warpsPerCTA[kAxis] = numWarps; + SmallVector sizePerThread(2, 1); + sizePerThread[tileAxis] = loadShape[tileAxis] / threadsPerWarpShape[tileAxis]; + auto loadLayout = ttg::BlockedEncodingAttr::get( + rewriter.getContext(), sizePerThread, threadsPerWarpShape, warpsPerCTA, + SmallVector{tileAxis, kAxis}, replicatedCGA); + auto loadTy = + RankedTensorType::get(loadShape, resultTy.getElementType(), loadLayout); + auto indicesTy = + RankedTensorType::get(loadShape, rewriter.getI32Type(), loadLayout); + auto kEncoding = getSingleDimSliceEncoding(loadLayout, kAxis); + auto kTy = RankedTensorType::get({loadShape[kAxis]}, rewriter.getI32Type(), + kEncoding); + Value indices = + tt::MakeRangeOp::create(rewriter, loc, kTy, 0, loadShape[kAxis]); + Value kSplat = tt::SplatOp::create(rewriter, loc, kTy, kOffset); + indices = arith::AddIOp::create(rewriter, loc, indices, kSplat); + indices = expandAllSlicedDims(rewriter, loc, indices); + if (cast(indices.getType()).getShape() != + ArrayRef(loadShape)) + indices = tt::BroadcastOp::create(rewriter, loc, indicesTy, indices); + Value zero = + arith::ConstantOp::create(rewriter, loc, rewriter.getI32IntegerAttr(0)); + SmallVector offsets(2, zero); + offsets[tileAxis] = tileOffset; + Value loaded = ExperimentalLocalGatherOp::create( + rewriter, loc, loadTy, shared, indices, offsets, + rewriter.getI32IntegerAttr(kAxis)) + .getResult(); + if (loaded.getType() != resultTy) + loaded = ttg::ConvertLayoutOp::create(rewriter, loc, resultTy, loaded); + return loaded; +} + Operation *storeScratchStrided2D(PatternRewriter &rewriter, Location loc, Value base, Value tensor, RankedTensorType tensorTy, int64_t stride0, @@ -1266,17 +1348,17 @@ Value unpackPackedFp4Tensor(PatternRewriter &rewriter, Location loc, } Value loadOperandK32(PatternRewriter &rewriter, Location loc, bool isLhs, - Value tilePtr, RankedTensorType fullTileTy, Value kI32, - int64_t rowStride, int64_t stride, + const MmaOperandSource &source, Value tileIdx, Value kI32, int64_t packFactor = 1) { - SmallVector logicalShape{isLhs ? fullTileTy.getShape()[0] : kI8MmaK, - isLhs ? kI8MmaK : fullTileTy.getShape()[1]}; + SmallVector logicalShape{ + isLhs ? source.tileType.getShape()[0] : kI8MmaK, + isLhs ? kI8MmaK : source.tileType.getShape()[1]}; SmallVector rawShape = logicalShape; rawShape[isLhs ? 1 : 0] /= packFactor; - auto rawLayout = getOptimizedBlockedEncoding(rewriter, rawShape, - fullTileTy.getElementType()); - auto rawTy = - RankedTensorType::get(rawShape, fullTileTy.getElementType(), rawLayout); + auto rawLayout = getOptimizedBlockedEncoding( + rewriter, rawShape, source.tileType.getElementType()); + auto rawTy = RankedTensorType::get(rawShape, source.tileType.getElementType(), + rawLayout); Value packedKIdx = kI32; if (packFactor != 1) { @@ -1284,13 +1366,8 @@ Value loadOperandK32(PatternRewriter &rewriter, Location loc, bool isLhs, rewriter, loc, rewriter.getI32IntegerAttr(packFactor)); packedKIdx = arith::DivUIOp::create(rewriter, loc, kI32, factor); } - Value kStride = arith::ConstantOp::create( - rewriter, loc, rewriter.getI32IntegerAttr(isLhs ? stride : rowStride)); - Value offset = arith::MulIOp::create(rewriter, loc, packedKIdx, kStride); - Value chunkPtr = - tt::AddPtrOp::create(rewriter, loc, tilePtr.getType(), tilePtr, offset); Value chunk = - loadScratchStrided2D(rewriter, loc, chunkPtr, rawTy, rowStride, stride); + loadMmaOperand(rewriter, loc, source, rawTy, isLhs, tileIdx, packedKIdx); if (packFactor == 2) { auto logicalLayout = getOptimizedBlockedEncoding(rewriter, logicalShape, @@ -1365,13 +1442,13 @@ Value loadScaledScaleK32(PatternRewriter &rewriter, Location loc, bool isLhs, } Value loadScaledOperandK32(PatternRewriter &rewriter, Location loc, bool isLhs, - Value tilePtr, RankedTensorType fullTileTy, + const MmaOperandSource &source, const DotScaleConfig &scale, Value tileIdx, - Value kI32, int64_t rowStride, int64_t stride) { + Value kI32) { int64_t packFactor = isLhs ? scale.aKPackFactor : scale.bKPackFactor; tt::ScaleDotElemType elemType = isLhs ? scale.aElemType : scale.bElemType; - Value chunk = loadOperandK32(rewriter, loc, isLhs, tilePtr, fullTileTy, kI32, - rowStride, stride, packFactor); + Value chunk = + loadOperandK32(rewriter, loc, isLhs, source, tileIdx, kI32, packFactor); Value payload = castDotScaledOperandToComputePayload( rewriter, loc, chunk, elemType, scale.computeElem); @@ -1487,49 +1564,22 @@ Value tryEmitI8DotDecomposition(PatternRewriter &rewriter, Location loc, return ttg::ConvertLayoutOp::create(rewriter, loc, accTy, product); } -std::optional emitMmaEmulationLoops( - PatternRewriter &rewriter, Location loc, Value aPtr, Value bPtr, Value dPtr, - int64_t m, int64_t n, int64_t k, int64_t tileM, int64_t tileN, - RankedTensorType aTileTy, RankedTensorType bTileTy, - RankedTensorType accTileTy, ttg::DistributedEncodingTrait accLayout, - IntegerType accElem, Value useDInt, Value predInt, int64_t aStride, - int64_t bStride, int64_t dStride, const DotScaleConfig &scale = {}, - int64_t aRowStride = 1, int64_t bRowStride = 1, int64_t dRowStride = 1) { - if ((m % tileM) != 0 || (n % tileN) != 0) - return std::nullopt; +LogicalResult emitMmaEmulationTile( + PatternRewriter &rewriter, Location loc, const MmaOperandSource &aSource, + const MmaOperandSource &bSource, Value dPtr, int64_t k, int64_t tileM, + int64_t tileN, RankedTensorType accTileTy, + ttg::DistributedEncodingTrait accLayout, IntegerType accElem, Value useDInt, + Value predInt, Value mIdxI32, Value nIdxI32, int64_t dStride, + const DotScaleConfig &scale = {}, int64_t dRowStride = 1) { + bool hasSharedOperand = aSource.isShared() || bSource.isShared(); - OpBuilder::InsertionGuard guard(rewriter); + auto i32Ty = rewriter.getI32Type(); Value zero = arith::ConstantOp::create(rewriter, loc, rewriter.getI32IntegerAttr(0)); - Value mUpper = - arith::ConstantOp::create(rewriter, loc, rewriter.getI32IntegerAttr(m)); - Value nUpper = - arith::ConstantOp::create(rewriter, loc, rewriter.getI32IntegerAttr(n)); - Value mStep = arith::ConstantOp::create(rewriter, loc, - rewriter.getI32IntegerAttr(tileM)); - Value nStep = arith::ConstantOp::create(rewriter, loc, - rewriter.getI32IntegerAttr(tileN)); - - auto mLoop = scf::ForOp::create(rewriter, loc, zero, mUpper, mStep); - rewriter.setInsertionPointToStart(mLoop.getBody()); - Value mIdx = mLoop.getInductionVar(); - auto nLoop = scf::ForOp::create(rewriter, loc, zero, nUpper, nStep); - rewriter.setInsertionPointToStart(nLoop.getBody()); - Value nIdx = nLoop.getInductionVar(); - - auto i32Ty = rewriter.getI32Type(); - Value mIdxI32 = arith::IndexCastOp::create(rewriter, loc, i32Ty, mIdx); - Value nIdxI32 = arith::IndexCastOp::create(rewriter, loc, i32Ty, nIdx); Value dRowStrideConst = arith::ConstantOp::create( rewriter, loc, rewriter.getI32IntegerAttr(dRowStride)); Value dStrideConst = arith::ConstantOp::create( rewriter, loc, rewriter.getI32IntegerAttr(dStride)); - Value aRowStrideConst = arith::ConstantOp::create( - rewriter, loc, rewriter.getI32IntegerAttr(aRowStride)); - Value bStrideConst = arith::ConstantOp::create( - rewriter, loc, rewriter.getI32IntegerAttr(bStride)); - Value bRowStrideConst = arith::ConstantOp::create( - rewriter, loc, rewriter.getI32IntegerAttr(bRowStride)); Value mDOffset = arith::MulIOp::create(rewriter, loc, mIdxI32, dRowStrideConst); @@ -1541,14 +1591,6 @@ std::optional emitMmaEmulationLoops( dRowStride, dStride); Value accTileI = embedToInt(rewriter, loc, accTile); - Value aMOffset = - arith::MulIOp::create(rewriter, loc, mIdxI32, aRowStrideConst); - Value aTilePtr = - tt::AddPtrOp::create(rewriter, loc, aPtr.getType(), aPtr, aMOffset); - Value bOffset = arith::MulIOp::create(rewriter, loc, nIdxI32, bStrideConst); - Value bTilePtr = - tt::AddPtrOp::create(rewriter, loc, bPtr.getType(), bPtr, bOffset); - Value sum; int numWarps = ttg::lookupNumWarps(rewriter.getInsertionBlock()->getParent()); auto isScaleK32Aligned = [](Value scalePtr, int64_t scaleFactor) { @@ -1561,11 +1603,11 @@ std::optional emitMmaEmulationLoops( isScaleK32Aligned(scale.bScalePtr, scale.bScaleFactor) && canUseI8MmaTile(tileM, tileN, numWarps); if (canUseI8Decomposition) { - if (!scale.computeElem && accElem.getWidth() <= 32) { - Value aTile = loadScratchStrided2D(rewriter, loc, aTilePtr, aTileTy, - aRowStride, aStride); - Value bTile = loadScratchStrided2D(rewriter, loc, bTilePtr, bTileTy, - bRowStride, bStride); + if (!hasSharedOperand && !scale.computeElem && accElem.getWidth() <= 32) { + Value aTile = loadMmaOperand(rewriter, loc, aSource, aSource.tileType, + /*isLhs=*/true, mIdxI32, zero); + Value bTile = loadMmaOperand(rewriter, loc, bSource, bSource.tileType, + /*isLhs=*/false, nIdxI32, zero); sum = tryEmitI8DotDecomposition( rewriter, loc, embedToInt(rewriter, loc, aTile), embedToInt(rewriter, loc, bTile), accLayout, accElem, numWarps); @@ -1585,17 +1627,15 @@ std::optional emitMmaEmulationLoops( Value aChunk; Value bChunk; if (scale.computeElem) { - aChunk = loadScaledOperandK32(rewriter, loc, /*isLhs=*/true, aTilePtr, - aTileTy, scale, mIdxI32, kI32, aRowStride, - aStride); - bChunk = loadScaledOperandK32(rewriter, loc, /*isLhs=*/false, bTilePtr, - bTileTy, scale, nIdxI32, kI32, bRowStride, - bStride); + aChunk = loadScaledOperandK32(rewriter, loc, /*isLhs=*/true, aSource, + scale, mIdxI32, kI32); + bChunk = loadScaledOperandK32(rewriter, loc, /*isLhs=*/false, bSource, + scale, nIdxI32, kI32); } else { - aChunk = loadOperandK32(rewriter, loc, /*isLhs=*/true, aTilePtr, - aTileTy, kI32, aRowStride, aStride); - bChunk = loadOperandK32(rewriter, loc, /*isLhs=*/false, bTilePtr, - bTileTy, kI32, bRowStride, bStride); + aChunk = loadOperandK32(rewriter, loc, /*isLhs=*/true, aSource, mIdxI32, + kI32); + bChunk = loadOperandK32(rewriter, loc, /*isLhs=*/false, bSource, + nIdxI32, kI32); } Value partial = tryEmitI8DotDecomposition( rewriter, loc, embedToInt(rewriter, loc, aChunk), @@ -1612,13 +1652,10 @@ std::optional emitMmaEmulationLoops( } if (!sum) { - auto aSliceTy = - RankedTensorType::get({tileM, 1}, aTileTy.getElementType(), accLayout); - auto bSliceTy = - RankedTensorType::get({1, tileN}, bTileTy.getElementType(), accLayout); - Value aStrideVal = arith::ConstantOp::create( - rewriter, loc, rewriter.getI32IntegerAttr(aStride)); - + auto aSliceTy = RankedTensorType::get( + {tileM, 1}, aSource.tileType.getElementType(), accLayout); + auto bSliceTy = RankedTensorType::get( + {1, tileN}, bSource.tileType.getElementType(), accLayout); Value zeroSum = getIntConstantLike(rewriter, loc, accTileI.getType(), 0); Value kUpper = arith::ConstantOp::create(rewriter, loc, rewriter.getI32IntegerAttr(k)); @@ -1641,18 +1678,10 @@ std::optional emitMmaEmulationLoops( rewriter, loc, rewriter.getI32IntegerAttr(scale.bKPackFactor)); bKIdx = arith::DivUIOp::create(rewriter, loc, kI32, bPackFactor); } - Value aOffset = - arith::MulIOp::create(rewriter, loc, i32Ty, aKIdx, aStrideVal); - Value aSlicePtr = - tt::AddPtrOp::create(rewriter, loc, aPtr.getType(), aTilePtr, aOffset); - Value aSlice = loadScratchStrided2D(rewriter, loc, aSlicePtr, aSliceTy, - aRowStride, aStride); - Value bKOffset = - arith::MulIOp::create(rewriter, loc, bKIdx, bRowStrideConst); - Value bSlicePtr = - tt::AddPtrOp::create(rewriter, loc, bPtr.getType(), bTilePtr, bKOffset); - Value bSlice = loadScratchStrided2D(rewriter, loc, bSlicePtr, bSliceTy, - bRowStride, bStride); + Value aSlice = loadMmaOperand(rewriter, loc, aSource, aSliceTy, + /*isLhs=*/true, mIdxI32, aKIdx); + Value bSlice = loadMmaOperand(rewriter, loc, bSource, bSliceTy, + /*isLhs=*/false, nIdxI32, bKIdx); if (scale.aKPackFactor == 2) aSlice = unpackPackedFp4Slice(rewriter, loc, aSlice, kI32); if (scale.bKPackFactor == 2) @@ -1695,10 +1724,63 @@ std::optional emitMmaEmulationLoops( createGlobalScratchBarrier(rewriter, loc); storeScratchStrided2D(rewriter, loc, dTilePtr, out, accTileTy, dRowStride, dStride); + return success(); +} + +std::optional emitMmaEmulationLoops( + PatternRewriter &rewriter, Location loc, const MmaOperandSource &aSource, + const MmaOperandSource &bSource, Value dPtr, int64_t m, int64_t n, + int64_t k, int64_t tileM, int64_t tileN, RankedTensorType accTileTy, + ttg::DistributedEncodingTrait accLayout, IntegerType accElem, Value useDInt, + Value predInt, int64_t dStride, const DotScaleConfig &scale = {}, + int64_t dRowStride = 1) { + if ((m % tileM) != 0 || (n % tileN) != 0) + return std::nullopt; + + OpBuilder::InsertionGuard guard(rewriter); + Value zero = + arith::ConstantOp::create(rewriter, loc, rewriter.getI32IntegerAttr(0)); + Value mUpper = + arith::ConstantOp::create(rewriter, loc, rewriter.getI32IntegerAttr(m)); + Value nUpper = + arith::ConstantOp::create(rewriter, loc, rewriter.getI32IntegerAttr(n)); + Value mStep = arith::ConstantOp::create(rewriter, loc, + rewriter.getI32IntegerAttr(tileM)); + Value nStep = arith::ConstantOp::create(rewriter, loc, + rewriter.getI32IntegerAttr(tileN)); + auto mLoop = scf::ForOp::create(rewriter, loc, zero, mUpper, mStep); + rewriter.setInsertionPointToStart(mLoop.getBody()); + Value mIdx = mLoop.getInductionVar(); + auto nLoop = scf::ForOp::create(rewriter, loc, zero, nUpper, nStep); + rewriter.setInsertionPointToStart(nLoop.getBody()); + auto i32Ty = rewriter.getI32Type(); + Value mIdxI32 = arith::IndexCastOp::create(rewriter, loc, i32Ty, mIdx); + Value nIdxI32 = + arith::IndexCastOp::create(rewriter, loc, i32Ty, nLoop.getInductionVar()); + if (failed(emitMmaEmulationTile(rewriter, loc, aSource, bSource, dPtr, k, + tileM, tileN, accTileTy, accLayout, accElem, + useDInt, predInt, mIdxI32, nIdxI32, dStride, + scale, dRowStride))) + return std::nullopt; return mLoop; } +std::optional emitMmaEmulationLoops( + PatternRewriter &rewriter, Location loc, Value aPtr, Value bPtr, Value dPtr, + int64_t m, int64_t n, int64_t k, int64_t tileM, int64_t tileN, + RankedTensorType aTileTy, RankedTensorType bTileTy, + RankedTensorType accTileTy, ttg::DistributedEncodingTrait accLayout, + IntegerType accElem, Value useDInt, Value predInt, int64_t aStride, + int64_t bStride, int64_t dStride, const DotScaleConfig &scale = {}, + int64_t aRowStride = 1, int64_t bRowStride = 1, int64_t dRowStride = 1) { + MmaOperandSource aSource{aPtr, Value(), aTileTy, aRowStride, aStride}; + MmaOperandSource bSource{bPtr, Value(), bTileTy, bRowStride, bStride}; + return emitMmaEmulationLoops(rewriter, loc, aSource, bSource, dPtr, m, n, k, + tileM, tileN, accTileTy, accLayout, accElem, + useDInt, predInt, dStride, scale, dRowStride); +} + //---------------------------------------- // Patterns //---------------------------------------- @@ -2505,42 +2587,57 @@ struct TCGen5MMAPattern : public OpRewritePattern { arith::ExtUIOp::create(rewriter, loc, accElem, op.getPred()); rewriter.setInsertionPoint(op); - auto aScratch = createOperandScratch(rewriter, loc, *scratch, op.getA(), - aMemTy, aIsTmem, scope); - if (!aScratch) - return emitFpSanCodegenError(op.getOperation()); - auto bScratch = createOperandScratch(rewriter, loc, *scratch, op.getB(), - bMemTy, bIsTmem, scope); - if (!bScratch) - return emitFpSanCodegenError(op.getOperation()); - - // Each warp may only populate a subset of the operand scratch tiles, so - // synchronize before the emulation loops start reading them. - createGlobalScratchBarrier(rewriter, loc, - scratch->usesSharedClusterState()); - - auto [tileM, tileN] = getMmaEmulationTileShape(rewriter, m, n, k, accElem); + auto [tileM, tileN] = + (!aIsTmem || !bIsTmem) + ? getDirectSharedMmaEmulationTileShape(rewriter, m, n, k, accElem) + : getMmaEmulationTileShape(rewriter, m, n, k, accElem); auto accTileLayout = getOptimizedBlockedEncoding(rewriter, {tileM, tileN}, accElem); auto accTileTy = RankedTensorType::get({tileM, tileN}, accElem, accTileLayout); - auto aTileLayout = getOptimizedBlockedEncoding( - rewriter, {tileM, k}, - getScratchStorageElementType(aMemTy.getElementType())); - auto aTileTy = RankedTensorType::get( - {tileM, k}, getScratchStorageElementType(aMemTy.getElementType()), - aTileLayout); - auto bTileLayout = getOptimizedBlockedEncoding( - rewriter, {k, tileN}, - getScratchStorageElementType(bMemTy.getElementType())); - auto bTileTy = RankedTensorType::get( - {k, tileN}, getScratchStorageElementType(bMemTy.getElementType()), - bTileLayout); + Type aTileElem = aIsTmem + ? getScratchStorageElementType(aMemTy.getElementType()) + : aMemTy.getElementType(); + auto aTileLayout = + getOptimizedBlockedEncoding(rewriter, {tileM, k}, aTileElem); + auto aTileTy = RankedTensorType::get({tileM, k}, aTileElem, aTileLayout); + Type bTileElem = bIsTmem + ? getScratchStorageElementType(bMemTy.getElementType()) + : bMemTy.getElementType(); + auto bTileLayout = + getOptimizedBlockedEncoding(rewriter, {k, tileN}, bTileElem); + auto bTileTy = RankedTensorType::get({k, tileN}, bTileElem, bTileLayout); + + std::optional aScratch; + if (aIsTmem) { + aScratch = createTmemOperandScratch(rewriter, loc, *scratch, op.getA(), + aMemTy, scope); + if (!aScratch) + return emitFpSanCodegenError(op.getOperation()); + } + std::optional bScratch; + if (bIsTmem) { + bScratch = createTmemOperandScratch(rewriter, loc, *scratch, op.getB(), + bMemTy, scope); + if (!bScratch) + return emitFpSanCodegenError(op.getOperation()); + } + MmaOperandSource aSource{aIsTmem ? aScratch->ptr : Value(), + aIsTmem ? Value() : op.getA(), aTileTy, + /*rowStride=*/1, /*stride=*/m}; + MmaOperandSource bSource{bIsTmem ? bScratch->ptr : Value(), + bIsTmem ? Value() : op.getB(), bTileTy, + /*rowStride=*/1, /*stride=*/k}; + + // TMEM and D scratch are written cooperatively. In two-CTA mode, the + // cluster barrier also makes both CTAs' shared operands visible before the + // first direct load. + createGlobalScratchBarrier(rewriter, loc, + scratch->usesSharedClusterState()); auto mLoop = emitMmaEmulationLoops( - rewriter, loc, aScratch->ptr, bScratch->ptr, dInfo->ptr, m, n, k, tileM, - tileN, aTileTy, bTileTy, accTileTy, accTileLayout, accElem, useDInt, - predInt, /*aStride=*/m, /*bStride=*/k, /*dStride=*/m); + rewriter, loc, aSource, bSource, dInfo->ptr, m, n, k, tileM, tileN, + accTileTy, accTileLayout, accElem, useDInt, predInt, /*dStride=*/m); if (!mLoop) return emitFpSanUnsupported(op.getOperation()); rewriter.setInsertionPointAfter(*mLoop); @@ -2661,24 +2758,19 @@ struct TCGen5MMAScaledPattern arith::ExtUIOp::create(rewriter, loc, accElem, op.getPred()); rewriter.setInsertionPoint(op); - auto aScratch = createOperandScratch(rewriter, loc, *scratch, op.getA(), - aMemTy, aIsTmem, scope); - if (!aScratch) - return emitFpSanCodegenError(op.getOperation()); - auto bScratch = createOperandScratch(rewriter, loc, *scratch, op.getB(), - bMemTy, bIsTmem, scope); - if (!bScratch) - return emitFpSanCodegenError(op.getOperation()); - auto aScaleScratch = createOperandScratch( - rewriter, loc, *scratch, op.getAScale(), aScaleMemTy, true, scope); + auto aScaleScratch = createTmemOperandScratch( + rewriter, loc, *scratch, op.getAScale(), aScaleMemTy, scope); if (!aScaleScratch) return emitFpSanCodegenError(op.getOperation()); - auto bScaleScratch = createOperandScratch( - rewriter, loc, *scratch, op.getBScale(), bScaleMemTy, true, scope); + auto bScaleScratch = createTmemOperandScratch( + rewriter, loc, *scratch, op.getBScale(), bScaleMemTy, scope); if (!bScaleScratch) return emitFpSanCodegenError(op.getOperation()); - auto [tileM, tileN] = getMmaEmulationTileShape(rewriter, m, n, k, accElem); + auto [tileM, tileN] = + (!aIsTmem || !bIsTmem) + ? getDirectSharedMmaEmulationTileShape(rewriter, m, n, k, accElem) + : getMmaEmulationTileShape(rewriter, m, n, k, accElem); auto accTileLayout = getOptimizedBlockedEncoding(rewriter, {tileM, tileN}, dMemTy.getElementType()); @@ -2693,6 +2785,27 @@ struct TCGen5MMAScaledPattern auto bTileTy = RankedTensorType::get({bPackedK, tileN}, bMemTy.getElementType(), bTileLayout); + std::optional aScratch; + if (aIsTmem) { + aScratch = createTmemOperandScratch(rewriter, loc, *scratch, op.getA(), + aMemTy, scope); + if (!aScratch) + return emitFpSanCodegenError(op.getOperation()); + } + std::optional bScratch; + if (bIsTmem) { + bScratch = createTmemOperandScratch(rewriter, loc, *scratch, op.getB(), + bMemTy, scope); + if (!bScratch) + return emitFpSanCodegenError(op.getOperation()); + } + MmaOperandSource aSource{aIsTmem ? aScratch->ptr : Value(), + aIsTmem ? Value() : op.getA(), aTileTy, + /*rowStride=*/1, /*stride=*/m}; + MmaOperandSource bSource{bIsTmem ? bScratch->ptr : Value(), + bIsTmem ? Value() : op.getB(), bTileTy, + /*rowStride=*/1, /*stride=*/bPackedK}; + DotScaleConfig scale; scale.aElemType = op.getAType(); scale.bElemType = op.getBType(); @@ -2711,15 +2824,15 @@ struct TCGen5MMAScaledPattern scale.aScaleFactor = *aScaleFactor; scale.bScaleFactor = *bScaleFactor; - // The operand and scale scratch buffers are written cooperatively, so all - // warps must finish those stores before the emulation loop reads them. + // TMEM scales and D scratch are written cooperatively. In two-CTA mode, + // rendezvous before either CTA directly reads the shared operands. createGlobalScratchBarrier(rewriter, loc, scratch->usesSharedClusterState()); - auto mLoop = emitMmaEmulationLoops( - rewriter, loc, aScratch->ptr, bScratch->ptr, dInfo->ptr, m, n, k, tileM, - tileN, aTileTy, bTileTy, accTileTy, accTileLayout, accElem, useDInt, - predInt, /*aStride=*/m, /*bStride=*/bPackedK, /*dStride=*/m, scale); + auto mLoop = + emitMmaEmulationLoops(rewriter, loc, aSource, bSource, dInfo->ptr, m, n, + k, tileM, tileN, accTileTy, accTileLayout, + accElem, useDInt, predInt, /*dStride=*/m, scale); if (!mLoop) return emitFpSanUnsupported(op.getOperation()); rewriter.setInsertionPointAfter(*mLoop); diff --git a/test/Conversion/tritoninstrument_to_llvm.mlir b/test/Conversion/tritoninstrument_to_llvm.mlir index 173335e69916..d1c00549034f 100644 --- a/test/Conversion/tritoninstrument_to_llvm.mlir +++ b/test/Conversion/tritoninstrument_to_llvm.mlir @@ -137,3 +137,25 @@ tt.func private @experimental_fpsan_unembed(%arg0: i32) -> f32 { tt.return %0 : f32 } } + +// ----- + +#local_gather_blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0], CGALayout = [[0, 1]]}> +#local_gather_shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], CGALayout = [[0, 1]]}> + +module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32, ttg.target = "cuda:90"} { +// CHECK-LABEL: @experimental_local_gather +// CHECK: ld.shared::cluster +tt.func private @experimental_local_gather(%out: !tt.ptr) { + %src = ttg.local_alloc {allocation.offset = [0 : i32, 256 : i32]} : () -> !ttg.memdesc<2x32xi32, #local_gather_shared, #ttg.shared_memory, mutable> + %idx = arith.constant dense<0> : tensor<2x32xi32, #local_gather_blocked> + %c0 = arith.constant 0 : i32 + %c1 = arith.constant 1 : i32 + %g = tti.experimental_local_gather %src[%idx] offsets = [%c1, %c0] {axis = 0 : i32} : !ttg.memdesc<2x32xi32, #local_gather_shared, #ttg.shared_memory, mutable>, tensor<2x32xi32, #local_gather_blocked> -> tensor<2x32xi32, #local_gather_blocked> + %ptrs = tt.splat %out : !tt.ptr -> tensor<2x32x!tt.ptr, #local_gather_blocked> + %offs = arith.constant dense<0> : tensor<2x32xi32, #local_gather_blocked> + %out_ptrs = tt.addptr %ptrs, %offs : tensor<2x32x!tt.ptr, #local_gather_blocked>, tensor<2x32xi32, #local_gather_blocked> + tt.store %out_ptrs, %g : tensor<2x32x!tt.ptr, #local_gather_blocked> + tt.return +} +} diff --git a/test/TritonGPU/nvidia-fpsan.mlir b/test/TritonGPU/nvidia-fpsan.mlir index 94a56218a400..fde1b2c22fb4 100644 --- a/test/TritonGPU/nvidia-fpsan.mlir +++ b/test/TritonGPU/nvidia-fpsan.mlir @@ -76,6 +76,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar // CHECK: ttg.global_scratch_alloc // CHECK: ttg.barrier global_read|global_write // CHECK-NEXT: scf.for + // CHECK: tti.experimental_local_gather + // CHECK: tti.experimental_local_gather // CHECK: tti.dot_i8 {{.*}} aSigned = true, bSigned = true // CHECK: tti.dot_i8 {{.*}} aSigned = false, bSigned = true // CHECK: tti.dot_i8 {{.*}} aSigned = true, bSigned = false @@ -97,6 +99,32 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar // ----- +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}> +#shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}> +#smem = #ttg.shared_memory +#tmem_a = #ttng.tensor_memory_encoding +#tmem_d = #ttng.tensor_memory_encoding +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 65544 : i32, ttg.tensor_memory_size = 0 : i32, "ttg.total-num-warps" = 1 : i32} { + // CHECK-LABEL: @tcgen05_mma_tmem_a_shared_b + tt.func public @tcgen05_mma_tmem_a_shared_b() { + // CHECK: ttg.global_scratch_alloc + // CHECK: tt.store + // CHECK: ttg.barrier global_read|global_write + // CHECK: tti.experimental_local_gather + // CHECK: tti.dot_i8 + // CHECK-NOT: ttng.tc_gen5_mma + %true = arith.constant true + %a = ttng.tmem_alloc {tensor_memory_col_offset = 0 : i32, tensor_memory_row_offset = 0 : i32} : () -> !ttg.memdesc<128x128xf16, #tmem_a, #ttng.tensor_memory, mutable> + %b = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<128x128xf16, #shared, #smem, mutable> + %d = ttng.tmem_alloc {tensor_memory_col_offset = 128 : i32, tensor_memory_row_offset = 0 : i32} : () -> !ttg.memdesc<128x128xf32, #tmem_d, #ttng.tensor_memory, mutable> + %bar = ttg.local_alloc {allocation.offset = 8192 : i32} : () -> !ttg.memdesc<1xi64, #shared1, #smem, mutable> + ttng.tc_gen5_mma %a, %b, %d, %true, %true, %bar[%true] {is_async} : !ttg.memdesc<128x128xf16, #tmem_a, #ttng.tensor_memory, mutable>, !ttg.memdesc<128x128xf16, #shared, #smem, mutable>, !ttg.memdesc<128x128xf32, #tmem_d, #ttng.tensor_memory, mutable>, !ttg.memdesc<1xi64, #shared1, #smem, mutable> + tt.return + } +} + +// ----- + #shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}> #shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}> #smem = #ttg.shared_memory @@ -135,6 +163,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar // CHECK: ttg.global_scratch_alloc // CHECK: ttg.barrier global_read|global_write // CHECK-NEXT: scf.for + // CHECK: tti.experimental_local_gather + // CHECK: tti.experimental_local_gather // CHECK: tti.dot_i8 {{.*}} aSigned = true, bSigned = true // CHECK: tti.dot_i8 {{.*}} aSigned = false, bSigned = true // CHECK: tti.dot_i8 {{.*}} aSigned = true, bSigned = false @@ -261,22 +291,19 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar #shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CGALayout = [[1]]}> #smem = #ttg.shared_memory #tmem = #ttng.tensor_memory_encoding -// CHECK: #[[$SHARED_A_SCRATCH:[A-Za-z0-9_]+]] = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0], CGALayout = {{\[\[1, 0\]\]}}}> // CHECK: module attributes {{.*}}"ttng.two-ctas" = true module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 65544 : i32, ttg.tensor_memory_size = 0 : i32, "ttg.total-num-warps" = 1 : i32} { // CHECK-LABEL: @tcgen05_mma_two_ctas tt.func public @tcgen05_mma_two_ctas() { // CHECK: ttg.global_scratch_alloc {{.*}}shared_cluster_state // CHECK: ttng.cluster_barrier - // CHECK-NEXT: {{.*}} = ttg.local_load {{.*}} -> tensor<256x128xf16, #[[$SHARED_A_SCRATCH]]> - // CHECK: tt.store - // CHECK: ttg.barrier global_read|global_write - // CHECK-NEXT: ttng.cluster_barrier - // CHECK: scf.for + // CHECK-NEXT: scf.for + // CHECK: tti.experimental_local_gather + // CHECK: tti.experimental_local_gather // CHECK: tt.store // CHECK: ttg.barrier global_read|global_write // CHECK-NEXT: ttng.cluster_barrier - // CHECK: ttng.arrive_barrier + // CHECK-NEXT: ttng.arrive_barrier // CHECK-NOT: ttng.tc_gen5_mma %true = arith.constant true %a = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<256x128xf16, #shared_a, #smem, mutable> @@ -298,7 +325,6 @@ module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32, ttg.shar #blocked_multibuffer = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1], CGALayout = [[1, 0]]}> #tmem = #ttng.tensor_memory_encoding #tmem_scales = #ttng.tensor_memory_scales_encoding -// CHECK: #[[$TMEM_VIEW_SCRATCH:[A-Za-z0-9_]+]] = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1], CGALayout = {{\[\[1, 0\]\]}}}> module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 65544 : i32, ttg.tensor_memory_size = 0 : i32, "ttg.total-num-warps" = 1 : i32} { tt.func public @enable_two_ctas() { %true = arith.constant true @@ -312,7 +338,7 @@ module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32, ttg.shar // CHECK-LABEL: @tmem_multibuffer_two_ctas tt.func public @tmem_multibuffer_two_ctas(%idx: i32) { // CHECK: ttg.global_scratch_alloc {{.*}}shared_cluster_state - // CHECK: tt.store {{.*}} {ignore_cta} : tensor<256x128x!tt.ptr, #[[$TMEM_VIEW_SCRATCH]]> + // CHECK: tt.store {{.*}} {ignore_cta} : tensor<256x128x!tt.ptr // CHECK-NOT: ttng.tmem_load // CHECK-NOT: ttng.tmem_store // CHECK-NOT: ttg.memdesc_index @@ -361,10 +387,13 @@ module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32, ttg.shar tt.func public @tcgen05_mma_scaled_two_ctas() { // CHECK: ttg.global_scratch_alloc {{.*}}shared_cluster_state // CHECK: ttng.cluster_barrier - // CHECK-NEXT: {{.*}} = ttg.local_load + // CHECK-NEXT: scf.for + // CHECK: tti.experimental_local_gather + // CHECK: tti.experimental_local_gather + // CHECK: tt.store // CHECK: ttg.barrier global_read|global_write // CHECK-NEXT: ttng.cluster_barrier - // CHECK: ttng.arrive_barrier + // CHECK-NEXT: ttng.arrive_barrier // CHECK-NOT: ttng.tc_gen5_mma_scaled %true = arith.constant true %a = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<256x256xi8, #shared_a, #smem, mutable> From 2c9405f8383483f0c823e67ed0542d8cca8113be Mon Sep 17 00:00:00 2001 From: Jeff Niu Date: Thu, 4 Jun 2026 06:33:28 +0000 Subject: [PATCH 09/22] Simplify instrumentation local gather --- .../Conversion/TritonGPUToLLVM/Utility.h | 10 +- .../TritonGPUToLLVM/MemoryOpToLLVM.cpp | 7 +- lib/Conversion/TritonGPUToLLVM/Utility.cpp | 47 ++--- .../InstrumentationToLLVM.cpp | 2 +- .../Transforms/FpSanitizer.cpp | 193 +++++++----------- 5 files changed, 97 insertions(+), 162 deletions(-) diff --git a/include/triton/Conversion/TritonGPUToLLVM/Utility.h b/include/triton/Conversion/TritonGPUToLLVM/Utility.h index e7b14224159b..749ffa860d6c 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/Utility.h +++ b/include/triton/Conversion/TritonGPUToLLVM/Utility.h @@ -616,14 +616,8 @@ struct LocalSharedMemoryAddress { SmallVector computeLocalAddrs( Location loc, triton::gpu::MemDescType memDescTy, SharedMemoryObject smemObj, Type llvmElemTy, ArrayRef idxValues, - ArrayRef> coords, unsigned axis, RewriterBase &rewriter); - -SmallVector -computeLocalAddrs(Location loc, triton::gpu::MemDescType memDescTy, - SharedMemoryObject smemObj, Type llvmElemTy, - ArrayRef idxValues, - ArrayRef> coords, unsigned axis, - ArrayRef offsets, RewriterBase &rewriter); + ArrayRef> coords, unsigned axis, RewriterBase &rewriter, + ArrayRef offsets = {}); SmallVector computeLocalPtrs(Location loc, triton::gpu::MemDescType memDescTy, diff --git a/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp index 00239413efef..def2c9a5a2c8 100644 --- a/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp @@ -24,17 +24,16 @@ lowerLocalScGt(Location loc, MLIRContext *ctx, MemDescType memDescTy, unsigned axis, ArrayRef storeVals, RewriterBase &rewriter, const TargetInfoBase &targetInfo) { auto b = TritonLLVMOpBuilder(loc, rewriter); - bool isScatter = !storeVals.empty(); SmallVector addrs = computeLocalAddrs( loc, memDescTy, smemObj, llvmElemTy, idxValues, coords, axis, rewriter); - - SmallVector results; - if (!isScatter) + if (storeVals.empty()) return loadLocalAddrs(loc, llvmElemTy, addrs, rewriter, targetInfo); Value currentCtaId; if (!addrs.empty() && addrs.front().ctaId) currentCtaId = targetInfo.getClusterCTAId(rewriter, loc); + + SmallVector results; for (auto [i, addr] : llvm::enumerate(addrs)) { if (addr.ctaId) { Value isLocal = b.icmp_eq(*addr.ctaId, currentCtaId); diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.cpp b/lib/Conversion/TritonGPUToLLVM/Utility.cpp index fde30b0b53b5..cb912c9953bc 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.cpp +++ b/lib/Conversion/TritonGPUToLLVM/Utility.cpp @@ -545,17 +545,7 @@ computeLocalAddrs(Location loc, triton::gpu::MemDescType memDescTy, SharedMemoryObject smemObj, Type llvmElemTy, ArrayRef idxValues, ArrayRef> coords, unsigned axis, - RewriterBase &rewriter) { - return computeLocalAddrs(loc, memDescTy, smemObj, llvmElemTy, idxValues, - coords, axis, /*offsets=*/{}, rewriter); -} - -SmallVector -computeLocalAddrs(Location loc, triton::gpu::MemDescType memDescTy, - SharedMemoryObject smemObj, Type llvmElemTy, - ArrayRef idxValues, - ArrayRef> coords, unsigned axis, - ArrayRef offsets, RewriterBase &rewriter) { + RewriterBase &rewriter, ArrayRef offsets) { MLIRContext *ctx = memDescTy.getContext(); auto b = TritonLLVMOpBuilder(loc, rewriter); @@ -670,29 +660,22 @@ SmallVector loadLocalAddrs(Location loc, Type llvmElemTy, const TargetInfoBase &targetInfo) { auto b = TritonLLVMOpBuilder(loc, rewriter); Value currentCtaId; - if (llvm::any_of(addrs, [](const LocalSharedMemoryAddress &addr) { - return addr.ctaId.has_value(); - })) + if (!addrs.empty() && addrs.front().ctaId) currentCtaId = targetInfo.getClusterCTAId(rewriter, loc); - SmallVector results; - results.reserve(addrs.size()); - for (const LocalSharedMemoryAddress &addr : addrs) { - if (!addr.ctaId) { - results.push_back(targetInfo.loadShared(rewriter, loc, addr.ptr, - llvmElemTy, b.true_val())); - continue; - } - - Value isLocal = b.icmp_eq(*addr.ctaId, currentCtaId); - Value isRemote = b.icmp_ne(*addr.ctaId, currentCtaId); - Value local = - targetInfo.loadShared(rewriter, loc, addr.ptr, llvmElemTy, isLocal); - Value remote = targetInfo.loadDShared(rewriter, loc, addr.ptr, addr.ctaId, - llvmElemTy, isRemote); - results.push_back(b.select(isLocal, local, remote)); - } - return results; + return llvm::map_to_vector( + addrs, [&](const LocalSharedMemoryAddress &addr) -> Value { + if (!addr.ctaId) + return targetInfo.loadShared(rewriter, loc, addr.ptr, llvmElemTy, + b.true_val()); + Value isLocal = b.icmp_eq(*addr.ctaId, currentCtaId); + Value local = + targetInfo.loadShared(rewriter, loc, addr.ptr, llvmElemTy, isLocal); + Value remote = targetInfo.loadDShared( + rewriter, loc, addr.ptr, addr.ctaId, llvmElemTy, + b.icmp_ne(*addr.ctaId, currentCtaId)); + return b.select(isLocal, local, remote); + }); } FailureOr prepareLocalAtomicScatterRMW( diff --git a/lib/Conversion/TritonInstrumentToLLVM/InstrumentationToLLVM.cpp b/lib/Conversion/TritonInstrumentToLLVM/InstrumentationToLLVM.cpp index ba0a374db97a..af467612ed42 100644 --- a/lib/Conversion/TritonInstrumentToLLVM/InstrumentationToLLVM.cpp +++ b/lib/Conversion/TritonInstrumentToLLVM/InstrumentationToLLVM.cpp @@ -375,7 +375,7 @@ struct LocalGatherOpConversion auto addrs = computeLocalAddrs(loc, memDescTy, smemObj, llvmElemTy, idxValues, - dstIndices, op.getAxis(), offsets, rewriter); + dstIndices, op.getAxis(), rewriter, offsets); auto results = loadLocalAddrs(loc, llvmElemTy, addrs, rewriter, targetInfo); Value result = packLLElements(loc, typeConverter, results, rewriter, regTy); diff --git a/lib/Dialect/TritonInstrument/Transforms/FpSanitizer.cpp b/lib/Dialect/TritonInstrument/Transforms/FpSanitizer.cpp index 17680e0cd514..71c191e72d63 100644 --- a/lib/Dialect/TritonInstrument/Transforms/FpSanitizer.cpp +++ b/lib/Dialect/TritonInstrument/Transforms/FpSanitizer.cpp @@ -3,7 +3,6 @@ #include "mlir/IR/Types.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/RegionUtils.h" -#include "triton/Analysis/Utility.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/Transforms/Utility.h" @@ -72,31 +71,25 @@ bool canUseI8MmaTile(int64_t m, int64_t n, int numWarps) { std::pair getMmaEmulationTileShape(PatternRewriter &rewriter, int64_t m, int64_t n, int64_t k, - IntegerType accElem) { + IntegerType accElem, + bool directShared = false) { + std::pair tile = { + std::min(kTileM, m), std::min(kTileN, n)}; + int64_t numWarps = + ttg::lookupNumWarps(rewriter.getInsertionBlock()->getParent()); if (supportsI8DotDecomposition(rewriter, accElem) && (k % kI8MmaK) == 0) { - int64_t numWarps = - ttg::lookupNumWarps(rewriter.getInsertionBlock()->getParent()); int64_t tileM = std::min(16 * numWarps, m); int64_t tileN = std::min(8 * numWarps, n); if (canUseI8MmaTile(tileM, tileN, numWarps)) - return {tileM, tileN}; + tile = {tileM, tileN}; } - return {std::min(kTileM, m), std::min(kTileN, n)}; -} - -std::pair -getDirectSharedMmaEmulationTileShape(PatternRewriter &rewriter, int64_t m, - int64_t n, int64_t k, - IntegerType accElem) { - auto [tileM, tileN] = getMmaEmulationTileShape(rewriter, m, n, k, accElem); - int64_t numWarps = - ttg::lookupNumWarps(rewriter.getInsertionBlock()->getParent()); - int64_t widerN = std::min(16 * numWarps, n); - while (widerN > tileN && (n % widerN) != 0) - widerN /= 2; - if (widerN > tileN && getI8MmaWarpsPerCTA(tileM, widerN, numWarps)) - tileN = widerN; - return {tileM, tileN}; + if (directShared) { + int64_t widerN = std::min(16 * numWarps, n); + if (widerN > tile.second && + canUseI8MmaTile(tile.first, widerN, numWarps)) + tile.second = widerN; + } + return tile; } Operation *createGlobalScratchBarrier(PatternRewriter &rewriter, Location loc, @@ -936,6 +929,19 @@ createTmemOperandScratch(PatternRewriter &rewriter, Location loc, return ScratchInfo{ptr, tensorTy}; } +std::optional createMmaOperandSource( + PatternRewriter &rewriter, Location loc, TmemScratchManager &scratch, + Value memdesc, ttg::MemDescType memTy, bool isTmem, + RankedTensorType tileTy, Region *scope, int64_t rowStride, int64_t stride) { + if (!isTmem) + return MmaOperandSource{Value(), memdesc, tileTy, rowStride, stride}; + auto info = + createTmemOperandScratch(rewriter, loc, scratch, memdesc, memTy, scope); + if (!info) + return std::nullopt; + return MmaOperandSource{info->ptr, Value(), tileTy, rowStride, stride}; +} + std::optional createWGMMAScratch(PatternRewriter &rewriter, Location loc, Value operand) { if (auto memTy = dyn_cast(operand.getType())) { @@ -1564,18 +1570,37 @@ Value tryEmitI8DotDecomposition(PatternRewriter &rewriter, Location loc, return ttg::ConvertLayoutOp::create(rewriter, loc, accTy, product); } -LogicalResult emitMmaEmulationTile( +std::optional emitMmaEmulationLoops( PatternRewriter &rewriter, Location loc, const MmaOperandSource &aSource, - const MmaOperandSource &bSource, Value dPtr, int64_t k, int64_t tileM, - int64_t tileN, RankedTensorType accTileTy, + const MmaOperandSource &bSource, Value dPtr, int64_t m, int64_t n, + int64_t k, int64_t tileM, int64_t tileN, RankedTensorType accTileTy, ttg::DistributedEncodingTrait accLayout, IntegerType accElem, Value useDInt, - Value predInt, Value mIdxI32, Value nIdxI32, int64_t dStride, - const DotScaleConfig &scale = {}, int64_t dRowStride = 1) { - bool hasSharedOperand = aSource.isShared() || bSource.isShared(); + Value predInt, int64_t dStride, const DotScaleConfig &scale = {}, + int64_t dRowStride = 1) { + if ((m % tileM) != 0 || (n % tileN) != 0) + return std::nullopt; + OpBuilder::InsertionGuard guard(rewriter); auto i32Ty = rewriter.getI32Type(); Value zero = arith::ConstantOp::create(rewriter, loc, rewriter.getI32IntegerAttr(0)); + Value mUpper = + arith::ConstantOp::create(rewriter, loc, rewriter.getI32IntegerAttr(m)); + Value nUpper = + arith::ConstantOp::create(rewriter, loc, rewriter.getI32IntegerAttr(n)); + Value mStep = arith::ConstantOp::create(rewriter, loc, + rewriter.getI32IntegerAttr(tileM)); + Value nStep = arith::ConstantOp::create(rewriter, loc, + rewriter.getI32IntegerAttr(tileN)); + auto mLoop = scf::ForOp::create(rewriter, loc, zero, mUpper, mStep); + rewriter.setInsertionPointToStart(mLoop.getBody()); + auto nLoop = scf::ForOp::create(rewriter, loc, zero, nUpper, nStep); + rewriter.setInsertionPointToStart(nLoop.getBody()); + Value mIdxI32 = arith::IndexCastOp::create( + rewriter, loc, i32Ty, mLoop.getInductionVar()); + Value nIdxI32 = arith::IndexCastOp::create( + rewriter, loc, i32Ty, nLoop.getInductionVar()); + Value dRowStrideConst = arith::ConstantOp::create( rewriter, loc, rewriter.getI32IntegerAttr(dRowStride)); Value dStrideConst = arith::ConstantOp::create( @@ -1592,6 +1617,7 @@ LogicalResult emitMmaEmulationTile( Value accTileI = embedToInt(rewriter, loc, accTile); Value sum; + bool hasSharedOperand = aSource.isShared() || bSource.isShared(); int numWarps = ttg::lookupNumWarps(rewriter.getInsertionBlock()->getParent()); auto isScaleK32Aligned = [](Value scalePtr, int64_t scaleFactor) { return !scalePtr || (scaleFactor > 0 && ((kI8MmaK % scaleFactor) == 0 || @@ -1724,45 +1750,6 @@ LogicalResult emitMmaEmulationTile( createGlobalScratchBarrier(rewriter, loc); storeScratchStrided2D(rewriter, loc, dTilePtr, out, accTileTy, dRowStride, dStride); - return success(); -} - -std::optional emitMmaEmulationLoops( - PatternRewriter &rewriter, Location loc, const MmaOperandSource &aSource, - const MmaOperandSource &bSource, Value dPtr, int64_t m, int64_t n, - int64_t k, int64_t tileM, int64_t tileN, RankedTensorType accTileTy, - ttg::DistributedEncodingTrait accLayout, IntegerType accElem, Value useDInt, - Value predInt, int64_t dStride, const DotScaleConfig &scale = {}, - int64_t dRowStride = 1) { - if ((m % tileM) != 0 || (n % tileN) != 0) - return std::nullopt; - - OpBuilder::InsertionGuard guard(rewriter); - Value zero = - arith::ConstantOp::create(rewriter, loc, rewriter.getI32IntegerAttr(0)); - Value mUpper = - arith::ConstantOp::create(rewriter, loc, rewriter.getI32IntegerAttr(m)); - Value nUpper = - arith::ConstantOp::create(rewriter, loc, rewriter.getI32IntegerAttr(n)); - Value mStep = arith::ConstantOp::create(rewriter, loc, - rewriter.getI32IntegerAttr(tileM)); - Value nStep = arith::ConstantOp::create(rewriter, loc, - rewriter.getI32IntegerAttr(tileN)); - - auto mLoop = scf::ForOp::create(rewriter, loc, zero, mUpper, mStep); - rewriter.setInsertionPointToStart(mLoop.getBody()); - Value mIdx = mLoop.getInductionVar(); - auto nLoop = scf::ForOp::create(rewriter, loc, zero, nUpper, nStep); - rewriter.setInsertionPointToStart(nLoop.getBody()); - auto i32Ty = rewriter.getI32Type(); - Value mIdxI32 = arith::IndexCastOp::create(rewriter, loc, i32Ty, mIdx); - Value nIdxI32 = - arith::IndexCastOp::create(rewriter, loc, i32Ty, nLoop.getInductionVar()); - if (failed(emitMmaEmulationTile(rewriter, loc, aSource, bSource, dPtr, k, - tileM, tileN, accTileTy, accLayout, accElem, - useDInt, predInt, mIdxI32, nIdxI32, dStride, - scale, dRowStride))) - return std::nullopt; return mLoop; } @@ -2587,10 +2574,8 @@ struct TCGen5MMAPattern : public OpRewritePattern { arith::ExtUIOp::create(rewriter, loc, accElem, op.getPred()); rewriter.setInsertionPoint(op); - auto [tileM, tileN] = - (!aIsTmem || !bIsTmem) - ? getDirectSharedMmaEmulationTileShape(rewriter, m, n, k, accElem) - : getMmaEmulationTileShape(rewriter, m, n, k, accElem); + auto [tileM, tileN] = getMmaEmulationTileShape( + rewriter, m, n, k, accElem, /*directShared=*/!aIsTmem || !bIsTmem); auto accTileLayout = getOptimizedBlockedEncoding(rewriter, {tileM, tileN}, accElem); auto accTileTy = @@ -2608,26 +2593,14 @@ struct TCGen5MMAPattern : public OpRewritePattern { getOptimizedBlockedEncoding(rewriter, {k, tileN}, bTileElem); auto bTileTy = RankedTensorType::get({k, tileN}, bTileElem, bTileLayout); - std::optional aScratch; - if (aIsTmem) { - aScratch = createTmemOperandScratch(rewriter, loc, *scratch, op.getA(), - aMemTy, scope); - if (!aScratch) - return emitFpSanCodegenError(op.getOperation()); - } - std::optional bScratch; - if (bIsTmem) { - bScratch = createTmemOperandScratch(rewriter, loc, *scratch, op.getB(), - bMemTy, scope); - if (!bScratch) - return emitFpSanCodegenError(op.getOperation()); - } - MmaOperandSource aSource{aIsTmem ? aScratch->ptr : Value(), - aIsTmem ? Value() : op.getA(), aTileTy, - /*rowStride=*/1, /*stride=*/m}; - MmaOperandSource bSource{bIsTmem ? bScratch->ptr : Value(), - bIsTmem ? Value() : op.getB(), bTileTy, - /*rowStride=*/1, /*stride=*/k}; + auto aSource = createMmaOperandSource(rewriter, loc, *scratch, op.getA(), + aMemTy, aIsTmem, aTileTy, scope, + /*rowStride=*/1, /*stride=*/m); + auto bSource = createMmaOperandSource(rewriter, loc, *scratch, op.getB(), + bMemTy, bIsTmem, bTileTy, scope, + /*rowStride=*/1, /*stride=*/k); + if (!aSource || !bSource) + return emitFpSanCodegenError(op.getOperation()); // TMEM and D scratch are written cooperatively. In two-CTA mode, the // cluster barrier also makes both CTAs' shared operands visible before the @@ -2636,7 +2609,7 @@ struct TCGen5MMAPattern : public OpRewritePattern { scratch->usesSharedClusterState()); auto mLoop = emitMmaEmulationLoops( - rewriter, loc, aSource, bSource, dInfo->ptr, m, n, k, tileM, tileN, + rewriter, loc, *aSource, *bSource, dInfo->ptr, m, n, k, tileM, tileN, accTileTy, accTileLayout, accElem, useDInt, predInt, /*dStride=*/m); if (!mLoop) return emitFpSanUnsupported(op.getOperation()); @@ -2767,10 +2740,8 @@ struct TCGen5MMAScaledPattern if (!bScaleScratch) return emitFpSanCodegenError(op.getOperation()); - auto [tileM, tileN] = - (!aIsTmem || !bIsTmem) - ? getDirectSharedMmaEmulationTileShape(rewriter, m, n, k, accElem) - : getMmaEmulationTileShape(rewriter, m, n, k, accElem); + auto [tileM, tileN] = getMmaEmulationTileShape( + rewriter, m, n, k, accElem, /*directShared=*/!aIsTmem || !bIsTmem); auto accTileLayout = getOptimizedBlockedEncoding(rewriter, {tileM, tileN}, dMemTy.getElementType()); @@ -2785,26 +2756,14 @@ struct TCGen5MMAScaledPattern auto bTileTy = RankedTensorType::get({bPackedK, tileN}, bMemTy.getElementType(), bTileLayout); - std::optional aScratch; - if (aIsTmem) { - aScratch = createTmemOperandScratch(rewriter, loc, *scratch, op.getA(), - aMemTy, scope); - if (!aScratch) - return emitFpSanCodegenError(op.getOperation()); - } - std::optional bScratch; - if (bIsTmem) { - bScratch = createTmemOperandScratch(rewriter, loc, *scratch, op.getB(), - bMemTy, scope); - if (!bScratch) - return emitFpSanCodegenError(op.getOperation()); - } - MmaOperandSource aSource{aIsTmem ? aScratch->ptr : Value(), - aIsTmem ? Value() : op.getA(), aTileTy, - /*rowStride=*/1, /*stride=*/m}; - MmaOperandSource bSource{bIsTmem ? bScratch->ptr : Value(), - bIsTmem ? Value() : op.getB(), bTileTy, - /*rowStride=*/1, /*stride=*/bPackedK}; + auto aSource = createMmaOperandSource(rewriter, loc, *scratch, op.getA(), + aMemTy, aIsTmem, aTileTy, scope, + /*rowStride=*/1, /*stride=*/m); + auto bSource = createMmaOperandSource( + rewriter, loc, *scratch, op.getB(), bMemTy, bIsTmem, bTileTy, scope, + /*rowStride=*/1, /*stride=*/bPackedK); + if (!aSource || !bSource) + return emitFpSanCodegenError(op.getOperation()); DotScaleConfig scale; scale.aElemType = op.getAType(); @@ -2830,8 +2789,8 @@ struct TCGen5MMAScaledPattern scratch->usesSharedClusterState()); auto mLoop = - emitMmaEmulationLoops(rewriter, loc, aSource, bSource, dInfo->ptr, m, n, - k, tileM, tileN, accTileTy, accTileLayout, + emitMmaEmulationLoops(rewriter, loc, *aSource, *bSource, dInfo->ptr, m, + n, k, tileM, tileN, accTileTy, accTileLayout, accElem, useDInt, predInt, /*dStride=*/m, scale); if (!mLoop) return emitFpSanUnsupported(op.getOperation()); From b1d02d788f9431ce2fd4ed70f5e8e45bd3916cab Mon Sep 17 00:00:00 2001 From: Jeff Niu Date: Thu, 4 Jun 2026 06:59:35 +0000 Subject: [PATCH 10/22] Apply pre-commit formatting --- .../Conversion/TritonGPUToLLVM/Utility.h | 11 ++++--- .../Transforms/FpSanitizer.cpp | 31 +++++++++---------- 2 files changed, 21 insertions(+), 21 deletions(-) diff --git a/include/triton/Conversion/TritonGPUToLLVM/Utility.h b/include/triton/Conversion/TritonGPUToLLVM/Utility.h index 749ffa860d6c..07a9df20fc59 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/Utility.h +++ b/include/triton/Conversion/TritonGPUToLLVM/Utility.h @@ -613,11 +613,12 @@ struct LocalSharedMemoryAddress { // Compute per-element shared-memory addresses for a local atomic/ldst update by // replacing `coords[*][axis]` with `idxValues[*]` and mapping the resulting // logical coordinates back to shared-memory offsets and target CTAs. -SmallVector computeLocalAddrs( - Location loc, triton::gpu::MemDescType memDescTy, - SharedMemoryObject smemObj, Type llvmElemTy, ArrayRef idxValues, - ArrayRef> coords, unsigned axis, RewriterBase &rewriter, - ArrayRef offsets = {}); +SmallVector +computeLocalAddrs(Location loc, triton::gpu::MemDescType memDescTy, + SharedMemoryObject smemObj, Type llvmElemTy, + ArrayRef idxValues, + ArrayRef> coords, unsigned axis, + RewriterBase &rewriter, ArrayRef offsets = {}); SmallVector computeLocalPtrs(Location loc, triton::gpu::MemDescType memDescTy, diff --git a/lib/Dialect/TritonInstrument/Transforms/FpSanitizer.cpp b/lib/Dialect/TritonInstrument/Transforms/FpSanitizer.cpp index 71c191e72d63..12236fe8a9db 100644 --- a/lib/Dialect/TritonInstrument/Transforms/FpSanitizer.cpp +++ b/lib/Dialect/TritonInstrument/Transforms/FpSanitizer.cpp @@ -68,13 +68,12 @@ bool canUseI8MmaTile(int64_t m, int64_t n, int numWarps) { return m >= 16 && n >= 8 && (m / 16) * (n / 8) >= numWarps; } -std::pair getMmaEmulationTileShape(PatternRewriter &rewriter, - int64_t m, int64_t n, - int64_t k, - IntegerType accElem, - bool directShared = false) { - std::pair tile = { - std::min(kTileM, m), std::min(kTileN, n)}; +std::pair +getMmaEmulationTileShape(PatternRewriter &rewriter, int64_t m, int64_t n, + int64_t k, IntegerType accElem, + bool directShared = false) { + std::pair tile = {std::min(kTileM, m), + std::min(kTileN, n)}; int64_t numWarps = ttg::lookupNumWarps(rewriter.getInsertionBlock()->getParent()); if (supportsI8DotDecomposition(rewriter, accElem) && (k % kI8MmaK) == 0) { @@ -931,8 +930,8 @@ createTmemOperandScratch(PatternRewriter &rewriter, Location loc, std::optional createMmaOperandSource( PatternRewriter &rewriter, Location loc, TmemScratchManager &scratch, - Value memdesc, ttg::MemDescType memTy, bool isTmem, - RankedTensorType tileTy, Region *scope, int64_t rowStride, int64_t stride) { + Value memdesc, ttg::MemDescType memTy, bool isTmem, RankedTensorType tileTy, + Region *scope, int64_t rowStride, int64_t stride) { if (!isTmem) return MmaOperandSource{Value(), memdesc, tileTy, rowStride, stride}; auto info = @@ -1596,10 +1595,10 @@ std::optional emitMmaEmulationLoops( rewriter.setInsertionPointToStart(mLoop.getBody()); auto nLoop = scf::ForOp::create(rewriter, loc, zero, nUpper, nStep); rewriter.setInsertionPointToStart(nLoop.getBody()); - Value mIdxI32 = arith::IndexCastOp::create( - rewriter, loc, i32Ty, mLoop.getInductionVar()); - Value nIdxI32 = arith::IndexCastOp::create( - rewriter, loc, i32Ty, nLoop.getInductionVar()); + Value mIdxI32 = + arith::IndexCastOp::create(rewriter, loc, i32Ty, mLoop.getInductionVar()); + Value nIdxI32 = + arith::IndexCastOp::create(rewriter, loc, i32Ty, nLoop.getInductionVar()); Value dRowStrideConst = arith::ConstantOp::create( rewriter, loc, rewriter.getI32IntegerAttr(dRowStride)); @@ -2759,9 +2758,9 @@ struct TCGen5MMAScaledPattern auto aSource = createMmaOperandSource(rewriter, loc, *scratch, op.getA(), aMemTy, aIsTmem, aTileTy, scope, /*rowStride=*/1, /*stride=*/m); - auto bSource = createMmaOperandSource( - rewriter, loc, *scratch, op.getB(), bMemTy, bIsTmem, bTileTy, scope, - /*rowStride=*/1, /*stride=*/bPackedK); + auto bSource = createMmaOperandSource(rewriter, loc, *scratch, op.getB(), + bMemTy, bIsTmem, bTileTy, scope, + /*rowStride=*/1, /*stride=*/bPackedK); if (!aSource || !bSource) return emitFpSanCodegenError(op.getOperation()); From 8cc70b2e3b1d2f943d2e8aa35e91c9becbf47591 Mon Sep 17 00:00:00 2001 From: Jeff Niu Date: Fri, 5 Jun 2026 08:33:47 +0000 Subject: [PATCH 11/22] Apply post-restack formatting --- lib/Dialect/TritonInstrument/Transforms/FpSanitizer.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/lib/Dialect/TritonInstrument/Transforms/FpSanitizer.cpp b/lib/Dialect/TritonInstrument/Transforms/FpSanitizer.cpp index 12236fe8a9db..c768609f5dd6 100644 --- a/lib/Dialect/TritonInstrument/Transforms/FpSanitizer.cpp +++ b/lib/Dialect/TritonInstrument/Transforms/FpSanitizer.cpp @@ -84,8 +84,7 @@ getMmaEmulationTileShape(PatternRewriter &rewriter, int64_t m, int64_t n, } if (directShared) { int64_t widerN = std::min(16 * numWarps, n); - if (widerN > tile.second && - canUseI8MmaTile(tile.first, widerN, numWarps)) + if (widerN > tile.second && canUseI8MmaTile(tile.first, widerN, numWarps)) tile.second = widerN; } return tile; From abde8915e462b0f674243b8a66b1e668cc46d1ff Mon Sep 17 00:00:00 2001 From: Jeff Niu Date: Fri, 5 Jun 2026 20:14:47 +0000 Subject: [PATCH 12/22] [NVIDIA] Address multi-CTA gather review Keep partitioned shared gather and scatter unsupported until their multiple-base addressing is implemented. Pass target CTA IDs through the shared-memory target API and let NVIDIA select local versus cluster accesses. Add forced cross-CTA gather and scatter correctness coverage by swapping the CGA-distributed column bit. Validate with the compiler build, conversion lit tests, focused Gluon gather/scatter tests, and pre-commit. --- .../TritonGPUToLLVM/TargetInfoBase.h | 8 +- .../TritonGPUToLLVM/MemoryOpToLLVM.cpp | 35 ++--- lib/Conversion/TritonGPUToLLVM/Utility.cpp | 21 +-- python/test/gluon/test_core.py | 58 ++++++++ test/Conversion/tritonnvidiagpu_to_llvm.mlir | 2 + .../lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp | 129 +++++++++++------- 6 files changed, 159 insertions(+), 94 deletions(-) diff --git a/include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h b/include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h index 29d0e927408e..15c4a942d31c 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h +++ b/include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h @@ -30,11 +30,11 @@ class TargetInfoBase { // emit a block-level barrier with local address space visibility. virtual void warpSync(Location loc, RewriterBase &rewriter) const = 0; - // Store/load a value from shared memory, either in the same CTA or, if - // `ctaId` is non-nullopt, in another CTA in the same group. + // Store/load a value from shared memory. If `ctaId` is non-nullopt, it + // identifies the target CTA in the cluster and may identify the current CTA. // - // A target that does not support cross-CTA transfers will assert if ctaId is - // non-nullopt. + // A target that does not support clustered shared memory will assert if + // ctaId is non-nullopt. // // Assumes the address is aligned to the width of `val`. virtual void storeDShared(RewriterBase &rewriter, Location loc, Value ptr, diff --git a/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp index 37f8b7008b8c..ebd20bbce39b 100644 --- a/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp @@ -27,9 +27,6 @@ lowerLocalScGt(Location loc, MLIRContext *ctx, MemDescType memDescTy, bool isScatter = !storeVals.empty(); SmallVector addrs = computeLocalAddrs( loc, memDescTy, smemObj, llvmElemTy, idxValues, coords, axis, rewriter); - Value currentCtaId; - if (!addrs.empty() && addrs.front().ctaId) - currentCtaId = targetInfo.getClusterCTAId(rewriter, loc); SmallVector results; if (!isScatter) @@ -37,27 +34,11 @@ lowerLocalScGt(Location loc, MLIRContext *ctx, MemDescType memDescTy, for (auto [i, addr] : llvm::enumerate(addrs)) { if (isScatter) { - if (addr.ctaId) { - Value isLocal = b.icmp_eq(*addr.ctaId, currentCtaId); - Value isRemote = b.icmp_ne(*addr.ctaId, currentCtaId); - targetInfo.storeShared(rewriter, loc, addr.ptr, storeVals[i], isLocal); - targetInfo.storeDShared(rewriter, loc, addr.ptr, addr.ctaId, - storeVals[i], isRemote); - } else { - targetInfo.storeShared(rewriter, loc, addr.ptr, storeVals[i], - b.true_val()); - } - } else if (addr.ctaId) { - Value isLocal = b.icmp_eq(*addr.ctaId, currentCtaId); - Value isRemote = b.icmp_ne(*addr.ctaId, currentCtaId); - Value local = - targetInfo.loadShared(rewriter, loc, addr.ptr, llvmElemTy, isLocal); - Value remote = targetInfo.loadDShared(rewriter, loc, addr.ptr, addr.ctaId, - llvmElemTy, isRemote); - results[i] = b.select(isLocal, local, remote); + targetInfo.storeDShared(rewriter, loc, addr.ptr, addr.ctaId, storeVals[i], + b.true_val()); } else { - results[i] = targetInfo.loadShared(rewriter, loc, addr.ptr, llvmElemTy, - b.true_val()); + results[i] = targetInfo.loadDShared(rewriter, loc, addr.ptr, addr.ctaId, + llvmElemTy, b.true_val()); } } @@ -287,6 +268,10 @@ struct LocalGatherOpConversion : public ConvertOpToLLVMPattern { auto loc = op.getLoc(); auto *ctx = op.getContext(); auto memDescTy = cast(op.getSrc().getType()); + if (isa(memDescTy.getEncoding())) { + return rewriter.notifyMatchFailure( + op, "PartitionedSharedEncoding not yet supported in lowering"); + } auto regTy = cast(op.getType()); auto typeConverter = getTypeConverter(); @@ -329,6 +314,10 @@ struct LocalScatterOpConversion auto loc = op.getLoc(); auto *ctx = op.getContext(); auto memDescTy = cast(op.getDst().getType()); + if (isa(memDescTy.getEncoding())) { + return rewriter.notifyMatchFailure( + op, "PartitionedSharedEncoding not yet supported in lowering"); + } auto valuesTy = cast(op.getValues().getType()); auto typeConverter = getTypeConverter(); diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.cpp b/lib/Conversion/TritonGPUToLLVM/Utility.cpp index 651666f64818..808289ed5e68 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.cpp +++ b/lib/Conversion/TritonGPUToLLVM/Utility.cpp @@ -562,8 +562,7 @@ computeLocalAddrs(Location loc, triton::gpu::MemDescType memDescTy, auto kOffset = str_attr("offset"); auto kBlock = str_attr("block"); - bool useBlockId = invSharedLayout.hasOutDim(kBlock) && - invSharedLayout.getOutDimSize(kBlock) > 1; + bool useBlockId = invSharedLayout.getOutDimSize(kBlock) > 1; // Get the subslice affine offset (non-zero for memdesc subslices) Value affineOffset = smemObj.getShmemOffset(loc, rewriter, memDescTy); auto bitwidth = getIntOrFloatOrPtrBitWidth(llvmElemTy); @@ -591,19 +590,11 @@ computeLocalAddrs(Location loc, triton::gpu::MemDescType memDescTy, inputs.push_back({allDims[dim], indices[dim]}); auto outputs = applyLinearLayout(loc, rewriter, invSharedLayout, inputs); - - // Extract the offset and target CTA. - Value offset = nullptr; - Value blockId = nullptr; - for (auto [name, value] : outputs) { - if (name == kOffset) - offset = value; - else if (name == kBlock) - blockId = value; - } - assert(offset && "expected offset output from inverted shared layout"); - assert((!useBlockId || blockId) && - "expected block output from multi-CTA shared layout"); + assert(outputs.size() == 2); + auto [offsetName, offset] = outputs[0]; + auto [blockName, blockId] = outputs[1]; + assert(offsetName == kOffset); + assert(blockName == kBlock); // For subslices, the physical offset is computed as: // physical_offset = L⁻¹(coords) ⊕ L⁻¹(subslice_logical_offset) diff --git a/python/test/gluon/test_core.py b/python/test/gluon/test_core.py index 82af34927889..f9d0ecbe4bb2 100644 --- a/python/test/gluon/test_core.py +++ b/python/test/gluon/test_core.py @@ -2710,6 +2710,64 @@ def test_shared_gather(N, M): torch.testing.assert_close(output, expected) +@gluon.jit +def shared_gather_scatter_two_ctas_kernel( + inp, + out, + GATHER: ttgl.constexpr, + layout: ttgl.constexpr, + shared_layout: ttgl.constexpr, +): + rows = ttgl.arange(0, 2, layout=ttgl.SliceLayout(1, layout)) + cols = ttgl.arange(0, 32, layout=ttgl.SliceLayout(0, layout)) + offsets = rows[:, None] * 32 + cols[None, :] + values = ttgl.load(inp + offsets) + smem = ttgl.allocate_shared_memory(ttgl.int32, [2, 32], shared_layout, value=values) + ttgl.barrier(cluster=True) + + peer_cols = (cols ^ 1)[None, :] + rows[:, None] * 0 + if GATHER: + result = smem.gather(peer_cols, axis=1) + else: + smem.scatter(values, peer_cols, axis=1) + ttgl.barrier(cluster=True) + result = smem.load(layout) + ttgl.store(out + offsets, result) + + +@pytest.mark.skipif(not is_hopper_or_newer(), reason="Requires Hopper") +@pytest.mark.parametrize("gather", [True, False], ids=["gather", "scatter"]) +def test_shared_gather_scatter_two_ctas(gather): + layout = ttgl.BlockedLayout( + size_per_thread=[1, 1], + threads_per_warp=[1, THREADS_PER_WARP], + warps_per_cta=[1, 4], + order=[1, 0], + cga_layout=[[0, 1]], + ) + shared_layout = ttgl.SwizzledSharedLayout( + vec=1, + per_phase=1, + max_phase=1, + order=[1, 0], + cga_layout=[[0, 1]], + ) + inp = torch.arange(64, dtype=torch.int32, device="cuda").reshape(2, 32) + out = torch.empty_like(inp) + + shared_gather_scatter_two_ctas_kernel[(1, )]( + inp, + out, + GATHER=gather, + layout=layout, + shared_layout=shared_layout, + num_warps=4, + num_ctas=2, + ) + + torch.testing.assert_close(out, inp.reshape(2, 16, 2).flip(-1).reshape(2, 32)) + + @gluon.jit def shared_scatter_kernel( indices_ptr, diff --git a/test/Conversion/tritonnvidiagpu_to_llvm.mlir b/test/Conversion/tritonnvidiagpu_to_llvm.mlir index 4d269832209e..8fbd4a5c6595 100644 --- a/test/Conversion/tritonnvidiagpu_to_llvm.mlir +++ b/test/Conversion/tritonnvidiagpu_to_llvm.mlir @@ -810,8 +810,10 @@ module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32, "ttg.tot module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: @local_gather_scatter_two_ctas + // CHECK: ld.shared::cta // CHECK: ld.shared::cluster // CHECK: nvvm.barrier0 + // CHECK: st.shared::cta // CHECK: st.shared::cluster tt.func @local_gather_scatter_two_ctas(%out: !tt.ptr, %vals: tensor<2x32xi32, #local_gather_scatter_blocked>) { %src = ttg.local_alloc {allocation.offset = [0 : i32, 256 : i32]} : () -> !ttg.memdesc<2x32xi32, #local_gather_scatter_shared, #ttg.shared_memory, mutable> diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp index 8c62f27b0412..2465b0b22bfa 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp @@ -296,21 +296,16 @@ void TargetInfo::storeDShared(RewriterBase &rewriter, Location loc, Value ptr, assert(1 <= vec && vec <= 4); assert(vec * elemBitwidth <= 128); - // Get pointer to remote shared memory if needed. - if (ctaId.has_value()) { - ptr = mapa(rewriter, loc, ptr, *ctaId, pred); - } - - PTXBuilder builder; - auto st = builder.create("st") - ->o(ctaId.has_value() ? "shared::cluster" : "shared::cta") - .v(vec, /*predicate=*/vec > 1) - .b(elemBitwidth); - auto *ptrOpr = builder.newAddrOperand(ptr, "r"); - - if (isConstantTruePred(pred)) { - b.store(val, ptr, /*align=*/vec * elemBitwidth / 8); - } else { + auto emitStore = [&](Value storePtr, bool isRemote, Value storePred) { + if (isConstantTruePred(storePred)) { + b.store(val, storePtr, /*align=*/vec * elemBitwidth / 8); + return; + } + PTXBuilder builder; + auto st = builder.create("st") + ->o(isRemote ? "shared::cluster" : "shared::cta") + .v(vec, /*predicate=*/vec > 1) + .b(elemBitwidth); PTXBuilder::Operand *valOpr; std::string constraint = getConstraintForBitwidth(elemBitwidth); if (vec > 1) { @@ -322,9 +317,22 @@ void TargetInfo::storeDShared(RewriterBase &rewriter, Location loc, Value ptr, } else { valOpr = builder.newOperand(val, constraint); } - st(ptrOpr, valOpr).predicate(pred, "b"); + st(builder.newAddrOperand(storePtr, "r"), valOpr).predicate(storePred, "b"); builder.launch(rewriter, loc, void_ty(ctx)); + }; + + if (!ctaId) { + emitStore(ptr, /*isRemote=*/false, pred); + return; } + + Value currentCtaId = getClusterCTAId(rewriter, loc); + Value isLocal = b.icmp_eq(*ctaId, currentCtaId); + Value localPred = b.and_(pred, isLocal); + Value remotePred = b.and_(pred, b.icmp_ne(*ctaId, currentCtaId)); + emitStore(ptr, /*isRemote=*/false, localPred); + emitStore(mapa(rewriter, loc, ptr, *ctaId, remotePred), + /*isRemote=*/true, remotePred); } Value TargetInfo::loadDShared(RewriterBase &rewriter, Location loc, Value ptr, @@ -413,45 +421,62 @@ Value TargetInfo::loadDShared(RewriterBase &rewriter, Location loc, Value ptr, assert(1 <= vec && vec <= 4); assert(vec * elemBitwidth <= 128); - // Get pointer to remote shared memory if needed. - if (ctaId.has_value()) { - ptr = mapa(rewriter, loc, ptr, *ctaId, pred); - } - - PTXBuilder builder; - auto ld = builder.create("ld") - ->o(ctaId.has_value() ? "shared::cluster" : "shared::cta") - .v(vec, /*predicate=*/vec > 1) - .b(elemBitwidth); - - Value load; - if (isConstantTruePred(pred)) { - Type resultTy = vec == 1 ? Type(int_ty(elemBitwidth)) - : Type(vec_ty(int_ty(elemBitwidth), vec)); - load = b.load(resultTy, ptr, /*align=*/vec * elemBitwidth / 8); - if (vec > 1) { - Type structTy = struct_ty(SmallVector(vec, int_ty(elemBitwidth))); - Value structValue = b.undef(structTy); - for (int i = 0; i < vec; i++) { - structValue = b.insert_val(structTy, structValue, - b.extract_element(load, b.i32_val(i)), i); + auto emitLoad = [&](Value loadPtr, bool isRemote, Value loadPred) { + Value load; + if (isConstantTruePred(loadPred)) { + Type resultTy = vec == 1 ? Type(int_ty(elemBitwidth)) + : Type(vec_ty(int_ty(elemBitwidth), vec)); + load = b.load(resultTy, loadPtr, /*align=*/vec * elemBitwidth / 8); + if (vec > 1) { + Type structTy = struct_ty(SmallVector(vec, int_ty(elemBitwidth))); + Value structValue = b.undef(structTy); + for (int i = 0; i < vec; i++) { + structValue = b.insert_val(structTy, structValue, + b.extract_element(load, b.i32_val(i)), i); + } + load = structValue; } - load = structValue; + } else { + PTXBuilder builder; + auto ld = builder.create("ld") + ->o(isRemote ? "shared::cluster" : "shared::cta") + .v(vec, /*predicate=*/vec > 1) + .b(elemBitwidth); + std::string elemConstraint = "=" + getConstraintForBitwidth(elemBitwidth); + auto *outOpr = vec == 1 ? builder.newOperand(elemConstraint) + : builder.newListOperand(vec, elemConstraint); + ld(outOpr, builder.newAddrOperand(loadPtr, "r")).predicate(loadPred, "b"); + + Type resultTy = + vec == 1 + ? Type(int_ty(elemBitwidth)) + : Type(struct_ty(SmallVector(vec, int_ty(elemBitwidth)))); + load = builder.launch(rewriter, loc, resultTy, /*hasSideEffects=*/true); } - } else { - std::string elemConstraint = "=" + getConstraintForBitwidth(elemBitwidth); - auto *outOpr = vec == 1 ? builder.newOperand(elemConstraint) - : builder.newListOperand(vec, elemConstraint); - ld(outOpr, builder.newAddrOperand(ptr, "r")).predicate(pred, "b"); - - Type resultTy = - vec == 1 - ? Type(int_ty(elemBitwidth)) - : Type(struct_ty(SmallVector(vec, int_ty(elemBitwidth)))); - load = builder.launch(rewriter, loc, resultTy, /*hasSideEffects=*/true); + SmallVector resultVals = unpackLLElements(loc, load, rewriter); + return packLLVector(loc, resultVals, rewriter); + }; + + if (!ctaId) + return emitLoad(ptr, /*isRemote=*/false, pred); + + Value currentCtaId = getClusterCTAId(rewriter, loc); + Value isLocal = b.icmp_eq(*ctaId, currentCtaId); + Value localPred = b.and_(pred, isLocal); + Value remotePred = b.and_(pred, b.icmp_ne(*ctaId, currentCtaId)); + Value local = emitLoad(ptr, /*isRemote=*/false, localPred); + Value remote = emitLoad(mapa(rewriter, loc, ptr, *ctaId, remotePred), + /*isRemote=*/true, remotePred); + if (vec == 1) + return b.select(isLocal, local, remote); + + SmallVector selected; + for (auto [localVal, remoteVal] : + llvm::zip(unpackLLVector(loc, local, rewriter), + unpackLLVector(loc, remote, rewriter))) { + selected.push_back(b.select(isLocal, localVal, remoteVal)); } - SmallVector resultVals = unpackLLElements(loc, load, rewriter); - return packLLVector(loc, resultVals, rewriter); + return packLLVector(loc, selected, rewriter); } Value TargetInfo::shuffleXor(RewriterBase &rewriter, Location loc, Value val, From 68037d457596948475d96f8f550fb10ca98e478b Mon Sep 17 00:00:00 2001 From: Jeff Niu Date: Fri, 5 Jun 2026 20:23:23 +0000 Subject: [PATCH 13/22] [NVIDIA] Minimize multi-CTA shared dispatch --- .../TritonGPUToLLVM/TargetInfoBase.h | 8 +- .../TritonGPUToLLVM/MemoryOpToLLVM.cpp | 4 + .../lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp | 207 ++++++++++-------- .../lib/TritonNVIDIAGPUToLLVM/TargetInfo.h | 7 + 4 files changed, 130 insertions(+), 96 deletions(-) diff --git a/include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h b/include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h index 15c4a942d31c..690853ac6cd3 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h +++ b/include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h @@ -30,11 +30,11 @@ class TargetInfoBase { // emit a block-level barrier with local address space visibility. virtual void warpSync(Location loc, RewriterBase &rewriter) const = 0; - // Store/load a value from shared memory. If `ctaId` is non-nullopt, it - // identifies the target CTA in the cluster and may identify the current CTA. + // Store/load a value from shared memory, either in the same CTA or, if + // `ctaId` is non-nullopt, in the specified CTA in the same group. // - // A target that does not support clustered shared memory will assert if - // ctaId is non-nullopt. + // A target that does not support cross-CTA transfers will assert if ctaId is + // non-nullopt. // // Assumes the address is aligned to the width of `val`. virtual void storeDShared(RewriterBase &rewriter, Location loc, Value ptr, diff --git a/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp index ebd20bbce39b..98a3379e4b40 100644 --- a/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp @@ -268,6 +268,8 @@ struct LocalGatherOpConversion : public ConvertOpToLLVMPattern { auto loc = op.getLoc(); auto *ctx = op.getContext(); auto memDescTy = cast(op.getSrc().getType()); + // TODO: PartitionedSharedEncoding lowering will be enabled in subsequent + // PRs. if (isa(memDescTy.getEncoding())) { return rewriter.notifyMatchFailure( op, "PartitionedSharedEncoding not yet supported in lowering"); @@ -314,6 +316,8 @@ struct LocalScatterOpConversion auto loc = op.getLoc(); auto *ctx = op.getContext(); auto memDescTy = cast(op.getDst().getType()); + // TODO: PartitionedSharedEncoding lowering will be enabled in subsequent + // PRs. if (isa(memDescTy.getEncoding())) { return rewriter.notifyMatchFailure( op, "PartitionedSharedEncoding not yet supported in lowering"); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp index 2465b0b22bfa..a076348c5d12 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp @@ -209,14 +209,31 @@ static std::string getConstraintForBitwidth(unsigned bitwidth) { void TargetInfo::storeDShared(RewriterBase &rewriter, Location loc, Value ptr, std::optional ctaId, Value val, Value pred) const { + if (!ctaId) { + storeDSharedImpl(rewriter, loc, ptr, ctaId, val, pred); + return; + } + + auto b = TritonLLVMOpBuilder(loc, rewriter); + Value currentCtaId = getClusterCTAId(rewriter, loc); + Value isLocal = b.icmp_eq(*ctaId, currentCtaId); + storeDSharedImpl(rewriter, loc, ptr, std::nullopt, val, + b.and_(pred, isLocal)); + storeDSharedImpl(rewriter, loc, ptr, ctaId, val, + b.and_(pred, b.icmp_ne(*ctaId, currentCtaId))); +} + +void TargetInfo::storeDSharedImpl(RewriterBase &rewriter, Location loc, + Value ptr, std::optional ctaId, + Value val, Value pred) const { auto b = TritonLLVMOpBuilder(loc, rewriter); MLIRContext *ctx = rewriter.getContext(); auto ptrTy = cast(ptr.getType()); assert(ptrTy.getAddressSpace() == 3 && "Invalid addr space for load_dsmem"); if (!isa(val.getType())) { - storeDShared(rewriter, loc, ptr, ctaId, packLLVector(loc, {val}, rewriter), - pred); + storeDSharedImpl(rewriter, loc, ptr, ctaId, + packLLVector(loc, {val}, rewriter), pred); return; } @@ -233,8 +250,8 @@ void TargetInfo::storeDShared(RewriterBase &rewriter, Location loc, Value ptr, for (Value &v : vals) { v = b.zext(int_ty(8), b.bitcast(v, int_ty(elemBitwidth))); } - storeDShared(rewriter, loc, ptr, ctaId, packLLVector(loc, vals, rewriter), - pred); + storeDSharedImpl(rewriter, loc, ptr, ctaId, + packLLVector(loc, vals, rewriter), pred); return; } @@ -247,8 +264,8 @@ void TargetInfo::storeDShared(RewriterBase &rewriter, Location loc, Value ptr, v = b.bitcast(v, int_ty(elemBitwidth)); } } - storeDShared(rewriter, loc, ptr, ctaId, packLLVector(loc, vals, rewriter), - pred); + storeDSharedImpl(rewriter, loc, ptr, ctaId, + packLLVector(loc, vals, rewriter), pred); return; } @@ -268,8 +285,8 @@ void TargetInfo::storeDShared(RewriterBase &rewriter, Location loc, Value ptr, rewriter); newVals.push_back(b.bitcast(v, i32_ty)); } - storeDShared(rewriter, loc, ptr, ctaId, - packLLVector(loc, newVals, rewriter), pred); + storeDSharedImpl(rewriter, loc, ptr, ctaId, + packLLVector(loc, newVals, rewriter), pred); return; } @@ -282,7 +299,7 @@ void TargetInfo::storeDShared(RewriterBase &rewriter, Location loc, Value ptr, for (int i = 0; i < vec / maxVec; i++) { auto newPtr = b.gep(ptr.getType(), elemTy, ptr, b.i32_val(i * maxVec), LLVM::GEPNoWrapFlags::inbounds); - storeDShared( + storeDSharedImpl( rewriter, loc, newPtr, ctaId, packLLVector(loc, ArrayRef(vals).slice(i * maxVec, maxVec), rewriter), pred); @@ -296,16 +313,21 @@ void TargetInfo::storeDShared(RewriterBase &rewriter, Location loc, Value ptr, assert(1 <= vec && vec <= 4); assert(vec * elemBitwidth <= 128); - auto emitStore = [&](Value storePtr, bool isRemote, Value storePred) { - if (isConstantTruePred(storePred)) { - b.store(val, storePtr, /*align=*/vec * elemBitwidth / 8); - return; - } - PTXBuilder builder; - auto st = builder.create("st") - ->o(isRemote ? "shared::cluster" : "shared::cta") - .v(vec, /*predicate=*/vec > 1) - .b(elemBitwidth); + // Get pointer to remote shared memory if needed. + if (ctaId.has_value()) { + ptr = mapa(rewriter, loc, ptr, *ctaId, pred); + } + + PTXBuilder builder; + auto st = builder.create("st") + ->o(ctaId.has_value() ? "shared::cluster" : "shared::cta") + .v(vec, /*predicate=*/vec > 1) + .b(elemBitwidth); + auto *ptrOpr = builder.newAddrOperand(ptr, "r"); + + if (isConstantTruePred(pred)) { + b.store(val, ptr, /*align=*/vec * elemBitwidth / 8); + } else { PTXBuilder::Operand *valOpr; std::string constraint = getConstraintForBitwidth(elemBitwidth); if (vec > 1) { @@ -317,27 +339,42 @@ void TargetInfo::storeDShared(RewriterBase &rewriter, Location loc, Value ptr, } else { valOpr = builder.newOperand(val, constraint); } - st(builder.newAddrOperand(storePtr, "r"), valOpr).predicate(storePred, "b"); + st(ptrOpr, valOpr).predicate(pred, "b"); builder.launch(rewriter, loc, void_ty(ctx)); - }; - - if (!ctaId) { - emitStore(ptr, /*isRemote=*/false, pred); - return; } - - Value currentCtaId = getClusterCTAId(rewriter, loc); - Value isLocal = b.icmp_eq(*ctaId, currentCtaId); - Value localPred = b.and_(pred, isLocal); - Value remotePred = b.and_(pred, b.icmp_ne(*ctaId, currentCtaId)); - emitStore(ptr, /*isRemote=*/false, localPred); - emitStore(mapa(rewriter, loc, ptr, *ctaId, remotePred), - /*isRemote=*/true, remotePred); } Value TargetInfo::loadDShared(RewriterBase &rewriter, Location loc, Value ptr, std::optional ctaId, Type loadTy, Value pred, Operation *localLoadOp) const { + if (!ctaId) + return loadDSharedImpl(rewriter, loc, ptr, ctaId, loadTy, pred, + localLoadOp); + + auto b = TritonLLVMOpBuilder(loc, rewriter); + Value currentCtaId = getClusterCTAId(rewriter, loc); + Value isLocal = b.icmp_eq(*ctaId, currentCtaId); + Value local = loadDSharedImpl(rewriter, loc, ptr, std::nullopt, loadTy, + b.and_(pred, isLocal), localLoadOp); + Value remote = loadDSharedImpl(rewriter, loc, ptr, ctaId, loadTy, + b.and_(pred, b.icmp_ne(*ctaId, currentCtaId)), + localLoadOp); + if (!isa(loadTy)) + return b.select(isLocal, local, remote); + + SmallVector selected; + for (auto [localVal, remoteVal] : + llvm::zip(unpackLLVector(loc, local, rewriter), + unpackLLVector(loc, remote, rewriter))) { + selected.push_back(b.select(isLocal, localVal, remoteVal)); + } + return packLLVector(loc, selected, rewriter); +} + +Value TargetInfo::loadDSharedImpl(RewriterBase &rewriter, Location loc, + Value ptr, std::optional ctaId, + Type loadTy, Value pred, + Operation *localLoadOp) const { auto b = TritonLLVMOpBuilder(loc, rewriter); MLIRContext *ctx = rewriter.getContext(); auto ptrTy = cast(ptr.getType()); @@ -345,7 +382,8 @@ Value TargetInfo::loadDShared(RewriterBase &rewriter, Location loc, Value ptr, if (!isa(loadTy)) { SmallVector values = unpackLLVector( - loc, loadDShared(rewriter, loc, ptr, ctaId, vec_ty(loadTy, 1), pred), + loc, + loadDSharedImpl(rewriter, loc, ptr, ctaId, vec_ty(loadTy, 1), pred), rewriter); assert(values.size() == 1); return values[0]; @@ -361,7 +399,8 @@ Value TargetInfo::loadDShared(RewriterBase &rewriter, Location loc, Value ptr, assert(vec == 1 && "don't know how to load/store vectors of sub-byte elems"); SmallVector vals = unpackLLVector( - loc, loadDShared(rewriter, loc, ptr, ctaId, int_ty(8), pred), rewriter); + loc, loadDSharedImpl(rewriter, loc, ptr, ctaId, int_ty(8), pred), + rewriter); assert(vals.size() == 1); return b.bitcast(b.trunc(int_ty(elemBitwidth), vals[0]), elemTy); } @@ -370,7 +409,8 @@ Value TargetInfo::loadDShared(RewriterBase &rewriter, Location loc, Value ptr, if (!elemTy.isInteger()) { Type newLoadTy = vec_ty(int_ty(elemBitwidth), vec); SmallVector vals = unpackLLVector( - loc, loadDShared(rewriter, loc, ptr, ctaId, newLoadTy, pred), rewriter); + loc, loadDSharedImpl(rewriter, loc, ptr, ctaId, newLoadTy, pred), + rewriter); for (Value &v : vals) { v = b.bitcast(v, elemTy); } @@ -384,7 +424,7 @@ Value TargetInfo::loadDShared(RewriterBase &rewriter, Location loc, Value ptr, if (vec > 4 && elemBitwidth < 32) { int newVec = vec / (32 / elemBitwidth); auto newVecTy = vec_ty(i32_ty, newVec); - auto res = loadDShared(rewriter, loc, ptr, ctaId, newVecTy, pred); + auto res = loadDSharedImpl(rewriter, loc, ptr, ctaId, newVecTy, pred); // Unpack the b32's into the original vector type. SmallVector vals; @@ -406,8 +446,8 @@ Value TargetInfo::loadDShared(RewriterBase &rewriter, Location loc, Value ptr, for (int i = 0; i < vec / maxVec; i++) { auto newPtr = b.gep(ptr.getType(), elemTy, ptr, b.i32_val(i * maxVec), LLVM::GEPNoWrapFlags::inbounds); - auto newVal = loadDShared(rewriter, loc, newPtr, ctaId, - vec_ty(elemTy, maxVec), pred); + auto newVal = loadDSharedImpl(rewriter, loc, newPtr, ctaId, + vec_ty(elemTy, maxVec), pred); for (Value v : unpackLLVector(loc, newVal, rewriter)) { vals.push_back(v); } @@ -421,62 +461,45 @@ Value TargetInfo::loadDShared(RewriterBase &rewriter, Location loc, Value ptr, assert(1 <= vec && vec <= 4); assert(vec * elemBitwidth <= 128); - auto emitLoad = [&](Value loadPtr, bool isRemote, Value loadPred) { - Value load; - if (isConstantTruePred(loadPred)) { - Type resultTy = vec == 1 ? Type(int_ty(elemBitwidth)) - : Type(vec_ty(int_ty(elemBitwidth), vec)); - load = b.load(resultTy, loadPtr, /*align=*/vec * elemBitwidth / 8); - if (vec > 1) { - Type structTy = struct_ty(SmallVector(vec, int_ty(elemBitwidth))); - Value structValue = b.undef(structTy); - for (int i = 0; i < vec; i++) { - structValue = b.insert_val(structTy, structValue, - b.extract_element(load, b.i32_val(i)), i); - } - load = structValue; - } - } else { - PTXBuilder builder; - auto ld = builder.create("ld") - ->o(isRemote ? "shared::cluster" : "shared::cta") - .v(vec, /*predicate=*/vec > 1) - .b(elemBitwidth); - std::string elemConstraint = "=" + getConstraintForBitwidth(elemBitwidth); - auto *outOpr = vec == 1 ? builder.newOperand(elemConstraint) - : builder.newListOperand(vec, elemConstraint); - ld(outOpr, builder.newAddrOperand(loadPtr, "r")).predicate(loadPred, "b"); - - Type resultTy = - vec == 1 - ? Type(int_ty(elemBitwidth)) - : Type(struct_ty(SmallVector(vec, int_ty(elemBitwidth)))); - load = builder.launch(rewriter, loc, resultTy, /*hasSideEffects=*/true); - } - SmallVector resultVals = unpackLLElements(loc, load, rewriter); - return packLLVector(loc, resultVals, rewriter); - }; - - if (!ctaId) - return emitLoad(ptr, /*isRemote=*/false, pred); + // Get pointer to remote shared memory if needed. + if (ctaId.has_value()) { + ptr = mapa(rewriter, loc, ptr, *ctaId, pred); + } - Value currentCtaId = getClusterCTAId(rewriter, loc); - Value isLocal = b.icmp_eq(*ctaId, currentCtaId); - Value localPred = b.and_(pred, isLocal); - Value remotePred = b.and_(pred, b.icmp_ne(*ctaId, currentCtaId)); - Value local = emitLoad(ptr, /*isRemote=*/false, localPred); - Value remote = emitLoad(mapa(rewriter, loc, ptr, *ctaId, remotePred), - /*isRemote=*/true, remotePred); - if (vec == 1) - return b.select(isLocal, local, remote); + PTXBuilder builder; + auto ld = builder.create("ld") + ->o(ctaId.has_value() ? "shared::cluster" : "shared::cta") + .v(vec, /*predicate=*/vec > 1) + .b(elemBitwidth); - SmallVector selected; - for (auto [localVal, remoteVal] : - llvm::zip(unpackLLVector(loc, local, rewriter), - unpackLLVector(loc, remote, rewriter))) { - selected.push_back(b.select(isLocal, localVal, remoteVal)); + Value load; + if (isConstantTruePred(pred)) { + Type resultTy = vec == 1 ? Type(int_ty(elemBitwidth)) + : Type(vec_ty(int_ty(elemBitwidth), vec)); + load = b.load(resultTy, ptr, /*align=*/vec * elemBitwidth / 8); + if (vec > 1) { + Type structTy = struct_ty(SmallVector(vec, int_ty(elemBitwidth))); + Value structValue = b.undef(structTy); + for (int i = 0; i < vec; i++) { + structValue = b.insert_val(structTy, structValue, + b.extract_element(load, b.i32_val(i)), i); + } + load = structValue; + } + } else { + std::string elemConstraint = "=" + getConstraintForBitwidth(elemBitwidth); + auto *outOpr = vec == 1 ? builder.newOperand(elemConstraint) + : builder.newListOperand(vec, elemConstraint); + ld(outOpr, builder.newAddrOperand(ptr, "r")).predicate(pred, "b"); + + Type resultTy = + vec == 1 + ? Type(int_ty(elemBitwidth)) + : Type(struct_ty(SmallVector(vec, int_ty(elemBitwidth)))); + load = builder.launch(rewriter, loc, resultTy, /*hasSideEffects=*/true); } - return packLLVector(loc, selected, rewriter); + SmallVector resultVals = unpackLLElements(loc, load, rewriter); + return packLLVector(loc, resultVals, rewriter); } Value TargetInfo::shuffleXor(RewriterBase &rewriter, Location loc, Value val, diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.h b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.h index 00efb5f5f469..7e0a3cdd2bda 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.h +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.h @@ -94,6 +94,13 @@ class TargetInfo : public mlir::triton::TargetInfoBase { bool isCuda() const override { return true; } private: + void storeDSharedImpl(RewriterBase &rewriter, Location loc, Value ptr, + std::optional ctaId, Value val, + Value pred) const; + Value loadDSharedImpl(RewriterBase &rewriter, Location loc, Value ptr, + std::optional ctaId, Type elemTy, Value pred, + Operation *localLoadOp = nullptr) const; + triton::nvidia_gpu::TargetFeatures targetFeatures; int ptxVersion; }; From 3d0f65f608e85c975e2448228d5a1bb68d350cf1 Mon Sep 17 00:00:00 2001 From: Jeff Niu Date: Sat, 6 Jun 2026 22:52:52 +0000 Subject: [PATCH 14/22] [NVIDIA] Trim multi-CTA gather changes --- .../TritonGPUToLLVM/TargetInfoBase.h | 2 +- .../Conversion/TritonGPUToLLVM/Utility.h | 7 -- .../TritonGPUToLLVM/MemoryOpToLLVM.cpp | 6 +- lib/Conversion/TritonGPUToLLVM/Utility.cpp | 19 +---- python/test/gluon/test_core.py | 22 +---- test/Conversion/tritonnvidiagpu_to_llvm.mlir | 25 ------ .../lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp | 83 +++++++++---------- .../lib/TritonNVIDIAGPUToLLVM/TargetInfo.h | 7 -- 8 files changed, 52 insertions(+), 119 deletions(-) diff --git a/include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h b/include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h index 690853ac6cd3..29d0e927408e 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h +++ b/include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h @@ -31,7 +31,7 @@ class TargetInfoBase { virtual void warpSync(Location loc, RewriterBase &rewriter) const = 0; // Store/load a value from shared memory, either in the same CTA or, if - // `ctaId` is non-nullopt, in the specified CTA in the same group. + // `ctaId` is non-nullopt, in another CTA in the same group. // // A target that does not support cross-CTA transfers will assert if ctaId is // non-nullopt. diff --git a/include/triton/Conversion/TritonGPUToLLVM/Utility.h b/include/triton/Conversion/TritonGPUToLLVM/Utility.h index abd813304502..15fff8556cc5 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/Utility.h +++ b/include/triton/Conversion/TritonGPUToLLVM/Utility.h @@ -618,13 +618,6 @@ SmallVector computeLocalAddrs( SharedMemoryObject smemObj, Type llvmElemTy, ArrayRef idxValues, ArrayRef> coords, unsigned axis, RewriterBase &rewriter); -SmallVector computeLocalPtrs(Location loc, - triton::gpu::MemDescType memDescTy, - SharedMemoryObject smemObj, Type llvmElemTy, - ArrayRef idxValues, - ArrayRef> coords, - unsigned axis, RewriterBase &rewriter); - // Backend-agnostic preparation for lowering LocalAtomicScatterRMWOp. struct LocalAtomicScatterRMWInfo { RankedTensorType valuesTy; diff --git a/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp index 98a3379e4b40..aea0c5de70c9 100644 --- a/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp @@ -270,7 +270,8 @@ struct LocalGatherOpConversion : public ConvertOpToLLVMPattern { auto memDescTy = cast(op.getSrc().getType()); // TODO: PartitionedSharedEncoding lowering will be enabled in subsequent // PRs. - if (isa(memDescTy.getEncoding())) { + if (isa( + memDescTy.getEncoding())) { return rewriter.notifyMatchFailure( op, "PartitionedSharedEncoding not yet supported in lowering"); } @@ -318,7 +319,8 @@ struct LocalScatterOpConversion auto memDescTy = cast(op.getDst().getType()); // TODO: PartitionedSharedEncoding lowering will be enabled in subsequent // PRs. - if (isa(memDescTy.getEncoding())) { + if (isa( + memDescTy.getEncoding())) { return rewriter.notifyMatchFailure( op, "PartitionedSharedEncoding not yet supported in lowering"); } diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.cpp b/lib/Conversion/TritonGPUToLLVM/Utility.cpp index 808289ed5e68..51d17dde4263 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.cpp +++ b/lib/Conversion/TritonGPUToLLVM/Utility.cpp @@ -630,18 +630,6 @@ computeLocalAddrs(Location loc, triton::gpu::MemDescType memDescTy, return addrs; } -SmallVector computeLocalPtrs(Location loc, - triton::gpu::MemDescType memDescTy, - SharedMemoryObject smemObj, Type llvmElemTy, - ArrayRef idxValues, - ArrayRef> coords, - unsigned axis, RewriterBase &rewriter) { - return llvm::map_to_vector( - computeLocalAddrs(loc, memDescTy, smemObj, llvmElemTy, idxValues, coords, - axis, rewriter), - [](const LocalSharedMemoryAddress &addr) { return addr.ptr; }); -} - FailureOr prepareLocalAtomicScatterRMW( triton::gpu::LocalAtomicScatterRMWOp op, Value dst, Value indices, Value inputValues, Value mask, ConversionPatternRewriter &rewriter, @@ -680,9 +668,10 @@ FailureOr prepareLocalAtomicScatterRMW( emitIndices(loc, rewriter, targetInfo, activeRegLayout, valuesTy, /*withCTAOffset=*/true); - SmallVector ptrs = - computeLocalPtrs(loc, memDescTy, smemObj, llvmElemTy, idxValues, - srcIndices, op.getAxis(), rewriter); + SmallVector ptrs = llvm::map_to_vector( + computeLocalAddrs(loc, memDescTy, smemObj, llvmElemTy, idxValues, + srcIndices, op.getAxis(), rewriter), + [](const LocalSharedMemoryAddress &addr) { return addr.ptr; }); return LocalAtomicScatterRMWInfo{valuesTy, llvmElemTy, regLayout, removeBroadcast, threadPred, values, diff --git a/python/test/gluon/test_core.py b/python/test/gluon/test_core.py index f9d0ecbe4bb2..bc739d9db28f 100644 --- a/python/test/gluon/test_core.py +++ b/python/test/gluon/test_core.py @@ -2714,10 +2714,10 @@ def test_shared_gather(N, M): def shared_gather_scatter_two_ctas_kernel( inp, out, - GATHER: ttgl.constexpr, layout: ttgl.constexpr, - shared_layout: ttgl.constexpr, + GATHER: ttgl.constexpr, ): + shared_layout: ttgl.constexpr = ttgl.SwizzledSharedLayout(1, 1, 1, order=[1, 0], cga_layout=[[0, 1]]) rows = ttgl.arange(0, 2, layout=ttgl.SliceLayout(1, layout)) cols = ttgl.arange(0, 32, layout=ttgl.SliceLayout(0, layout)) offsets = rows[:, None] * 32 + cols[None, :] @@ -2738,29 +2738,15 @@ def shared_gather_scatter_two_ctas_kernel( @pytest.mark.skipif(not is_hopper_or_newer(), reason="Requires Hopper") @pytest.mark.parametrize("gather", [True, False], ids=["gather", "scatter"]) def test_shared_gather_scatter_two_ctas(gather): - layout = ttgl.BlockedLayout( - size_per_thread=[1, 1], - threads_per_warp=[1, THREADS_PER_WARP], - warps_per_cta=[1, 4], - order=[1, 0], - cga_layout=[[0, 1]], - ) - shared_layout = ttgl.SwizzledSharedLayout( - vec=1, - per_phase=1, - max_phase=1, - order=[1, 0], - cga_layout=[[0, 1]], - ) inp = torch.arange(64, dtype=torch.int32, device="cuda").reshape(2, 32) out = torch.empty_like(inp) + layout = ttgl.BlockedLayout([1, 1], [1, 32], [1, 4], [1, 0], cga_layout=[[0, 1]]) shared_gather_scatter_two_ctas_kernel[(1, )]( inp, out, + layout, GATHER=gather, - layout=layout, - shared_layout=shared_layout, num_warps=4, num_ctas=2, ) diff --git a/test/Conversion/tritonnvidiagpu_to_llvm.mlir b/test/Conversion/tritonnvidiagpu_to_llvm.mlir index 8fbd4a5c6595..9c81ab8b9899 100644 --- a/test/Conversion/tritonnvidiagpu_to_llvm.mlir +++ b/test/Conversion/tritonnvidiagpu_to_llvm.mlir @@ -802,28 +802,3 @@ module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32, "ttg.tot llvm.return } } - -// ----- - -#local_gather_scatter_blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0], CGALayout = [[0, 1]]}> -#local_gather_scatter_shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], CGALayout = [[0, 1]]}> - -module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { - // CHECK-LABEL: @local_gather_scatter_two_ctas - // CHECK: ld.shared::cta - // CHECK: ld.shared::cluster - // CHECK: nvvm.barrier0 - // CHECK: st.shared::cta - // CHECK: st.shared::cluster - tt.func @local_gather_scatter_two_ctas(%out: !tt.ptr, %vals: tensor<2x32xi32, #local_gather_scatter_blocked>) { - %src = ttg.local_alloc {allocation.offset = [0 : i32, 256 : i32]} : () -> !ttg.memdesc<2x32xi32, #local_gather_scatter_shared, #ttg.shared_memory, mutable> - %idx = arith.constant dense<0> : tensor<2x32xi32, #local_gather_scatter_blocked> - %g = ttg.local_gather %src[%idx] {axis = 0 : i32} : !ttg.memdesc<2x32xi32, #local_gather_scatter_shared, #ttg.shared_memory, mutable>, tensor<2x32xi32, #local_gather_scatter_blocked> -> tensor<2x32xi32, #local_gather_scatter_blocked> - ttg.local_scatter %src[%idx], %vals {axis = 0 : i32} : !ttg.memdesc<2x32xi32, #local_gather_scatter_shared, #ttg.shared_memory, mutable>, tensor<2x32xi32, #local_gather_scatter_blocked>, tensor<2x32xi32, #local_gather_scatter_blocked> - %ptrs = tt.splat %out : !tt.ptr -> tensor<2x32x!tt.ptr, #local_gather_scatter_blocked> - %offs = arith.constant dense<0> : tensor<2x32xi32, #local_gather_scatter_blocked> - %out_ptrs = tt.addptr %ptrs, %offs : tensor<2x32x!tt.ptr, #local_gather_scatter_blocked>, tensor<2x32xi32, #local_gather_scatter_blocked> - tt.store %out_ptrs, %g : tensor<2x32x!tt.ptr, #local_gather_scatter_blocked> - tt.return - } -} diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp index a076348c5d12..956b2d250e28 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp @@ -206,26 +206,9 @@ static std::string getConstraintForBitwidth(unsigned bitwidth) { } } -void TargetInfo::storeDShared(RewriterBase &rewriter, Location loc, Value ptr, - std::optional ctaId, Value val, - Value pred) const { - if (!ctaId) { - storeDSharedImpl(rewriter, loc, ptr, ctaId, val, pred); - return; - } - - auto b = TritonLLVMOpBuilder(loc, rewriter); - Value currentCtaId = getClusterCTAId(rewriter, loc); - Value isLocal = b.icmp_eq(*ctaId, currentCtaId); - storeDSharedImpl(rewriter, loc, ptr, std::nullopt, val, - b.and_(pred, isLocal)); - storeDSharedImpl(rewriter, loc, ptr, ctaId, val, - b.and_(pred, b.icmp_ne(*ctaId, currentCtaId))); -} - -void TargetInfo::storeDSharedImpl(RewriterBase &rewriter, Location loc, - Value ptr, std::optional ctaId, - Value val, Value pred) const { +static void storeDSharedImpl(RewriterBase &rewriter, Location loc, Value ptr, + std::optional ctaId, Value val, + Value pred) { auto b = TritonLLVMOpBuilder(loc, rewriter); MLIRContext *ctx = rewriter.getContext(); auto ptrTy = cast(ptr.getType()); @@ -344,37 +327,24 @@ void TargetInfo::storeDSharedImpl(RewriterBase &rewriter, Location loc, } } -Value TargetInfo::loadDShared(RewriterBase &rewriter, Location loc, Value ptr, - std::optional ctaId, Type loadTy, - Value pred, Operation *localLoadOp) const { +void TargetInfo::storeDShared(RewriterBase &rewriter, Location loc, Value ptr, + std::optional ctaId, Value val, + Value pred) const { if (!ctaId) - return loadDSharedImpl(rewriter, loc, ptr, ctaId, loadTy, pred, - localLoadOp); + return storeDSharedImpl(rewriter, loc, ptr, ctaId, val, pred); auto b = TritonLLVMOpBuilder(loc, rewriter); Value currentCtaId = getClusterCTAId(rewriter, loc); Value isLocal = b.icmp_eq(*ctaId, currentCtaId); - Value local = loadDSharedImpl(rewriter, loc, ptr, std::nullopt, loadTy, - b.and_(pred, isLocal), localLoadOp); - Value remote = loadDSharedImpl(rewriter, loc, ptr, ctaId, loadTy, - b.and_(pred, b.icmp_ne(*ctaId, currentCtaId)), - localLoadOp); - if (!isa(loadTy)) - return b.select(isLocal, local, remote); - - SmallVector selected; - for (auto [localVal, remoteVal] : - llvm::zip(unpackLLVector(loc, local, rewriter), - unpackLLVector(loc, remote, rewriter))) { - selected.push_back(b.select(isLocal, localVal, remoteVal)); - } - return packLLVector(loc, selected, rewriter); + storeDSharedImpl(rewriter, loc, ptr, std::nullopt, val, + b.and_(pred, isLocal)); + storeDSharedImpl(rewriter, loc, ptr, ctaId, val, + b.and_(pred, b.icmp_ne(*ctaId, currentCtaId))); } -Value TargetInfo::loadDSharedImpl(RewriterBase &rewriter, Location loc, - Value ptr, std::optional ctaId, - Type loadTy, Value pred, - Operation *localLoadOp) const { +static Value loadDSharedImpl(RewriterBase &rewriter, Location loc, Value ptr, + std::optional ctaId, Type loadTy, + Value pred) { auto b = TritonLLVMOpBuilder(loc, rewriter); MLIRContext *ctx = rewriter.getContext(); auto ptrTy = cast(ptr.getType()); @@ -502,6 +472,31 @@ Value TargetInfo::loadDSharedImpl(RewriterBase &rewriter, Location loc, return packLLVector(loc, resultVals, rewriter); } +Value TargetInfo::loadDShared(RewriterBase &rewriter, Location loc, Value ptr, + std::optional ctaId, Type loadTy, + Value pred, Operation *) const { + if (!ctaId) + return loadDSharedImpl(rewriter, loc, ptr, ctaId, loadTy, pred); + + auto b = TritonLLVMOpBuilder(loc, rewriter); + Value currentCtaId = getClusterCTAId(rewriter, loc); + Value isLocal = b.icmp_eq(*ctaId, currentCtaId); + Value local = loadDSharedImpl(rewriter, loc, ptr, std::nullopt, loadTy, + b.and_(pred, isLocal)); + Value remote = loadDSharedImpl(rewriter, loc, ptr, ctaId, loadTy, + b.and_(pred, b.icmp_ne(*ctaId, currentCtaId))); + if (!isa(loadTy)) + return b.select(isLocal, local, remote); + + SmallVector selected; + for (auto [localVal, remoteVal] : + llvm::zip(unpackLLVector(loc, local, rewriter), + unpackLLVector(loc, remote, rewriter))) { + selected.push_back(b.select(isLocal, localVal, remoteVal)); + } + return packLLVector(loc, selected, rewriter); +} + Value TargetInfo::shuffleXor(RewriterBase &rewriter, Location loc, Value val, int i) const { return LLVM::NVIDIA::shuffleXor(loc, rewriter, val, i); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.h b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.h index 7e0a3cdd2bda..00efb5f5f469 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.h +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.h @@ -94,13 +94,6 @@ class TargetInfo : public mlir::triton::TargetInfoBase { bool isCuda() const override { return true; } private: - void storeDSharedImpl(RewriterBase &rewriter, Location loc, Value ptr, - std::optional ctaId, Value val, - Value pred) const; - Value loadDSharedImpl(RewriterBase &rewriter, Location loc, Value ptr, - std::optional ctaId, Type elemTy, Value pred, - Operation *localLoadOp = nullptr) const; - triton::nvidia_gpu::TargetFeatures targetFeatures; int ptxVersion; }; From d098a0e9235733793eddc0d8eb61d0fe6e78a1ff Mon Sep 17 00:00:00 2001 From: Jeff Niu Date: Sat, 6 Jun 2026 22:56:21 +0000 Subject: [PATCH 15/22] [NVIDIA] Restore multi-CTA lowering coverage --- python/test/gluon/test_core.py | 22 +++++++++++++---- test/Conversion/tritonnvidiagpu_to_llvm.mlir | 25 ++++++++++++++++++++ 2 files changed, 43 insertions(+), 4 deletions(-) diff --git a/python/test/gluon/test_core.py b/python/test/gluon/test_core.py index bc739d9db28f..f9d0ecbe4bb2 100644 --- a/python/test/gluon/test_core.py +++ b/python/test/gluon/test_core.py @@ -2714,10 +2714,10 @@ def test_shared_gather(N, M): def shared_gather_scatter_two_ctas_kernel( inp, out, - layout: ttgl.constexpr, GATHER: ttgl.constexpr, + layout: ttgl.constexpr, + shared_layout: ttgl.constexpr, ): - shared_layout: ttgl.constexpr = ttgl.SwizzledSharedLayout(1, 1, 1, order=[1, 0], cga_layout=[[0, 1]]) rows = ttgl.arange(0, 2, layout=ttgl.SliceLayout(1, layout)) cols = ttgl.arange(0, 32, layout=ttgl.SliceLayout(0, layout)) offsets = rows[:, None] * 32 + cols[None, :] @@ -2738,15 +2738,29 @@ def shared_gather_scatter_two_ctas_kernel( @pytest.mark.skipif(not is_hopper_or_newer(), reason="Requires Hopper") @pytest.mark.parametrize("gather", [True, False], ids=["gather", "scatter"]) def test_shared_gather_scatter_two_ctas(gather): + layout = ttgl.BlockedLayout( + size_per_thread=[1, 1], + threads_per_warp=[1, THREADS_PER_WARP], + warps_per_cta=[1, 4], + order=[1, 0], + cga_layout=[[0, 1]], + ) + shared_layout = ttgl.SwizzledSharedLayout( + vec=1, + per_phase=1, + max_phase=1, + order=[1, 0], + cga_layout=[[0, 1]], + ) inp = torch.arange(64, dtype=torch.int32, device="cuda").reshape(2, 32) out = torch.empty_like(inp) - layout = ttgl.BlockedLayout([1, 1], [1, 32], [1, 4], [1, 0], cga_layout=[[0, 1]]) shared_gather_scatter_two_ctas_kernel[(1, )]( inp, out, - layout, GATHER=gather, + layout=layout, + shared_layout=shared_layout, num_warps=4, num_ctas=2, ) diff --git a/test/Conversion/tritonnvidiagpu_to_llvm.mlir b/test/Conversion/tritonnvidiagpu_to_llvm.mlir index 9c81ab8b9899..266734636f83 100644 --- a/test/Conversion/tritonnvidiagpu_to_llvm.mlir +++ b/test/Conversion/tritonnvidiagpu_to_llvm.mlir @@ -611,6 +611,31 @@ module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32, "ttg.tot // ----- +#local_gather_scatter_blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0], CGALayout = [[0, 1]]}> +#local_gather_scatter_shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], CGALayout = [[0, 1]]}> + +module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: @local_gather_scatter_two_ctas + // CHECK: ld.shared::cta + // CHECK: ld.shared::cluster + // CHECK: nvvm.barrier0 + // CHECK: st.shared::cta + // CHECK: st.shared::cluster + tt.func @local_gather_scatter_two_ctas(%out: !tt.ptr, %vals: tensor<2x32xi32, #local_gather_scatter_blocked>) { + %src = ttg.local_alloc {allocation.offset = [0 : i32, 256 : i32]} : () -> !ttg.memdesc<2x32xi32, #local_gather_scatter_shared, #ttg.shared_memory, mutable> + %idx = arith.constant dense<0> : tensor<2x32xi32, #local_gather_scatter_blocked> + %g = ttg.local_gather %src[%idx] {axis = 0 : i32} : !ttg.memdesc<2x32xi32, #local_gather_scatter_shared, #ttg.shared_memory, mutable>, tensor<2x32xi32, #local_gather_scatter_blocked> -> tensor<2x32xi32, #local_gather_scatter_blocked> + ttg.local_scatter %src[%idx], %vals {axis = 0 : i32} : !ttg.memdesc<2x32xi32, #local_gather_scatter_shared, #ttg.shared_memory, mutable>, tensor<2x32xi32, #local_gather_scatter_blocked>, tensor<2x32xi32, #local_gather_scatter_blocked> + %ptrs = tt.splat %out : !tt.ptr -> tensor<2x32x!tt.ptr, #local_gather_scatter_blocked> + %offs = arith.constant dense<0> : tensor<2x32xi32, #local_gather_scatter_blocked> + %out_ptrs = tt.addptr %ptrs, %offs : tensor<2x32x!tt.ptr, #local_gather_scatter_blocked>, tensor<2x32xi32, #local_gather_scatter_blocked> + tt.store %out_ptrs, %g : tensor<2x32x!tt.ptr, #local_gather_scatter_blocked> + tt.return + } +} + +// ----- + module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 5 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK: module attributes { // CHECK-DAG: ttg.shared = 24 : i32 From 96e29f06232d5a63945d4949eb2c73a179a3cf58 Mon Sep 17 00:00:00 2001 From: Jeff Niu Date: Sat, 6 Jun 2026 22:58:43 +0000 Subject: [PATCH 16/22] [NVIDIA] Simplify multi-CTA runtime test setup --- python/test/gluon/test_core.py | 22 ++++------------------ 1 file changed, 4 insertions(+), 18 deletions(-) diff --git a/python/test/gluon/test_core.py b/python/test/gluon/test_core.py index f9d0ecbe4bb2..bc739d9db28f 100644 --- a/python/test/gluon/test_core.py +++ b/python/test/gluon/test_core.py @@ -2714,10 +2714,10 @@ def test_shared_gather(N, M): def shared_gather_scatter_two_ctas_kernel( inp, out, - GATHER: ttgl.constexpr, layout: ttgl.constexpr, - shared_layout: ttgl.constexpr, + GATHER: ttgl.constexpr, ): + shared_layout: ttgl.constexpr = ttgl.SwizzledSharedLayout(1, 1, 1, order=[1, 0], cga_layout=[[0, 1]]) rows = ttgl.arange(0, 2, layout=ttgl.SliceLayout(1, layout)) cols = ttgl.arange(0, 32, layout=ttgl.SliceLayout(0, layout)) offsets = rows[:, None] * 32 + cols[None, :] @@ -2738,29 +2738,15 @@ def shared_gather_scatter_two_ctas_kernel( @pytest.mark.skipif(not is_hopper_or_newer(), reason="Requires Hopper") @pytest.mark.parametrize("gather", [True, False], ids=["gather", "scatter"]) def test_shared_gather_scatter_two_ctas(gather): - layout = ttgl.BlockedLayout( - size_per_thread=[1, 1], - threads_per_warp=[1, THREADS_PER_WARP], - warps_per_cta=[1, 4], - order=[1, 0], - cga_layout=[[0, 1]], - ) - shared_layout = ttgl.SwizzledSharedLayout( - vec=1, - per_phase=1, - max_phase=1, - order=[1, 0], - cga_layout=[[0, 1]], - ) inp = torch.arange(64, dtype=torch.int32, device="cuda").reshape(2, 32) out = torch.empty_like(inp) + layout = ttgl.BlockedLayout([1, 1], [1, 32], [1, 4], [1, 0], cga_layout=[[0, 1]]) shared_gather_scatter_two_ctas_kernel[(1, )]( inp, out, + layout, GATHER=gather, - layout=layout, - shared_layout=shared_layout, num_warps=4, num_ctas=2, ) From 1e55b3cb8ece14bbff6ac78a6ce4947fbf18d4ab Mon Sep 17 00:00:00 2001 From: Jeff Niu Date: Sat, 6 Jun 2026 23:04:29 +0000 Subject: [PATCH 17/22] [NVIDIA] Use nullable values for distributed shared memory --- .../TritonGPUToLLVM/TargetInfoBase.h | 14 +++---- .../Conversion/TritonGPUToLLVM/Utility.h | 25 ++++++------ lib/Conversion/TritonGPUToLLVM/Utility.cpp | 35 ++++++++--------- .../TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp | 5 +-- .../lib/TritonAMDGPUToLLVM/MemoryOpToLLVM.cpp | 5 +-- .../amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp | 11 +++--- .../amd/lib/TritonAMDGPUToLLVM/TargetInfo.h | 5 +-- .../LoadStoreOpToLLVM.cpp | 4 +- .../TritonNVIDIAGPUToLLVM/MemoryOpToLLVM.cpp | 4 +- .../lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp | 38 +++++++++---------- .../lib/TritonNVIDIAGPUToLLVM/TargetInfo.h | 5 +-- .../AMDPatternProtonGPUOpToLLVM.cpp | 6 +-- .../NvidiaPatternProtonGPUOpToLLVM.cpp | 6 +-- 13 files changed, 74 insertions(+), 89 deletions(-) diff --git a/include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h b/include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h index 29d0e927408e..be74d1894a6a 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h +++ b/include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h @@ -31,27 +31,25 @@ class TargetInfoBase { virtual void warpSync(Location loc, RewriterBase &rewriter) const = 0; // Store/load a value from shared memory, either in the same CTA or, if - // `ctaId` is non-nullopt, in another CTA in the same group. + // `ctaId` is non-null, in another CTA in the same group. // // A target that does not support cross-CTA transfers will assert if ctaId is - // non-nullopt. + // non-null. // // Assumes the address is aligned to the width of `val`. virtual void storeDShared(RewriterBase &rewriter, Location loc, Value ptr, - std::optional ctaId, Value val, - Value pred) const = 0; + Value ctaId, Value val, Value pred) const = 0; virtual Value loadDShared(RewriterBase &rewriter, Location loc, Value ptr, - std::optional ctaId, Type elemTy, Value pred, + Value ctaId, Type elemTy, Value pred, Operation *localLoadOp = nullptr) const = 0; void storeShared(RewriterBase &rewriter, Location loc, Value ptr, Value val, Value pred) const { - storeDShared(rewriter, loc, ptr, /*ctaId=*/std::nullopt, val, pred); + storeDShared(rewriter, loc, ptr, /*ctaId=*/Value(), val, pred); } Value loadShared(RewriterBase &rewriter, Location loc, Value ptr, Type elemTy, Value pred) const { - return loadDShared(rewriter, loc, ptr, /*ctaId=*/std::nullopt, elemTy, - pred); + return loadDShared(rewriter, loc, ptr, /*ctaId=*/Value(), elemTy, pred); } virtual Value shuffleXor(RewriterBase &rewriter, Location loc, Value val, diff --git a/include/triton/Conversion/TritonGPUToLLVM/Utility.h b/include/triton/Conversion/TritonGPUToLLVM/Utility.h index 15fff8556cc5..3948751981b5 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/Utility.h +++ b/include/triton/Conversion/TritonGPUToLLVM/Utility.h @@ -607,7 +607,7 @@ emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target, struct LocalSharedMemoryAddress { Value ptr; - std::optional ctaId; + Value ctaId; }; // Compute per-element shared-memory addresses for a local atomic/ldst update by @@ -670,18 +670,17 @@ lowerLdStShared(Location loc, MLIRContext *ctx, LinearLayout cvt, // and computes a new offset (mlir::Value) by applying padding based on // shared memory layout. // cvt: Maps (reg, lane, warp, block) → (offset[, partition]). -SmallVector -lowerLdSt(Location loc, MLIRContext *ctx, LinearLayout cvt, - ArrayRef valsArray, // Input for store, output for load - Type llvmElemTy, ArrayRef smemBases, - ArrayRef> paddingShifts, - Value affineOffset, uint64_t maskSpanAffineOffset, Value laneId, - Value warpId, RewriterBase &rewriter, - const TargetInfoBase &targetInfo, std::optional maybeMaxVecElems, - std::function(RewriterBase &, Location, - ArrayRef, Value, int, - VectorType, std::optional)> - lowerInst); +SmallVector lowerLdSt( + Location loc, MLIRContext *ctx, LinearLayout cvt, + ArrayRef valsArray, // Input for store, output for load + Type llvmElemTy, ArrayRef smemBases, + ArrayRef> paddingShifts, Value affineOffset, + uint64_t maskSpanAffineOffset, Value laneId, Value warpId, + RewriterBase &rewriter, const TargetInfoBase &targetInfo, + std::optional maybeMaxVecElems, + std::function(RewriterBase &, Location, ArrayRef, + Value, int, VectorType, Value)> + lowerInst); // Lower local_load/local_store via ld.shared/st.shared SmallVector diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.cpp b/lib/Conversion/TritonGPUToLLVM/Utility.cpp index 51d17dde4263..ac7fb894396a 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.cpp +++ b/lib/Conversion/TritonGPUToLLVM/Utility.cpp @@ -623,8 +623,7 @@ computeLocalAddrs(Location loc, triton::gpu::MemDescType memDescTy, ptr = b.gep(smemObj.getBase().getType(), llvmElemTy, smemObj.getBase(), offset); } - addrs.push_back( - {ptr, useBlockId ? std::optional(blockId) : std::nullopt}); + addrs.push_back({ptr, useBlockId ? blockId : Value()}); } return addrs; @@ -736,8 +735,7 @@ lowerLdStShared(Location loc, MLIRContext *ctx, LinearLayout cvt, auto emitLdSt = [&](RewriterBase &rewriter, Location loc, ArrayRef vals, Value shmemAddr, int idx, - VectorType vecTy, - std::optional ctaId) -> SmallVector { + VectorType vecTy, Value ctaId) -> SmallVector { auto length = vecTy.getNumElements(); if (isStore) { Value valsVec = @@ -759,18 +757,17 @@ lowerLdStShared(Location loc, MLIRContext *ctx, LinearLayout cvt, warpId, rewriter, targetInfo, maybeMaxVecElems, emitLdSt); } -SmallVector -lowerLdSt(Location loc, MLIRContext *ctx, LinearLayout cvt, - ArrayRef valsArray, // Input for store, output for load - Type llvmElemTy, ArrayRef smemBases, - ArrayRef> paddingShifts, - Value affineOffset, uint64_t maskSpanAffineOffset, Value laneId, - Value warpId, RewriterBase &rewriter, - const TargetInfoBase &targetInfo, std::optional maybeMaxVecElems, - std::function(RewriterBase &, Location, - ArrayRef, Value, int, - VectorType, std::optional)> - lowerInst) { +SmallVector lowerLdSt( + Location loc, MLIRContext *ctx, LinearLayout cvt, + ArrayRef valsArray, // Input for store, output for load + Type llvmElemTy, ArrayRef smemBases, + ArrayRef> paddingShifts, Value affineOffset, + uint64_t maskSpanAffineOffset, Value laneId, Value warpId, + RewriterBase &rewriter, const TargetInfoBase &targetInfo, + std::optional maybeMaxVecElems, + std::function(RewriterBase &, Location, ArrayRef, + Value, int, VectorType, Value)> + lowerInst) { assert(!smemBases.empty() && "smemBases cannot be empty"); auto vals = to_vector(valsArray); bool isStore = !vals.empty(); @@ -916,7 +913,7 @@ lowerLdSt(Location loc, MLIRContext *ctx, LinearLayout cvt, smemBase = b.extract_element(basesVec, partitionIdx); } - std::optional innerCtaOffset; + Value innerCtaOffset; if (useBlockId) { innerCtaOffset = b.add(ctaOffset, b.i32_val(idxAndBlockAdd[1].second)); } @@ -2019,7 +2016,7 @@ void finalizeTensorAtomicResults(Operation *op, RankedTensorType tensorTy, auto emitSt = [&](RewriterBase &rewriter, Location loc, ArrayRef vals, Value shmemAddr, int idx, VectorType vecTy, - std::optional ctaId) -> SmallVector { + Value ctaId) -> SmallVector { auto length = vecTy.getNumElements(); Value valsVec = packLLVector(loc, ArrayRef(vals).slice(idx, length), rewriter); @@ -2030,7 +2027,7 @@ void finalizeTensorAtomicResults(Operation *op, RankedTensorType tensorTy, auto emitLd = [&](RewriterBase &rewriter, Location loc, ArrayRef vals, Value shmemAddr, int idx, VectorType vecTy, - std::optional ctaId) -> SmallVector { + Value ctaId) -> SmallVector { Value loadedVec = targetInfo.loadDShared(rewriter, loc, shmemAddr, ctaId, vecTy, b.true_val()); return unpackLLVector(loc, loadedVec, rewriter); diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp index 1d3ff1e8e1ca..4b368639dcf3 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -493,9 +493,8 @@ struct DirectToLdsLoadConversionBase : public LoadStoreConversionBase { auto lowerInstForwardMulticastMask = [&](RewriterBase &rewriter, Location loc, ArrayRef vals, - Value shmemAddr, int idx, VectorType vecTy, - std::optional ctaId) { - assert(!ctaId.has_value() && "NYI"); + Value shmemAddr, int idx, VectorType vecTy, Value ctaId) { + assert(!ctaId && "NYI"); return lowerInst(rewriter, loc, vals, shmemAddr, idx, vecTy, ctaMulticastMask); }; diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/MemoryOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/MemoryOpToLLVM.cpp index dd7e361b79d6..a135a8847a75 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/MemoryOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/MemoryOpToLLVM.cpp @@ -592,9 +592,8 @@ class LocalLoadPackedTransposedOpConversion assert(cvt.isTrivialOver({kBlock}) && "NYI"); auto lowerInst = [&](RewriterBase &rewriter, Location loc, ArrayRef inVals, Value vecAddr, int idx, - VectorType vTy, - std::optional ctaId) -> SmallVector { - assert(!ctaId.has_value() && "NYI"); + VectorType vTy, Value ctaId) -> SmallVector { + assert(!ctaId && "NYI"); auto numElemsI32 = (vTy.getNumElements() * bitWidth / 32); auto vTyI32 = VectorType::get(numElemsI32, i32_ty); Value dsReadTr = diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp index 38fe0cab2439..c6d7a08d3fb1 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp @@ -183,9 +183,8 @@ void TargetInfo::warpSync(Location loc, RewriterBase &rewriter) const { } void TargetInfo::storeDShared(RewriterBase &rewriter, Location loc, Value ptr, - std::optional ctaId, Value val, - Value pred) const { - if (ctaId.has_value()) { + Value ctaId, Value val, Value pred) const { + if (ctaId) { llvm::report_fatal_error( "AMDGPU does not support cross-CTA shared memory transfers"); } @@ -201,9 +200,9 @@ TargetInfo::queryLDSTransLoadParams(int bitWidth) const { } Value TargetInfo::loadDShared(RewriterBase &rewriter, Location loc, Value ptr, - std::optional ctaId, Type elemTy, - Value pred, Operation *localLoadOp) const { - if (ctaId.has_value()) { + Value ctaId, Type elemTy, Value pred, + Operation *localLoadOp) const { + if (ctaId) { llvm::report_fatal_error( "AMDGPU does not support cross-CTA shared memory transfers"); } diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.h b/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.h index 83a9792a3fbf..dda0625143fc 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.h +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.h @@ -45,10 +45,9 @@ class TargetInfo : public mlir::triton::TargetInfoBase { void warpSync(Location loc, RewriterBase &rewriter) const override; void storeDShared(RewriterBase &rewriter, Location loc, Value ptr, - std::optional ctaId, Value val, - Value pred) const override; + Value ctaId, Value val, Value pred) const override; Value loadDShared(RewriterBase &rewriter, Location loc, Value ptr, - std::optional ctaId, Type elemTy, Value pred, + Value ctaId, Type elemTy, Value pred, Operation *localLoadOp = nullptr) const override; // Describes the parameters of ds_read_tr for a particular data type. diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp index d50dcc9b4105..962c0ed4f6de 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -1037,8 +1037,8 @@ struct AsyncCopyGlobalToLocalOpConversion RewriterBase &rewriter, Location loc, ArrayRef vals, Value shmemAddr, int startIdx, VectorType vecTy, - std::optional ctaId) -> SmallVector { - assert(!ctaId.has_value() && "cp.async does not support cross-cta loads"); + Value ctaId) -> SmallVector { + assert(!ctaId && "cp.async does not support cross-cta loads"); assert(isa(vecTy)); auto *ctx = rewriter.getContext(); auto elemTy = vecTy.getElementType(); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/MemoryOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/MemoryOpToLLVM.cpp index e6fd6b73ad62..fe40e5a72585 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/MemoryOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/MemoryOpToLLVM.cpp @@ -382,8 +382,8 @@ static void lowerAsyncSharedStore(Location loc, MLIRContext *ctx, Value currentCTAId = targetInfo.getClusterCTAId(rewriter, loc); auto emitSt = [&](RewriterBase &, Location storeLoc, ArrayRef values, Value shmemAddr, int idx, VectorType vecTy, - std::optional ctaId) -> SmallVector { - Value targetCTAId = ctaId.value_or(currentCTAId); + Value ctaId) -> SmallVector { + Value targetCTAId = ctaId ? ctaId : currentCTAId; Value dst = mapSharedToCluster(storeLoc, shmemAddr, targetCTAId, rewriter); Value mbarrier = mapSharedToCluster(storeLoc, mbarrierPtr, targetCTAId, rewriter); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp index 956b2d250e28..0dadf29d902c 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp @@ -207,8 +207,7 @@ static std::string getConstraintForBitwidth(unsigned bitwidth) { } static void storeDSharedImpl(RewriterBase &rewriter, Location loc, Value ptr, - std::optional ctaId, Value val, - Value pred) { + Value ctaId, Value val, Value pred) { auto b = TritonLLVMOpBuilder(loc, rewriter); MLIRContext *ctx = rewriter.getContext(); auto ptrTy = cast(ptr.getType()); @@ -297,13 +296,13 @@ static void storeDSharedImpl(RewriterBase &rewriter, Location loc, Value ptr, assert(vec * elemBitwidth <= 128); // Get pointer to remote shared memory if needed. - if (ctaId.has_value()) { - ptr = mapa(rewriter, loc, ptr, *ctaId, pred); + if (ctaId) { + ptr = mapa(rewriter, loc, ptr, ctaId, pred); } PTXBuilder builder; auto st = builder.create("st") - ->o(ctaId.has_value() ? "shared::cluster" : "shared::cta") + ->o(ctaId ? "shared::cluster" : "shared::cta") .v(vec, /*predicate=*/vec > 1) .b(elemBitwidth); auto *ptrOpr = builder.newAddrOperand(ptr, "r"); @@ -328,23 +327,20 @@ static void storeDSharedImpl(RewriterBase &rewriter, Location loc, Value ptr, } void TargetInfo::storeDShared(RewriterBase &rewriter, Location loc, Value ptr, - std::optional ctaId, Value val, - Value pred) const { + Value ctaId, Value val, Value pred) const { if (!ctaId) return storeDSharedImpl(rewriter, loc, ptr, ctaId, val, pred); auto b = TritonLLVMOpBuilder(loc, rewriter); Value currentCtaId = getClusterCTAId(rewriter, loc); - Value isLocal = b.icmp_eq(*ctaId, currentCtaId); - storeDSharedImpl(rewriter, loc, ptr, std::nullopt, val, - b.and_(pred, isLocal)); + Value isLocal = b.icmp_eq(ctaId, currentCtaId); + storeDSharedImpl(rewriter, loc, ptr, Value(), val, b.and_(pred, isLocal)); storeDSharedImpl(rewriter, loc, ptr, ctaId, val, - b.and_(pred, b.icmp_ne(*ctaId, currentCtaId))); + b.and_(pred, b.icmp_ne(ctaId, currentCtaId))); } static Value loadDSharedImpl(RewriterBase &rewriter, Location loc, Value ptr, - std::optional ctaId, Type loadTy, - Value pred) { + Value ctaId, Type loadTy, Value pred) { auto b = TritonLLVMOpBuilder(loc, rewriter); MLIRContext *ctx = rewriter.getContext(); auto ptrTy = cast(ptr.getType()); @@ -432,13 +428,13 @@ static Value loadDSharedImpl(RewriterBase &rewriter, Location loc, Value ptr, assert(vec * elemBitwidth <= 128); // Get pointer to remote shared memory if needed. - if (ctaId.has_value()) { - ptr = mapa(rewriter, loc, ptr, *ctaId, pred); + if (ctaId) { + ptr = mapa(rewriter, loc, ptr, ctaId, pred); } PTXBuilder builder; auto ld = builder.create("ld") - ->o(ctaId.has_value() ? "shared::cluster" : "shared::cta") + ->o(ctaId ? "shared::cluster" : "shared::cta") .v(vec, /*predicate=*/vec > 1) .b(elemBitwidth); @@ -473,18 +469,18 @@ static Value loadDSharedImpl(RewriterBase &rewriter, Location loc, Value ptr, } Value TargetInfo::loadDShared(RewriterBase &rewriter, Location loc, Value ptr, - std::optional ctaId, Type loadTy, - Value pred, Operation *) const { + Value ctaId, Type loadTy, Value pred, + Operation *) const { if (!ctaId) return loadDSharedImpl(rewriter, loc, ptr, ctaId, loadTy, pred); auto b = TritonLLVMOpBuilder(loc, rewriter); Value currentCtaId = getClusterCTAId(rewriter, loc); - Value isLocal = b.icmp_eq(*ctaId, currentCtaId); - Value local = loadDSharedImpl(rewriter, loc, ptr, std::nullopt, loadTy, + Value isLocal = b.icmp_eq(ctaId, currentCtaId); + Value local = loadDSharedImpl(rewriter, loc, ptr, Value(), loadTy, b.and_(pred, isLocal)); Value remote = loadDSharedImpl(rewriter, loc, ptr, ctaId, loadTy, - b.and_(pred, b.icmp_ne(*ctaId, currentCtaId))); + b.and_(pred, b.icmp_ne(ctaId, currentCtaId))); if (!isa(loadTy)) return b.select(isLocal, local, remote); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.h b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.h index 00efb5f5f469..d5de52ee7dbb 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.h +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.h @@ -26,10 +26,9 @@ class TargetInfo : public mlir::triton::TargetInfoBase { void warpSync(Location loc, RewriterBase &rewriter) const override; void storeDShared(RewriterBase &rewriter, Location loc, Value ptr, - std::optional ctaId, Value val, - Value pred) const override; + Value ctaId, Value val, Value pred) const override; Value loadDShared(RewriterBase &rewriter, Location loc, Value ptr, - std::optional ctaId, Type elemTy, Value pred, + Value ctaId, Type elemTy, Value pred, Operation *localLoadOp = nullptr) const override; bool supportLdMatrix() const override { diff --git a/third_party/proton/Dialect/lib/ProtonGPUToLLVM/ProtonAMDGPUToLLVM/AMDPatternProtonGPUOpToLLVM.cpp b/third_party/proton/Dialect/lib/ProtonGPUToLLVM/ProtonAMDGPUToLLVM/AMDPatternProtonGPUOpToLLVM.cpp index 11a21f9f50dd..3e91f4e8b1c7 100644 --- a/third_party/proton/Dialect/lib/ProtonGPUToLLVM/ProtonAMDGPUToLLVM/AMDPatternProtonGPUOpToLLVM.cpp +++ b/third_party/proton/Dialect/lib/ProtonGPUToLLVM/ProtonAMDGPUToLLVM/AMDPatternProtonGPUOpToLLVM.cpp @@ -42,9 +42,9 @@ struct CircularStoreOpConversion mlir::LLVM::AMD::llStore(rewriter, loc, dataPack.ptr, dataPack.record, dataPack.isWriter); } else if (addrSpace == 3) { - targetInfo.getTritonTargetInfo().storeDShared( - rewriter, loc, dataPack.ptr, std::nullopt, dataPack.record, - dataPack.isWriter); + targetInfo.getTritonTargetInfo().storeDShared(rewriter, loc, dataPack.ptr, + Value(), dataPack.record, + dataPack.isWriter); } else { llvm::report_fatal_error("unsupported address space in circular store"); } diff --git a/third_party/proton/Dialect/lib/ProtonGPUToLLVM/ProtonNvidiaGPUToLLVM/NvidiaPatternProtonGPUOpToLLVM.cpp b/third_party/proton/Dialect/lib/ProtonGPUToLLVM/ProtonNvidiaGPUToLLVM/NvidiaPatternProtonGPUOpToLLVM.cpp index b9cf47c25c3a..977e1a274106 100644 --- a/third_party/proton/Dialect/lib/ProtonGPUToLLVM/ProtonNvidiaGPUToLLVM/NvidiaPatternProtonGPUOpToLLVM.cpp +++ b/third_party/proton/Dialect/lib/ProtonGPUToLLVM/ProtonNvidiaGPUToLLVM/NvidiaPatternProtonGPUOpToLLVM.cpp @@ -81,9 +81,9 @@ struct CircularStoreOpConversion builder.launch(rewriter, loc, void_ty(rewriter.getContext())); } } else if (addrSpace == 3) { - targetInfo.getTritonTargetInfo().storeDShared( - rewriter, loc, dataPack.ptr, std::nullopt, dataPack.record, - /*pred=*/dataPack.isWriter); + targetInfo.getTritonTargetInfo().storeDShared(rewriter, loc, dataPack.ptr, + Value(), dataPack.record, + /*pred=*/dataPack.isWriter); } else { llvm::report_fatal_error("unsupported address space in circular store"); } From d9bd4b7cba87ede62d9ee17dc3c0bb45e00d2a12 Mon Sep 17 00:00:00 2001 From: jeffniu-openai Date: Sat, 6 Jun 2026 16:06:37 -0700 Subject: [PATCH 18/22] merge --- lib/Conversion/TritonGPUToLLVM/Utility.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.cpp b/lib/Conversion/TritonGPUToLLVM/Utility.cpp index 808289ed5e68..357cd8ddbcca 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.cpp +++ b/lib/Conversion/TritonGPUToLLVM/Utility.cpp @@ -593,8 +593,7 @@ computeLocalAddrs(Location loc, triton::gpu::MemDescType memDescTy, assert(outputs.size() == 2); auto [offsetName, offset] = outputs[0]; auto [blockName, blockId] = outputs[1]; - assert(offsetName == kOffset); - assert(blockName == kBlock); + assert(offsetName == kOffset && blockName == kBlock); // For subslices, the physical offset is computed as: // physical_offset = L⁻¹(coords) ⊕ L⁻¹(subslice_logical_offset) From ac28f5d9fdffb221c4fad4d5c2207ea258f3a6fe Mon Sep 17 00:00:00 2001 From: jeffniu-openai Date: Sat, 6 Jun 2026 16:10:34 -0700 Subject: [PATCH 19/22] cleanup --- lib/Conversion/TritonGPUToLLVM/Utility.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.cpp b/lib/Conversion/TritonGPUToLLVM/Utility.cpp index aad233ea63d0..9365a77d66c6 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.cpp +++ b/lib/Conversion/TritonGPUToLLVM/Utility.cpp @@ -562,7 +562,7 @@ computeLocalAddrs(Location loc, triton::gpu::MemDescType memDescTy, auto kOffset = str_attr("offset"); auto kBlock = str_attr("block"); - bool useBlockId = invSharedLayout.getOutDimSize(kBlock) > 1; + bool isMultiCTA = invSharedLayout.getOutDimSize(kBlock) > 1; // Get the subslice affine offset (non-zero for memdesc subslices) Value affineOffset = smemObj.getShmemOffset(loc, rewriter, memDescTy); auto bitwidth = getIntOrFloatOrPtrBitWidth(llvmElemTy); @@ -622,7 +622,7 @@ computeLocalAddrs(Location loc, triton::gpu::MemDescType memDescTy, ptr = b.gep(smemObj.getBase().getType(), llvmElemTy, smemObj.getBase(), offset); } - addrs.push_back({ptr, useBlockId ? blockId : Value()}); + addrs.push_back({ptr, isMultiCTA ? blockId : Value()}); } return addrs; From ac6263d625ffcbe4ac22ad84db17b4c442bdb01b Mon Sep 17 00:00:00 2001 From: Jeff Niu Date: Sat, 6 Jun 2026 23:26:31 +0000 Subject: [PATCH 20/22] [NVIDIA] Always map distributed shared accesses --- test/Conversion/tritonnvidiagpu_to_llvm.mlir | 8 +- .../lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp | 80 +++++-------------- 2 files changed, 24 insertions(+), 64 deletions(-) diff --git a/test/Conversion/tritonnvidiagpu_to_llvm.mlir b/test/Conversion/tritonnvidiagpu_to_llvm.mlir index 266734636f83..4c9ace6eb7fd 100644 --- a/test/Conversion/tritonnvidiagpu_to_llvm.mlir +++ b/test/Conversion/tritonnvidiagpu_to_llvm.mlir @@ -616,11 +616,11 @@ module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32, "ttg.tot module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: @local_gather_scatter_two_ctas - // CHECK: ld.shared::cta - // CHECK: ld.shared::cluster + // CHECK: nvvm.mapa + // CHECK: llvm.load {{.*}} : !llvm.ptr<7> -> i32 // CHECK: nvvm.barrier0 - // CHECK: st.shared::cta - // CHECK: st.shared::cluster + // CHECK: nvvm.mapa + // CHECK: llvm.store {{.*}} : vector<1xi32>, !llvm.ptr<7> tt.func @local_gather_scatter_two_ctas(%out: !tt.ptr, %vals: tensor<2x32xi32, #local_gather_scatter_blocked>) { %src = ttg.local_alloc {allocation.offset = [0 : i32, 256 : i32]} : () -> !ttg.memdesc<2x32xi32, #local_gather_scatter_shared, #ttg.shared_memory, mutable> %idx = arith.constant dense<0> : tensor<2x32xi32, #local_gather_scatter_blocked> diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp index 0dadf29d902c..29eac55e5db8 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp @@ -206,16 +206,16 @@ static std::string getConstraintForBitwidth(unsigned bitwidth) { } } -static void storeDSharedImpl(RewriterBase &rewriter, Location loc, Value ptr, - Value ctaId, Value val, Value pred) { +void TargetInfo::storeDShared(RewriterBase &rewriter, Location loc, Value ptr, + Value ctaId, Value val, Value pred) const { auto b = TritonLLVMOpBuilder(loc, rewriter); MLIRContext *ctx = rewriter.getContext(); auto ptrTy = cast(ptr.getType()); assert(ptrTy.getAddressSpace() == 3 && "Invalid addr space for load_dsmem"); if (!isa(val.getType())) { - storeDSharedImpl(rewriter, loc, ptr, ctaId, - packLLVector(loc, {val}, rewriter), pred); + storeDShared(rewriter, loc, ptr, ctaId, packLLVector(loc, {val}, rewriter), + pred); return; } @@ -232,8 +232,8 @@ static void storeDSharedImpl(RewriterBase &rewriter, Location loc, Value ptr, for (Value &v : vals) { v = b.zext(int_ty(8), b.bitcast(v, int_ty(elemBitwidth))); } - storeDSharedImpl(rewriter, loc, ptr, ctaId, - packLLVector(loc, vals, rewriter), pred); + storeDShared(rewriter, loc, ptr, ctaId, packLLVector(loc, vals, rewriter), + pred); return; } @@ -246,8 +246,8 @@ static void storeDSharedImpl(RewriterBase &rewriter, Location loc, Value ptr, v = b.bitcast(v, int_ty(elemBitwidth)); } } - storeDSharedImpl(rewriter, loc, ptr, ctaId, - packLLVector(loc, vals, rewriter), pred); + storeDShared(rewriter, loc, ptr, ctaId, packLLVector(loc, vals, rewriter), + pred); return; } @@ -267,8 +267,8 @@ static void storeDSharedImpl(RewriterBase &rewriter, Location loc, Value ptr, rewriter); newVals.push_back(b.bitcast(v, i32_ty)); } - storeDSharedImpl(rewriter, loc, ptr, ctaId, - packLLVector(loc, newVals, rewriter), pred); + storeDShared(rewriter, loc, ptr, ctaId, + packLLVector(loc, newVals, rewriter), pred); return; } @@ -281,7 +281,7 @@ static void storeDSharedImpl(RewriterBase &rewriter, Location loc, Value ptr, for (int i = 0; i < vec / maxVec; i++) { auto newPtr = b.gep(ptr.getType(), elemTy, ptr, b.i32_val(i * maxVec), LLVM::GEPNoWrapFlags::inbounds); - storeDSharedImpl( + storeDShared( rewriter, loc, newPtr, ctaId, packLLVector(loc, ArrayRef(vals).slice(i * maxVec, maxVec), rewriter), pred); @@ -326,21 +326,9 @@ static void storeDSharedImpl(RewriterBase &rewriter, Location loc, Value ptr, } } -void TargetInfo::storeDShared(RewriterBase &rewriter, Location loc, Value ptr, - Value ctaId, Value val, Value pred) const { - if (!ctaId) - return storeDSharedImpl(rewriter, loc, ptr, ctaId, val, pred); - - auto b = TritonLLVMOpBuilder(loc, rewriter); - Value currentCtaId = getClusterCTAId(rewriter, loc); - Value isLocal = b.icmp_eq(ctaId, currentCtaId); - storeDSharedImpl(rewriter, loc, ptr, Value(), val, b.and_(pred, isLocal)); - storeDSharedImpl(rewriter, loc, ptr, ctaId, val, - b.and_(pred, b.icmp_ne(ctaId, currentCtaId))); -} - -static Value loadDSharedImpl(RewriterBase &rewriter, Location loc, Value ptr, - Value ctaId, Type loadTy, Value pred) { +Value TargetInfo::loadDShared(RewriterBase &rewriter, Location loc, Value ptr, + Value ctaId, Type loadTy, Value pred, + Operation *) const { auto b = TritonLLVMOpBuilder(loc, rewriter); MLIRContext *ctx = rewriter.getContext(); auto ptrTy = cast(ptr.getType()); @@ -348,8 +336,7 @@ static Value loadDSharedImpl(RewriterBase &rewriter, Location loc, Value ptr, if (!isa(loadTy)) { SmallVector values = unpackLLVector( - loc, - loadDSharedImpl(rewriter, loc, ptr, ctaId, vec_ty(loadTy, 1), pred), + loc, loadDShared(rewriter, loc, ptr, ctaId, vec_ty(loadTy, 1), pred), rewriter); assert(values.size() == 1); return values[0]; @@ -365,8 +352,7 @@ static Value loadDSharedImpl(RewriterBase &rewriter, Location loc, Value ptr, assert(vec == 1 && "don't know how to load/store vectors of sub-byte elems"); SmallVector vals = unpackLLVector( - loc, loadDSharedImpl(rewriter, loc, ptr, ctaId, int_ty(8), pred), - rewriter); + loc, loadDShared(rewriter, loc, ptr, ctaId, int_ty(8), pred), rewriter); assert(vals.size() == 1); return b.bitcast(b.trunc(int_ty(elemBitwidth), vals[0]), elemTy); } @@ -375,8 +361,7 @@ static Value loadDSharedImpl(RewriterBase &rewriter, Location loc, Value ptr, if (!elemTy.isInteger()) { Type newLoadTy = vec_ty(int_ty(elemBitwidth), vec); SmallVector vals = unpackLLVector( - loc, loadDSharedImpl(rewriter, loc, ptr, ctaId, newLoadTy, pred), - rewriter); + loc, loadDShared(rewriter, loc, ptr, ctaId, newLoadTy, pred), rewriter); for (Value &v : vals) { v = b.bitcast(v, elemTy); } @@ -390,7 +375,7 @@ static Value loadDSharedImpl(RewriterBase &rewriter, Location loc, Value ptr, if (vec > 4 && elemBitwidth < 32) { int newVec = vec / (32 / elemBitwidth); auto newVecTy = vec_ty(i32_ty, newVec); - auto res = loadDSharedImpl(rewriter, loc, ptr, ctaId, newVecTy, pred); + auto res = loadDShared(rewriter, loc, ptr, ctaId, newVecTy, pred); // Unpack the b32's into the original vector type. SmallVector vals; @@ -412,8 +397,8 @@ static Value loadDSharedImpl(RewriterBase &rewriter, Location loc, Value ptr, for (int i = 0; i < vec / maxVec; i++) { auto newPtr = b.gep(ptr.getType(), elemTy, ptr, b.i32_val(i * maxVec), LLVM::GEPNoWrapFlags::inbounds); - auto newVal = loadDSharedImpl(rewriter, loc, newPtr, ctaId, - vec_ty(elemTy, maxVec), pred); + auto newVal = loadDShared(rewriter, loc, newPtr, ctaId, + vec_ty(elemTy, maxVec), pred); for (Value v : unpackLLVector(loc, newVal, rewriter)) { vals.push_back(v); } @@ -468,31 +453,6 @@ static Value loadDSharedImpl(RewriterBase &rewriter, Location loc, Value ptr, return packLLVector(loc, resultVals, rewriter); } -Value TargetInfo::loadDShared(RewriterBase &rewriter, Location loc, Value ptr, - Value ctaId, Type loadTy, Value pred, - Operation *) const { - if (!ctaId) - return loadDSharedImpl(rewriter, loc, ptr, ctaId, loadTy, pred); - - auto b = TritonLLVMOpBuilder(loc, rewriter); - Value currentCtaId = getClusterCTAId(rewriter, loc); - Value isLocal = b.icmp_eq(ctaId, currentCtaId); - Value local = loadDSharedImpl(rewriter, loc, ptr, Value(), loadTy, - b.and_(pred, isLocal)); - Value remote = loadDSharedImpl(rewriter, loc, ptr, ctaId, loadTy, - b.and_(pred, b.icmp_ne(ctaId, currentCtaId))); - if (!isa(loadTy)) - return b.select(isLocal, local, remote); - - SmallVector selected; - for (auto [localVal, remoteVal] : - llvm::zip(unpackLLVector(loc, local, rewriter), - unpackLLVector(loc, remote, rewriter))) { - selected.push_back(b.select(isLocal, localVal, remoteVal)); - } - return packLLVector(loc, selected, rewriter); -} - Value TargetInfo::shuffleXor(RewriterBase &rewriter, Location loc, Value val, int i) const { return LLVM::NVIDIA::shuffleXor(loc, rewriter, val, i); From 32f914d0745e0ed3e0931335c08ed308a096d518 Mon Sep 17 00:00:00 2001 From: Jeff Niu Date: Sat, 6 Jun 2026 23:54:52 +0000 Subject: [PATCH 21/22] [GPUToLLVM] Lookup local address outputs by name --- lib/Conversion/TritonGPUToLLVM/Utility.cpp | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.cpp b/lib/Conversion/TritonGPUToLLVM/Utility.cpp index 9365a77d66c6..d9f98c3a3303 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.cpp +++ b/lib/Conversion/TritonGPUToLLVM/Utility.cpp @@ -590,10 +590,16 @@ computeLocalAddrs(Location loc, triton::gpu::MemDescType memDescTy, inputs.push_back({allDims[dim], indices[dim]}); auto outputs = applyLinearLayout(loc, rewriter, invSharedLayout, inputs); - assert(outputs.size() == 2); - auto [offsetName, offset] = outputs[0]; - auto [blockName, blockId] = outputs[1]; - assert(offsetName == kOffset && blockName == kBlock); + Value offset; + Value blockId; + for (auto [name, value] : outputs) { + if (name == kOffset) + offset = value; + else if (name == kBlock) + blockId = value; + } + assert(offset && "expected offset output from inverted shared layout"); + assert(blockId && "expected block output from inverted shared layout"); // For subslices, the physical offset is computed as: // physical_offset = L⁻¹(coords) ⊕ L⁻¹(subslice_logical_offset) From a54b8e9e1650f0bc9c5466cb0315c9d3565e1418 Mon Sep 17 00:00:00 2001 From: Jeff Niu Date: Sun, 7 Jun 2026 00:06:04 +0000 Subject: [PATCH 22/22] [NVIDIA] Relax local gather barrier check --- lib/Conversion/TritonGPUToLLVM/Utility.cpp | 14 ++++---------- test/Conversion/tritonnvidiagpu_to_llvm.mlir | 2 +- 2 files changed, 5 insertions(+), 11 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.cpp b/lib/Conversion/TritonGPUToLLVM/Utility.cpp index d9f98c3a3303..9365a77d66c6 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.cpp +++ b/lib/Conversion/TritonGPUToLLVM/Utility.cpp @@ -590,16 +590,10 @@ computeLocalAddrs(Location loc, triton::gpu::MemDescType memDescTy, inputs.push_back({allDims[dim], indices[dim]}); auto outputs = applyLinearLayout(loc, rewriter, invSharedLayout, inputs); - Value offset; - Value blockId; - for (auto [name, value] : outputs) { - if (name == kOffset) - offset = value; - else if (name == kBlock) - blockId = value; - } - assert(offset && "expected offset output from inverted shared layout"); - assert(blockId && "expected block output from inverted shared layout"); + assert(outputs.size() == 2); + auto [offsetName, offset] = outputs[0]; + auto [blockName, blockId] = outputs[1]; + assert(offsetName == kOffset && blockName == kBlock); // For subslices, the physical offset is computed as: // physical_offset = L⁻¹(coords) ⊕ L⁻¹(subslice_logical_offset) diff --git a/test/Conversion/tritonnvidiagpu_to_llvm.mlir b/test/Conversion/tritonnvidiagpu_to_llvm.mlir index 4c9ace6eb7fd..f035fe9dfb5a 100644 --- a/test/Conversion/tritonnvidiagpu_to_llvm.mlir +++ b/test/Conversion/tritonnvidiagpu_to_llvm.mlir @@ -618,7 +618,7 @@ module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr // CHECK-LABEL: @local_gather_scatter_two_ctas // CHECK: nvvm.mapa // CHECK: llvm.load {{.*}} : !llvm.ptr<7> -> i32 - // CHECK: nvvm.barrier0 + // CHECK: nvvm.barrier // CHECK: nvvm.mapa // CHECK: llvm.store {{.*}} : vector<1xi32>, !llvm.ptr<7> tt.func @local_gather_scatter_two_ctas(%out: !tt.ptr, %vals: tensor<2x32xi32, #local_gather_scatter_blocked>) {