Skip to content

Commit eb2406c

Browse files
committed
[BACKEND] Implement support for cross-CTA tt.reduce
The title of this PR is a bit of a lie. Even though the lowering is now implemented to support cross-CTA reductions, it depends on `convert_layout` supporting them, and it doesn't currently support LinearLayouts. We should generalise this one first and then enable it here. We should also emit the correct cross-CTA barrier from `targetInfo` in the case of cross-CTA memory reuse. In this PR, we take the chance to also generalise the lowering to avoid convert layouts whenever possible. stack-info: PR: #9221, branch: lezcano/stack/8
1 parent 7ce04be commit eb2406c

8 files changed

Lines changed: 292 additions & 198 deletions

File tree

include/triton/Analysis/Allocation.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,9 @@ using AllocationAnalysisScratchSizeFn = std::function<unsigned(Operation *)>;
2020

2121
unsigned defaultAllocationAnalysisScratchSizeFn(Operation *op);
2222

23-
unsigned getNumScratchElemsSwizzledCvt(RankedTensorType srcTy,
24-
RankedTensorType dstTy);
23+
unsigned getNumScratchElemsSwizzledCvt(const LinearLayout &srcLayout,
24+
const LinearLayout &dstLayout,
25+
int bitwidth);
2526

2627
} // namespace triton
2728

include/triton/Analysis/Utility.h

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,6 @@ class ReduceOpHelper {
4444

4545
RankedTensorType getSrcTy() { return srcTy; }
4646

47-
bool isWarpSynchronous();
48-
4947
unsigned getInterWarpSizeWithUniqueData();
5048

5149
unsigned getIntraWarpSizeWithUniqueData();
@@ -67,10 +65,6 @@ class ReduceOpHelper {
6765
static triton::LinearLayout reducedRegLaneLayout(RankedTensorType srcTy,
6866
unsigned axis);
6967

70-
SmallVector<unsigned>
71-
getScratchBytesForCvt(const triton::LinearLayout &srcLayout,
72-
const triton::LinearLayout &dstLayout);
73-
7468
private:
7569
triton::ReduceOp op;
7670
RankedTensorType srcTy;

lib/Analysis/Allocation.cpp

Lines changed: 41 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -31,49 +31,20 @@ namespace mlir {
3131
//===----------------------------------------------------------------------===//
3232
namespace triton {
3333

34-
unsigned getNumScratchElemsSwizzledCvt(RankedTensorType srcTy,
35-
RankedTensorType dstTy) {
36-
auto *ctx = srcTy.getContext();
37-
auto srcLayout = gpu::toLinearLayout(srcTy);
38-
auto dstLayout = gpu::toLinearLayout(dstTy);
39-
srcLayout = actionRemoveBroadcastedRegs(srcLayout).apply(srcLayout);
40-
dstLayout = actionRemoveBroadcastedRegs(dstLayout).apply(dstLayout);
41-
auto bitwidth = getBitwidth(srcTy);
42-
auto smem = gpu::optimalSwizzlingLdSt(srcLayout, dstLayout, bitwidth);
34+
unsigned getNumScratchElemsSwizzledCvt(const LinearLayout &srcLayout,
35+
const LinearLayout &dstLayout,
36+
int bitwidth) {
37+
auto *ctx = srcLayout.getInDimNames().begin()->getContext();
38+
auto srcLayoutNoBroadcast =
39+
actionRemoveBroadcastedRegs(srcLayout).apply(srcLayout);
40+
auto dstLayoutNoBroadcast =
41+
actionRemoveBroadcastedRegs(dstLayout).apply(dstLayout);
42+
auto smem = gpu::optimalSwizzlingLdSt(srcLayoutNoBroadcast,
43+
dstLayoutNoBroadcast, bitwidth);
4344
auto reps = smem.getInDimSize(StringAttr::get(ctx, "reps"));
4445
return smem.getTotalOutDimSize() / reps;
4546
}
4647

47-
namespace {
48-
constexpr int64_t kReduceScratchAlign = 16;
49-
50-
Type getReduceMemElemTy(Type elemTy, MLIRContext *ctx) {
51-
if (elemTy.isIntOrFloat() && elemTy.getIntOrFloatBitWidth() < 8)
52-
return IntegerType::get(ctx, 8);
53-
return elemTy;
54-
}
55-
56-
int64_t getReduceScratchSizeBytes(triton::ReduceOp op,
57-
ArrayRef<unsigned> bytesPerOperand) {
58-
std::vector<unsigned> indices(op.getNumOperands());
59-
std::iota(indices.begin(), indices.end(), 0);
60-
auto *ctx = op.getContext();
61-
std::sort(indices.begin(), indices.end(), [&](unsigned i, unsigned j) {
62-
auto lhsTy = getReduceMemElemTy(op.getElementTypes()[i], ctx);
63-
auto rhsTy = getReduceMemElemTy(op.getElementTypes()[j], ctx);
64-
return getIntOrFloatOrPtrBitWidth(lhsTy) >
65-
getIntOrFloatOrPtrBitWidth(rhsTy);
66-
});
67-
// Aling to 16 bytes to allow for vectorisation
68-
int64_t offset = 0;
69-
for (unsigned idx : indices) {
70-
offset += llvm::alignTo(bytesPerOperand[idx], kReduceScratchAlign);
71-
}
72-
return offset;
73-
}
74-
75-
} // namespace
76-
7748
// Both `atomic_cas` and `atomic_rmw` may need scratch memory to store values
7849
// because Triton's block-based programming model ensures that
7950
// all threads sharing the same partition of the tensor see the same values,
@@ -100,15 +71,35 @@ static SmallVector<unsigned> getRepShapeForAtomic(Value result) {
10071

10172
unsigned defaultAllocationAnalysisScratchSizeFn(Operation *op) {
10273
if (auto reduceOp = dyn_cast<ReduceOp>(op)) {
103-
ReduceOpHelper helper(reduceOp);
104-
if (helper.isWarpSynchronous())
105-
return 0;
74+
auto srcTy = ReduceOpHelper(reduceOp).getSrcTy();
75+
auto axis = reduceOp.getAxis();
76+
auto kLane = StringAttr::get(reduceOp.getContext(), "lane");
10677

107-
auto regLl = ReduceOpHelper::reducedRegLaneLayout(helper.getSrcTy(),
108-
reduceOp.getAxis());
109-
auto tmpLl = ReduceOpHelper::getInterLayout(regLl, reduceOp.getAxis());
110-
auto bytesRegToTmp = helper.getScratchBytesForCvt(regLl, tmpLl);
111-
return getReduceScratchSizeBytes(reduceOp, bytesRegToTmp);
78+
auto isReduced = [axis](const LinearLayout &layout) {
79+
return layout.getOutDimSizes().begin()[axis] == 1;
80+
};
81+
auto regLl =
82+
ReduceOpHelper::reducedRegLaneLayout(srcTy, reduceOp.getAxis());
83+
84+
// All the inputs have the same layout so, since we order them from largest
85+
// bitsize to smallest, and the first one is aligned, by induction, they are
86+
// all aligned, so we don't need to align the byte numbers returned here.
87+
auto bytesRegToTmp = 0;
88+
while (!isReduced(regLl)) {
89+
auto tmpLl = ReduceOpHelper::getInterLayout(regLl, axis);
90+
// We take the maximum of the elements and multiply by the total bitwidth
91+
// We do this as otherwise it's quite tricky to find the correct
92+
// BaseOffsets in the lowering
93+
int bytes = 0;
94+
for (auto inputTy : reduceOp.getInputTypes()) {
95+
auto nelem =
96+
getNumScratchElemsSwizzledCvt(regLl, tmpLl, getBitwidth(inputTy));
97+
bytes += nelem * (getBitwidth(inputTy) / 8);
98+
}
99+
bytesRegToTmp = std::max<unsigned>(bytesRegToTmp, bytes);
100+
regLl = ReduceOpHelper::zeroBasesAlongDimAndReorder(tmpLl, axis, kLane);
101+
}
102+
return bytesRegToTmp;
112103
}
113104
if (auto scanOp = dyn_cast<ScanOp>(op)) {
114105
ScanLoweringHelper helper(scanOp);
@@ -131,7 +122,9 @@ unsigned defaultAllocationAnalysisScratchSizeFn(Operation *op) {
131122
if (!cvtNeedsSharedMemory(srcTy, dstTy))
132123
return 0;
133124
// The generic pass uses swizzling
134-
auto elems = getNumScratchElemsSwizzledCvt(srcTy, dstTy);
125+
auto elems = getNumScratchElemsSwizzledCvt(gpu::toLinearLayout(srcTy),
126+
gpu::toLinearLayout(dstTy),
127+
getBitwidth(srcTy));
135128
return elems * getBitwidth(srcTy) / 8;
136129
}
137130
if (isa<AtomicRMWOp, AtomicCASOp>(op)) {

lib/Analysis/Utility.cpp

Lines changed: 91 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -76,10 +76,6 @@ unsigned ReduceOpHelper::getIntraWarpSizeWithUniqueData() {
7676
return getThreadsPerWarp(srcEncoding, srcShape)[axis];
7777
}
7878

79-
bool ReduceOpHelper::isWarpSynchronous() {
80-
return getWarpsPerCTA(srcEncoding, srcShape)[axis] == 1;
81-
}
82-
8379
bool ReduceOpHelper::isReduceWithinCTA() {
8480
// TODO: Support reduce across CTAS
8581
// Layout optimization passes such as PlanCTAPass and
@@ -155,23 +151,97 @@ LinearLayout ReduceOpHelper::getInterLayout(const LinearLayout &layout,
155151
auto *ctx = layout.getOutDimNames().begin()->getContext();
156152
auto kLane = mlir::StringAttr::get(ctx, "lane");
157153
auto kWarp = mlir::StringAttr::get(ctx, "warp");
154+
auto kBlock = mlir::StringAttr::get(ctx, "block");
158155
auto regBases = layout.getBases();
159-
auto linearAttr = triton::gpu::LinearEncodingAttr::get(ctx, layout);
160-
int laneBits = layout.getInDimSizeLog2(kLane);
161-
int neededLaneBits = llvm::Log2_32(linearAttr.getWarpsPerCTA()[axis]);
162-
// TODO move to verifier
163-
assert(neededLaneBits <= laneBits && "NYI: more inter-warps than lanes");
164-
// Move the warp axis bases we need to reduce into lane bases, while
165-
// keeping non-axis components in their original in-dim.
166-
auto &laneBases = regBases[kLane];
167-
auto &warpBases = regBases[kWarp];
168-
int moved = 0;
169-
for (auto &warpBasis : warpBases) {
170-
if (warpBasis[axis] == 0)
171-
continue;
172-
assert(moved < neededLaneBits && "unexpected warp axis bases count");
173-
std::swap(laneBases[moved], warpBasis);
174-
moved++;
156+
auto laneIt = regBases.find(kLane);
157+
auto warpIt = regBases.find(kWarp);
158+
auto blockIt = regBases.find(kBlock);
159+
if (laneIt == regBases.end() || warpIt == regBases.end()) {
160+
return layout;
161+
}
162+
163+
auto &laneBases = laneIt->second;
164+
auto &warpBases = warpIt->second;
165+
auto &blockBases = blockIt->second;
166+
167+
auto collectAxisBases = [&](const std::vector<std::vector<int32_t>> &bases,
168+
SmallVector<unsigned> &out) {
169+
for (unsigned i = 0; i < bases.size(); ++i) {
170+
if (bases[i][axis] != 0)
171+
out.push_back(i);
172+
}
173+
};
174+
175+
SmallVector<unsigned> warpAxisBases;
176+
collectAxisBases(warpBases, warpAxisBases);
177+
SmallVector<unsigned> blockAxisBases;
178+
collectAxisBases(blockBases, blockAxisBases);
179+
180+
SmallVector<unsigned> zeroLaneBases;
181+
for (unsigned i = 0; i < laneBases.size(); ++i) {
182+
if (llvm::all_of(laneBases[i], [](int32_t v) { return v == 0; }))
183+
zeroLaneBases.push_back(i);
184+
}
185+
186+
auto axisSize = to_vector(layout.getOutDimSizes())[axis];
187+
auto totalAxisBases = warpAxisBases.size() + blockAxisBases.size();
188+
189+
// First try to place all warp/block axis bases into lane bases that are
190+
// currently zero. If we can do this we will be able to perform the full
191+
// reduction with just one convert_layout
192+
if (zeroLaneBases.size() >= totalAxisBases) {
193+
unsigned laneIdx = 0;
194+
for (unsigned idx : warpAxisBases) {
195+
std::swap(laneBases[zeroLaneBases[laneIdx]], warpBases[idx]);
196+
++laneIdx;
197+
}
198+
for (unsigned idx : blockAxisBases) {
199+
std::swap(laneBases[zeroLaneBases[laneIdx]], blockBases[idx]);
200+
++laneIdx;
201+
}
202+
return LinearLayout(std::move(regBases),
203+
to_vector(layout.getOutDimNames()));
204+
}
205+
206+
// If we can fit all the bases inside the lane dimension, we can perform the
207+
// reduction with two convert_layouts
208+
// The first cvt to move the relevant bases to the lane dimension
209+
// The second to move all the bases we moved out of the lane dimension back to
210+
// their original positions
211+
if (warpAxisBases.size() + blockAxisBases.size() <= laneBases.size()) {
212+
assert(totalAxisBases <= laneBases.size() &&
213+
"unexpected lane base count for axis layout");
214+
unsigned laneIdx = 0;
215+
for (unsigned idx : warpAxisBases) {
216+
std::swap(laneBases[laneIdx], warpBases[idx]);
217+
++laneIdx;
218+
}
219+
for (unsigned idx : blockAxisBases) {
220+
std::swap(laneBases[laneIdx], blockBases[idx]);
221+
++laneIdx;
222+
}
223+
return LinearLayout(std::move(regBases),
224+
to_vector(layout.getOutDimNames()));
225+
}
226+
227+
// Assumptions (easily relaxed if AMD needs it)
228+
// We assume that
229+
// max number of warps * max number of blocks <= (max number of lanes)^2
230+
// We check this in logarithmic space (number of bases)
231+
// This is true in nvidia as the max numbers are warps=64 ctas=16 so that
232+
// 64 * 16 = 1024 = 32 * 32 = laneBases.size() * laneBases.size()
233+
// This implies that, even if we have to perform 3 cvt_layouts, we can perform
234+
// first one that does not cross CTAs, and then two that may cross CTAs
235+
assert(blockBases.size() <= laneBases.size());
236+
assert(warpBases.size() + blockBases.size() <= 2 * laneBases.size());
237+
238+
// Otherwise, fit as many warp bases as possible into the lane dimension
239+
unsigned laneIdx = 0;
240+
for (unsigned idx : warpAxisBases) {
241+
std::swap(laneBases[laneIdx], warpBases[idx]);
242+
++laneIdx;
243+
if (laneIdx >= laneBases.size())
244+
break;
175245
}
176246

177247
return LinearLayout(std::move(regBases), to_vector(layout.getOutDimNames()));
@@ -184,9 +254,7 @@ LinearLayout ReduceOpHelper::reducedRegLaneLayout(RankedTensorType srcTy,
184254
auto kLane = StringAttr::get(ctx, "lane");
185255
auto kWarp = StringAttr::get(ctx, "warp");
186256

187-
auto reduced = triton::gpu::toLinearLayout(srcTy);
188-
reduced = reduced.sublayout({kReg, kLane, kWarp},
189-
to_vector(reduced.getOutDimNames()));
257+
auto reduced = toLinearLayout(srcTy);
190258
reduced = actionRemoveBroadcastedRegs(reduced).apply(reduced);
191259
reduced = makeAxisContiguous(reduced, axis).apply(reduced);
192260
reduced = zeroBasesAlongDimAndReorder(reduced, axis, kReg);
@@ -195,32 +263,6 @@ LinearLayout ReduceOpHelper::reducedRegLaneLayout(RankedTensorType srcTy,
195263
return reduced;
196264
}
197265

198-
SmallVector<unsigned>
199-
ReduceOpHelper::getScratchBytesForCvt(const LinearLayout &srcLayout,
200-
const LinearLayout &dstLayout) {
201-
SmallVector<unsigned> bytes(srcElementTypes.size(), 0);
202-
auto *ctx = op.getContext();
203-
SmallVector<int64_t> shape;
204-
shape.reserve(srcLayout.getNumOutDims());
205-
for (auto dim : srcLayout.getOutDimNames()) {
206-
shape.push_back(srcLayout.getOutDimSize(dim));
207-
}
208-
auto srcEnc = triton::gpu::LinearEncodingAttr::get(ctx, srcLayout);
209-
auto dstEnc = triton::gpu::LinearEncodingAttr::get(ctx, dstLayout);
210-
for (unsigned i = 0; i < srcElementTypes.size(); ++i) {
211-
auto elemTy = srcElementTypes[i];
212-
if (elemTy.isIntOrFloat() && elemTy.getIntOrFloatBitWidth() < 8)
213-
elemTy = IntegerType::get(ctx, 8);
214-
auto srcTy = RankedTensorType::get(shape, elemTy, srcEnc);
215-
auto dstTy = RankedTensorType::get(shape, elemTy, dstEnc);
216-
if (!cvtNeedsSharedMemory(srcTy, dstTy))
217-
continue;
218-
auto elems = getNumScratchElemsSwizzledCvt(srcTy, dstTy);
219-
bytes[i] = elems * getBitwidth(srcTy) / 8;
220-
}
221-
return bytes;
222-
}
223-
224266
ScanLoweringHelper::ScanLoweringHelper(triton::ScanOp op) : scanOp(op) {
225267
auto firstTy = cast<RankedTensorType>(op.getOperands()[0].getType());
226268
srcShape = firstTy.getShape();

0 commit comments

Comments
 (0)