Skip to content

Commit 147422e

Browse files
committed
[BACKEND] Perform tree reductions on in-thread values
We generate ternary trees for suitable integer ops and binary trees for everything else. We manually generate `{add,mul}.{f16,f32}x2` ops. This brings a speed-up to some gluon attention kernels.
1 parent 4c14120 commit 147422e

1 file changed

Lines changed: 147 additions & 4 deletions

File tree

lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp

Lines changed: 147 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
#include "ReduceScanCommon.h"
22

3+
#include <memory>
34
#include <tuple>
45
#include <utility>
56

7+
#include "mlir/Dialect/Arith/IR/Arith.h"
8+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
69
#include "mlir/Support/LLVM.h"
710
#include "triton/Analysis/Utility.h"
811
#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h"
@@ -110,6 +113,140 @@ struct ReduceOpConversion
110113

111114
private:
112115
const TargetInfoBase &targetInfo;
116+
117+
static bool useTernaryTreeReduction(triton::ReduceOp op) {
118+
if (op.getNumOperands() != 1)
119+
return false;
120+
auto elemTy = op.getElementTypes()[0];
121+
if (!isa<IntegerType>(elemTy))
122+
return false;
123+
Operation *combiner = op.getSingleCombiner();
124+
if (!combiner)
125+
return false;
126+
return isa<arith::AndIOp, arith::OrIOp, arith::XOrIOp, arith::AddIOp>(
127+
combiner);
128+
}
129+
130+
std::unique_ptr<Region>
131+
maybePackValuesf32x2(Location loc, ConversionPatternRewriter &rewriter,
132+
Region &combineOp,
133+
SmallVector<SmallVector<Value>> &values) const {
134+
if (values.size() < 2 || (values.size() % 2) != 0)
135+
return nullptr;
136+
if (values.front().size() != 1)
137+
return nullptr;
138+
auto elemTy = values.front().front().getType();
139+
if (!(elemTy.isF16() || elemTy.isBF16() || elemTy.isF32()))
140+
return nullptr;
141+
Operation *combiner = nullptr;
142+
if (!combineOp.empty()) {
143+
auto &block = combineOp.front();
144+
if (block.getOperations().size() == 2)
145+
combiner = &block.front();
146+
}
147+
if (!combiner)
148+
return nullptr;
149+
if (!isa<arith::AddFOp, arith::MulFOp>(combiner))
150+
return nullptr;
151+
bool isMul = isa<arith::MulFOp>(combiner);
152+
// Pack the values into 2-element vectors
153+
SmallVector<SmallVector<Value>> packed;
154+
for (size_t i = 0; i < values.size(); i += 2) {
155+
SmallVector<Value> vecTuple(values.front().size());
156+
for (unsigned opIdx = 0; opIdx < values.front().size(); ++opIdx) {
157+
vecTuple[opIdx] = packLLVector(
158+
loc, {values[i][opIdx], values[i + 1][opIdx]}, rewriter);
159+
}
160+
packed.push_back(std::move(vecTuple));
161+
}
162+
values = std::move(packed);
163+
// Create a new region that takes 2-element vectors as inputs and returns a
164+
// 2-element vector
165+
auto region = std::make_unique<Region>();
166+
auto *block = new Block();
167+
region->push_back(block);
168+
auto vecTy = vec_ty(elemTy, 2);
169+
block->addArgument(vecTy, loc);
170+
block->addArgument(vecTy, loc);
171+
auto *ctx = rewriter.getContext();
172+
OpBuilder builder(ctx);
173+
builder.setInsertionPointToStart(block);
174+
Value lhs = block->getArgument(0);
175+
Value rhs = block->getArgument(1);
176+
Value result =
177+
isMul ? LLVM::FMulOp::create(builder, loc, lhs, rhs).getResult()
178+
: LLVM::FAddOp::create(builder, loc, lhs, rhs).getResult();
179+
triton::ReduceReturnOp::create(builder, loc, ValueRange{result});
180+
return region;
181+
}
182+
183+
void maybeUnpackValuesf32x2(Location loc, ConversionPatternRewriter &rewriter,
184+
Region &combineOp,
185+
SmallVector<Value> &values) const {
186+
// If it has more than one output, it's not vectorized
187+
// There is a world where we just check whether the region has all arith ops
188+
// and we vectorize them all with a pass... but that's for another day.
189+
if (values.size() != 1)
190+
return;
191+
auto vecTy = dyn_cast<VectorType>(values.front().getType());
192+
if (!vecTy || vecTy.getNumElements() != 2 ||
193+
(!vecTy.getElementType().isF16() && !vecTy.getElementType().isF32()))
194+
return;
195+
// Perform the last (non-vectorized) combine operation.
196+
auto elems = unpackLLVector(loc, values.front(), rewriter);
197+
SmallVector<Value> acc = {elems[0]};
198+
accumulate(loc, rewriter, combineOp, acc, {elems[1]});
199+
values = std::move(acc);
200+
}
201+
202+
SmallVector<Value>
203+
treeReduceTernary(Location loc, ConversionPatternRewriter &rewriter,
204+
Region &combineOp,
205+
SmallVector<SmallVector<Value>> values) const {
206+
while (values.size() > 1) {
207+
SmallVector<SmallVector<Value>> next;
208+
size_t i = 0;
209+
for (; i + 2 < values.size(); i += 3) {
210+
SmallVector<Value> acc = values[i];
211+
accumulate(loc, rewriter, combineOp, acc, values[i + 1]);
212+
accumulate(loc, rewriter, combineOp, acc, values[i + 2]);
213+
next.push_back(std::move(acc));
214+
}
215+
// Process tail
216+
if (values.size() - i == 1) {
217+
next.push_back(std::move(values[i]));
218+
} else if (values.size() - i == 2) {
219+
SmallVector<Value> acc = values[i];
220+
accumulate(loc, rewriter, combineOp, acc, values[i + 1]);
221+
next.push_back(std::move(acc));
222+
}
223+
values = std::move(next);
224+
}
225+
return std::move(values.front());
226+
}
227+
228+
SmallVector<Value>
229+
treeReduceBinary(Location loc, ConversionPatternRewriter &rewriter,
230+
Region &combineOp,
231+
SmallVector<SmallVector<Value>> values) const {
232+
// The number of elements is always a power of two
233+
assert(llvm::isPowerOf2_64(values.size()) && !values.empty());
234+
auto vectorCombine = maybePackValuesf32x2(loc, rewriter, combineOp, values);
235+
Region &accumulateRegion = vectorCombine ? *vectorCombine : combineOp;
236+
while (values.size() > 1) {
237+
SmallVector<SmallVector<Value>> next;
238+
for (size_t i = 0; i + 1 < values.size(); i += 2) {
239+
SmallVector<Value> acc = values[i];
240+
accumulate(loc, rewriter, accumulateRegion, acc, values[i + 1]);
241+
next.push_back(std::move(acc));
242+
}
243+
values = std::move(next);
244+
}
245+
SmallVector<Value> val = std::move(values.front());
246+
maybeUnpackValuesf32x2(loc, rewriter, combineOp, val);
247+
return val;
248+
}
249+
113250
void accumulate(Location loc, ConversionPatternRewriter &rewriter,
114251
Region &combineOp, SmallVector<Value> &acc, ValueRange cur,
115252
Value pred = {}) const {
@@ -148,6 +285,7 @@ struct ReduceOpConversion
148285
return {std::move(layout), std::move(accs)};
149286
}
150287

288+
bool useTernary = useTernaryTreeReduction(op);
151289
// Bring the registers that move the axis to the front
152290
auto perm = ReduceOpHelper::makeAxisContiguous(layout, op.getAxis());
153291
if (!perm.isIdentity()) {
@@ -157,19 +295,24 @@ struct ReduceOpConversion
157295
}
158296
}
159297

160-
// Reduce linearly
161-
// TODO Perform a tree reduction
298+
// Reduce with a tree. Use a ternary tree when it can map to 3-input SASS
299+
// ops (IADD3/LOP3); otherwise use a binary tree.
162300
SmallVector<SmallVector<Value>> reduced(op.getNumOperands());
163301
for (unsigned regBase = 0; regBase < layout.getInDimSize(kReg);
164302
regBase += axisPack) {
165-
SmallVector<Value> acc;
303+
SmallVector<SmallVector<Value>> vals;
166304
for (unsigned i = 0; i < axisPack; ++i) {
167305
SmallVector<Value> cur(op.getNumOperands());
168306
for (unsigned opIdx = 0; opIdx < op.getNumOperands(); ++opIdx) {
169307
cur[opIdx] = accs[opIdx][regBase + i];
170308
}
171-
accumulate(op.getLoc(), rewriter, op.getCombineOp(), acc, cur);
309+
vals.push_back(std::move(cur));
172310
}
311+
auto acc = useTernary
312+
? treeReduceTernary(op.getLoc(), rewriter,
313+
op.getCombineOp(), std::move(vals))
314+
: treeReduceBinary(op.getLoc(), rewriter,
315+
op.getCombineOp(), std::move(vals));
173316
for (unsigned opIdx = 0; opIdx < op.getNumOperands(); ++opIdx) {
174317
reduced[opIdx].push_back(acc[opIdx]);
175318
}

0 commit comments

Comments
 (0)