Skip to content

Commit 4b0fd06

Browse files
committed
[BACKEND] Improve and simplify ReduceOp's lowering
We implement a LinearLayout-based `ReduceOp` lowering. This has a number of benefits: - The logic is noticeably simpler as we barely have to implement anything. ConvertLayout and some LL helpers do all the heavy lifting - We get shmem swizzling for free - We sometimes save a shmem round-trip (before we did it unconditionally) - It is now clear that we have a `tmpLl` variable we can carefully choose (we'll do so in a future PR) - It opens the door to returning an arbitrary layout (fusing a `convert_layout` into this op) - It is now really simple to generalise this op to perform cross-cluster reductions, provided that `convert_layout` supports them. - We fix some latent issues the previous implementation had when run on arbitrary linear layouts. We add a funky regression test that used to fail and now passes. - All this while being LOC-neutral! In future PRs we will improve the choice fo `tmpLl` to avoid in many cases the last `convert_layout`, and we will pack the inputs in shmem to be able to vectorize the load/stores for full reductions with multiple inputs. This PR was the result of quite a long (but rather successful) vibe-coding session together with `gpt-5.2-codex`. I found particularly useful being able to emit a ConvertLayout within this lowering rather than having to call the lowering of the function manually. This simplifies the code quite a bit and I would have struggled to convince MLIR to do so myself. stack-info: PR: #9219, branch: lezcano/stack/6
1 parent 5b856e6 commit 4b0fd06

14 files changed

Lines changed: 521 additions & 420 deletions

File tree

include/triton/Analysis/Utility.h

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -42,30 +42,34 @@ class ReduceOpHelper {
4242
}
4343
}
4444

45-
ArrayRef<int64_t> getSrcShape() { return srcShape; }
46-
47-
Attribute getSrcLayout() { return srcEncoding; }
48-
49-
triton::ReduceOp getOperation() { return op; }
50-
51-
unsigned getThreadOffsetOnReductionAxis();
45+
RankedTensorType getSrcTy() { return srcTy; }
5246

5347
bool isWarpSynchronous();
5448

5549
unsigned getInterWarpSizeWithUniqueData();
5650

5751
unsigned getIntraWarpSizeWithUniqueData();
5852

59-
// The shape of the shared memory space needed for the reduction.
60-
SmallVector<unsigned> getScratchRepShape();
53+
bool isReduceWithinCTA();
54+
55+
bool isAssociative();
6156

62-
SmallVector<unsigned> getOrderWithAxisAtBeginning();
57+
static triton::ColumnAction
58+
moveAxisBasesToFront(const triton::LinearLayout &layout, int axis);
6359

64-
unsigned getScratchSizeInBytes();
60+
static triton::LinearLayout
61+
zeroBasesAlongDimAndReorder(const triton::LinearLayout &layout, unsigned axis,
62+
mlir::StringAttr dim);
6563

66-
bool isReduceWithinCTA();
64+
static triton::LinearLayout getInterLayout(const triton::LinearLayout &layout,
65+
unsigned axis);
6766

68-
bool isAssociative();
67+
static triton::LinearLayout reducedRegLaneLayout(RankedTensorType srcTy,
68+
unsigned axis);
69+
70+
SmallVector<unsigned>
71+
getScratchBytesForCvt(const triton::LinearLayout &srcLayout,
72+
const triton::LinearLayout &dstLayout);
6973

7074
private:
7175
triton::ReduceOp op;

include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#define TRITON_CONVERSION_TRITONGPU_TO_LLVM_TARGETINFOBASE_H
33

44
#include "triton/Conversion/MLIRTypes.h"
5+
#include "llvm/ADT/ArrayRef.h"
56

67
namespace mlir::triton {
78
enum class ProgramIDDim : uint32_t;
@@ -66,8 +67,7 @@ class TargetInfoBase {
6667

6768
virtual bool warpReduce(RewriterBase &rewriter, Location loc,
6869
SmallVector<Value> &acc, triton::ReduceOp op,
69-
unsigned numLaneToReduce,
70-
unsigned interleave) const = 0;
70+
unsigned reduceLaneIdMask) const = 0;
7171

7272
virtual std::string getMulhiFuncName(Type resultElementTy) const = 0;
7373
// Emits LLVM code with |rewriter| to print a message following the given

lib/Analysis/Allocation.cpp

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include <algorithm>
44
#include <limits>
5+
#include <numeric>
56

67
#include "mlir/Analysis/Liveness.h"
78
#include "mlir/Support/LLVM.h"
@@ -14,6 +15,7 @@
1415
#include "triton/Tools/LayoutUtils.h"
1516
#include "llvm/ADT/SmallVector.h"
1617
#include "llvm/Support/Debug.h"
18+
#include "llvm/Support/MathExtras.h"
1719
#include "llvm/Support/raw_ostream.h"
1820

1921
#define DEBUG_TYPE "allocation-shared-memory"
@@ -42,6 +44,36 @@ unsigned getNumScratchElemsSwizzledCvt(RankedTensorType srcTy,
4244
return smem.getTotalOutDimSize() / reps;
4345
}
4446

47+
namespace {
48+
constexpr int64_t kReduceScratchAlign = 16;
49+
50+
Type getReduceMemElemTy(Type elemTy, MLIRContext *ctx) {
51+
if (elemTy.isIntOrFloat() && elemTy.getIntOrFloatBitWidth() < 8)
52+
return IntegerType::get(ctx, 8);
53+
return elemTy;
54+
}
55+
56+
int64_t getReduceScratchSizeBytes(triton::ReduceOp op,
57+
ArrayRef<unsigned> bytesPerOperand) {
58+
std::vector<unsigned> indices(op.getNumOperands());
59+
std::iota(indices.begin(), indices.end(), 0);
60+
auto *ctx = op.getContext();
61+
std::sort(indices.begin(), indices.end(), [&](unsigned i, unsigned j) {
62+
auto lhsTy = getReduceMemElemTy(op.getElementTypes()[i], ctx);
63+
auto rhsTy = getReduceMemElemTy(op.getElementTypes()[j], ctx);
64+
return getIntOrFloatOrPtrBitWidth(lhsTy) >
65+
getIntOrFloatOrPtrBitWidth(rhsTy);
66+
});
67+
// Aling to 16 bytes to allow for vectorisation
68+
int64_t offset = 0;
69+
for (unsigned idx : indices) {
70+
offset += llvm::alignTo(bytesPerOperand[idx], kReduceScratchAlign);
71+
}
72+
return offset;
73+
}
74+
75+
} // namespace
76+
4577
// Both `atomic_cas` and `atomic_rmw` may need scratch memory to store values
4678
// because Triton's block-based programming model ensures that
4779
// all threads sharing the same partition of the tensor see the same values,
@@ -69,7 +101,14 @@ static SmallVector<unsigned> getRepShapeForAtomic(Value result) {
69101
unsigned defaultAllocationAnalysisScratchSizeFn(Operation *op) {
70102
if (auto reduceOp = dyn_cast<ReduceOp>(op)) {
71103
ReduceOpHelper helper(reduceOp);
72-
return helper.getScratchSizeInBytes();
104+
if (helper.isWarpSynchronous())
105+
return 0;
106+
107+
auto regLl = ReduceOpHelper::reducedRegLaneLayout(helper.getSrcTy(),
108+
reduceOp.getAxis());
109+
auto tmpLl = ReduceOpHelper::getInterLayout(regLl, reduceOp.getAxis());
110+
auto bytesRegToTmp = helper.getScratchBytesForCvt(regLl, tmpLl);
111+
return getReduceScratchSizeBytes(reduceOp, bytesRegToTmp);
73112
}
74113
if (auto scanOp = dyn_cast<ScanOp>(op)) {
75114
ScanLoweringHelper helper(scanOp);

lib/Analysis/Utility.cpp

Lines changed: 119 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -23,33 +23,6 @@ namespace mlir {
2323
using namespace triton;
2424
using namespace triton::gpu;
2525

26-
SmallVector<unsigned> ReduceOpHelper::getOrderWithAxisAtBeginning() {
27-
auto order = toLinearEncoding(srcTy).getOrder();
28-
auto it = std::find(order.begin(), order.end(), axis);
29-
// delete the axis from order
30-
order.erase(it);
31-
// insert axis at the beginning of order
32-
order.insert(order.begin(), axis);
33-
return order;
34-
}
35-
36-
// Thread offset is the thread index offset of two adjacent threads on the
37-
// reduction axis within the warp.
38-
unsigned ReduceOpHelper::getThreadOffsetOnReductionAxis() {
39-
auto *ctx = srcEncoding.getContext();
40-
auto linearLayout = toLinearLayout(srcTy);
41-
auto kLane = mlir::StringAttr::get(ctx, "lane");
42-
const auto &bases = linearLayout.getBases();
43-
const auto &lanes = bases.find(kLane)->second;
44-
auto offset = 1;
45-
for (const auto &lane : lanes) {
46-
if (lane[axis] != 0)
47-
break;
48-
offset *= 2;
49-
}
50-
return offset;
51-
}
52-
5326
// Cases where distributed shared memory is not required in ConvertLayout:
5427
// (1) numCTAs == 1
5528
// (2) numCTAs > 1 but srcCGALayout == dstCGALayout
@@ -107,29 +80,6 @@ bool ReduceOpHelper::isWarpSynchronous() {
10780
return getWarpsPerCTA(srcEncoding, srcShape)[axis] == 1;
10881
}
10982

110-
SmallVector<unsigned> ReduceOpHelper::getScratchRepShape() {
111-
SmallVector<unsigned> smemShape;
112-
// This case doesn't need inter-warp communication
113-
if (isWarpSynchronous())
114-
return {0, 0};
115-
116-
smemShape = convertType<unsigned>(srcShape);
117-
smemShape[axis] = getInterWarpSizeWithUniqueData();
118-
119-
return smemShape;
120-
}
121-
122-
unsigned ReduceOpHelper::getScratchSizeInBytes() {
123-
auto smemShape = getScratchRepShape();
124-
auto elems = product<unsigned>(smemShape);
125-
126-
unsigned bytesPerElem = 0;
127-
for (const auto &ty : srcElementTypes) {
128-
bytesPerElem += ceil<unsigned>(ty.getIntOrFloatBitWidth(), 8);
129-
}
130-
return bytesPerElem * elems;
131-
}
132-
13383
bool ReduceOpHelper::isReduceWithinCTA() {
13484
// TODO: Support reduce across CTAS
13585
// Layout optimization passes such as PlanCTAPass and
@@ -157,6 +107,125 @@ bool ReduceOpHelper::isAssociative() {
157107
return !hasNoAssociativeOp;
158108
}
159109

110+
ColumnAction ReduceOpHelper::moveAxisBasesToFront(const LinearLayout &layout,
111+
int axis) {
112+
auto *ctx = layout.getOutDimNames().begin()->getContext();
113+
auto kReg = StringAttr::get(ctx, "register");
114+
const auto &bases = layout.getBases().lookup(kReg);
115+
SmallVector<size_t> perm;
116+
SmallVector<size_t> back;
117+
for (size_t i = 0; i < bases.size(); ++i) {
118+
if (bases[i][axis] != 0)
119+
perm.push_back(i);
120+
else
121+
back.push_back(i);
122+
}
123+
perm.append(back.begin(), back.end());
124+
return ColumnAction(perm, kReg, bases.size());
125+
}
126+
127+
LinearLayout
128+
ReduceOpHelper::zeroBasesAlongDimAndReorder(const LinearLayout &layout,
129+
unsigned axis, StringAttr dim) {
130+
// Zeros out the basis along the specified axis in the given hardware
131+
// dimension, and reindexes the remaining bases along axis so that each
132+
// element is in linearly increasing order from the hardware's perspective.
133+
// Note that for this reordering we need the operator to be commutative, but
134+
// it's the only way to have a performant lowering.
135+
LinearLayout::BasesT newBases;
136+
for (auto [inDim, bases] : layout.getBases()) {
137+
std::vector<std::vector<int32_t>> newInBases = bases;
138+
if (inDim == dim) {
139+
for (auto &basis : newInBases)
140+
basis[axis] = 0;
141+
}
142+
newBases[inDim] = std::move(newInBases);
143+
}
144+
145+
int32_t nextAxisBase = 1;
146+
for (auto &[inDim, inDimBases] : newBases) {
147+
for (auto &basis : inDimBases) {
148+
if (basis[axis] == 0)
149+
continue;
150+
basis[axis] = nextAxisBase;
151+
nextAxisBase *= 2;
152+
}
153+
}
154+
155+
return LinearLayout(std::move(newBases), to_vector(layout.getOutDimNames()));
156+
}
157+
158+
LinearLayout ReduceOpHelper::getInterLayout(const LinearLayout &layout,
159+
unsigned axis) {
160+
auto *ctx = layout.getOutDimNames().begin()->getContext();
161+
auto kLane = mlir::StringAttr::get(ctx, "lane");
162+
auto kWarp = mlir::StringAttr::get(ctx, "warp");
163+
auto regBases = layout.getBases();
164+
auto linearAttr = triton::gpu::LinearEncodingAttr::get(ctx, layout);
165+
int laneBits = layout.getInDimSizeLog2(kLane);
166+
int neededLaneBits = llvm::Log2_32(linearAttr.getWarpsPerCTA()[axis]);
167+
// TODO move to verifier
168+
assert(neededLaneBits <= laneBits && "NYI: more inter-warps than lanes");
169+
// Move the warp axis bases we need to reduce into lane bases, while
170+
// keeping non-axis components in their original in-dim.
171+
auto &laneBases = regBases[kLane];
172+
auto &warpBases = regBases[kWarp];
173+
int moved = 0;
174+
for (auto &warpBasis : warpBases) {
175+
if (warpBasis[axis] == 0)
176+
continue;
177+
assert(moved < neededLaneBits && "unexpected warp axis bases count");
178+
std::swap(laneBases[moved], warpBasis);
179+
moved++;
180+
}
181+
182+
return LinearLayout(std::move(regBases), to_vector(layout.getOutDimNames()));
183+
}
184+
185+
LinearLayout ReduceOpHelper::reducedRegLaneLayout(RankedTensorType srcTy,
186+
unsigned axis) {
187+
auto *ctx = srcTy.getContext();
188+
auto kReg = StringAttr::get(ctx, "register");
189+
auto kLane = StringAttr::get(ctx, "lane");
190+
auto kWarp = StringAttr::get(ctx, "warp");
191+
192+
auto reduced = triton::gpu::toLinearLayout(srcTy);
193+
reduced = reduced.sublayout({kReg, kLane, kWarp},
194+
to_vector(reduced.getOutDimNames()));
195+
reduced = actionRemoveBroadcastedRegs(reduced).apply(reduced);
196+
reduced = moveAxisBasesToFront(reduced, axis).apply(reduced);
197+
reduced = zeroBasesAlongDimAndReorder(reduced, axis, kReg);
198+
reduced = actionRemoveBroadcastedRegs(reduced).apply(reduced);
199+
reduced = zeroBasesAlongDimAndReorder(reduced, axis, kLane);
200+
return reduced;
201+
}
202+
203+
SmallVector<unsigned>
204+
ReduceOpHelper::getScratchBytesForCvt(const LinearLayout &srcLayout,
205+
const LinearLayout &dstLayout) {
206+
SmallVector<unsigned> bytes(srcElementTypes.size(), 0);
207+
auto *ctx = op.getContext();
208+
SmallVector<int64_t> shape;
209+
shape.reserve(srcLayout.getNumOutDims());
210+
for (auto dim : srcLayout.getOutDimNames()) {
211+
shape.push_back(srcLayout.getOutDimSize(dim));
212+
}
213+
auto srcEnc = triton::gpu::LinearEncodingAttr::get(ctx, srcLayout);
214+
auto dstEnc = triton::gpu::LinearEncodingAttr::get(ctx, dstLayout);
215+
for (unsigned i = 0; i < srcElementTypes.size(); ++i) {
216+
auto elemTy = srcElementTypes[i];
217+
if (elemTy.isIntOrFloat() && elemTy.getIntOrFloatBitWidth() < 8)
218+
elemTy = IntegerType::get(ctx, 8);
219+
auto srcTy = RankedTensorType::get(shape, elemTy, srcEnc);
220+
auto dstTy = RankedTensorType::get(shape, elemTy, dstEnc);
221+
if (!cvtNeedsSharedMemory(srcTy, dstTy))
222+
continue;
223+
auto elems = getNumScratchElemsSwizzledCvt(srcTy, dstTy);
224+
bytes[i] = elems * getBitwidth(srcTy) / 8;
225+
}
226+
return bytes;
227+
}
228+
160229
ScanLoweringHelper::ScanLoweringHelper(triton::ScanOp op) : scanOp(op) {
161230
auto firstTy = cast<RankedTensorType>(op.getOperands()[0].getType());
162231
srcShape = firstTy.getShape();

0 commit comments

Comments
 (0)