Skip to content

Commit 977774b

Browse files
committed
[BACKEND] Support generic multi-cta convert_layouts
We generalise the swizzling algorithm to work with blocks and generalise the most of the memory lowerings to support layouts with blocks. We remove the legacy lowering. The generic swizzling algorithm for blocks might be fine, but we didn't try to be super clever. There might be some perf left on the table. We can look into this at a later point if it becomes relevant. We also activate multi-cta reductions in the process and test both there. TODO: Add some funky tests that just test the `convert_layout`, not the `convert_layout` within the reduction. TODO: Check how to perform multiCTA barriers in AMD and perhaps merge cluster barriers into ttg.barrier, predicate broadcasting blocks, etc. stack-info: PR: #9317, branch: lezcano/stack/9
1 parent eb2406c commit 977774b

25 files changed

Lines changed: 357 additions & 404 deletions

File tree

include/triton/Analysis/Utility.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -389,8 +389,8 @@ template <typename T> class CallGraph {
389389
// Create a basic DataFlowSolver with constant and dead code analysis included.
390390
std::unique_ptr<DataFlowSolver> createDataFlowSolver();
391391

392-
bool isCvtWarpSync(const triton::LinearLayout &srcLayout,
393-
const triton::LinearLayout &dstLayout);
392+
bool isCvtDimSync(const triton::LinearLayout &srcLayout,
393+
const triton::LinearLayout &dstLayout, StringAttr dim);
394394

395395
} // namespace mlir
396396

include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ class TargetInfoBase {
2020
// target address space
2121
virtual void barrier(Location loc, RewriterBase &rewriter,
2222
triton::gpu::AddrSpace targets) const = 0;
23+
// Emit a cluster-level barrier when supported. Defaults to CTA barrier.
24+
virtual void clusterBarrier(Location loc, RewriterBase &rewriter) const = 0;
2325
// Insert a warp syncronization barrier that also guarantees local address
2426
// space visibility at warp level when supported by the backend.
2527
// Backends that do not support warp-level barriers should conservatively

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
#include "triton/Tools/StrUtil.h"
1616
#include "llvm/ADT/STLExtras.h"
1717

18+
#include <optional>
19+
1820
#define DEBUG_TYPE "ttgpu_to_llvm"
1921
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
2022
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
@@ -561,17 +563,18 @@ lowerLdStShared(Location loc, MLIRContext *ctx, LinearLayout cvt,
561563
// calcPaddedOffset is a lambda that takes a base offset (mlir::Value)
562564
// and computes a new offset (mlir::Value) by applying padding based on
563565
// shared memory layout.
564-
SmallVector<Value> lowerLdSt(
565-
Location loc, MLIRContext *ctx, LinearLayout cvt,
566-
ArrayRef<Value> valsArray, // Input for store, output for load
567-
Type llvmElemTy, Value smemBase,
568-
ArrayRef<std::pair<unsigned, unsigned>> paddingShifts, Value affineOffset,
569-
uint64_t maskSpanAffineOffset, Value laneId, Value warpId,
570-
RewriterBase &rewriter, const TargetInfoBase &targetInfo,
571-
std::optional<int> maybeMaxVecElems,
572-
std::function<SmallVector<Value>(RewriterBase &, Location, ArrayRef<Value>,
573-
Value, int, VectorType)>
574-
lowerInst);
566+
SmallVector<Value>
567+
lowerLdSt(Location loc, MLIRContext *ctx, LinearLayout cvt,
568+
ArrayRef<Value> valsArray, // Input for store, output for load
569+
Type llvmElemTy, Value smemBase,
570+
ArrayRef<std::pair<unsigned, unsigned>> paddingShifts, Value affineOffset,
571+
uint64_t maskSpanAffineOffset, Value laneId, Value warpId,
572+
RewriterBase &rewriter, const TargetInfoBase &targetInfo,
573+
std::optional<int> maybeMaxVecElems,
574+
std::function<SmallVector<Value>(RewriterBase &, Location,
575+
ArrayRef<Value>, Value, int,
576+
VectorType, std::optional<Value>)>
577+
lowerInst);
575578

576579
// Lower local_load/local_store via ld.shared/st.shared
577580
SmallVector<Value>
@@ -613,10 +616,10 @@ void makeAllWarpGroupsIsolatedFromAbove(Operation *op);
613616
// Set the correct loop annotation on LLVM branch ops.
614617
void fixUpLoopAnnotation(ModuleOp mod);
615618

616-
void transferWithinBlockSwizzling(triton::gpu::ConvertLayoutOp op, Value src,
617-
const TargetInfoBase &targetInfo,
618-
const LLVMTypeConverter *typeConverter,
619-
RewriterBase &rewriter);
619+
void transferSwizzlingLocalMem(triton::gpu::ConvertLayoutOp op, Value src,
620+
const TargetInfoBase &targetInfo,
621+
const LLVMTypeConverter *typeConverter,
622+
RewriterBase &rewriter);
620623

621624
SmallVector<Value> inlineRegionImpl(RewriterBase &rewriter, Region &region,
622625
ArrayRef<Value> args,

lib/Analysis/Membar.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,8 @@ void MembarAnalysis::update(Operation *op, BlockInfo *blockInfo,
324324
auto dstTy = cast<RankedTensorType>(cvt.getType());
325325
auto srcLayout = triton::gpu::toLinearLayout(srcTy);
326326
auto dstLayout = triton::gpu::toLinearLayout(dstTy);
327-
isWarpSync = mlir::isCvtWarpSync(srcLayout, dstLayout);
327+
auto kWarp = StringAttr::get(op->getContext(), "warp");
328+
isWarpSync = mlir::isCvtDimSync(srcLayout, dstLayout, kWarp);
328329
}
329330

330331
if (!curBlockInfo.syncReadSlices.empty() ||

lib/Analysis/Utility.cpp

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1258,17 +1258,23 @@ std::unique_ptr<DataFlowSolver> createDataFlowSolver() {
12581258
return solver;
12591259
}
12601260

1261-
bool isCvtWarpSync(const triton::LinearLayout &srcLayout,
1262-
const triton::LinearLayout &dstLayout) {
1263-
// We can use warp.sync when the warp dimension in the convert is trival
1264-
// and there is no broadcasting at a warp level (otherwise reads may be
1265-
// wrong)
1261+
bool isCvtDimSync(const triton::LinearLayout &srcLayout,
1262+
const triton::LinearLayout &dstLayout, StringAttr dim) {
1263+
// We can use a dimension-level sync when the conversion is trivial over that
1264+
// dimension and there is no broadcasting over it.
12661265
auto *ctx = srcLayout.getInDimNames().begin()->getContext();
1267-
auto comp = dstLayout.invertAndCompose(srcLayout);
12681266
auto kWarp = StringAttr::get(ctx, "warp");
1269-
return comp.isTrivialOver(kWarp) &&
1270-
srcLayout.getFreeVariableMasks()[kWarp] == 0 &&
1271-
dstLayout.getFreeVariableMasks()[kWarp] == 0;
1267+
auto kBlock = StringAttr::get(ctx, "block");
1268+
assert((dim == kWarp || dim == kBlock) && "expected dim to be warp or block");
1269+
assert(srcLayout.hasInDim(dim) && dstLayout.hasInDim(dim) &&
1270+
"expected dim to be present in both layouts");
1271+
auto parentTrivial = true;
1272+
if (dim == kWarp) {
1273+
parentTrivial = isCvtDimSync(srcLayout, dstLayout, kBlock);
1274+
}
1275+
auto comp = dstLayout.invertAndCompose(srcLayout);
1276+
return parentTrivial && comp.isTrivialOver(dim) &&
1277+
srcLayout.getFreeVariableMasks()[dim] == 0 &&
1278+
dstLayout.getFreeVariableMasks()[dim] == 0;
12721279
}
1273-
12741280
} // namespace mlir

lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 41 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h"
55
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
66

7+
#include <optional>
8+
79
#include "triton/Analysis/Allocation.h"
810
#include "triton/Dialect/Triton/IR/Types.h"
911
#include "triton/Dialect/Triton/IR/Utility.h"
@@ -53,15 +55,10 @@ struct ConvertLayoutOpConversion
5355
assert(to_vector(conversion.getInDimNames()) ==
5456
to_vector(conversion.getOutDimNames()));
5557
auto dims = conversion.getInDimNames();
56-
if (llvm::is_contained(dims, kBlock)) {
57-
// Case 1: Transfer between values in different CTAs.
58-
// This requires moving values through distributed shared memory.
59-
return rewriter.notifyMatchFailure(
60-
op, "NYI: Transfer between different CTAs");
61-
} else if (llvm::is_contained(dims, kWarp)) {
62-
// Case 2: Transfer between values in the same CTA, in which case we move
63-
// values through shared memory.
64-
transferWithinBlockSwizzling(op, adaptor.getSrc(), rewriter);
58+
if (llvm::is_contained(dims, kBlock) || llvm::is_contained(dims, kWarp)) {
59+
// Transfer between values in the same CTA, or across CTAs. We move values
60+
// through (distributed) shared memory.
61+
transferSwizzlingLocalMem(op, adaptor.getSrc(), rewriter);
6562
return success();
6663
} else if (llvm::is_contained(dims, kLane)) {
6764
// Case 3. Transfer between values in the same warp, in which case we try
@@ -70,7 +67,7 @@ struct ConvertLayoutOpConversion
7067
if (cvtNeedsWarpShuffle(srcTy, dstTy))
7168
return transferWithinWarp(op, adaptor, rewriter);
7269

73-
transferWithinBlockSwizzling(op, adaptor.getSrc(), rewriter);
70+
transferSwizzlingLocalMem(op, adaptor.getSrc(), rewriter);
7471
return success();
7572
} else if (llvm::is_contained(dims, kRegister)) {
7673
// Case 4. Transfer between values in the same thread, in which case we
@@ -107,7 +104,7 @@ struct ConvertLayoutOpConversion
107104
return success();
108105
}
109106

110-
SmallVector<Value> transferWithinBlockSwizzlingImpl(
107+
SmallVector<Value> transferSwizzlingLocalMemImpl(
111108
Location loc, ConversionPatternRewriter &rewriter,
112109
const LinearLayout &srcLayout, const LinearLayout &dstLayout,
113110
ArrayRef<Value> inVals, Type llvmElemTy, Value smemBase) const {
@@ -123,8 +120,8 @@ struct ConvertLayoutOpConversion
123120
return b.ptrtoint(llvmElemTyPtr, v).getResult();
124121
}));
125122
auto outVals =
126-
transferWithinBlockSwizzlingImpl(loc, rewriter, srcLayout, dstLayout,
127-
newInVals, llvmElemTyPtr, smemBase);
123+
transferSwizzlingLocalMemImpl(loc, rewriter, srcLayout, dstLayout,
124+
newInVals, llvmElemTyPtr, smemBase);
128125
for (auto &v : outVals) {
129126
v = b.inttoptr(llvmElemTy, v);
130127
}
@@ -137,7 +134,7 @@ struct ConvertLayoutOpConversion
137134
auto i8ElemTy = i8_ty;
138135
auto newInVals = llvm::to_vector(llvm::map_range(
139136
inVals, [&](Value v) { return b.zext(i8ElemTy, v).getResult(); }));
140-
auto outVals = transferWithinBlockSwizzlingImpl(
137+
auto outVals = transferSwizzlingLocalMemImpl(
141138
loc, rewriter, srcLayout, dstLayout, newInVals, i8ElemTy, smemBase);
142139
for (auto &v : outVals) {
143140
v = b.trunc(llvmElemTy, v);
@@ -150,26 +147,29 @@ struct ConvertLayoutOpConversion
150147
if (!removeBroadcastSrc.isIdentity()) {
151148
auto prmtSrc = removeBroadcastSrc.apply(srcLayout);
152149
auto newInVals = removeBroadcastSrc.apply(inVals);
153-
return transferWithinBlockSwizzlingImpl(loc, rewriter, prmtSrc, dstLayout,
154-
newInVals, llvmElemTy, smemBase);
150+
return transferSwizzlingLocalMemImpl(loc, rewriter, prmtSrc, dstLayout,
151+
newInVals, llvmElemTy, smemBase);
155152
}
156153

157154
// Remove broadcasting in dst
158155
auto removeBroadcastDst = actionRemoveBroadcastedRegs(dstLayout);
159156
if (!removeBroadcastDst.isIdentity()) {
160157
auto prmtDst = removeBroadcastDst.apply(dstLayout);
161-
auto outVals = transferWithinBlockSwizzlingImpl(
158+
auto outVals = transferSwizzlingLocalMemImpl(
162159
loc, rewriter, srcLayout, prmtDst, inVals, llvmElemTy, smemBase);
163160
return broadcastAs(outVals, dstLayout);
164161
}
165162

166163
// At this point we have a type that's at least 8-bit
167164
// and we don't have broadcasting in the registers
165+
168166
auto bitwidth = llvmElemTy.getIntOrFloatBitWidth();
169167
auto smem = optimalSwizzlingLdSt(srcLayout, dstLayout, bitwidth);
170168

171169
// Extract reps from smem
172170
auto kReg = str_attr("register");
171+
auto kWarp = StringAttr::get(ctx, "warp");
172+
auto kBlock = StringAttr::get(ctx, "block");
173173
auto kReps = str_attr("reps");
174174
auto nReps = smem.getInDimSize(kReps);
175175
auto reps = LinearLayout::identity1D(nReps, kReg, kReps);
@@ -191,8 +191,11 @@ struct ConvertLayoutOpConversion
191191
auto storeCvt = *divideRight(totalStoreCvt, reps);
192192
auto loadCvt = *divideRight(totalLoadCvt, reps);
193193
auto kOffset = str_attr("offset");
194-
storeCvt = storeCvt.reshapeOuts({{kOffset, storeCvt.getTotalOutDimSize()}});
195-
loadCvt = loadCvt.reshapeOuts({{kOffset, loadCvt.getTotalOutDimSize()}});
194+
auto nBlock = storeCvt.getInDimSize(kBlock);
195+
storeCvt = storeCvt.reshapeOuts(
196+
{{kOffset, storeCvt.getTotalOutDimSize() / nBlock}, {kBlock, nBlock}});
197+
loadCvt = loadCvt.reshapeOuts(
198+
{{kOffset, loadCvt.getTotalOutDimSize() / nBlock}, {kBlock, nBlock}});
196199

197200
auto tileSize = storeCvt.getInDimSize(kReg);
198201

@@ -201,28 +204,30 @@ struct ConvertLayoutOpConversion
201204
auto affineOffset = b.i32_val(0);
202205
auto maskSpanAffineOffset = 0;
203206

204-
bool isWarpSync = mlir::isCvtWarpSync(srcLayout, dstLayout);
205-
for (int i = 0; i < nReps; ++i) {
206-
if (i > 0) {
207-
if (isWarpSync) {
208-
targetInfo.warpSync(loc, rewriter);
209-
} else {
210-
targetInfo.barrier(loc, rewriter, triton::gpu::AddrSpace::Local);
211-
}
207+
bool isWarpSync = mlir::isCvtDimSync(srcLayout, dstLayout, kWarp);
208+
bool isBlockSync = mlir::isCvtDimSync(srcLayout, dstLayout, kBlock);
209+
auto emitBarrier = [&]() {
210+
if (isWarpSync) {
211+
targetInfo.warpSync(loc, rewriter);
212+
} else if (isBlockSync) {
213+
targetInfo.barrier(loc, rewriter, triton::gpu::AddrSpace::Local);
214+
} else {
215+
targetInfo.clusterBarrier(loc, rewriter);
212216
}
217+
};
218+
219+
for (int i = 0; i < nReps; ++i) {
220+
if (i > 0)
221+
emitBarrier();
213222
auto tileInVals =
214223
ArrayRef<Value>(permutedInVals).slice(i * tileSize, tileSize);
215224
// Store
216225
lowerLdStShared(loc, ctx, storeCvt, tileInVals, llvmElemTy, smemBase,
217226
/*paddingShifts=*/{}, affineOffset, maskSpanAffineOffset,
218227
rewriter, targetInfo);
219-
if (isWarpSync) {
220-
targetInfo.warpSync(loc, rewriter);
221-
} else {
222-
targetInfo.barrier(loc, rewriter, triton::gpu::AddrSpace::Local);
223-
}
228+
emitBarrier();
224229
// Load
225-
SmallVector<Value> tileOutVals = lowerLdStShared(
230+
auto tileOutVals = lowerLdStShared(
226231
loc, ctx, loadCvt, {}, llvmElemTy, smemBase, /*paddingShifts=*/{},
227232
affineOffset, maskSpanAffineOffset, rewriter, targetInfo);
228233
llvm::append_range(outVals, tileOutVals);
@@ -233,30 +238,21 @@ struct ConvertLayoutOpConversion
233238
return outVals;
234239
}
235240

236-
void transferWithinBlockSwizzling(ConvertLayoutOp op, Value src,
237-
ConversionPatternRewriter &rewriter) const {
241+
void transferSwizzlingLocalMem(ConvertLayoutOp op, Value src,
242+
ConversionPatternRewriter &rewriter) const {
238243
auto loc = op.getLoc();
239244
auto *ctx = op.getContext();
240245
auto srcTy = op.getSrc().getType();
241246
auto dstTy = op.getType();
242247

243-
// Remove the kBlock dimension from the layout as it's the identity in the
244-
// cvt
245248
auto srcLayout = toLinearLayout(srcTy);
246249
auto dstLayout = toLinearLayout(dstTy);
247-
auto kReg = str_attr("register");
248-
auto kLane = str_attr("lane");
249-
auto kWarp = str_attr("warp");
250-
srcLayout = srcLayout.sublayout({kReg, kLane, kWarp},
251-
to_vector(srcLayout.getOutDimNames()));
252-
dstLayout = dstLayout.sublayout({kReg, kLane, kWarp},
253-
to_vector(dstLayout.getOutDimNames()));
254250

255251
auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType());
256252
auto smemBase =
257253
LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, op.getOperation());
258254
auto inVals = unpackLLElements(loc, src, rewriter);
259-
auto outVals = transferWithinBlockSwizzlingImpl(
255+
auto outVals = transferSwizzlingLocalMemImpl(
260256
loc, rewriter, srcLayout, dstLayout, inVals, llvmElemTy, smemBase);
261257

262258
Value result =

lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -150,11 +150,10 @@ LogicalResult lowerLocalStore(Location loc, MLIRContext *ctx, Value regVal,
150150
cvt = regLayout.invertAndCompose(sharedLayout);
151151
}
152152
auto kBlock = str_attr("block");
153-
// NYI. We would need to emit a map.shared::cluster instruction.
153+
// We could support it by removing this check if we ever want to
154154
if (!cvt.isTrivialOver({kBlock})) {
155155
return failure();
156156
}
157-
cvt = cvt.sublayout({kReg, kLane, kWarp}, {kOffset});
158157
lowerLocalLdSt(loc, ctx, cvt, inVals, llvmElemTy, memDescTy, smemObj,
159158
rewriter, targetInfo);
160159

@@ -287,11 +286,10 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
287286
cvt = regLayout.invertAndCompose(sharedLayout);
288287
}
289288
auto kBlock = str_attr("block");
290-
// NYI. We would need to emit a map.shared::cluster instruction.
289+
// We could support it by removing this check if we ever want to
291290
if (!cvt.isTrivialOver({kBlock})) {
292291
return failure();
293292
}
294-
cvt = cvt.sublayout({kReg, kLane, kWarp}, {kOffset});
295293

296294
auto outVals = lowerLocalLdSt(loc, ctx, cvt, {}, llvmElemTy, memDescTy,
297295
smemObj, rewriter, targetInfo, op);

lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -37,16 +37,6 @@ struct ReduceOpConversion
3737
auto accs = unpackInputs(loc, op, adaptor, rewriter);
3838
unsigned axis = op.getAxis();
3939

40-
// The lowering already supports cross-CTA reductions in principle
41-
// We are only missing:
42-
// - Supporting them in convert_layout for LinearLayouts
43-
// - Emitting cross-CTA barriers between convert_layouts when the second
44-
// convert_layout crosses CTAs
45-
// After this, we can uncomment the tests in test_reduce_funky_layout
46-
if (!helper.isReduceWithinCTA()) {
47-
return failure();
48-
}
49-
5040
auto *ctx = op.getContext();
5141

5242
// Remove block as we don't currently support it
@@ -80,15 +70,18 @@ struct ReduceOpConversion
8070
// That is, they fit in 2 rounds of warp reductions
8171
// Even more, if we do two rounds, getInterLayout will make sure that the
8272
// first one does not cross CTAs
73+
auto kBlock = StringAttr::get(ctx, "block");
74+
bool lastCvtCrossesCTAs = false;
8375
int i = 0;
8476
while (to_vector(regLl.getOutDimSizes())[axis] != 1) {
8577
LinearLayout tmpLl = ReduceOpHelper::getInterLayout(regLl, axis);
8678

8779
// Emit a barrier if we are reusing the shmem
8880
if (i > 0) {
89-
sync(rewriter, loc);
81+
sync(rewriter, loc, lastCvtCrossesCTAs);
9082
}
9183
accs = convertLayoutValues(loc, rewriter, op, regLl, tmpLl, accs);
84+
lastCvtCrossesCTAs = !mlir::isCvtDimSync(regLl, tmpLl, kBlock);
9285

9386
std::tie(regLl, accs) =
9487
reduceWithinWarps(op, std::move(tmpLl), std::move(accs), rewriter);
@@ -105,7 +98,7 @@ struct ReduceOpConversion
10598
auto outputLayout = triton::gpu::toLinearLayout(resultTy);
10699
if (regLl != outputLayout) {
107100
// Reuse the shmem
108-
sync(rewriter, loc);
101+
sync(rewriter, loc, lastCvtCrossesCTAs);
109102
accs =
110103
convertLayoutValues(loc, rewriter, op, regLl, outputLayout, accs);
111104
}
@@ -276,9 +269,13 @@ struct ReduceOpConversion
276269
return srcValues;
277270
}
278271

279-
void sync(ConversionPatternRewriter &rewriter, Location loc) const {
280-
auto b = TritonLLVMOpBuilder(loc, rewriter);
281-
b.barrier(triton::gpu::AddrSpace::Local);
272+
void sync(ConversionPatternRewriter &rewriter, Location loc,
273+
bool crossCTA) const {
274+
if (crossCTA) {
275+
targetInfo.clusterBarrier(loc, rewriter);
276+
} else {
277+
targetInfo.barrier(loc, rewriter, triton::gpu::AddrSpace::Local);
278+
}
282279
}
283280

284281
// Reduce along op axis for elements that are in the same thread. The

0 commit comments

Comments
 (0)