Skip to content

Commit c155e4a

Browse files
authored
[BACKEND] Support generic multi-cta convert_layouts (#9317)
Stacked PRs: * #9327 * #9318 * __->__#9317 --- --- --- ### [BACKEND] Support generic multi-cta convert_layouts We generalise the swizzling algorithm to work with blocks and generalise the most of the memory lowerings to support layouts with blocks. We remove the legacy lowering. The generic swizzling algorithm for blocks might be fine, but we didn't try to be super clever. There might be some perf left on the table. We can look into this at a later point if it becomes relevant. We also activate multi-cta reductions in the process and test both there. TODO: Add some funky tests that just test the `convert_layout`, not the `convert_layout` within the reduction. TODO: Check how to perform multiCTA barriers in AMD and perhaps merge cluster barriers into ttg.barrier, predicate broadcasting blocks, etc.
1 parent 2a41426 commit c155e4a

29 files changed

Lines changed: 426 additions & 439 deletions

File tree

include/triton/Analysis/Utility.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -419,8 +419,8 @@ template <typename T> class CallGraph {
419419
// Create a basic DataFlowSolver with constant and dead code analysis included.
420420
std::unique_ptr<DataFlowSolver> createDataFlowSolver();
421421

422-
bool isCvtWarpSync(const triton::LinearLayout &srcLayout,
423-
const triton::LinearLayout &dstLayout);
422+
bool isCvtDimSync(const triton::LinearLayout &srcLayout,
423+
const triton::LinearLayout &dstLayout, StringAttr dim);
424424

425425
} // namespace mlir
426426

include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ class TargetInfoBase {
2020
// target address space
2121
virtual void barrier(Location loc, RewriterBase &rewriter,
2222
triton::gpu::AddrSpace targets) const = 0;
23+
// Emit a cluster-level barrier when supported. Defaults to CTA barrier.
24+
virtual void clusterBarrier(Location loc, RewriterBase &rewriter) const = 0;
2325
// Insert a warp syncronization barrier that also guarantees local address
2426
// space visibility at warp level when supported by the backend.
2527
// Backends that do not support warp-level barriers should conservatively

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
#include "triton/Tools/StrUtil.h"
1616
#include "llvm/ADT/STLExtras.h"
1717

18+
#include <optional>
19+
1820
#define DEBUG_TYPE "ttgpu_to_llvm"
1921
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
2022
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
@@ -331,7 +333,7 @@ namespace triton {
331333
namespace gpu {
332334

333335
std::pair<SmallVector<LocalMemOpTile>, SmallVector<LocalMemOpTile>>
334-
getSrcDstTiles(const TargetInfoBase &targetInfo, int bitwidth);
336+
getSrcDstTiles(const TargetInfoBase &targetInfo, int bitwidth, bool crossCTA);
335337

336338
Type getFunctionType(Type resultType, ValueRange operands);
337339

@@ -567,17 +569,18 @@ lowerLdStShared(Location loc, MLIRContext *ctx, LinearLayout cvt,
567569
// calcPaddedOffset is a lambda that takes a base offset (mlir::Value)
568570
// and computes a new offset (mlir::Value) by applying padding based on
569571
// shared memory layout.
570-
SmallVector<Value> lowerLdSt(
571-
Location loc, MLIRContext *ctx, LinearLayout cvt,
572-
ArrayRef<Value> valsArray, // Input for store, output for load
573-
Type llvmElemTy, Value smemBase,
574-
ArrayRef<std::pair<unsigned, unsigned>> paddingShifts, Value affineOffset,
575-
uint64_t maskSpanAffineOffset, Value laneId, Value warpId,
576-
RewriterBase &rewriter, const TargetInfoBase &targetInfo,
577-
std::optional<int> maybeMaxVecElems,
578-
std::function<SmallVector<Value>(RewriterBase &, Location, ArrayRef<Value>,
579-
Value, int, VectorType)>
580-
lowerInst);
572+
SmallVector<Value>
573+
lowerLdSt(Location loc, MLIRContext *ctx, LinearLayout cvt,
574+
ArrayRef<Value> valsArray, // Input for store, output for load
575+
Type llvmElemTy, Value smemBase,
576+
ArrayRef<std::pair<unsigned, unsigned>> paddingShifts,
577+
Value affineOffset, uint64_t maskSpanAffineOffset, Value laneId,
578+
Value warpId, RewriterBase &rewriter,
579+
const TargetInfoBase &targetInfo, std::optional<int> maybeMaxVecElems,
580+
std::function<SmallVector<Value>(RewriterBase &, Location,
581+
ArrayRef<Value>, Value, int,
582+
VectorType, std::optional<Value>)>
583+
lowerInst);
581584

582585
// Lower local_load/local_store via ld.shared/st.shared
583586
SmallVector<Value>
@@ -619,10 +622,10 @@ void makeAllWarpGroupsIsolatedFromAbove(Operation *op);
619622
// Set the correct loop annotation on LLVM branch ops.
620623
void fixUpLoopAnnotation(ModuleOp mod);
621624

622-
void transferWithinBlockSwizzling(triton::gpu::ConvertLayoutOp op, Value src,
623-
const TargetInfoBase &targetInfo,
624-
const LLVMTypeConverter *typeConverter,
625-
RewriterBase &rewriter);
625+
void transferSwizzlingLocalMem(triton::gpu::ConvertLayoutOp op, Value src,
626+
const TargetInfoBase &targetInfo,
627+
const LLVMTypeConverter *typeConverter,
628+
RewriterBase &rewriter);
626629

627630
SmallVector<Value> inlineRegionImpl(RewriterBase &rewriter, Region &region,
628631
ArrayRef<Value> args,

lib/Analysis/Allocation.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,11 @@ unsigned getNumScratchElemsSwizzledCvt(const LinearLayout &srcLayout,
4949
auto smem = gpu::optimalSwizzlingLdSt(srcLayoutNoBroadcast,
5050
dstLayoutNoBroadcast, bitwidth);
5151
auto reps = smem.getInDimSize(StringAttr::get(ctx, "reps"));
52-
return smem.getTotalOutDimSize() / reps;
52+
// The smem has the same cta layout as the srcLayout, so we use that instead
53+
// We remove the number of elements that are duplicated in the cta layout
54+
auto nBlocks = product(triton::gpu::getCTASplitNum(
55+
gpu::LinearEncodingAttr::get(ctx, srcLayout)));
56+
return smem.getTotalOutDimSize() / (reps * nBlocks);
5357
}
5458

5559
unsigned getNumScratchElemsSwizzledCvt(RankedTensorType srcTy,

lib/Analysis/Membar.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,8 @@ void MembarAnalysis::update(Operation *op, BlockInfo *blockInfo,
324324
auto dstTy = cast<RankedTensorType>(cvt.getType());
325325
auto srcLayout = triton::gpu::toLinearLayout(srcTy);
326326
auto dstLayout = triton::gpu::toLinearLayout(dstTy);
327-
isWarpSync = mlir::isCvtWarpSync(srcLayout, dstLayout);
327+
auto kWarp = StringAttr::get(op->getContext(), "warp");
328+
isWarpSync = mlir::isCvtDimSync(srcLayout, dstLayout, kWarp);
328329
}
329330

330331
if (!curBlockInfo.syncReadSlices.empty() ||

lib/Analysis/Utility.cpp

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1421,17 +1421,23 @@ std::unique_ptr<DataFlowSolver> createDataFlowSolver() {
14211421
return solver;
14221422
}
14231423

1424-
bool isCvtWarpSync(const triton::LinearLayout &srcLayout,
1425-
const triton::LinearLayout &dstLayout) {
1426-
// We can use warp.sync when the warp dimension in the convert is trival
1427-
// and there is no broadcasting at a warp level (otherwise reads may be
1428-
// wrong)
1424+
bool isCvtDimSync(const triton::LinearLayout &srcLayout,
1425+
const triton::LinearLayout &dstLayout, StringAttr dim) {
1426+
// We can use a dimension-level sync when the conversion is trivial over that
1427+
// dimension and there is no broadcasting over it.
14291428
auto *ctx = srcLayout.getInDimNames().begin()->getContext();
1430-
auto comp = dstLayout.invertAndCompose(srcLayout);
14311429
auto kWarp = StringAttr::get(ctx, "warp");
1432-
return comp.isTrivialOver(kWarp) &&
1433-
srcLayout.getFreeVariableMasks()[kWarp] == 0 &&
1434-
dstLayout.getFreeVariableMasks()[kWarp] == 0;
1430+
auto kBlock = StringAttr::get(ctx, "block");
1431+
assert((dim == kWarp || dim == kBlock) && "expected dim to be warp or block");
1432+
assert(srcLayout.hasInDim(dim) && dstLayout.hasInDim(dim) &&
1433+
"expected dim to be present in both layouts");
1434+
auto parentTrivial = true;
1435+
if (dim == kWarp) {
1436+
parentTrivial = isCvtDimSync(srcLayout, dstLayout, kBlock);
1437+
}
1438+
auto comp = dstLayout.invertAndCompose(srcLayout);
1439+
return parentTrivial && comp.isTrivialOver(dim) &&
1440+
srcLayout.getFreeVariableMasks()[dim] == 0 &&
1441+
dstLayout.getFreeVariableMasks()[dim] == 0;
14351442
}
1436-
14371443
} // namespace mlir

lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 50 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h"
55
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
66

7+
#include <optional>
8+
79
#include "triton/Analysis/Allocation.h"
810
#include "triton/Dialect/Triton/IR/Types.h"
911
#include "triton/Dialect/Triton/IR/Utility.h"
@@ -45,26 +47,20 @@ struct ConvertLayoutOpConversion
4547
LinearLayout srcLayout = toLinearLayout(srcTy);
4648
LinearLayout dstLayout = toLinearLayout(dstTy);
4749

48-
StringAttr kBlock = str_attr("block");
49-
StringAttr kWarp = str_attr("warp");
50-
StringAttr kLane = str_attr("lane");
51-
StringAttr kRegister = str_attr("register");
50+
auto kBlock = str_attr("block");
51+
auto kWarp = str_attr("warp");
52+
auto kLane = str_attr("lane");
53+
auto kRegister = str_attr("register");
5254

5355
auto dims = conversion.getInDimNames();
5456
bool alwaysUseWarpShuffle = cvtAlwaysUseWarpShuffle(op);
55-
assert(!alwaysUseWarpShuffle || (!llvm::is_contained(dims, kBlock) &&
56-
!llvm::is_contained(dims, kWarp)));
5757
assert(to_vector(conversion.getInDimNames()) ==
5858
to_vector(conversion.getOutDimNames()));
59-
if (llvm::is_contained(dims, kBlock)) {
60-
// Case 1: Transfer between values in different CTAs.
61-
// This requires moving values through distributed shared memory.
62-
return rewriter.notifyMatchFailure(
63-
op, "NYI: Transfer between different CTAs");
64-
} else if (llvm::is_contained(dims, kWarp)) {
65-
// Case 2: Transfer between values in the same CTA, in which case we move
66-
// values through shared memory.
67-
transferWithinBlockSwizzling(op, adaptor.getSrc(), rewriter);
59+
if (llvm::is_contained(dims, kBlock) || llvm::is_contained(dims, kWarp)) {
60+
assert(!alwaysUseWarpShuffle);
61+
// Transfer between values in the same CTA, or across CTAs. We move values
62+
// through (distributed) shared memory.
63+
transferSwizzlingLocalMem(op, adaptor.getSrc(), rewriter);
6864
return success();
6965
} else if (llvm::is_contained(dims, kLane)) {
7066
// Case 3. Transfer between values in the same warp, in which case we try
@@ -73,7 +69,7 @@ struct ConvertLayoutOpConversion
7369
if (cvtNeedsWarpShuffle(srcTy, dstTy) || alwaysUseWarpShuffle)
7470
return transferWithinWarp(op, adaptor, rewriter);
7571

76-
transferWithinBlockSwizzling(op, adaptor.getSrc(), rewriter);
72+
transferSwizzlingLocalMem(op, adaptor.getSrc(), rewriter);
7773
return success();
7874
} else if (llvm::is_contained(dims, kRegister)) {
7975
// Case 4. Transfer between values in the same thread, in which case we
@@ -93,7 +89,7 @@ struct ConvertLayoutOpConversion
9389
ConversionPatternRewriter &rewriter) const {
9490
MLIRContext *ctx = op.getContext();
9591
auto loc = op.getLoc();
96-
StringAttr kRegister = str_attr("register");
92+
auto kRegister = str_attr("register");
9793
assert(!cvtNeedsSharedMemory(op.getSrc().getType(), op.getType()));
9894

9995
auto srcTy = op.getSrc().getType();
@@ -110,7 +106,7 @@ struct ConvertLayoutOpConversion
110106
return success();
111107
}
112108

113-
SmallVector<Value> transferWithinBlockSwizzlingImpl(
109+
SmallVector<Value> transferSwizzlingLocalMemImpl(
114110
Location loc, ConversionPatternRewriter &rewriter,
115111
const LinearLayout &srcLayout, const LinearLayout &dstLayout,
116112
ArrayRef<Value> inVals, Type llvmElemTy, Value smemBase) const {
@@ -126,8 +122,8 @@ struct ConvertLayoutOpConversion
126122
return b.ptrtoint(llvmElemTyPtr, v).getResult();
127123
}));
128124
auto outVals =
129-
transferWithinBlockSwizzlingImpl(loc, rewriter, srcLayout, dstLayout,
130-
newInVals, llvmElemTyPtr, smemBase);
125+
transferSwizzlingLocalMemImpl(loc, rewriter, srcLayout, dstLayout,
126+
newInVals, llvmElemTyPtr, smemBase);
131127
for (auto &v : outVals) {
132128
v = b.inttoptr(llvmElemTy, v);
133129
}
@@ -140,7 +136,7 @@ struct ConvertLayoutOpConversion
140136
auto i8ElemTy = i8_ty;
141137
auto newInVals = llvm::to_vector(llvm::map_range(
142138
inVals, [&](Value v) { return b.zext(i8ElemTy, v).getResult(); }));
143-
auto outVals = transferWithinBlockSwizzlingImpl(
139+
auto outVals = transferSwizzlingLocalMemImpl(
144140
loc, rewriter, srcLayout, dstLayout, newInVals, i8ElemTy, smemBase);
145141
for (auto &v : outVals) {
146142
v = b.trunc(llvmElemTy, v);
@@ -153,15 +149,15 @@ struct ConvertLayoutOpConversion
153149
if (!removeBroadcastSrc.isIdentity()) {
154150
auto prmtSrc = removeBroadcastSrc.apply(srcLayout);
155151
auto newInVals = removeBroadcastSrc.apply(inVals);
156-
return transferWithinBlockSwizzlingImpl(loc, rewriter, prmtSrc, dstLayout,
157-
newInVals, llvmElemTy, smemBase);
152+
return transferSwizzlingLocalMemImpl(loc, rewriter, prmtSrc, dstLayout,
153+
newInVals, llvmElemTy, smemBase);
158154
}
159155

160156
// Remove broadcasting in dst
161157
auto removeBroadcastDst = actionRemoveBroadcastedRegs(dstLayout);
162158
if (!removeBroadcastDst.isIdentity()) {
163159
auto prmtDst = removeBroadcastDst.apply(dstLayout);
164-
auto outVals = transferWithinBlockSwizzlingImpl(
160+
auto outVals = transferSwizzlingLocalMemImpl(
165161
loc, rewriter, srcLayout, prmtDst, inVals, llvmElemTy, smemBase);
166162
return broadcastAs(outVals, dstLayout);
167163
}
@@ -173,6 +169,8 @@ struct ConvertLayoutOpConversion
173169

174170
// Extract reps from smem
175171
auto kReg = str_attr("register");
172+
auto kWarp = str_attr("warp");
173+
auto kBlock = str_attr("block");
176174
auto kReps = str_attr("reps");
177175
auto nReps = smem.getInDimSize(kReps);
178176
auto reps = LinearLayout::identity1D(nReps, kReg, kReps);
@@ -194,8 +192,11 @@ struct ConvertLayoutOpConversion
194192
auto storeCvt = *divideRight(totalStoreCvt, reps);
195193
auto loadCvt = *divideRight(totalLoadCvt, reps);
196194
auto kOffset = str_attr("offset");
197-
storeCvt = storeCvt.reshapeOuts({{kOffset, storeCvt.getTotalOutDimSize()}});
198-
loadCvt = loadCvt.reshapeOuts({{kOffset, loadCvt.getTotalOutDimSize()}});
195+
auto nBlock = storeCvt.getInDimSize(kBlock);
196+
storeCvt = storeCvt.reshapeOuts(
197+
{{kOffset, storeCvt.getTotalOutDimSize() / nBlock}, {kBlock, nBlock}});
198+
loadCvt = loadCvt.reshapeOuts(
199+
{{kOffset, loadCvt.getTotalOutDimSize() / nBlock}, {kBlock, nBlock}});
199200

200201
auto tileSize = storeCvt.getInDimSize(kReg);
201202

@@ -204,28 +205,30 @@ struct ConvertLayoutOpConversion
204205
auto affineOffset = b.i32_val(0);
205206
auto maskSpanAffineOffset = 0;
206207

207-
bool isWarpSync = mlir::isCvtWarpSync(srcLayout, dstLayout);
208-
for (int i = 0; i < nReps; ++i) {
209-
if (i > 0) {
210-
if (isWarpSync) {
211-
targetInfo.warpSync(loc, rewriter);
212-
} else {
213-
targetInfo.barrier(loc, rewriter, triton::gpu::AddrSpace::Local);
214-
}
208+
bool isWarpSync = mlir::isCvtDimSync(srcLayout, dstLayout, kWarp);
209+
bool isBlockSync = mlir::isCvtDimSync(srcLayout, dstLayout, kBlock);
210+
auto emitBarrier = [&]() {
211+
if (isWarpSync) {
212+
targetInfo.warpSync(loc, rewriter);
213+
} else if (isBlockSync) {
214+
targetInfo.barrier(loc, rewriter, triton::gpu::AddrSpace::Local);
215+
} else {
216+
targetInfo.clusterBarrier(loc, rewriter);
215217
}
218+
};
219+
220+
for (int i = 0; i < nReps; ++i) {
221+
if (i > 0)
222+
emitBarrier();
216223
auto tileInVals =
217224
ArrayRef<Value>(permutedInVals).slice(i * tileSize, tileSize);
218225
// Store
219226
lowerLdStShared(loc, ctx, storeCvt, tileInVals, llvmElemTy, smemBase,
220227
/*paddingShifts=*/{}, affineOffset, maskSpanAffineOffset,
221228
rewriter, targetInfo);
222-
if (isWarpSync) {
223-
targetInfo.warpSync(loc, rewriter);
224-
} else {
225-
targetInfo.barrier(loc, rewriter, triton::gpu::AddrSpace::Local);
226-
}
229+
emitBarrier();
227230
// Load
228-
SmallVector<Value> tileOutVals = lowerLdStShared(
231+
auto tileOutVals = lowerLdStShared(
229232
loc, ctx, loadCvt, {}, llvmElemTy, smemBase, /*paddingShifts=*/{},
230233
affineOffset, maskSpanAffineOffset, rewriter, targetInfo);
231234
llvm::append_range(outVals, tileOutVals);
@@ -236,30 +239,21 @@ struct ConvertLayoutOpConversion
236239
return outVals;
237240
}
238241

239-
void transferWithinBlockSwizzling(ConvertLayoutOp op, Value src,
240-
ConversionPatternRewriter &rewriter) const {
242+
void transferSwizzlingLocalMem(ConvertLayoutOp op, Value src,
243+
ConversionPatternRewriter &rewriter) const {
241244
auto loc = op.getLoc();
242245
auto *ctx = op.getContext();
243246
auto srcTy = op.getSrc().getType();
244247
auto dstTy = op.getType();
245248

246-
// Remove the kBlock dimension from the layout as it's the identity in the
247-
// cvt
248249
auto srcLayout = toLinearLayout(srcTy);
249250
auto dstLayout = toLinearLayout(dstTy);
250-
auto kReg = str_attr("register");
251-
auto kLane = str_attr("lane");
252-
auto kWarp = str_attr("warp");
253-
srcLayout = srcLayout.sublayout({kReg, kLane, kWarp},
254-
to_vector(srcLayout.getOutDimNames()));
255-
dstLayout = dstLayout.sublayout({kReg, kLane, kWarp},
256-
to_vector(dstLayout.getOutDimNames()));
257251

258252
auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType());
259253
auto smemBase =
260254
LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, op.getOperation());
261255
auto inVals = unpackLLElements(loc, src, rewriter);
262-
auto outVals = transferWithinBlockSwizzlingImpl(
256+
auto outVals = transferSwizzlingLocalMemImpl(
263257
loc, rewriter, srcLayout, dstLayout, inVals, llvmElemTy, smemBase);
264258

265259
Value result =
@@ -276,8 +270,8 @@ struct ConvertLayoutOpConversion
276270
auto b = TritonLLVMOpBuilder(loc, rewriter);
277271
auto srcTy = op.getSrc().getType();
278272
auto dstTy = op.getType();
279-
StringAttr kReg = str_attr("register");
280-
StringAttr kLane = str_attr("lane");
273+
auto kReg = str_attr("register");
274+
auto kLane = str_attr("lane");
281275
auto elemTy = getTypeConverter()->convertType(srcTy.getElementType());
282276
int bitwidth = getIntOrFloatOrPtrBitWidth(elemTy);
283277

@@ -434,8 +428,8 @@ struct ConvertLayoutOpConversion
434428
ArrayRef<TranspositionInfo> mixedTranspositions) const {
435429
auto *ctx = rewriter.getContext();
436430
auto b = TritonLLVMOpBuilder(loc, rewriter);
437-
StringAttr kReg = str_attr("register");
438-
StringAttr kLane = str_attr("lane");
431+
auto kReg = str_attr("register");
432+
auto kLane = str_attr("lane");
439433

440434
SmallVector<Value> vals(inVals.begin(), inVals.end());
441435
int m = mixedTranspositions.size();

0 commit comments

Comments
 (0)