11#include " ReduceScanCommon.h"
22
3+ #include < memory>
34#include < tuple>
45#include < utility>
56
7+ #include " mlir/Dialect/Arith/IR/Arith.h"
8+ #include " mlir/Dialect/LLVMIR/LLVMDialect.h"
69#include " mlir/Support/LLVM.h"
710#include " triton/Analysis/Utility.h"
811#include " triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h"
@@ -110,6 +113,142 @@ struct ReduceOpConversion
110113
111114private:
112115 const TargetInfoBase &targetInfo;
116+
117+ static bool useTernaryTreeReduction (triton::ReduceOp op) {
118+ if (op.getNumOperands () != 1 )
119+ return false ;
120+ auto elemTy = op.getElementTypes ()[0 ];
121+ if (!isa<IntegerType>(elemTy))
122+ return false ;
123+ Operation *combiner = op.getSingleCombiner ();
124+ if (!combiner)
125+ return false ;
126+ return isa<arith::AndIOp, arith::OrIOp, arith::XOrIOp, arith::AddIOp>(
127+ combiner);
128+ }
129+
130+ std::unique_ptr<Region>
131+ maybePackValuesf32x2 (Location loc, ConversionPatternRewriter &rewriter,
132+ Region &combineOp,
133+ SmallVector<SmallVector<Value>> &values) const {
134+ if (values.size () < 2 || (values.size () % 2 ) != 0 )
135+ return nullptr ;
136+ if (values.front ().size () != 1 )
137+ return nullptr ;
138+ auto elemTy = values.front ().front ().getType ();
139+ if (!(elemTy.isF16 () || elemTy.isBF16 () || elemTy.isF32 ()))
140+ return nullptr ;
141+ Operation *combiner = nullptr ;
142+ if (!combineOp.empty ()) {
143+ auto &block = combineOp.front ();
144+ if (block.getOperations ().size () == 2 )
145+ combiner = &block.front ();
146+ }
147+ if (!combiner)
148+ return nullptr ;
149+ if (!isa<arith::AddFOp, arith::MulFOp>(combiner))
150+ return nullptr ;
151+ bool isMul = isa<arith::MulFOp>(combiner);
152+ // Pack the values into 2-element vectors
153+ SmallVector<SmallVector<Value>> packed;
154+ for (size_t i = 0 ; i < values.size (); i += 2 ) {
155+ SmallVector<Value> vecTuple (values.front ().size ());
156+ for (unsigned opIdx = 0 ; opIdx < values.front ().size (); ++opIdx) {
157+ vecTuple[opIdx] = packLLVector (
158+ loc, {values[i][opIdx], values[i + 1 ][opIdx]}, rewriter);
159+ }
160+ packed.push_back (std::move (vecTuple));
161+ }
162+ values = std::move (packed);
163+ // Create a new region that takes 2-element vectors as inputs and returns a
164+ // 2-element vector
165+ auto region = std::make_unique<Region>();
166+ auto *block = new Block ();
167+ region->push_back (block);
168+ auto vecTy = vec_ty (elemTy, 2 );
169+ block->addArgument (vecTy, loc);
170+ block->addArgument (vecTy, loc);
171+ auto *ctx = rewriter.getContext ();
172+ OpBuilder builder (ctx);
173+ builder.setInsertionPointToStart (block);
174+ Value lhs = block->getArgument (0 );
175+ Value rhs = block->getArgument (1 );
176+ Value result =
177+ isMul ? LLVM::FMulOp::create (builder, loc, lhs, rhs).getResult ()
178+ : LLVM::FAddOp::create (builder, loc, lhs, rhs).getResult ();
179+ triton::ReduceReturnOp::create (builder, loc, ValueRange{result});
180+ return region;
181+ }
182+
183+ void maybeUnpackValuesf32x2 (Location loc, ConversionPatternRewriter &rewriter,
184+ Region &combineOp,
185+ SmallVector<Value> &values) const {
186+ // If it has more than one output, it's not vectorized
187+ // There is a world where we just check whether the region has all arith ops
188+ // and we vectorize them all with a pass... but that's for another day.
189+ if (values.size () != 1 )
190+ return ;
191+ auto vecTy = dyn_cast<VectorType>(values.front ().getType ());
192+ if (!vecTy)
193+ return ;
194+ assert (vecTy.getNumElements () == 2 );
195+ auto elemTy = vecTy.getElementType ();
196+ assert (elemTy.isF16 () || elemTy.isBF16 () || elemTy.isF32 ());
197+ // Perform the last (non-vectorized) combine operation.
198+ auto elems = unpackLLVector (loc, values.front (), rewriter);
199+ SmallVector<Value> acc = {elems[0 ]};
200+ accumulate (loc, rewriter, combineOp, acc, {elems[1 ]});
201+ values = std::move (acc);
202+ }
203+
204+ SmallVector<Value>
205+ treeReduceTernary (Location loc, ConversionPatternRewriter &rewriter,
206+ Region &combineOp,
207+ SmallVector<SmallVector<Value>> values) const {
208+ while (values.size () > 1 ) {
209+ SmallVector<SmallVector<Value>> next;
210+ size_t i = 0 ;
211+ for (; i + 2 < values.size (); i += 3 ) {
212+ SmallVector<Value> acc = values[i];
213+ accumulate (loc, rewriter, combineOp, acc, values[i + 1 ]);
214+ accumulate (loc, rewriter, combineOp, acc, values[i + 2 ]);
215+ next.push_back (std::move (acc));
216+ }
217+ // Process tail
218+ if (values.size () - i == 1 ) {
219+ next.push_back (std::move (values[i]));
220+ } else if (values.size () - i == 2 ) {
221+ SmallVector<Value> acc = values[i];
222+ accumulate (loc, rewriter, combineOp, acc, values[i + 1 ]);
223+ next.push_back (std::move (acc));
224+ }
225+ values = std::move (next);
226+ }
227+ return std::move (values.front ());
228+ }
229+
230+ SmallVector<Value>
231+ treeReduceBinary (Location loc, ConversionPatternRewriter &rewriter,
232+ Region &combineOp,
233+ SmallVector<SmallVector<Value>> values) const {
234+ // The number of elements is always a power of two
235+ assert (llvm::isPowerOf2_64 (values.size ()) && !values.empty ());
236+ auto vectorCombine = maybePackValuesf32x2 (loc, rewriter, combineOp, values);
237+ Region &accumulateRegion = vectorCombine ? *vectorCombine : combineOp;
238+ while (values.size () > 1 ) {
239+ SmallVector<SmallVector<Value>> next;
240+ for (size_t i = 0 ; i + 1 < values.size (); i += 2 ) {
241+ SmallVector<Value> acc = values[i];
242+ accumulate (loc, rewriter, accumulateRegion, acc, values[i + 1 ]);
243+ next.push_back (std::move (acc));
244+ }
245+ values = std::move (next);
246+ }
247+ SmallVector<Value> val = std::move (values.front ());
248+ maybeUnpackValuesf32x2 (loc, rewriter, combineOp, val);
249+ return val;
250+ }
251+
113252 void accumulate (Location loc, ConversionPatternRewriter &rewriter,
114253 Region &combineOp, SmallVector<Value> &acc, ValueRange cur,
115254 Value pred = {}) const {
@@ -154,6 +293,7 @@ struct ReduceOpConversion
154293 return {std::move (layout), std::move (accs)};
155294 }
156295
296+ bool useTernary = useTernaryTreeReduction (op);
157297 // Bring the registers that move the axis to the front
158298 auto perm = ReduceOpHelper::makeAxisContiguous (layout, op.getAxis ());
159299 if (!perm.isIdentity ()) {
@@ -163,19 +303,24 @@ struct ReduceOpConversion
163303 }
164304 }
165305
166- // Reduce linearly
167- // TODO Perform a tree reduction
306+ // Reduce with a tree. Use a ternary tree when it can map to 3-input SASS
307+ // ops (IADD3/LOP3); otherwise use a binary tree.
168308 SmallVector<SmallVector<Value>> reduced (op.getNumOperands ());
169309 for (unsigned regBase = 0 ; regBase < layout.getInDimSize (kReg );
170310 regBase += axisPack) {
171- SmallVector<Value> acc ;
311+ SmallVector<SmallVector< Value>> vals ;
172312 for (unsigned i = 0 ; i < axisPack; ++i) {
173313 SmallVector<Value> cur (op.getNumOperands ());
174314 for (unsigned opIdx = 0 ; opIdx < op.getNumOperands (); ++opIdx) {
175315 cur[opIdx] = accs[opIdx][regBase + i];
176316 }
177- accumulate (op. getLoc (), rewriter, op. getCombineOp (), acc, cur);
317+ vals. push_back ( std::move ( cur) );
178318 }
319+ auto acc = useTernary
320+ ? treeReduceTernary (op.getLoc (), rewriter,
321+ op.getCombineOp (), std::move (vals))
322+ : treeReduceBinary (op.getLoc (), rewriter,
323+ op.getCombineOp (), std::move (vals));
179324 for (unsigned opIdx = 0 ; opIdx < op.getNumOperands (); ++opIdx) {
180325 reduced[opIdx].push_back (acc[opIdx]);
181326 }
0 commit comments