Skip to content

Commit 229f05a

Browse files
committed
[BACKEND] Implement support for cross-CTA tt.reduce
The title of this PR is a bit of a lie. Even though the lowering is now implemented to support cross-CTA reductions, it depends on `convert_layout` supporting them, and it doesn't currently support LinearLayouts. We should generalise this one first and then enable it here. We should also emit the correct cross-CTA barrier from `targetInfo` in the case of cross-CTA memory reuse. In this PR, we take the chance to also generalise the lowering to avoid convert layouts whenever possible. stack-info: PR: #9221, branch: lezcano/stack/8
1 parent 7687a5e commit 229f05a

7 files changed

Lines changed: 287 additions & 196 deletions

File tree

include/triton/Analysis/Allocation.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,10 @@ using AllocationAnalysisScratchSizeFn = std::function<unsigned(Operation *)>;
2020

2121
unsigned defaultAllocationAnalysisScratchSizeFn(Operation *op);
2222

23+
unsigned getNumScratchElemsSwizzledCvt(const LinearLayout &srcLayout,
24+
const LinearLayout &dstLayout,
25+
int bitwidth);
26+
2327
unsigned getNumScratchElemsSwizzledCvt(RankedTensorType srcTy,
2428
RankedTensorType dstTy);
2529

include/triton/Analysis/Utility.h

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,6 @@ class ReduceOpHelper {
6161

6262
RankedTensorType getSrcTy() { return srcTy; }
6363

64-
bool isWarpSynchronous();
65-
6664
unsigned getInterWarpSizeWithUniqueData();
6765

6866
unsigned getIntraWarpSizeWithUniqueData();
@@ -71,6 +69,9 @@ class ReduceOpHelper {
7169

7270
bool isAssociative();
7371

72+
// Get the shared memory scratch size required by this reduce op.
73+
unsigned getScratchSizeInBytes();
74+
7475
InThreadVectorizeOpKind getInThreadVectorizeOpKind(unsigned axisPack);
7576

7677
static triton::ColumnAction
@@ -92,10 +93,6 @@ class ReduceOpHelper {
9293
InThreadVectorizeOpKind kind,
9394
Value lhs, Value rhs);
9495

95-
SmallVector<unsigned>
96-
getScratchBytesForCvt(const triton::LinearLayout &srcLayout,
97-
const triton::LinearLayout &dstLayout);
98-
9996
private:
10097
triton::ReduceOp op;
10198
RankedTensorType srcTy;

lib/Analysis/Allocation.cpp

Lines changed: 16 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -31,49 +31,27 @@ namespace mlir {
3131
//===----------------------------------------------------------------------===//
3232
namespace triton {
3333

34-
unsigned getNumScratchElemsSwizzledCvt(RankedTensorType srcTy,
35-
RankedTensorType dstTy) {
36-
auto *ctx = srcTy.getContext();
37-
auto srcLayout = gpu::toLinearLayout(srcTy);
38-
auto dstLayout = gpu::toLinearLayout(dstTy);
39-
srcLayout = actionRemoveBroadcastedRegs(srcLayout).apply(srcLayout);
40-
dstLayout = actionRemoveBroadcastedRegs(dstLayout).apply(dstLayout);
41-
auto bitwidth = getBitwidth(srcTy);
42-
auto smem = gpu::optimalSwizzlingLdSt(srcLayout, dstLayout, bitwidth);
34+
unsigned getNumScratchElemsSwizzledCvt(const LinearLayout &srcLayout,
35+
const LinearLayout &dstLayout,
36+
int bitwidth) {
37+
auto *ctx = srcLayout.getInDimNames().begin()->getContext();
38+
auto srcLayoutNoBroadcast =
39+
actionRemoveBroadcastedRegs(srcLayout).apply(srcLayout);
40+
auto dstLayoutNoBroadcast =
41+
actionRemoveBroadcastedRegs(dstLayout).apply(dstLayout);
42+
auto smem = gpu::optimalSwizzlingLdSt(srcLayoutNoBroadcast,
43+
dstLayoutNoBroadcast, bitwidth);
4344
auto reps = smem.getInDimSize(StringAttr::get(ctx, "reps"));
4445
return smem.getTotalOutDimSize() / reps;
4546
}
4647

47-
namespace {
48-
constexpr int64_t kReduceScratchAlign = 16;
49-
50-
Type getReduceMemElemTy(Type elemTy, MLIRContext *ctx) {
51-
if (elemTy.isIntOrFloat() && elemTy.getIntOrFloatBitWidth() < 8)
52-
return IntegerType::get(ctx, 8);
53-
return elemTy;
54-
}
55-
56-
int64_t getReduceScratchSizeBytes(triton::ReduceOp op,
57-
ArrayRef<unsigned> bytesPerOperand) {
58-
std::vector<unsigned> indices(op.getNumOperands());
59-
std::iota(indices.begin(), indices.end(), 0);
60-
auto *ctx = op.getContext();
61-
std::sort(indices.begin(), indices.end(), [&](unsigned i, unsigned j) {
62-
auto lhsTy = getReduceMemElemTy(op.getElementTypes()[i], ctx);
63-
auto rhsTy = getReduceMemElemTy(op.getElementTypes()[j], ctx);
64-
return getIntOrFloatOrPtrBitWidth(lhsTy) >
65-
getIntOrFloatOrPtrBitWidth(rhsTy);
66-
});
67-
// Aling to 16 bytes to allow for vectorisation
68-
int64_t offset = 0;
69-
for (unsigned idx : indices) {
70-
offset += llvm::alignTo(bytesPerOperand[idx], kReduceScratchAlign);
71-
}
72-
return offset;
48+
unsigned getNumScratchElemsSwizzledCvt(RankedTensorType srcTy,
49+
RankedTensorType dstTy) {
50+
return getNumScratchElemsSwizzledCvt(gpu::toLinearLayout(srcTy),
51+
gpu::toLinearLayout(dstTy),
52+
getBitwidth(srcTy));
7353
}
7454

75-
} // namespace
76-
7755
// Both `atomic_cas` and `atomic_rmw` may need scratch memory to store values
7856
// because Triton's block-based programming model ensures that
7957
// all threads sharing the same partition of the tensor see the same values,
@@ -100,15 +78,7 @@ static SmallVector<unsigned> getRepShapeForAtomic(Value result) {
10078

10179
unsigned defaultAllocationAnalysisScratchSizeFn(Operation *op) {
10280
if (auto reduceOp = dyn_cast<ReduceOp>(op)) {
103-
ReduceOpHelper helper(reduceOp);
104-
if (helper.isWarpSynchronous())
105-
return 0;
106-
107-
auto regLl = ReduceOpHelper::reducedRegLaneLayout(helper.getSrcTy(),
108-
reduceOp.getAxis());
109-
auto tmpLl = ReduceOpHelper::getInterLayout(regLl, reduceOp.getAxis());
110-
auto bytesRegToTmp = helper.getScratchBytesForCvt(regLl, tmpLl);
111-
return getReduceScratchSizeBytes(reduceOp, bytesRegToTmp);
81+
return ReduceOpHelper(reduceOp).getScratchSizeInBytes();
11282
}
11383
if (auto scanOp = dyn_cast<ScanOp>(op)) {
11484
ScanLoweringHelper helper(scanOp);

lib/Analysis/Utility.cpp

Lines changed: 112 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -78,10 +78,6 @@ unsigned ReduceOpHelper::getIntraWarpSizeWithUniqueData() {
7878
return getThreadsPerWarp(srcEncoding, srcShape)[axis];
7979
}
8080

81-
bool ReduceOpHelper::isWarpSynchronous() {
82-
return getWarpsPerCTA(srcEncoding, srcShape)[axis] == 1;
83-
}
84-
8581
bool ReduceOpHelper::isReduceWithinCTA() {
8682
// TODO: Support reduce across CTAS
8783
// Layout optimization passes such as PlanCTAPass and
@@ -109,6 +105,35 @@ bool ReduceOpHelper::isAssociative() {
109105
return !hasNoAssociativeOp;
110106
}
111107

108+
unsigned ReduceOpHelper::getScratchSizeInBytes() {
109+
auto kLane = StringAttr::get(op.getContext(), "lane");
110+
111+
auto isReduced = [axis = axis](const LinearLayout &layout) {
112+
return layout.getOutDimSizes().begin()[axis] == 1;
113+
};
114+
auto regLl = reducedRegLaneLayout(srcTy, axis);
115+
116+
// All the inputs have the same layout so, since we order them from largest
117+
// bitsize to smallest, and the first one is aligned, by induction, they are
118+
// all aligned, so we don't need to align the byte numbers returned here.
119+
unsigned bytesRegToTmp = 0;
120+
while (!isReduced(regLl)) {
121+
auto tmpLl = getInterLayout(regLl, axis);
122+
// We take the maximum of the elements and multiply by the total bitwidth.
123+
// We do this as otherwise it's quite tricky to find the correct
124+
// BaseOffsets in the lowering.
125+
int bytes = 0;
126+
for (auto inputTy : op.getInputTypes()) {
127+
auto nelem =
128+
getNumScratchElemsSwizzledCvt(regLl, tmpLl, getBitwidth(inputTy));
129+
bytes += nelem * (getBitwidth(inputTy) / 8);
130+
}
131+
bytesRegToTmp = std::max<unsigned>(bytesRegToTmp, bytes);
132+
regLl = zeroBasesAlongDimAndReorder(tmpLl, axis, kLane);
133+
}
134+
return bytesRegToTmp;
135+
}
136+
112137
ReduceOpHelper::InThreadVectorizeOpKind
113138
ReduceOpHelper::getInThreadVectorizeOpKind(unsigned axisPack) {
114139
Operation *reduceOperation = op.getOperation();
@@ -291,26 +316,90 @@ LinearLayout ReduceOpHelper::getInterLayout(const LinearLayout &layout,
291316
auto *ctx = layout.getOutDimNames().begin()->getContext();
292317
auto kLane = mlir::StringAttr::get(ctx, "lane");
293318
auto kWarp = mlir::StringAttr::get(ctx, "warp");
294-
auto regBases = layout.getBases();
295-
auto linearAttr = triton::gpu::LinearEncodingAttr::get(ctx, layout);
296-
int laneBits = layout.getInDimSizeLog2(kLane);
297-
int neededLaneBits = llvm::Log2_32(linearAttr.getWarpsPerCTA()[axis]);
298-
// TODO move to verifier
299-
assert(neededLaneBits <= laneBits && "NYI: more inter-warps than lanes");
300-
// Move the warp axis bases we need to reduce into lane bases, while
301-
// keeping non-axis components in their original in-dim.
302-
auto &laneBases = regBases[kLane];
303-
auto &warpBases = regBases[kWarp];
304-
int moved = 0;
305-
for (auto &warpBasis : warpBases) {
306-
if (warpBasis[axis] == 0)
307-
continue;
308-
assert(moved < neededLaneBits && "unexpected warp axis bases count");
309-
std::swap(laneBases[moved], warpBasis);
310-
moved++;
319+
auto kBlock = mlir::StringAttr::get(ctx, "block");
320+
auto bases = layout.getBases();
321+
auto &laneBases = bases[kLane];
322+
auto &warpBases = bases[kWarp];
323+
auto &blockBases = bases[kBlock];
324+
325+
auto collectAxisBases = [&](ArrayRef<std::vector<int32_t>> bases) {
326+
SmallVector<unsigned> out;
327+
for (unsigned i = 0; i < bases.size(); ++i) {
328+
if (bases[i][axis] != 0)
329+
out.push_back(i);
330+
}
331+
return out;
332+
};
333+
334+
SmallVector<unsigned> warpAxisBases = collectAxisBases(warpBases);
335+
SmallVector<unsigned> blockAxisBases = collectAxisBases(blockBases);
336+
337+
SmallVector<unsigned> zeroLaneBases;
338+
for (unsigned i = 0; i < laneBases.size(); ++i) {
339+
if (llvm::all_of(laneBases[i], [](int32_t v) { return v == 0; }))
340+
zeroLaneBases.push_back(i);
311341
}
312342

313-
return LinearLayout(std::move(regBases), to_vector(layout.getOutDimNames()));
343+
auto axisSize = to_vector(layout.getOutDimSizes())[axis];
344+
auto totalAxisBases = warpAxisBases.size() + blockAxisBases.size();
345+
346+
// First try to place all warp/block axis bases into lane bases that are
347+
// currently zero. If we can do this we will be able to perform the full
348+
// reduction with just one convert_layout
349+
if (zeroLaneBases.size() >= totalAxisBases) {
350+
unsigned laneIdx = 0;
351+
for (unsigned idx : warpAxisBases) {
352+
std::swap(laneBases[zeroLaneBases[laneIdx]], warpBases[idx]);
353+
++laneIdx;
354+
}
355+
for (unsigned idx : blockAxisBases) {
356+
std::swap(laneBases[zeroLaneBases[laneIdx]], blockBases[idx]);
357+
++laneIdx;
358+
}
359+
return LinearLayout(std::move(bases), to_vector(layout.getOutDimNames()));
360+
}
361+
362+
// If we can fit all the bases inside the lane dimension, we can perform the
363+
// reduction with two convert_layouts
364+
// The first cvt to move the relevant bases to the lane dimension
365+
// The second to move all the bases we moved out of the lane dimension back to
366+
// their original positions
367+
if (warpAxisBases.size() + blockAxisBases.size() <= laneBases.size()) {
368+
assert(totalAxisBases <= laneBases.size() &&
369+
"unexpected lane base count for axis layout");
370+
unsigned laneIdx = 0;
371+
for (unsigned idx : warpAxisBases) {
372+
std::swap(laneBases[laneIdx], warpBases[idx]);
373+
++laneIdx;
374+
}
375+
for (unsigned idx : blockAxisBases) {
376+
std::swap(laneBases[laneIdx], blockBases[idx]);
377+
++laneIdx;
378+
}
379+
return LinearLayout(std::move(bases), to_vector(layout.getOutDimNames()));
380+
}
381+
382+
// Assumptions (easily relaxed if AMD needs it)
383+
// We assume that
384+
// max number of warps * max number of blocks <= (max number of lanes)^2
385+
// We check this in logarithmic space (number of bases)
386+
// This is true in nvidia as the max numbers are warps=64 ctas=16 so that
387+
// 64 * 16 = 1024 = 32 * 32 = laneBases.size() * laneBases.size()
388+
// This implies that, even if we have to perform 3 cvt_layouts, we can perform
389+
// first one that does not cross CTAs, and then two that may cross CTAs
390+
assert(blockBases.size() <= laneBases.size());
391+
assert(warpBases.size() + blockBases.size() <= 2 * laneBases.size());
392+
393+
// Otherwise, fit as many warp bases as possible into the lane dimension
394+
unsigned laneIdx = 0;
395+
for (unsigned idx : warpAxisBases) {
396+
std::swap(laneBases[laneIdx], warpBases[idx]);
397+
++laneIdx;
398+
if (laneIdx >= laneBases.size())
399+
break;
400+
}
401+
402+
return LinearLayout(std::move(bases), to_vector(layout.getOutDimNames()));
314403
}
315404

316405
LinearLayout ReduceOpHelper::reducedRegLaneLayout(RankedTensorType srcTy,
@@ -320,9 +409,7 @@ LinearLayout ReduceOpHelper::reducedRegLaneLayout(RankedTensorType srcTy,
320409
auto kLane = StringAttr::get(ctx, "lane");
321410
auto kWarp = StringAttr::get(ctx, "warp");
322411

323-
auto reduced = triton::gpu::toLinearLayout(srcTy);
324-
reduced = reduced.sublayout({kReg, kLane, kWarp},
325-
to_vector(reduced.getOutDimNames()));
412+
auto reduced = toLinearLayout(srcTy);
326413
reduced = actionRemoveBroadcastedRegs(reduced).apply(reduced);
327414

328415
reduced = moveAxisBasesToFront(reduced, axis).apply(reduced);
@@ -332,32 +419,6 @@ LinearLayout ReduceOpHelper::reducedRegLaneLayout(RankedTensorType srcTy,
332419
return reduced;
333420
}
334421

335-
SmallVector<unsigned>
336-
ReduceOpHelper::getScratchBytesForCvt(const LinearLayout &srcLayout,
337-
const LinearLayout &dstLayout) {
338-
SmallVector<unsigned> bytes(srcElementTypes.size(), 0);
339-
auto *ctx = op.getContext();
340-
SmallVector<int64_t> shape;
341-
shape.reserve(srcLayout.getNumOutDims());
342-
for (auto dim : srcLayout.getOutDimNames()) {
343-
shape.push_back(srcLayout.getOutDimSize(dim));
344-
}
345-
auto srcEnc = triton::gpu::LinearEncodingAttr::get(ctx, srcLayout);
346-
auto dstEnc = triton::gpu::LinearEncodingAttr::get(ctx, dstLayout);
347-
for (unsigned i = 0; i < srcElementTypes.size(); ++i) {
348-
auto elemTy = srcElementTypes[i];
349-
if (elemTy.isIntOrFloat() && elemTy.getIntOrFloatBitWidth() < 8)
350-
elemTy = IntegerType::get(ctx, 8);
351-
auto srcTy = RankedTensorType::get(shape, elemTy, srcEnc);
352-
auto dstTy = RankedTensorType::get(shape, elemTy, dstEnc);
353-
if (!cvtNeedsSharedMemory(srcTy, dstTy))
354-
continue;
355-
auto elems = getNumScratchElemsSwizzledCvt(srcTy, dstTy);
356-
bytes[i] = elems * getBitwidth(srcTy) / 8;
357-
}
358-
return bytes;
359-
}
360-
361422
ScanLoweringHelper::ScanLoweringHelper(triton::ScanOp op) : scanOp(op) {
362423
auto firstTy = cast<RankedTensorType>(op.getOperands()[0].getType());
363424
srcShape = firstTy.getShape();

0 commit comments

Comments
 (0)