Skip to content

Commit 57a329d

Browse files
committed
[BACKEND] Support generic multi-cta convert_layouts
We generalise the swizzling algorithm to work with blocks and generalise the most of the memory lowerings to support layouts with blocks. We remove the legacy lowering. The generic swizzling algorithm for blocks might be fine, but we didn't try to be super clever. There might be some perf left on the table. We can look into this at a later point if it becomes relevant. We also activate multi-cta reductions in the process and test both there. TODO: Add some funky tests that just test the `convert_layout`, not the `convert_layout` within the reduction. TODO: Check how to perform multiCTA barriers in AMD and perhaps merge cluster barriers into ttg.barrier, predicate broadcasting blocks, etc. stack-info: PR: #9317, branch: lezcano/stack/9
1 parent adc76aa commit 57a329d

29 files changed

Lines changed: 424 additions & 437 deletions

File tree

include/triton/Analysis/Utility.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -417,8 +417,8 @@ template <typename T> class CallGraph {
417417
// Create a basic DataFlowSolver with constant and dead code analysis included.
418418
std::unique_ptr<DataFlowSolver> createDataFlowSolver();
419419

420-
bool isCvtWarpSync(const triton::LinearLayout &srcLayout,
421-
const triton::LinearLayout &dstLayout);
420+
bool isCvtDimSync(const triton::LinearLayout &srcLayout,
421+
const triton::LinearLayout &dstLayout, StringAttr dim);
422422

423423
} // namespace mlir
424424

include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ class TargetInfoBase {
2020
// target address space
2121
virtual void barrier(Location loc, RewriterBase &rewriter,
2222
triton::gpu::AddrSpace targets) const = 0;
23+
// Emit a cluster-level barrier when supported. Defaults to CTA barrier.
24+
virtual void clusterBarrier(Location loc, RewriterBase &rewriter) const = 0;
2325
// Insert a warp syncronization barrier that also guarantees local address
2426
// space visibility at warp level when supported by the backend.
2527
// Backends that do not support warp-level barriers should conservatively

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
#include "triton/Tools/StrUtil.h"
1616
#include "llvm/ADT/STLExtras.h"
1717

18+
#include <optional>
19+
1820
#define DEBUG_TYPE "ttgpu_to_llvm"
1921
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
2022
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
@@ -331,7 +333,7 @@ namespace triton {
331333
namespace gpu {
332334

333335
std::pair<SmallVector<LocalMemOpTile>, SmallVector<LocalMemOpTile>>
334-
getSrcDstTiles(const TargetInfoBase &targetInfo, int bitwidth);
336+
getSrcDstTiles(const TargetInfoBase &targetInfo, int bitwidth, bool crossCTA);
335337

336338
Type getFunctionType(Type resultType, ValueRange operands);
337339

@@ -561,17 +563,18 @@ lowerLdStShared(Location loc, MLIRContext *ctx, LinearLayout cvt,
561563
// calcPaddedOffset is a lambda that takes a base offset (mlir::Value)
562564
// and computes a new offset (mlir::Value) by applying padding based on
563565
// shared memory layout.
564-
SmallVector<Value> lowerLdSt(
565-
Location loc, MLIRContext *ctx, LinearLayout cvt,
566-
ArrayRef<Value> valsArray, // Input for store, output for load
567-
Type llvmElemTy, Value smemBase,
568-
ArrayRef<std::pair<unsigned, unsigned>> paddingShifts, Value affineOffset,
569-
uint64_t maskSpanAffineOffset, Value laneId, Value warpId,
570-
RewriterBase &rewriter, const TargetInfoBase &targetInfo,
571-
std::optional<int> maybeMaxVecElems,
572-
std::function<SmallVector<Value>(RewriterBase &, Location, ArrayRef<Value>,
573-
Value, int, VectorType)>
574-
lowerInst);
566+
SmallVector<Value>
567+
lowerLdSt(Location loc, MLIRContext *ctx, LinearLayout cvt,
568+
ArrayRef<Value> valsArray, // Input for store, output for load
569+
Type llvmElemTy, Value smemBase,
570+
ArrayRef<std::pair<unsigned, unsigned>> paddingShifts,
571+
Value affineOffset, uint64_t maskSpanAffineOffset, Value laneId,
572+
Value warpId, RewriterBase &rewriter,
573+
const TargetInfoBase &targetInfo, std::optional<int> maybeMaxVecElems,
574+
std::function<SmallVector<Value>(RewriterBase &, Location,
575+
ArrayRef<Value>, Value, int,
576+
VectorType, std::optional<Value>)>
577+
lowerInst);
575578

576579
// Lower local_load/local_store via ld.shared/st.shared
577580
SmallVector<Value>
@@ -613,10 +616,10 @@ void makeAllWarpGroupsIsolatedFromAbove(Operation *op);
613616
// Set the correct loop annotation on LLVM branch ops.
614617
void fixUpLoopAnnotation(ModuleOp mod);
615618

616-
void transferWithinBlockSwizzling(triton::gpu::ConvertLayoutOp op, Value src,
617-
const TargetInfoBase &targetInfo,
618-
const LLVMTypeConverter *typeConverter,
619-
RewriterBase &rewriter);
619+
void transferSwizzlingLocalMem(triton::gpu::ConvertLayoutOp op, Value src,
620+
const TargetInfoBase &targetInfo,
621+
const LLVMTypeConverter *typeConverter,
622+
RewriterBase &rewriter);
620623

621624
SmallVector<Value> inlineRegionImpl(RewriterBase &rewriter, Region &region,
622625
ArrayRef<Value> args,

lib/Analysis/Allocation.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,11 @@ unsigned getNumScratchElemsSwizzledCvt(const LinearLayout &srcLayout,
4242
auto smem = gpu::optimalSwizzlingLdSt(srcLayoutNoBroadcast,
4343
dstLayoutNoBroadcast, bitwidth);
4444
auto reps = smem.getInDimSize(StringAttr::get(ctx, "reps"));
45-
return smem.getTotalOutDimSize() / reps;
45+
// The smem has the same cta layout as the srcLayout, so we use that instead
46+
// We remove the number of elements that are duplicated in the cta layout
47+
auto nBlocks = product(triton::gpu::getCTASplitNum(
48+
gpu::LinearEncodingAttr::get(ctx, srcLayout)));
49+
return smem.getTotalOutDimSize() / (reps * nBlocks);
4650
}
4751

4852
unsigned getNumScratchElemsSwizzledCvt(RankedTensorType srcTy,

lib/Analysis/Membar.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,8 @@ void MembarAnalysis::update(Operation *op, BlockInfo *blockInfo,
324324
auto dstTy = cast<RankedTensorType>(cvt.getType());
325325
auto srcLayout = triton::gpu::toLinearLayout(srcTy);
326326
auto dstLayout = triton::gpu::toLinearLayout(dstTy);
327-
isWarpSync = mlir::isCvtWarpSync(srcLayout, dstLayout);
327+
auto kWarp = StringAttr::get(op->getContext(), "warp");
328+
isWarpSync = mlir::isCvtDimSync(srcLayout, dstLayout, kWarp);
328329
}
329330

330331
if (!curBlockInfo.syncReadSlices.empty() ||

lib/Analysis/Utility.cpp

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1414,17 +1414,23 @@ std::unique_ptr<DataFlowSolver> createDataFlowSolver() {
14141414
return solver;
14151415
}
14161416

1417-
bool isCvtWarpSync(const triton::LinearLayout &srcLayout,
1418-
const triton::LinearLayout &dstLayout) {
1419-
// We can use warp.sync when the warp dimension in the convert is trival
1420-
// and there is no broadcasting at a warp level (otherwise reads may be
1421-
// wrong)
1417+
bool isCvtDimSync(const triton::LinearLayout &srcLayout,
1418+
const triton::LinearLayout &dstLayout, StringAttr dim) {
1419+
// We can use a dimension-level sync when the conversion is trivial over that
1420+
// dimension and there is no broadcasting over it.
14221421
auto *ctx = srcLayout.getInDimNames().begin()->getContext();
1423-
auto comp = dstLayout.invertAndCompose(srcLayout);
14241422
auto kWarp = StringAttr::get(ctx, "warp");
1425-
return comp.isTrivialOver(kWarp) &&
1426-
srcLayout.getFreeVariableMasks()[kWarp] == 0 &&
1427-
dstLayout.getFreeVariableMasks()[kWarp] == 0;
1423+
auto kBlock = StringAttr::get(ctx, "block");
1424+
assert((dim == kWarp || dim == kBlock) && "expected dim to be warp or block");
1425+
assert(srcLayout.hasInDim(dim) && dstLayout.hasInDim(dim) &&
1426+
"expected dim to be present in both layouts");
1427+
auto parentTrivial = true;
1428+
if (dim == kWarp) {
1429+
parentTrivial = isCvtDimSync(srcLayout, dstLayout, kBlock);
1430+
}
1431+
auto comp = dstLayout.invertAndCompose(srcLayout);
1432+
return parentTrivial && comp.isTrivialOver(dim) &&
1433+
srcLayout.getFreeVariableMasks()[dim] == 0 &&
1434+
dstLayout.getFreeVariableMasks()[dim] == 0;
14281435
}
1429-
14301436
} // namespace mlir

lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 49 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h"
55
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
66

7+
#include <optional>
8+
79
#include "triton/Analysis/Allocation.h"
810
#include "triton/Dialect/Triton/IR/Types.h"
911
#include "triton/Dialect/Triton/IR/Utility.h"
@@ -45,23 +47,18 @@ struct ConvertLayoutOpConversion
4547
LinearLayout srcLayout = toLinearLayout(srcTy);
4648
LinearLayout dstLayout = toLinearLayout(dstTy);
4749

48-
StringAttr kBlock = str_attr("block");
49-
StringAttr kWarp = str_attr("warp");
50-
StringAttr kLane = str_attr("lane");
51-
StringAttr kRegister = str_attr("register");
50+
auto kBlock = str_attr("block");
51+
auto kWarp = str_attr("warp");
52+
auto kLane = str_attr("lane");
53+
auto kRegister = str_attr("register");
5254

5355
assert(to_vector(conversion.getInDimNames()) ==
5456
to_vector(conversion.getOutDimNames()));
5557
auto dims = conversion.getInDimNames();
56-
if (llvm::is_contained(dims, kBlock)) {
57-
// Case 1: Transfer between values in different CTAs.
58-
// This requires moving values through distributed shared memory.
59-
return rewriter.notifyMatchFailure(
60-
op, "NYI: Transfer between different CTAs");
61-
} else if (llvm::is_contained(dims, kWarp)) {
62-
// Case 2: Transfer between values in the same CTA, in which case we move
63-
// values through shared memory.
64-
transferWithinBlockSwizzling(op, adaptor.getSrc(), rewriter);
58+
if (llvm::is_contained(dims, kBlock) || llvm::is_contained(dims, kWarp)) {
59+
// Transfer between values in the same CTA, or across CTAs. We move values
60+
// through (distributed) shared memory.
61+
transferSwizzlingLocalMem(op, adaptor.getSrc(), rewriter);
6562
return success();
6663
} else if (llvm::is_contained(dims, kLane)) {
6764
// Case 3. Transfer between values in the same warp, in which case we try
@@ -70,7 +67,7 @@ struct ConvertLayoutOpConversion
7067
if (cvtNeedsWarpShuffle(srcTy, dstTy))
7168
return transferWithinWarp(op, adaptor, rewriter);
7269

73-
transferWithinBlockSwizzling(op, adaptor.getSrc(), rewriter);
70+
transferSwizzlingLocalMem(op, adaptor.getSrc(), rewriter);
7471
return success();
7572
} else if (llvm::is_contained(dims, kRegister)) {
7673
// Case 4. Transfer between values in the same thread, in which case we
@@ -90,7 +87,7 @@ struct ConvertLayoutOpConversion
9087
ConversionPatternRewriter &rewriter) const {
9188
MLIRContext *ctx = op.getContext();
9289
auto loc = op.getLoc();
93-
StringAttr kRegister = str_attr("register");
90+
auto kRegister = str_attr("register");
9491
assert(!cvtNeedsSharedMemory(op.getSrc().getType(), op.getType()));
9592

9693
auto srcTy = op.getSrc().getType();
@@ -107,7 +104,7 @@ struct ConvertLayoutOpConversion
107104
return success();
108105
}
109106

110-
SmallVector<Value> transferWithinBlockSwizzlingImpl(
107+
SmallVector<Value> transferSwizzlingLocalMemImpl(
111108
Location loc, ConversionPatternRewriter &rewriter,
112109
const LinearLayout &srcLayout, const LinearLayout &dstLayout,
113110
ArrayRef<Value> inVals, Type llvmElemTy, Value smemBase) const {
@@ -123,8 +120,8 @@ struct ConvertLayoutOpConversion
123120
return b.ptrtoint(llvmElemTyPtr, v).getResult();
124121
}));
125122
auto outVals =
126-
transferWithinBlockSwizzlingImpl(loc, rewriter, srcLayout, dstLayout,
127-
newInVals, llvmElemTyPtr, smemBase);
123+
transferSwizzlingLocalMemImpl(loc, rewriter, srcLayout, dstLayout,
124+
newInVals, llvmElemTyPtr, smemBase);
128125
for (auto &v : outVals) {
129126
v = b.inttoptr(llvmElemTy, v);
130127
}
@@ -137,7 +134,7 @@ struct ConvertLayoutOpConversion
137134
auto i8ElemTy = i8_ty;
138135
auto newInVals = llvm::to_vector(llvm::map_range(
139136
inVals, [&](Value v) { return b.zext(i8ElemTy, v).getResult(); }));
140-
auto outVals = transferWithinBlockSwizzlingImpl(
137+
auto outVals = transferSwizzlingLocalMemImpl(
141138
loc, rewriter, srcLayout, dstLayout, newInVals, i8ElemTy, smemBase);
142139
for (auto &v : outVals) {
143140
v = b.trunc(llvmElemTy, v);
@@ -150,15 +147,15 @@ struct ConvertLayoutOpConversion
150147
if (!removeBroadcastSrc.isIdentity()) {
151148
auto prmtSrc = removeBroadcastSrc.apply(srcLayout);
152149
auto newInVals = removeBroadcastSrc.apply(inVals);
153-
return transferWithinBlockSwizzlingImpl(loc, rewriter, prmtSrc, dstLayout,
154-
newInVals, llvmElemTy, smemBase);
150+
return transferSwizzlingLocalMemImpl(loc, rewriter, prmtSrc, dstLayout,
151+
newInVals, llvmElemTy, smemBase);
155152
}
156153

157154
// Remove broadcasting in dst
158155
auto removeBroadcastDst = actionRemoveBroadcastedRegs(dstLayout);
159156
if (!removeBroadcastDst.isIdentity()) {
160157
auto prmtDst = removeBroadcastDst.apply(dstLayout);
161-
auto outVals = transferWithinBlockSwizzlingImpl(
158+
auto outVals = transferSwizzlingLocalMemImpl(
162159
loc, rewriter, srcLayout, prmtDst, inVals, llvmElemTy, smemBase);
163160
return broadcastAs(outVals, dstLayout);
164161
}
@@ -170,6 +167,8 @@ struct ConvertLayoutOpConversion
170167

171168
// Extract reps from smem
172169
auto kReg = str_attr("register");
170+
auto kWarp = str_attr("warp");
171+
auto kBlock = str_attr("block");
173172
auto kReps = str_attr("reps");
174173
auto nReps = smem.getInDimSize(kReps);
175174
auto reps = LinearLayout::identity1D(nReps, kReg, kReps);
@@ -191,8 +190,11 @@ struct ConvertLayoutOpConversion
191190
auto storeCvt = *divideRight(totalStoreCvt, reps);
192191
auto loadCvt = *divideRight(totalLoadCvt, reps);
193192
auto kOffset = str_attr("offset");
194-
storeCvt = storeCvt.reshapeOuts({{kOffset, storeCvt.getTotalOutDimSize()}});
195-
loadCvt = loadCvt.reshapeOuts({{kOffset, loadCvt.getTotalOutDimSize()}});
193+
auto nBlock = storeCvt.getInDimSize(kBlock);
194+
storeCvt = storeCvt.reshapeOuts(
195+
{{kOffset, storeCvt.getTotalOutDimSize() / nBlock}, {kBlock, nBlock}});
196+
loadCvt = loadCvt.reshapeOuts(
197+
{{kOffset, loadCvt.getTotalOutDimSize() / nBlock}, {kBlock, nBlock}});
196198

197199
auto tileSize = storeCvt.getInDimSize(kReg);
198200

@@ -201,28 +203,30 @@ struct ConvertLayoutOpConversion
201203
auto affineOffset = b.i32_val(0);
202204
auto maskSpanAffineOffset = 0;
203205

204-
bool isWarpSync = mlir::isCvtWarpSync(srcLayout, dstLayout);
205-
for (int i = 0; i < nReps; ++i) {
206-
if (i > 0) {
207-
if (isWarpSync) {
208-
targetInfo.warpSync(loc, rewriter);
209-
} else {
210-
targetInfo.barrier(loc, rewriter, triton::gpu::AddrSpace::Local);
211-
}
206+
bool isWarpSync = mlir::isCvtDimSync(srcLayout, dstLayout, kWarp);
207+
bool isBlockSync = mlir::isCvtDimSync(srcLayout, dstLayout, kBlock);
208+
auto emitBarrier = [&]() {
209+
if (isWarpSync) {
210+
targetInfo.warpSync(loc, rewriter);
211+
} else if (isBlockSync) {
212+
targetInfo.barrier(loc, rewriter, triton::gpu::AddrSpace::Local);
213+
} else {
214+
targetInfo.clusterBarrier(loc, rewriter);
212215
}
216+
};
217+
218+
for (int i = 0; i < nReps; ++i) {
219+
if (i > 0)
220+
emitBarrier();
213221
auto tileInVals =
214222
ArrayRef<Value>(permutedInVals).slice(i * tileSize, tileSize);
215223
// Store
216224
lowerLdStShared(loc, ctx, storeCvt, tileInVals, llvmElemTy, smemBase,
217225
/*paddingShifts=*/{}, affineOffset, maskSpanAffineOffset,
218226
rewriter, targetInfo);
219-
if (isWarpSync) {
220-
targetInfo.warpSync(loc, rewriter);
221-
} else {
222-
targetInfo.barrier(loc, rewriter, triton::gpu::AddrSpace::Local);
223-
}
227+
emitBarrier();
224228
// Load
225-
SmallVector<Value> tileOutVals = lowerLdStShared(
229+
auto tileOutVals = lowerLdStShared(
226230
loc, ctx, loadCvt, {}, llvmElemTy, smemBase, /*paddingShifts=*/{},
227231
affineOffset, maskSpanAffineOffset, rewriter, targetInfo);
228232
llvm::append_range(outVals, tileOutVals);
@@ -233,30 +237,21 @@ struct ConvertLayoutOpConversion
233237
return outVals;
234238
}
235239

236-
void transferWithinBlockSwizzling(ConvertLayoutOp op, Value src,
237-
ConversionPatternRewriter &rewriter) const {
240+
void transferSwizzlingLocalMem(ConvertLayoutOp op, Value src,
241+
ConversionPatternRewriter &rewriter) const {
238242
auto loc = op.getLoc();
239243
auto *ctx = op.getContext();
240244
auto srcTy = op.getSrc().getType();
241245
auto dstTy = op.getType();
242246

243-
// Remove the kBlock dimension from the layout as it's the identity in the
244-
// cvt
245247
auto srcLayout = toLinearLayout(srcTy);
246248
auto dstLayout = toLinearLayout(dstTy);
247-
auto kReg = str_attr("register");
248-
auto kLane = str_attr("lane");
249-
auto kWarp = str_attr("warp");
250-
srcLayout = srcLayout.sublayout({kReg, kLane, kWarp},
251-
to_vector(srcLayout.getOutDimNames()));
252-
dstLayout = dstLayout.sublayout({kReg, kLane, kWarp},
253-
to_vector(dstLayout.getOutDimNames()));
254249

255250
auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType());
256251
auto smemBase =
257252
LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, op.getOperation());
258253
auto inVals = unpackLLElements(loc, src, rewriter);
259-
auto outVals = transferWithinBlockSwizzlingImpl(
254+
auto outVals = transferSwizzlingLocalMemImpl(
260255
loc, rewriter, srcLayout, dstLayout, inVals, llvmElemTy, smemBase);
261256

262257
Value result =
@@ -273,8 +268,8 @@ struct ConvertLayoutOpConversion
273268
auto b = TritonLLVMOpBuilder(loc, rewriter);
274269
auto srcTy = op.getSrc().getType();
275270
auto dstTy = op.getType();
276-
StringAttr kReg = str_attr("register");
277-
StringAttr kLane = str_attr("lane");
271+
auto kReg = str_attr("register");
272+
auto kLane = str_attr("lane");
278273
auto elemTy = getTypeConverter()->convertType(srcTy.getElementType());
279274
int bitwidth = getIntOrFloatOrPtrBitWidth(elemTy);
280275

@@ -431,8 +426,8 @@ struct ConvertLayoutOpConversion
431426
ArrayRef<TranspositionInfo> mixedTranspositions) const {
432427
auto *ctx = rewriter.getContext();
433428
auto b = TritonLLVMOpBuilder(loc, rewriter);
434-
StringAttr kReg = str_attr("register");
435-
StringAttr kLane = str_attr("lane");
429+
auto kReg = str_attr("register");
430+
auto kLane = str_attr("lane");
436431

437432
SmallVector<Value> vals(inVals.begin(), inVals.end());
438433
int m = mixedTranspositions.size();

lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -150,11 +150,10 @@ LogicalResult lowerLocalStore(Location loc, MLIRContext *ctx, Value regVal,
150150
cvt = regLayout.invertAndCompose(sharedLayout);
151151
}
152152
auto kBlock = str_attr("block");
153-
// NYI. We would need to emit a map.shared::cluster instruction.
153+
// We could support it by removing this check if we ever want to
154154
if (!cvt.isTrivialOver({kBlock})) {
155155
return failure();
156156
}
157-
cvt = cvt.sublayout({kReg, kLane, kWarp}, {kOffset});
158157
lowerLocalLdSt(loc, ctx, cvt, inVals, llvmElemTy, memDescTy, smemObj,
159158
rewriter, targetInfo);
160159

@@ -287,11 +286,10 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
287286
cvt = regLayout.invertAndCompose(sharedLayout);
288287
}
289288
auto kBlock = str_attr("block");
290-
// NYI. We would need to emit a map.shared::cluster instruction.
289+
// We could support it by removing this check if we ever want to
291290
if (!cvt.isTrivialOver({kBlock})) {
292291
return failure();
293292
}
294-
cvt = cvt.sublayout({kReg, kLane, kWarp}, {kOffset});
295293

296294
auto outVals = lowerLocalLdSt(loc, ctx, cvt, {}, llvmElemTy, memDescTy,
297295
smemObj, rewriter, targetInfo, op);

0 commit comments

Comments
 (0)