Skip to content

Commit 4c14120

Browse files
committed
[BACKEND] Improve and simplify ReduceOp's lowering
We implement a LinearLayout-based `ReduceOp` lowering. This has a number of benefits: - The logic is noticeably simpler as we barely have to implement anything. ConvertLayout and some LL helpers do all the heavy lifting - We get shmem swizzling for free - We sometimes save a shmem round-trip (before we did it unconditionally) - It is now clear that we have a `tmpLl` variable we can carefully choose (we'll do so in a future PR) - It opens the door to returning an arbitrary layout (fusing a `convert_layout` into this op) - It is now really simple to generalise this op to perform cross-cluster reductions, provided that `convert_layout` supports them. - We fix some latent issues the previous implementation had when run on arbitrary linear layouts. We add a funky regression test that used to fail and now passes. - All this while being LOC-neutral! In future PRs we will improve the choice fo `tmpLl` to avoid in many cases the last `convert_layout`, and we will pack the inputs in shmem to be able to vectorize the load/stores for full reductions with multiple inputs. This PR was the result of quite a long (but rather successful) vibe-coding session together with `gpt-5.2-codex`. I found particularly useful being able to emit a ConvertLayout within this lowering rather than having to call the lowering of the function manually. This simplifies the code quite a bit and I would have struggled to convince MLIR to do so myself.
1 parent 6457061 commit 4c14120

14 files changed

Lines changed: 511 additions & 424 deletions

File tree

include/triton/Analysis/Utility.h

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -42,30 +42,34 @@ class ReduceOpHelper {
4242
}
4343
}
4444

45-
ArrayRef<int64_t> getSrcShape() { return srcShape; }
46-
47-
Attribute getSrcLayout() { return srcEncoding; }
48-
49-
triton::ReduceOp getOperation() { return op; }
50-
51-
unsigned getThreadOffsetOnReductionAxis();
45+
RankedTensorType getSrcTy() { return srcTy; }
5246

5347
bool isWarpSynchronous();
5448

5549
unsigned getInterWarpSizeWithUniqueData();
5650

5751
unsigned getIntraWarpSizeWithUniqueData();
5852

59-
// The shape of the shared memory space needed for the reduction.
60-
SmallVector<unsigned> getScratchRepShape();
53+
bool isReduceWithinCTA();
54+
55+
bool isAssociative();
6156

62-
SmallVector<unsigned> getOrderWithAxisAtBeginning();
57+
static triton::ColumnAction
58+
makeAxisContiguous(const triton::LinearLayout &layout, int axis);
6359

64-
unsigned getScratchSizeInBytes();
60+
static triton::LinearLayout
61+
zeroBasesAlongDimAndReorder(const triton::LinearLayout &layout, unsigned axis,
62+
mlir::StringAttr dim);
6563

66-
bool isReduceWithinCTA();
64+
static triton::LinearLayout getInterLayout(const triton::LinearLayout &layout,
65+
unsigned axis);
6766

68-
bool isAssociative();
67+
static triton::LinearLayout reducedRegLaneLayout(RankedTensorType srcTy,
68+
unsigned axis);
69+
70+
SmallVector<unsigned>
71+
getScratchBytesForCvt(const triton::LinearLayout &srcLayout,
72+
const triton::LinearLayout &dstLayout);
6973

7074
private:
7175
triton::ReduceOp op;

include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#define TRITON_CONVERSION_TRITONGPU_TO_LLVM_TARGETINFOBASE_H
33

44
#include "triton/Conversion/MLIRTypes.h"
5+
#include "llvm/ADT/ArrayRef.h"
56

67
namespace mlir::triton {
78
enum class ProgramIDDim : uint32_t;
@@ -63,8 +64,7 @@ class TargetInfoBase {
6364

6465
virtual bool warpReduce(RewriterBase &rewriter, Location loc,
6566
SmallVector<Value> &acc, triton::ReduceOp op,
66-
unsigned numLaneToReduce,
67-
unsigned interleave) const = 0;
67+
unsigned activeLanes) const = 0;
6868

6969
virtual std::string getMulhiFuncName(Type resultElementTy) const = 0;
7070
// Emits LLVM code with |rewriter| to print a message following the given

lib/Analysis/Allocation.cpp

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include <algorithm>
44
#include <limits>
5+
#include <numeric>
56

67
#include "mlir/Analysis/Liveness.h"
78
#include "mlir/Support/LLVM.h"
@@ -14,6 +15,7 @@
1415
#include "triton/Tools/LayoutUtils.h"
1516
#include "llvm/ADT/SmallVector.h"
1617
#include "llvm/Support/Debug.h"
18+
#include "llvm/Support/MathExtras.h"
1719
#include "llvm/Support/raw_ostream.h"
1820

1921
#define DEBUG_TYPE "allocation-shared-memory"
@@ -42,6 +44,36 @@ unsigned getNumScratchElemsSwizzledCvt(RankedTensorType srcTy,
4244
return smem.getTotalOutDimSize() / reps;
4345
}
4446

47+
namespace {
48+
constexpr int64_t kReduceScratchAlign = 16;
49+
50+
Type getReduceMemElemTy(Type elemTy, MLIRContext *ctx) {
51+
if (elemTy.isIntOrFloat() && elemTy.getIntOrFloatBitWidth() < 8)
52+
return IntegerType::get(ctx, 8);
53+
return elemTy;
54+
}
55+
56+
int64_t getReduceScratchSizeBytes(triton::ReduceOp op,
57+
ArrayRef<unsigned> bytesPerOperand) {
58+
std::vector<unsigned> indices(op.getNumOperands());
59+
std::iota(indices.begin(), indices.end(), 0);
60+
auto *ctx = op.getContext();
61+
std::sort(indices.begin(), indices.end(), [&](unsigned i, unsigned j) {
62+
auto lhsTy = getReduceMemElemTy(op.getElementTypes()[i], ctx);
63+
auto rhsTy = getReduceMemElemTy(op.getElementTypes()[j], ctx);
64+
return getIntOrFloatOrPtrBitWidth(lhsTy) >
65+
getIntOrFloatOrPtrBitWidth(rhsTy);
66+
});
67+
// Aling to 16 bytes to allow for vectorisation
68+
int64_t offset = 0;
69+
for (unsigned idx : indices) {
70+
offset += llvm::alignTo(bytesPerOperand[idx], kReduceScratchAlign);
71+
}
72+
return offset;
73+
}
74+
75+
} // namespace
76+
4577
// Both `atomic_cas` and `atomic_rmw` may need scratch memory to store values
4678
// because Triton's block-based programming model ensures that
4779
// all threads sharing the same partition of the tensor see the same values,
@@ -69,7 +101,14 @@ static SmallVector<unsigned> getRepShapeForAtomic(Value result) {
69101
unsigned defaultAllocationAnalysisScratchSizeFn(Operation *op) {
70102
if (auto reduceOp = dyn_cast<ReduceOp>(op)) {
71103
ReduceOpHelper helper(reduceOp);
72-
return helper.getScratchSizeInBytes();
104+
if (helper.isWarpSynchronous())
105+
return 0;
106+
107+
auto regLl = ReduceOpHelper::reducedRegLaneLayout(helper.getSrcTy(),
108+
reduceOp.getAxis());
109+
auto tmpLl = ReduceOpHelper::getInterLayout(regLl, reduceOp.getAxis());
110+
auto bytesRegToTmp = helper.getScratchBytesForCvt(regLl, tmpLl);
111+
return getReduceScratchSizeBytes(reduceOp, bytesRegToTmp);
73112
}
74113
if (auto scanOp = dyn_cast<ScanOp>(op)) {
75114
ScanLoweringHelper helper(scanOp);

lib/Analysis/Utility.cpp

Lines changed: 114 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -23,33 +23,6 @@ namespace mlir {
2323
using namespace triton;
2424
using namespace triton::gpu;
2525

26-
SmallVector<unsigned> ReduceOpHelper::getOrderWithAxisAtBeginning() {
27-
auto order = toLinearEncoding(srcTy).getOrder();
28-
auto it = std::find(order.begin(), order.end(), axis);
29-
// delete the axis from order
30-
order.erase(it);
31-
// insert axis at the beginning of order
32-
order.insert(order.begin(), axis);
33-
return order;
34-
}
35-
36-
// Thread offset is the thread index offset of two adjacent threads on the
37-
// reduction axis within the warp.
38-
unsigned ReduceOpHelper::getThreadOffsetOnReductionAxis() {
39-
auto *ctx = srcEncoding.getContext();
40-
auto linearLayout = toLinearLayout(srcTy);
41-
auto kLane = mlir::StringAttr::get(ctx, "lane");
42-
const auto &bases = linearLayout.getBases();
43-
const auto &lanes = bases.find(kLane)->second;
44-
auto offset = 1;
45-
for (const auto &lane : lanes) {
46-
if (lane[axis] != 0)
47-
break;
48-
offset *= 2;
49-
}
50-
return offset;
51-
}
52-
5326
// Cases where distributed shared memory is not required in ConvertLayout:
5427
// (1) numCTAs == 1
5528
// (2) numCTAs > 1 but srcCGALayout == dstCGALayout
@@ -107,29 +80,6 @@ bool ReduceOpHelper::isWarpSynchronous() {
10780
return getWarpsPerCTA(srcEncoding, srcShape)[axis] == 1;
10881
}
10982

110-
SmallVector<unsigned> ReduceOpHelper::getScratchRepShape() {
111-
SmallVector<unsigned> smemShape;
112-
// This case doesn't need inter-warp communication
113-
if (isWarpSynchronous())
114-
return {0, 0};
115-
116-
smemShape = convertType<unsigned>(srcShape);
117-
smemShape[axis] = getInterWarpSizeWithUniqueData();
118-
119-
return smemShape;
120-
}
121-
122-
unsigned ReduceOpHelper::getScratchSizeInBytes() {
123-
auto smemShape = getScratchRepShape();
124-
auto elems = product<unsigned>(smemShape);
125-
126-
unsigned bytesPerElem = 0;
127-
for (const auto &ty : srcElementTypes) {
128-
bytesPerElem += ceil<unsigned>(ty.getIntOrFloatBitWidth(), 8);
129-
}
130-
return bytesPerElem * elems;
131-
}
132-
13383
bool ReduceOpHelper::isReduceWithinCTA() {
13484
// TODO: Support reduce across CTAS
13585
// Layout optimization passes such as PlanCTAPass and
@@ -157,6 +107,120 @@ bool ReduceOpHelper::isAssociative() {
157107
return !hasNoAssociativeOp;
158108
}
159109

110+
ColumnAction ReduceOpHelper::makeAxisContiguous(const LinearLayout &layout,
111+
int axis) {
112+
auto *ctx = layout.getOutDimNames().begin()->getContext();
113+
auto kReg = StringAttr::get(ctx, "register");
114+
const auto &bases = layout.getBases().lookup(kReg);
115+
SmallVector<size_t> perm;
116+
SmallVector<size_t> back;
117+
for (size_t i = 0; i < bases.size(); ++i) {
118+
if (bases[i][axis] != 0)
119+
perm.push_back(i);
120+
else
121+
back.push_back(i);
122+
}
123+
perm.append(back.begin(), back.end());
124+
return ColumnAction(perm, kReg, bases.size());
125+
}
126+
127+
LinearLayout
128+
ReduceOpHelper::zeroBasesAlongDimAndReorder(const LinearLayout &layout,
129+
unsigned axis, StringAttr dim) {
130+
LinearLayout::BasesT newBases;
131+
for (auto [inDim, bases] : layout.getBases()) {
132+
std::vector<std::vector<int32_t>> newInBases = bases;
133+
if (inDim == dim) {
134+
for (auto &basis : newInBases)
135+
basis[axis] = 0;
136+
}
137+
newBases[inDim] = std::move(newInBases);
138+
}
139+
140+
int32_t nextAxisBase = 1;
141+
for (auto &[inDim, inDimBases] : newBases) {
142+
for (auto &basis : inDimBases) {
143+
if (basis[axis] == 0)
144+
continue;
145+
basis[axis] = nextAxisBase;
146+
nextAxisBase *= 2;
147+
}
148+
}
149+
150+
return LinearLayout(std::move(newBases), to_vector(layout.getOutDimNames()));
151+
}
152+
153+
LinearLayout ReduceOpHelper::getInterLayout(const LinearLayout &layout,
154+
unsigned axis) {
155+
auto *ctx = layout.getOutDimNames().begin()->getContext();
156+
auto kLane = mlir::StringAttr::get(ctx, "lane");
157+
auto kWarp = mlir::StringAttr::get(ctx, "warp");
158+
auto regBases = layout.getBases();
159+
auto linearAttr = triton::gpu::LinearEncodingAttr::get(ctx, layout);
160+
int laneBits = layout.getInDimSizeLog2(kLane);
161+
int neededLaneBits = llvm::Log2_32(linearAttr.getWarpsPerCTA()[axis]);
162+
// TODO move to verifier
163+
assert(neededLaneBits <= laneBits && "NYI: more inter-warps than lanes");
164+
// Move the warp axis bases we need to reduce into lane bases, while
165+
// keeping non-axis components in their original in-dim.
166+
auto &laneBases = regBases[kLane];
167+
auto &warpBases = regBases[kWarp];
168+
int moved = 0;
169+
for (auto &warpBasis : warpBases) {
170+
if (warpBasis[axis] == 0)
171+
continue;
172+
assert(moved < neededLaneBits && "unexpected warp axis bases count");
173+
std::swap(laneBases[moved], warpBasis);
174+
moved++;
175+
}
176+
177+
return LinearLayout(std::move(regBases), to_vector(layout.getOutDimNames()));
178+
}
179+
180+
LinearLayout ReduceOpHelper::reducedRegLaneLayout(RankedTensorType srcTy,
181+
unsigned axis) {
182+
auto *ctx = srcTy.getContext();
183+
auto kReg = StringAttr::get(ctx, "register");
184+
auto kLane = StringAttr::get(ctx, "lane");
185+
auto kWarp = StringAttr::get(ctx, "warp");
186+
187+
auto reduced = triton::gpu::toLinearLayout(srcTy);
188+
reduced = reduced.sublayout({kReg, kLane, kWarp},
189+
to_vector(reduced.getOutDimNames()));
190+
reduced = actionRemoveBroadcastedRegs(reduced).apply(reduced);
191+
reduced = makeAxisContiguous(reduced, axis).apply(reduced);
192+
reduced = zeroBasesAlongDimAndReorder(reduced, axis, kReg);
193+
reduced = actionRemoveBroadcastedRegs(reduced).apply(reduced);
194+
reduced = zeroBasesAlongDimAndReorder(reduced, axis, kLane);
195+
return reduced;
196+
}
197+
198+
SmallVector<unsigned>
199+
ReduceOpHelper::getScratchBytesForCvt(const LinearLayout &srcLayout,
200+
const LinearLayout &dstLayout) {
201+
SmallVector<unsigned> bytes(srcElementTypes.size(), 0);
202+
auto *ctx = op.getContext();
203+
SmallVector<int64_t> shape;
204+
shape.reserve(srcLayout.getNumOutDims());
205+
for (auto dim : srcLayout.getOutDimNames()) {
206+
shape.push_back(srcLayout.getOutDimSize(dim));
207+
}
208+
auto srcEnc = triton::gpu::LinearEncodingAttr::get(ctx, srcLayout);
209+
auto dstEnc = triton::gpu::LinearEncodingAttr::get(ctx, dstLayout);
210+
for (unsigned i = 0; i < srcElementTypes.size(); ++i) {
211+
auto elemTy = srcElementTypes[i];
212+
if (elemTy.isIntOrFloat() && elemTy.getIntOrFloatBitWidth() < 8)
213+
elemTy = IntegerType::get(ctx, 8);
214+
auto srcTy = RankedTensorType::get(shape, elemTy, srcEnc);
215+
auto dstTy = RankedTensorType::get(shape, elemTy, dstEnc);
216+
if (!cvtNeedsSharedMemory(srcTy, dstTy))
217+
continue;
218+
auto elems = getNumScratchElemsSwizzledCvt(srcTy, dstTy);
219+
bytes[i] = elems * getBitwidth(srcTy) / 8;
220+
}
221+
return bytes;
222+
}
223+
160224
ScanLoweringHelper::ScanLoweringHelper(triton::ScanOp op) : scanOp(op) {
161225
auto firstTy = cast<RankedTensorType>(op.getOperands()[0].getType());
162226
srcShape = firstTy.getShape();

0 commit comments

Comments
 (0)