Skip to content

Commit d3adbad

Browse files
committed
[BACKEND] Perform tree reductions on in-thread values
We generate ternary trees for suitable integer ops and binary trees for everything else. We manually generate `{add,mul}.{f16,f32}x2` ops. This brings a speed-up to some gluon attention kernels. stack-info: PR: #9220, branch: lezcano/stack/7
1 parent 0ca2260 commit d3adbad

1 file changed

Lines changed: 149 additions & 4 deletions

File tree

lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp

Lines changed: 149 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
#include "ReduceScanCommon.h"
22

3+
#include <memory>
34
#include <tuple>
45
#include <utility>
56

7+
#include "mlir/Dialect/Arith/IR/Arith.h"
8+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
69
#include "mlir/Support/LLVM.h"
710
#include "triton/Analysis/Utility.h"
811
#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h"
@@ -110,6 +113,142 @@ struct ReduceOpConversion
110113

111114
private:
112115
const TargetInfoBase &targetInfo;
116+
117+
static bool useTernaryTreeReduction(triton::ReduceOp op) {
118+
if (op.getNumOperands() != 1)
119+
return false;
120+
auto elemTy = op.getElementTypes()[0];
121+
if (!isa<IntegerType>(elemTy))
122+
return false;
123+
Operation *combiner = op.getSingleCombiner();
124+
if (!combiner)
125+
return false;
126+
return isa<arith::AndIOp, arith::OrIOp, arith::XOrIOp, arith::AddIOp>(
127+
combiner);
128+
}
129+
130+
std::unique_ptr<Region>
131+
maybePackValuesf32x2(Location loc, ConversionPatternRewriter &rewriter,
132+
Region &combineOp,
133+
SmallVector<SmallVector<Value>> &values) const {
134+
if (values.size() < 2 || (values.size() % 2) != 0)
135+
return nullptr;
136+
if (values.front().size() != 1)
137+
return nullptr;
138+
auto elemTy = values.front().front().getType();
139+
if (!(elemTy.isF16() || elemTy.isBF16() || elemTy.isF32()))
140+
return nullptr;
141+
Operation *combiner = nullptr;
142+
if (!combineOp.empty()) {
143+
auto &block = combineOp.front();
144+
if (block.getOperations().size() == 2)
145+
combiner = &block.front();
146+
}
147+
if (!combiner)
148+
return nullptr;
149+
if (!isa<arith::AddFOp, arith::MulFOp>(combiner))
150+
return nullptr;
151+
bool isMul = isa<arith::MulFOp>(combiner);
152+
// Pack the values into 2-element vectors
153+
SmallVector<SmallVector<Value>> packed;
154+
for (size_t i = 0; i < values.size(); i += 2) {
155+
SmallVector<Value> vecTuple(values.front().size());
156+
for (unsigned opIdx = 0; opIdx < values.front().size(); ++opIdx) {
157+
vecTuple[opIdx] = packLLVector(
158+
loc, {values[i][opIdx], values[i + 1][opIdx]}, rewriter);
159+
}
160+
packed.push_back(std::move(vecTuple));
161+
}
162+
values = std::move(packed);
163+
// Create a new region that takes 2-element vectors as inputs and returns a
164+
// 2-element vector
165+
auto region = std::make_unique<Region>();
166+
auto *block = new Block();
167+
region->push_back(block);
168+
auto vecTy = vec_ty(elemTy, 2);
169+
block->addArgument(vecTy, loc);
170+
block->addArgument(vecTy, loc);
171+
auto *ctx = rewriter.getContext();
172+
OpBuilder builder(ctx);
173+
builder.setInsertionPointToStart(block);
174+
Value lhs = block->getArgument(0);
175+
Value rhs = block->getArgument(1);
176+
Value result =
177+
isMul ? LLVM::FMulOp::create(builder, loc, lhs, rhs).getResult()
178+
: LLVM::FAddOp::create(builder, loc, lhs, rhs).getResult();
179+
triton::ReduceReturnOp::create(builder, loc, ValueRange{result});
180+
return region;
181+
}
182+
183+
void maybeUnpackValuesf32x2(Location loc, ConversionPatternRewriter &rewriter,
184+
Region &combineOp,
185+
SmallVector<Value> &values) const {
186+
// If it has more than one output, it's not vectorized
187+
// There is a world where we just check whether the region has all arith ops
188+
// and we vectorize them all with a pass... but that's for another day.
189+
if (values.size() != 1)
190+
return;
191+
auto vecTy = dyn_cast<VectorType>(values.front().getType());
192+
if (!vecTy)
193+
return;
194+
assert(vecTy.getNumElements() == 2);
195+
auto elemTy = vecTy.getElementType();
196+
assert(elemTy.isF16() || elemTy.isBF16() || elemTy.isF32());
197+
// Perform the last (non-vectorized) combine operation.
198+
auto elems = unpackLLVector(loc, values.front(), rewriter);
199+
SmallVector<Value> acc = {elems[0]};
200+
accumulate(loc, rewriter, combineOp, acc, {elems[1]});
201+
values = std::move(acc);
202+
}
203+
204+
SmallVector<Value>
205+
treeReduceTernary(Location loc, ConversionPatternRewriter &rewriter,
206+
Region &combineOp,
207+
SmallVector<SmallVector<Value>> values) const {
208+
while (values.size() > 1) {
209+
SmallVector<SmallVector<Value>> next;
210+
size_t i = 0;
211+
for (; i + 2 < values.size(); i += 3) {
212+
SmallVector<Value> acc = values[i];
213+
accumulate(loc, rewriter, combineOp, acc, values[i + 1]);
214+
accumulate(loc, rewriter, combineOp, acc, values[i + 2]);
215+
next.push_back(std::move(acc));
216+
}
217+
// Process tail
218+
if (values.size() - i == 1) {
219+
next.push_back(std::move(values[i]));
220+
} else if (values.size() - i == 2) {
221+
SmallVector<Value> acc = values[i];
222+
accumulate(loc, rewriter, combineOp, acc, values[i + 1]);
223+
next.push_back(std::move(acc));
224+
}
225+
values = std::move(next);
226+
}
227+
return std::move(values.front());
228+
}
229+
230+
SmallVector<Value>
231+
treeReduceBinary(Location loc, ConversionPatternRewriter &rewriter,
232+
Region &combineOp,
233+
SmallVector<SmallVector<Value>> values) const {
234+
// The number of elements is always a power of two
235+
assert(llvm::isPowerOf2_64(values.size()) && !values.empty());
236+
auto vectorCombine = maybePackValuesf32x2(loc, rewriter, combineOp, values);
237+
Region &accumulateRegion = vectorCombine ? *vectorCombine : combineOp;
238+
while (values.size() > 1) {
239+
SmallVector<SmallVector<Value>> next;
240+
for (size_t i = 0; i + 1 < values.size(); i += 2) {
241+
SmallVector<Value> acc = values[i];
242+
accumulate(loc, rewriter, accumulateRegion, acc, values[i + 1]);
243+
next.push_back(std::move(acc));
244+
}
245+
values = std::move(next);
246+
}
247+
SmallVector<Value> val = std::move(values.front());
248+
maybeUnpackValuesf32x2(loc, rewriter, combineOp, val);
249+
return val;
250+
}
251+
113252
void accumulate(Location loc, ConversionPatternRewriter &rewriter,
114253
Region &combineOp, SmallVector<Value> &acc, ValueRange cur,
115254
Value pred = {}) const {
@@ -154,6 +293,7 @@ struct ReduceOpConversion
154293
return {std::move(layout), std::move(accs)};
155294
}
156295

296+
bool useTernary = useTernaryTreeReduction(op);
157297
// Bring the registers that move the axis to the front
158298
auto perm = ReduceOpHelper::makeAxisContiguous(layout, op.getAxis());
159299
if (!perm.isIdentity()) {
@@ -163,19 +303,24 @@ struct ReduceOpConversion
163303
}
164304
}
165305

166-
// Reduce linearly
167-
// TODO Perform a tree reduction
306+
// Reduce with a tree. Use a ternary tree when it can map to 3-input SASS
307+
// ops (IADD3/LOP3); otherwise use a binary tree.
168308
SmallVector<SmallVector<Value>> reduced(op.getNumOperands());
169309
for (unsigned regBase = 0; regBase < layout.getInDimSize(kReg);
170310
regBase += axisPack) {
171-
SmallVector<Value> acc;
311+
SmallVector<SmallVector<Value>> vals;
172312
for (unsigned i = 0; i < axisPack; ++i) {
173313
SmallVector<Value> cur(op.getNumOperands());
174314
for (unsigned opIdx = 0; opIdx < op.getNumOperands(); ++opIdx) {
175315
cur[opIdx] = accs[opIdx][regBase + i];
176316
}
177-
accumulate(op.getLoc(), rewriter, op.getCombineOp(), acc, cur);
317+
vals.push_back(std::move(cur));
178318
}
319+
auto acc = useTernary
320+
? treeReduceTernary(op.getLoc(), rewriter,
321+
op.getCombineOp(), std::move(vals))
322+
: treeReduceBinary(op.getLoc(), rewriter,
323+
op.getCombineOp(), std::move(vals));
179324
for (unsigned opIdx = 0; opIdx < op.getNumOperands(); ++opIdx) {
180325
reduced[opIdx].push_back(acc[opIdx]);
181326
}

0 commit comments

Comments
 (0)