Skip to content

Commit bb75a87

Browse files
authored
[BACKEND] Implement support for cross-CTA tt.reduce (#9221)
Stacked PRs: * #9327 * #9318 * #9317 * __->__#9221 --- --- --- ### [BACKEND] Implement support for cross-CTA tt.reduce The title of this PR is a bit of a lie. Even though the lowering is now implemented to support cross-CTA reductions, it depends on `convert_layout` supporting them, and it doesn't currently support LinearLayouts. We should generalise this one first and then enable it here. We should also emit the correct cross-CTA barrier from `targetInfo` in the case of cross-CTA memory reuse. In this PR, we take the chance to also generalise the lowering to avoid convert layouts whenever possible.
1 parent ba570e1 commit bb75a87

7 files changed

Lines changed: 286 additions & 196 deletions

File tree

include/triton/Analysis/Allocation.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,10 @@ using AllocationAnalysisScratchSizeFn = std::function<unsigned(Operation *)>;
2020

2121
unsigned defaultAllocationAnalysisScratchSizeFn(Operation *op);
2222

23+
unsigned getNumScratchElemsSwizzledCvt(const LinearLayout &srcLayout,
24+
const LinearLayout &dstLayout,
25+
int bitwidth);
26+
2327
unsigned getNumScratchElemsSwizzledCvt(RankedTensorType srcTy,
2428
RankedTensorType dstTy);
2529

include/triton/Analysis/Utility.h

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,6 @@ class ReduceOpHelper {
6161

6262
RankedTensorType getSrcTy() { return srcTy; }
6363

64-
bool isWarpSynchronous();
65-
6664
unsigned getInterWarpSizeWithUniqueData();
6765

6866
unsigned getIntraWarpSizeWithUniqueData();
@@ -71,6 +69,8 @@ class ReduceOpHelper {
7169

7270
bool isAssociative();
7371

72+
unsigned getScratchSizeInBytes();
73+
7474
InThreadVectorizeOpKind
7575
getInThreadVectorizeOpKind(unsigned axisPack,
7676
bool supportBitwidth16Elementwise,
@@ -95,10 +95,6 @@ class ReduceOpHelper {
9595
InThreadVectorizeOpKind kind,
9696
Value lhs, Value rhs);
9797

98-
SmallVector<unsigned>
99-
getScratchBytesForCvt(const triton::LinearLayout &srcLayout,
100-
const triton::LinearLayout &dstLayout);
101-
10298
private:
10399
triton::ReduceOp op;
104100
RankedTensorType srcTy;

lib/Analysis/Allocation.cpp

Lines changed: 16 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -38,49 +38,27 @@ namespace mlir {
3838
//===----------------------------------------------------------------------===//
3939
namespace triton {
4040

41-
unsigned getNumScratchElemsSwizzledCvt(RankedTensorType srcTy,
42-
RankedTensorType dstTy) {
43-
auto *ctx = srcTy.getContext();
44-
auto srcLayout = gpu::toLinearLayout(srcTy);
45-
auto dstLayout = gpu::toLinearLayout(dstTy);
46-
srcLayout = actionRemoveBroadcastedRegs(srcLayout).apply(srcLayout);
47-
dstLayout = actionRemoveBroadcastedRegs(dstLayout).apply(dstLayout);
48-
auto bitwidth = getBitwidth(srcTy);
49-
auto smem = gpu::optimalSwizzlingLdSt(srcLayout, dstLayout, bitwidth);
41+
unsigned getNumScratchElemsSwizzledCvt(const LinearLayout &srcLayout,
42+
const LinearLayout &dstLayout,
43+
int bitwidth) {
44+
auto *ctx = srcLayout.getInDimNames().begin()->getContext();
45+
auto srcLayoutNoBroadcast =
46+
actionRemoveBroadcastedRegs(srcLayout).apply(srcLayout);
47+
auto dstLayoutNoBroadcast =
48+
actionRemoveBroadcastedRegs(dstLayout).apply(dstLayout);
49+
auto smem = gpu::optimalSwizzlingLdSt(srcLayoutNoBroadcast,
50+
dstLayoutNoBroadcast, bitwidth);
5051
auto reps = smem.getInDimSize(StringAttr::get(ctx, "reps"));
5152
return smem.getTotalOutDimSize() / reps;
5253
}
5354

54-
namespace {
55-
constexpr int64_t kReduceScratchAlign = 16;
56-
57-
Type getReduceMemElemTy(Type elemTy, MLIRContext *ctx) {
58-
if (elemTy.isIntOrFloat() && elemTy.getIntOrFloatBitWidth() < 8)
59-
return IntegerType::get(ctx, 8);
60-
return elemTy;
61-
}
62-
63-
int64_t getReduceScratchSizeBytes(triton::ReduceOp op,
64-
ArrayRef<unsigned> bytesPerOperand) {
65-
std::vector<unsigned> indices(op.getNumOperands());
66-
std::iota(indices.begin(), indices.end(), 0);
67-
auto *ctx = op.getContext();
68-
std::sort(indices.begin(), indices.end(), [&](unsigned i, unsigned j) {
69-
auto lhsTy = getReduceMemElemTy(op.getElementTypes()[i], ctx);
70-
auto rhsTy = getReduceMemElemTy(op.getElementTypes()[j], ctx);
71-
return getIntOrFloatOrPtrBitWidth(lhsTy) >
72-
getIntOrFloatOrPtrBitWidth(rhsTy);
73-
});
74-
// Aling to 16 bytes to allow for vectorisation
75-
int64_t offset = 0;
76-
for (unsigned idx : indices) {
77-
offset += llvm::alignTo(bytesPerOperand[idx], kReduceScratchAlign);
78-
}
79-
return offset;
55+
unsigned getNumScratchElemsSwizzledCvt(RankedTensorType srcTy,
56+
RankedTensorType dstTy) {
57+
return getNumScratchElemsSwizzledCvt(gpu::toLinearLayout(srcTy),
58+
gpu::toLinearLayout(dstTy),
59+
getBitwidth(srcTy));
8060
}
8161

82-
} // namespace
83-
8462
// Both `atomic_cas` and `atomic_rmw` may need scratch memory to store values
8563
// because Triton's block-based programming model ensures that
8664
// all threads sharing the same partition of the tensor see the same values,
@@ -107,15 +85,7 @@ static SmallVector<unsigned> getRepShapeForAtomic(Value result) {
10785

10886
unsigned defaultAllocationAnalysisScratchSizeFn(Operation *op) {
10987
if (auto reduceOp = dyn_cast<ReduceOp>(op)) {
110-
ReduceOpHelper helper(reduceOp);
111-
if (helper.isWarpSynchronous())
112-
return 0;
113-
114-
auto regLl = ReduceOpHelper::reducedRegLaneLayout(helper.getSrcTy(),
115-
reduceOp.getAxis());
116-
auto tmpLl = ReduceOpHelper::getInterLayout(regLl, reduceOp.getAxis());
117-
auto bytesRegToTmp = helper.getScratchBytesForCvt(regLl, tmpLl);
118-
return getReduceScratchSizeBytes(reduceOp, bytesRegToTmp);
88+
return ReduceOpHelper(reduceOp).getScratchSizeInBytes();
11989
}
12090
if (auto scanOp = dyn_cast<ScanOp>(op)) {
12191
ScanLoweringHelper helper(scanOp);

lib/Analysis/Utility.cpp

Lines changed: 112 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -78,10 +78,6 @@ unsigned ReduceOpHelper::getIntraWarpSizeWithUniqueData() {
7878
return getThreadsPerWarp(srcEncoding, srcShape)[axis];
7979
}
8080

81-
bool ReduceOpHelper::isWarpSynchronous() {
82-
return getWarpsPerCTA(srcEncoding, srcShape)[axis] == 1;
83-
}
84-
8581
bool ReduceOpHelper::isReduceWithinCTA() {
8682
// TODO: Support reduce across CTAS
8783
// Layout optimization passes such as PlanCTAPass and
@@ -109,6 +105,35 @@ bool ReduceOpHelper::isAssociative() {
109105
return !hasNoAssociativeOp;
110106
}
111107

108+
unsigned ReduceOpHelper::getScratchSizeInBytes() {
109+
auto kLane = StringAttr::get(op.getContext(), "lane");
110+
111+
auto isReduced = [axis = axis](const LinearLayout &layout) {
112+
return layout.getOutDimSizes().begin()[axis] == 1;
113+
};
114+
auto regLl = reducedRegLaneLayout(srcTy, axis);
115+
116+
// All the inputs have the same layout so, since we order them from largest
117+
// bitsize to smallest, and the first one is aligned, by induction, they are
118+
// all aligned, so we don't need to align the byte numbers returned here.
119+
unsigned bytesRegToTmp = 0;
120+
while (!isReduced(regLl)) {
121+
auto tmpLl = getInterLayout(regLl, axis);
122+
// We take the maximum of the elements and multiply by the total bitwidth.
123+
// We do this as otherwise it's quite tricky to find the correct
124+
// BaseOffsets in the lowering.
125+
int bytes = 0;
126+
for (auto inputTy : op.getInputTypes()) {
127+
auto nelem =
128+
getNumScratchElemsSwizzledCvt(regLl, tmpLl, getBitwidth(inputTy));
129+
bytes += nelem * (getBitwidth(inputTy) / 8);
130+
}
131+
bytesRegToTmp = std::max<unsigned>(bytesRegToTmp, bytes);
132+
regLl = zeroBasesAlongDimAndReorder(tmpLl, axis, kLane);
133+
}
134+
return bytesRegToTmp;
135+
}
136+
112137
ReduceOpHelper::InThreadVectorizeOpKind
113138
ReduceOpHelper::getInThreadVectorizeOpKind(unsigned axisPack,
114139
bool supportBitwidth16Elementwise,
@@ -298,26 +323,90 @@ LinearLayout ReduceOpHelper::getInterLayout(const LinearLayout &layout,
298323
auto *ctx = layout.getOutDimNames().begin()->getContext();
299324
auto kLane = mlir::StringAttr::get(ctx, "lane");
300325
auto kWarp = mlir::StringAttr::get(ctx, "warp");
301-
auto regBases = layout.getBases();
302-
auto linearAttr = triton::gpu::LinearEncodingAttr::get(ctx, layout);
303-
int laneBits = layout.getInDimSizeLog2(kLane);
304-
int neededLaneBits = llvm::Log2_32(linearAttr.getWarpsPerCTA()[axis]);
305-
// TODO move to verifier
306-
assert(neededLaneBits <= laneBits && "NYI: more inter-warps than lanes");
307-
// Move the warp axis bases we need to reduce into lane bases, while
308-
// keeping non-axis components in their original in-dim.
309-
auto &laneBases = regBases[kLane];
310-
auto &warpBases = regBases[kWarp];
311-
int moved = 0;
312-
for (auto &warpBasis : warpBases) {
313-
if (warpBasis[axis] == 0)
314-
continue;
315-
assert(moved < neededLaneBits && "unexpected warp axis bases count");
316-
std::swap(laneBases[moved], warpBasis);
317-
moved++;
326+
auto kBlock = mlir::StringAttr::get(ctx, "block");
327+
auto bases = layout.getBases();
328+
auto &laneBases = bases[kLane];
329+
auto &warpBases = bases[kWarp];
330+
auto &blockBases = bases[kBlock];
331+
332+
auto collectAxisBases = [&](ArrayRef<std::vector<int32_t>> bases) {
333+
SmallVector<unsigned> out;
334+
for (unsigned i = 0; i < bases.size(); ++i) {
335+
if (bases[i][axis] != 0)
336+
out.push_back(i);
337+
}
338+
return out;
339+
};
340+
341+
SmallVector<unsigned> warpAxisBases = collectAxisBases(warpBases);
342+
SmallVector<unsigned> blockAxisBases = collectAxisBases(blockBases);
343+
344+
SmallVector<unsigned> zeroLaneBases;
345+
for (unsigned i = 0; i < laneBases.size(); ++i) {
346+
if (llvm::all_of(laneBases[i], [](int32_t v) { return v == 0; }))
347+
zeroLaneBases.push_back(i);
318348
}
319349

320-
return LinearLayout(std::move(regBases), to_vector(layout.getOutDimNames()));
350+
auto axisSize = to_vector(layout.getOutDimSizes())[axis];
351+
auto totalAxisBases = warpAxisBases.size() + blockAxisBases.size();
352+
353+
// First try to place all warp/block axis bases into lane bases that are
354+
// currently zero. If we can do this we will be able to perform the full
355+
// reduction with just one convert_layout
356+
if (zeroLaneBases.size() >= totalAxisBases) {
357+
unsigned laneIdx = 0;
358+
for (unsigned idx : warpAxisBases) {
359+
std::swap(laneBases[zeroLaneBases[laneIdx]], warpBases[idx]);
360+
++laneIdx;
361+
}
362+
for (unsigned idx : blockAxisBases) {
363+
std::swap(laneBases[zeroLaneBases[laneIdx]], blockBases[idx]);
364+
++laneIdx;
365+
}
366+
return LinearLayout(std::move(bases), to_vector(layout.getOutDimNames()));
367+
}
368+
369+
// If we can fit all the bases inside the lane dimension, we can perform the
370+
// reduction with two convert_layouts
371+
// The first cvt to move the relevant bases to the lane dimension
372+
// The second to move all the bases we moved out of the lane dimension back to
373+
// their original positions
374+
if (warpAxisBases.size() + blockAxisBases.size() <= laneBases.size()) {
375+
assert(totalAxisBases <= laneBases.size() &&
376+
"unexpected lane base count for axis layout");
377+
unsigned laneIdx = 0;
378+
for (unsigned idx : warpAxisBases) {
379+
std::swap(laneBases[laneIdx], warpBases[idx]);
380+
++laneIdx;
381+
}
382+
for (unsigned idx : blockAxisBases) {
383+
std::swap(laneBases[laneIdx], blockBases[idx]);
384+
++laneIdx;
385+
}
386+
return LinearLayout(std::move(bases), to_vector(layout.getOutDimNames()));
387+
}
388+
389+
// Assumptions (easily relaxed if AMD needs it)
390+
// We assume that
391+
// max number of warps * max number of blocks <= (max number of lanes)^2
392+
// We check this in logarithmic space (number of bases)
393+
// This is true in nvidia as the max numbers are warps=64 ctas=16 so that
394+
// 64 * 16 = 1024 = 32 * 32 = laneBases.size() * laneBases.size()
395+
// This implies that, even if we have to perform 3 cvt_layouts, we can perform
396+
// first one that does not cross CTAs, and then two that may cross CTAs
397+
assert(blockBases.size() <= laneBases.size());
398+
assert(warpBases.size() + blockBases.size() <= 2 * laneBases.size());
399+
400+
// Otherwise, fit as many warp bases as possible into the lane dimension
401+
unsigned laneIdx = 0;
402+
for (unsigned idx : warpAxisBases) {
403+
std::swap(laneBases[laneIdx], warpBases[idx]);
404+
++laneIdx;
405+
if (laneIdx >= laneBases.size())
406+
break;
407+
}
408+
409+
return LinearLayout(std::move(bases), to_vector(layout.getOutDimNames()));
321410
}
322411

323412
LinearLayout ReduceOpHelper::reducedRegLaneLayout(RankedTensorType srcTy,
@@ -327,9 +416,7 @@ LinearLayout ReduceOpHelper::reducedRegLaneLayout(RankedTensorType srcTy,
327416
auto kLane = StringAttr::get(ctx, "lane");
328417
auto kWarp = StringAttr::get(ctx, "warp");
329418

330-
auto reduced = triton::gpu::toLinearLayout(srcTy);
331-
reduced = reduced.sublayout({kReg, kLane, kWarp},
332-
to_vector(reduced.getOutDimNames()));
419+
auto reduced = toLinearLayout(srcTy);
333420
reduced = actionRemoveBroadcastedRegs(reduced).apply(reduced);
334421

335422
reduced = moveAxisBasesToFront(reduced, axis).apply(reduced);
@@ -339,32 +426,6 @@ LinearLayout ReduceOpHelper::reducedRegLaneLayout(RankedTensorType srcTy,
339426
return reduced;
340427
}
341428

342-
SmallVector<unsigned>
343-
ReduceOpHelper::getScratchBytesForCvt(const LinearLayout &srcLayout,
344-
const LinearLayout &dstLayout) {
345-
SmallVector<unsigned> bytes(srcElementTypes.size(), 0);
346-
auto *ctx = op.getContext();
347-
SmallVector<int64_t> shape;
348-
shape.reserve(srcLayout.getNumOutDims());
349-
for (auto dim : srcLayout.getOutDimNames()) {
350-
shape.push_back(srcLayout.getOutDimSize(dim));
351-
}
352-
auto srcEnc = triton::gpu::LinearEncodingAttr::get(ctx, srcLayout);
353-
auto dstEnc = triton::gpu::LinearEncodingAttr::get(ctx, dstLayout);
354-
for (unsigned i = 0; i < srcElementTypes.size(); ++i) {
355-
auto elemTy = srcElementTypes[i];
356-
if (elemTy.isIntOrFloat() && elemTy.getIntOrFloatBitWidth() < 8)
357-
elemTy = IntegerType::get(ctx, 8);
358-
auto srcTy = RankedTensorType::get(shape, elemTy, srcEnc);
359-
auto dstTy = RankedTensorType::get(shape, elemTy, dstEnc);
360-
if (!cvtNeedsSharedMemory(srcTy, dstTy))
361-
continue;
362-
auto elems = getNumScratchElemsSwizzledCvt(srcTy, dstTy);
363-
bytes[i] = elems * getBitwidth(srcTy) / 8;
364-
}
365-
return bytes;
366-
}
367-
368429
ScanLoweringHelper::ScanLoweringHelper(triton::ScanOp op) : scanOp(op) {
369430
auto firstTy = cast<RankedTensorType>(op.getOperands()[0].getType());
370431
srcShape = firstTy.getShape();

0 commit comments

Comments
 (0)