11#include " ReduceScanCommon.h"
22
3+ #include < memory>
34#include < tuple>
45#include < utility>
56
7+ #include " mlir/Dialect/Arith/IR/Arith.h"
8+ #include " mlir/Dialect/LLVMIR/LLVMDialect.h"
69#include " mlir/Support/LLVM.h"
710#include " triton/Analysis/Utility.h"
811#include " triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h"
@@ -110,6 +113,140 @@ struct ReduceOpConversion
110113
111114private:
112115 const TargetInfoBase &targetInfo;
116+
117+ static bool useTernaryTreeReduction (triton::ReduceOp op) {
118+ if (op.getNumOperands () != 1 )
119+ return false ;
120+ auto elemTy = op.getElementTypes ()[0 ];
121+ if (!isa<IntegerType>(elemTy))
122+ return false ;
123+ Operation *combiner = op.getSingleCombiner ();
124+ if (!combiner)
125+ return false ;
126+ return isa<arith::AndIOp, arith::OrIOp, arith::XOrIOp, arith::AddIOp>(
127+ combiner);
128+ }
129+
130+ std::unique_ptr<Region>
131+ maybePackValuesf32x2 (Location loc, ConversionPatternRewriter &rewriter,
132+ Region &combineOp,
133+ SmallVector<SmallVector<Value>> &values) const {
134+ if (values.size () < 2 || (values.size () % 2 ) != 0 )
135+ return nullptr ;
136+ if (values.front ().size () != 1 )
137+ return nullptr ;
138+ auto elemTy = values.front ().front ().getType ();
139+ if (!(elemTy.isF16 () || elemTy.isBF16 () || elemTy.isF32 ()))
140+ return nullptr ;
141+ Operation *combiner = nullptr ;
142+ if (!combineOp.empty ()) {
143+ auto &block = combineOp.front ();
144+ if (block.getOperations ().size () == 2 )
145+ combiner = &block.front ();
146+ }
147+ if (!combiner)
148+ return nullptr ;
149+ if (!isa<arith::AddFOp, arith::MulFOp>(combiner))
150+ return nullptr ;
151+ bool isMul = isa<arith::MulFOp>(combiner);
152+ // Pack the values into 2-element vectors
153+ SmallVector<SmallVector<Value>> packed;
154+ for (size_t i = 0 ; i < values.size (); i += 2 ) {
155+ SmallVector<Value> vecTuple (values.front ().size ());
156+ for (unsigned opIdx = 0 ; opIdx < values.front ().size (); ++opIdx) {
157+ vecTuple[opIdx] = packLLVector (
158+ loc, {values[i][opIdx], values[i + 1 ][opIdx]}, rewriter);
159+ }
160+ packed.push_back (std::move (vecTuple));
161+ }
162+ values = std::move (packed);
163+ // Create a new region that takes 2-element vectors as inputs and returns a
164+ // 2-element vector
165+ auto region = std::make_unique<Region>();
166+ auto *block = new Block ();
167+ region->push_back (block);
168+ auto vecTy = vec_ty (elemTy, 2 );
169+ block->addArgument (vecTy, loc);
170+ block->addArgument (vecTy, loc);
171+ auto *ctx = rewriter.getContext ();
172+ OpBuilder builder (ctx);
173+ builder.setInsertionPointToStart (block);
174+ Value lhs = block->getArgument (0 );
175+ Value rhs = block->getArgument (1 );
176+ Value result =
177+ isMul ? LLVM::FMulOp::create (builder, loc, lhs, rhs).getResult ()
178+ : LLVM::FAddOp::create (builder, loc, lhs, rhs).getResult ();
179+ triton::ReduceReturnOp::create (builder, loc, ValueRange{result});
180+ return region;
181+ }
182+
183+ void maybeUnpackValuesf32x2 (Location loc, ConversionPatternRewriter &rewriter,
184+ Region &combineOp,
185+ SmallVector<Value> &values) const {
186+ // If it has more than one output, it's not vectorized
187+ // There is a world where we just check whether the region has all arith ops
188+ // and we vectorize them all with a pass... but that's for another day.
189+ if (values.size () != 1 )
190+ return ;
191+ auto vecTy = dyn_cast<VectorType>(values.front ().getType ());
192+ if (!vecTy || vecTy.getNumElements () != 2 ||
193+ (!vecTy.getElementType ().isF16 () && !vecTy.getElementType ().isF32 ()))
194+ return ;
195+ // Perform the last (non-vectorized) combine operation.
196+ auto elems = unpackLLVector (loc, values.front (), rewriter);
197+ SmallVector<Value> acc = {elems[0 ]};
198+ accumulate (loc, rewriter, combineOp, acc, {elems[1 ]});
199+ values = std::move (acc);
200+ }
201+
202+ SmallVector<Value>
203+ treeReduceTernary (Location loc, ConversionPatternRewriter &rewriter,
204+ Region &combineOp,
205+ SmallVector<SmallVector<Value>> values) const {
206+ while (values.size () > 1 ) {
207+ SmallVector<SmallVector<Value>> next;
208+ size_t i = 0 ;
209+ for (; i + 2 < values.size (); i += 3 ) {
210+ SmallVector<Value> acc = values[i];
211+ accumulate (loc, rewriter, combineOp, acc, values[i + 1 ]);
212+ accumulate (loc, rewriter, combineOp, acc, values[i + 2 ]);
213+ next.push_back (std::move (acc));
214+ }
215+ // Process tail
216+ if (values.size () - i == 1 ) {
217+ next.push_back (std::move (values[i]));
218+ } else if (values.size () - i == 2 ) {
219+ SmallVector<Value> acc = values[i];
220+ accumulate (loc, rewriter, combineOp, acc, values[i + 1 ]);
221+ next.push_back (std::move (acc));
222+ }
223+ values = std::move (next);
224+ }
225+ return std::move (values.front ());
226+ }
227+
228+ SmallVector<Value>
229+ treeReduceBinary (Location loc, ConversionPatternRewriter &rewriter,
230+ Region &combineOp,
231+ SmallVector<SmallVector<Value>> values) const {
232+ // The number of elements is always a power of two
233+ assert (llvm::isPowerOf2_64 (values.size ()) && !values.empty ());
234+ auto vectorCombine = maybePackValuesf32x2 (loc, rewriter, combineOp, values);
235+ Region &accumulateRegion = vectorCombine ? *vectorCombine : combineOp;
236+ while (values.size () > 1 ) {
237+ SmallVector<SmallVector<Value>> next;
238+ for (size_t i = 0 ; i + 1 < values.size (); i += 2 ) {
239+ SmallVector<Value> acc = values[i];
240+ accumulate (loc, rewriter, accumulateRegion, acc, values[i + 1 ]);
241+ next.push_back (std::move (acc));
242+ }
243+ values = std::move (next);
244+ }
245+ SmallVector<Value> val = std::move (values.front ());
246+ maybeUnpackValuesf32x2 (loc, rewriter, combineOp, val);
247+ return val;
248+ }
249+
113250 void accumulate (Location loc, ConversionPatternRewriter &rewriter,
114251 Region &combineOp, SmallVector<Value> &acc, ValueRange cur,
115252 Value pred = {}) const {
@@ -148,6 +285,7 @@ struct ReduceOpConversion
148285 return {std::move (layout), std::move (accs)};
149286 }
150287
288+ bool useTernary = useTernaryTreeReduction (op);
151289 // Bring the registers that move the axis to the front
152290 auto perm = ReduceOpHelper::makeAxisContiguous (layout, op.getAxis ());
153291 if (!perm.isIdentity ()) {
@@ -157,19 +295,24 @@ struct ReduceOpConversion
157295 }
158296 }
159297
160- // Reduce linearly
161- // TODO Perform a tree reduction
298+ // Reduce with a tree. Use a ternary tree when it can map to 3-input SASS
299+ // ops (IADD3/LOP3); otherwise use a binary tree.
162300 SmallVector<SmallVector<Value>> reduced (op.getNumOperands ());
163301 for (unsigned regBase = 0 ; regBase < layout.getInDimSize (kReg );
164302 regBase += axisPack) {
165- SmallVector<Value> acc ;
303+ SmallVector<SmallVector< Value>> vals ;
166304 for (unsigned i = 0 ; i < axisPack; ++i) {
167305 SmallVector<Value> cur (op.getNumOperands ());
168306 for (unsigned opIdx = 0 ; opIdx < op.getNumOperands (); ++opIdx) {
169307 cur[opIdx] = accs[opIdx][regBase + i];
170308 }
171- accumulate (op. getLoc (), rewriter, op. getCombineOp (), acc, cur);
309+ vals. push_back ( std::move ( cur) );
172310 }
311+ auto acc = useTernary
312+ ? treeReduceTernary (op.getLoc (), rewriter,
313+ op.getCombineOp (), std::move (vals))
314+ : treeReduceBinary (op.getLoc (), rewriter,
315+ op.getCombineOp (), std::move (vals));
173316 for (unsigned opIdx = 0 ; opIdx < op.getNumOperands (); ++opIdx) {
174317 reduced[opIdx].push_back (acc[opIdx]);
175318 }
0 commit comments