Skip to content

Commit 7ff455e

Browse files
committed
[BACKEND] Implement support for cross-CTA tt.reduce
The title of this PR is a bit of a lie. Even though the lowering is now implemented to support cross-CTA reductions, it depends on `convert_layout` supporting them, and it doesn't currently support LinearLayouts. We should generalise this one first and then enable it here. We should also emit the correct cross-CTA barrier from `targetInfo` in the case of cross-CTA memory reuse. In this PR, we take the chance to also generalise the lowering to avoid convert layouts whenever possible. stack-info: PR: #9221, branch: lezcano/stack/8
1 parent ae06838 commit 7ff455e

7 files changed

Lines changed: 291 additions & 182 deletions

File tree

include/triton/Analysis/Allocation.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,9 @@ using AllocationAnalysisScratchSizeFn = std::function<unsigned(Operation *)>;
2020

2121
unsigned defaultAllocationAnalysisScratchSizeFn(Operation *op);
2222

23-
unsigned getNumScratchElemsSwizzledCvt(RankedTensorType srcTy,
24-
RankedTensorType dstTy);
23+
unsigned getNumScratchElemsSwizzledCvt(const LinearLayout &srcLayout,
24+
const LinearLayout &dstLayout,
25+
int bitwidth);
2526

2627
} // namespace triton
2728

lib/Analysis/Allocation.cpp

Lines changed: 42 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -31,49 +31,20 @@ namespace mlir {
3131
//===----------------------------------------------------------------------===//
3232
namespace triton {
3333

34-
unsigned getNumScratchElemsSwizzledCvt(RankedTensorType srcTy,
35-
RankedTensorType dstTy) {
36-
auto *ctx = srcTy.getContext();
37-
auto srcLayout = gpu::toLinearLayout(srcTy);
38-
auto dstLayout = gpu::toLinearLayout(dstTy);
39-
srcLayout = actionRemoveBroadcastedRegs(srcLayout).apply(srcLayout);
40-
dstLayout = actionRemoveBroadcastedRegs(dstLayout).apply(dstLayout);
41-
auto bitwidth = getBitwidth(srcTy);
42-
auto smem = gpu::optimalSwizzlingLdSt(srcLayout, dstLayout, bitwidth);
34+
unsigned getNumScratchElemsSwizzledCvt(const LinearLayout &srcLayout,
35+
const LinearLayout &dstLayout,
36+
int bitwidth) {
37+
auto *ctx = srcLayout.getInDimNames().begin()->getContext();
38+
auto srcLayoutNoBroadcast =
39+
actionRemoveBroadcastedRegs(srcLayout).apply(srcLayout);
40+
auto dstLayoutNoBroadcast =
41+
actionRemoveBroadcastedRegs(dstLayout).apply(dstLayout);
42+
auto smem = gpu::optimalSwizzlingLdSt(srcLayoutNoBroadcast,
43+
dstLayoutNoBroadcast, bitwidth);
4344
auto reps = smem.getInDimSize(StringAttr::get(ctx, "reps"));
4445
return smem.getTotalOutDimSize() / reps;
4546
}
4647

47-
namespace {
48-
constexpr int64_t kReduceScratchAlign = 16;
49-
50-
Type getReduceMemElemTy(Type elemTy, MLIRContext *ctx) {
51-
if (elemTy.isIntOrFloat() && elemTy.getIntOrFloatBitWidth() < 8)
52-
return IntegerType::get(ctx, 8);
53-
return elemTy;
54-
}
55-
56-
int64_t getReduceScratchSizeBytes(triton::ReduceOp op,
57-
ArrayRef<unsigned> bytesPerOperand) {
58-
std::vector<unsigned> indices(op.getNumOperands());
59-
std::iota(indices.begin(), indices.end(), 0);
60-
auto *ctx = op.getContext();
61-
std::sort(indices.begin(), indices.end(), [&](unsigned i, unsigned j) {
62-
auto lhsTy = getReduceMemElemTy(op.getElementTypes()[i], ctx);
63-
auto rhsTy = getReduceMemElemTy(op.getElementTypes()[j], ctx);
64-
return getIntOrFloatOrPtrBitWidth(lhsTy) >
65-
getIntOrFloatOrPtrBitWidth(rhsTy);
66-
});
67-
// Aling to 16 bytes to allow for vectorisation
68-
int64_t offset = 0;
69-
for (unsigned idx : indices) {
70-
offset += llvm::alignTo(bytesPerOperand[idx], kReduceScratchAlign);
71-
}
72-
return offset;
73-
}
74-
75-
} // namespace
76-
7748
// Both `atomic_cas` and `atomic_rmw` may need scratch memory to store values
7849
// because Triton's block-based programming model ensures that
7950
// all threads sharing the same partition of the tensor see the same values,
@@ -100,15 +71,36 @@ static SmallVector<unsigned> getRepShapeForAtomic(Value result) {
10071

10172
unsigned defaultAllocationAnalysisScratchSizeFn(Operation *op) {
10273
if (auto reduceOp = dyn_cast<ReduceOp>(op)) {
103-
ReduceOpHelper helper(reduceOp);
104-
if (helper.isWarpSynchronous())
105-
return 0;
74+
auto srcTy = ReduceOpHelper(reduceOp).getSrcTy();
75+
auto axis = reduceOp.getAxis();
76+
auto kLane = StringAttr::get(reduceOp.getContext(), "lane");
10677

107-
auto regLl = ReduceOpHelper::reducedRegLaneLayout(helper.getSrcTy(),
108-
reduceOp.getAxis());
109-
auto tmpLl = ReduceOpHelper::getInterLayout(regLl, reduceOp.getAxis());
110-
auto bytesRegToTmp = helper.getScratchBytesForCvt(regLl, tmpLl);
111-
return getReduceScratchSizeBytes(reduceOp, bytesRegToTmp);
78+
auto isReduced = [axis](const LinearLayout &layout) {
79+
return layout.getOutDimSizes().begin()[axis] == 1;
80+
};
81+
auto regLl =
82+
ReduceOpHelper::reducedRegLaneLayout(srcTy, reduceOp.getAxis());
83+
84+
// All the inputs have the same layout so, since we order them from largest
85+
// bitsize to smallest, and the first one is aligned, by induction, they are
86+
// all aligned, so we don't need to align the byte numbers returned by
87+
// helper.getScratchBytesForCvt
88+
auto bytesRegToTmp = 0;
89+
while (!isReduced(regLl)) {
90+
auto tmpLl = ReduceOpHelper::getInterLayout(regLl, axis);
91+
// We take the maximum of the elements and multiply by the total bitwidth
92+
// We do this as otherwise it's quite tricky to find the correct
93+
// BaseOffsets in the lowering
94+
int bytes = 0;
95+
for (auto inputTy : reduceOp.getInputTypes()) {
96+
auto nelem =
97+
getNumScratchElemsSwizzledCvt(regLl, tmpLl, getBitwidth(inputTy));
98+
bytes += nelem * (getBitwidth(inputTy) / 8);
99+
}
100+
bytesRegToTmp = std::max<unsigned>(bytesRegToTmp, bytes);
101+
regLl = ReduceOpHelper::zeroBasesAlongDimAndReorder(tmpLl, axis, kLane);
102+
}
103+
return bytesRegToTmp;
112104
}
113105
if (auto scanOp = dyn_cast<ScanOp>(op)) {
114106
ScanLoweringHelper helper(scanOp);
@@ -131,7 +123,9 @@ unsigned defaultAllocationAnalysisScratchSizeFn(Operation *op) {
131123
if (!cvtNeedsSharedMemory(srcTy, dstTy))
132124
return 0;
133125
// The generic pass uses swizzling
134-
auto elems = getNumScratchElemsSwizzledCvt(srcTy, dstTy);
126+
auto elems = getNumScratchElemsSwizzledCvt(gpu::toLinearLayout(srcTy),
127+
gpu::toLinearLayout(dstTy),
128+
getBitwidth(srcTy));
135129
return elems * getBitwidth(srcTy) / 8;
136130
}
137131
if (isa<AtomicRMWOp, AtomicCASOp>(op)) {

lib/Analysis/Utility.cpp

Lines changed: 91 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -155,23 +155,97 @@ LinearLayout ReduceOpHelper::getInterLayout(const LinearLayout &layout,
155155
auto *ctx = layout.getOutDimNames().begin()->getContext();
156156
auto kLane = mlir::StringAttr::get(ctx, "lane");
157157
auto kWarp = mlir::StringAttr::get(ctx, "warp");
158+
auto kBlock = mlir::StringAttr::get(ctx, "block");
158159
auto regBases = layout.getBases();
159-
auto linearAttr = triton::gpu::LinearEncodingAttr::get(ctx, layout);
160-
int laneBits = layout.getInDimSizeLog2(kLane);
161-
int neededLaneBits = llvm::Log2_32(linearAttr.getWarpsPerCTA()[axis]);
162-
// TODO move to verifier
163-
assert(neededLaneBits <= laneBits && "NYI: more inter-warps than lanes");
164-
// Move the warp axis bases we need to reduce into lane bases, while
165-
// keeping non-axis components in their original in-dim.
166-
auto &laneBases = regBases[kLane];
167-
auto &warpBases = regBases[kWarp];
168-
int moved = 0;
169-
for (auto &warpBasis : warpBases) {
170-
if (warpBasis[axis] == 0)
171-
continue;
172-
assert(moved < neededLaneBits && "unexpected warp axis bases count");
173-
std::swap(laneBases[moved], warpBasis);
174-
moved++;
160+
auto laneIt = regBases.find(kLane);
161+
auto warpIt = regBases.find(kWarp);
162+
auto blockIt = regBases.find(kBlock);
163+
if (laneIt == regBases.end() || warpIt == regBases.end()) {
164+
return layout;
165+
}
166+
167+
auto &laneBases = laneIt->second;
168+
auto &warpBases = warpIt->second;
169+
auto &blockBases = blockIt->second;
170+
171+
auto collectAxisBases = [&](const std::vector<std::vector<int32_t>> &bases,
172+
SmallVector<unsigned> &out) {
173+
for (unsigned i = 0; i < bases.size(); ++i) {
174+
if (bases[i][axis] != 0)
175+
out.push_back(i);
176+
}
177+
};
178+
179+
SmallVector<unsigned> warpAxisBases;
180+
collectAxisBases(warpBases, warpAxisBases);
181+
SmallVector<unsigned> blockAxisBases;
182+
collectAxisBases(blockBases, blockAxisBases);
183+
184+
SmallVector<unsigned> zeroLaneBases;
185+
for (unsigned i = 0; i < laneBases.size(); ++i) {
186+
if (llvm::all_of(laneBases[i], [](int32_t v) { return v == 0; }))
187+
zeroLaneBases.push_back(i);
188+
}
189+
190+
auto axisSize = to_vector(layout.getOutDimSizes())[axis];
191+
auto totalAxisBases = warpAxisBases.size() + blockAxisBases.size();
192+
193+
// First try to place all warp/block axis bases into lane bases that are
194+
// currently zero. If we can do this we will be able to perform the full
195+
// reduction with just one convert_layout
196+
if (zeroLaneBases.size() >= totalAxisBases) {
197+
unsigned laneIdx = 0;
198+
for (unsigned idx : warpAxisBases) {
199+
std::swap(laneBases[zeroLaneBases[laneIdx]], warpBases[idx]);
200+
++laneIdx;
201+
}
202+
for (unsigned idx : blockAxisBases) {
203+
std::swap(laneBases[zeroLaneBases[laneIdx]], blockBases[idx]);
204+
++laneIdx;
205+
}
206+
return LinearLayout(std::move(regBases),
207+
to_vector(layout.getOutDimNames()));
208+
}
209+
210+
// If we can fit all the bases inside the lane dimension, we can perform the
211+
// reduction with two convert_layouts
212+
// The first cvt to move the relevant bases to the lane dimension
213+
// The second to move all the bases we moved out of the lane dimension back to
214+
// their original positions
215+
if (warpAxisBases.size() + blockAxisBases.size() <= laneBases.size()) {
216+
assert(totalAxisBases <= laneBases.size() &&
217+
"unexpected lane base count for axis layout");
218+
unsigned laneIdx = 0;
219+
for (unsigned idx : warpAxisBases) {
220+
std::swap(laneBases[laneIdx], warpBases[idx]);
221+
++laneIdx;
222+
}
223+
for (unsigned idx : blockAxisBases) {
224+
std::swap(laneBases[laneIdx], blockBases[idx]);
225+
++laneIdx;
226+
}
227+
return LinearLayout(std::move(regBases),
228+
to_vector(layout.getOutDimNames()));
229+
}
230+
231+
// Assumptions (easily relaxed if AMD needs it)
232+
// We assume that
233+
// max number of warps * max number of blocks <= (max number of lanes)^2
234+
// We check this in logarithmic space (number of bases)
235+
// This is true in nvidia as the max numbers are warps=64 ctas=16 so that
236+
// 64 * 16 = 1024 = 32 * 32 = laneBases.size() * laneBases.size()
237+
// This implies that, even if we have to perform 3 cvt_layouts, we can perform
238+
// first one that does not cross CTAs, and then two that may cross CTAs
239+
assert(blockBases.size() <= laneBases.size());
240+
assert(warpBases.size() + blockBases.size() <= 2 * laneBases.size());
241+
242+
// Otherwise, fit as many warp bases as possible into the lane dimension
243+
unsigned laneIdx = 0;
244+
for (unsigned idx : warpAxisBases) {
245+
std::swap(laneBases[laneIdx], warpBases[idx]);
246+
++laneIdx;
247+
if (laneIdx >= laneBases.size())
248+
break;
175249
}
176250

177251
return LinearLayout(std::move(regBases), to_vector(layout.getOutDimNames()));
@@ -184,9 +258,7 @@ LinearLayout ReduceOpHelper::reducedRegLaneLayout(RankedTensorType srcTy,
184258
auto kLane = StringAttr::get(ctx, "lane");
185259
auto kWarp = StringAttr::get(ctx, "warp");
186260

187-
auto reduced = triton::gpu::toLinearLayout(srcTy);
188-
reduced = reduced.sublayout({kReg, kLane, kWarp},
189-
to_vector(reduced.getOutDimNames()));
261+
auto reduced = toLinearLayout(srcTy);
190262
reduced = actionRemoveBroadcastedRegs(reduced).apply(reduced);
191263
reduced = makeAxisContiguous(reduced, axis).apply(reduced);
192264
reduced = zeroBasesAlongDimAndReorder(reduced, axis, kReg);
@@ -195,32 +267,6 @@ LinearLayout ReduceOpHelper::reducedRegLaneLayout(RankedTensorType srcTy,
195267
return reduced;
196268
}
197269

198-
SmallVector<unsigned>
199-
ReduceOpHelper::getScratchBytesForCvt(const LinearLayout &srcLayout,
200-
const LinearLayout &dstLayout) {
201-
SmallVector<unsigned> bytes(srcElementTypes.size(), 0);
202-
auto *ctx = op.getContext();
203-
SmallVector<int64_t> shape;
204-
shape.reserve(srcLayout.getNumOutDims());
205-
for (auto dim : srcLayout.getOutDimNames()) {
206-
shape.push_back(srcLayout.getOutDimSize(dim));
207-
}
208-
auto srcEnc = triton::gpu::LinearEncodingAttr::get(ctx, srcLayout);
209-
auto dstEnc = triton::gpu::LinearEncodingAttr::get(ctx, dstLayout);
210-
for (unsigned i = 0; i < srcElementTypes.size(); ++i) {
211-
auto elemTy = srcElementTypes[i];
212-
if (elemTy.isIntOrFloat() && elemTy.getIntOrFloatBitWidth() < 8)
213-
elemTy = IntegerType::get(ctx, 8);
214-
auto srcTy = RankedTensorType::get(shape, elemTy, srcEnc);
215-
auto dstTy = RankedTensorType::get(shape, elemTy, dstEnc);
216-
if (!cvtNeedsSharedMemory(srcTy, dstTy))
217-
continue;
218-
auto elems = getNumScratchElemsSwizzledCvt(srcTy, dstTy);
219-
bytes[i] = elems * getBitwidth(srcTy) / 8;
220-
}
221-
return bytes;
222-
}
223-
224270
ScanLoweringHelper::ScanLoweringHelper(triton::ScanOp op) : scanOp(op) {
225271
auto firstTy = cast<RankedTensorType>(op.getOperands()[0].getType());
226272
srcShape = firstTy.getShape();

0 commit comments

Comments
 (0)