@@ -76,10 +76,6 @@ unsigned ReduceOpHelper::getIntraWarpSizeWithUniqueData() {
7676 return getThreadsPerWarp (srcEncoding, srcShape)[axis];
7777}
7878
79- bool ReduceOpHelper::isWarpSynchronous () {
80- return getWarpsPerCTA (srcEncoding, srcShape)[axis] == 1 ;
81- }
82-
8379bool ReduceOpHelper::isReduceWithinCTA () {
8480 // TODO: Support reduce across CTAS
8581 // Layout optimization passes such as PlanCTAPass and
@@ -155,23 +151,97 @@ LinearLayout ReduceOpHelper::getInterLayout(const LinearLayout &layout,
155151 auto *ctx = layout.getOutDimNames ().begin ()->getContext ();
156152 auto kLane = mlir::StringAttr::get (ctx, " lane" );
157153 auto kWarp = mlir::StringAttr::get (ctx, " warp" );
154+ auto kBlock = mlir::StringAttr::get (ctx, " block" );
158155 auto regBases = layout.getBases ();
159- auto linearAttr = triton::gpu::LinearEncodingAttr::get (ctx, layout);
160- int laneBits = layout.getInDimSizeLog2 (kLane );
161- int neededLaneBits = llvm::Log2_32 (linearAttr.getWarpsPerCTA ()[axis]);
162- // TODO move to verifier
163- assert (neededLaneBits <= laneBits && " NYI: more inter-warps than lanes" );
164- // Move the warp axis bases we need to reduce into lane bases, while
165- // keeping non-axis components in their original in-dim.
166- auto &laneBases = regBases[kLane ];
167- auto &warpBases = regBases[kWarp ];
168- int moved = 0 ;
169- for (auto &warpBasis : warpBases) {
170- if (warpBasis[axis] == 0 )
171- continue ;
172- assert (moved < neededLaneBits && " unexpected warp axis bases count" );
173- std::swap (laneBases[moved], warpBasis);
174- moved++;
156+ auto laneIt = regBases.find (kLane );
157+ auto warpIt = regBases.find (kWarp );
158+ auto blockIt = regBases.find (kBlock );
159+ if (laneIt == regBases.end () || warpIt == regBases.end ()) {
160+ return layout;
161+ }
162+
163+ auto &laneBases = laneIt->second ;
164+ auto &warpBases = warpIt->second ;
165+ auto &blockBases = blockIt->second ;
166+
167+ auto collectAxisBases = [&](const std::vector<std::vector<int32_t >> &bases,
168+ SmallVector<unsigned > &out) {
169+ for (unsigned i = 0 ; i < bases.size (); ++i) {
170+ if (bases[i][axis] != 0 )
171+ out.push_back (i);
172+ }
173+ };
174+
175+ SmallVector<unsigned > warpAxisBases;
176+ collectAxisBases (warpBases, warpAxisBases);
177+ SmallVector<unsigned > blockAxisBases;
178+ collectAxisBases (blockBases, blockAxisBases);
179+
180+ SmallVector<unsigned > zeroLaneBases;
181+ for (unsigned i = 0 ; i < laneBases.size (); ++i) {
182+ if (llvm::all_of (laneBases[i], [](int32_t v) { return v == 0 ; }))
183+ zeroLaneBases.push_back (i);
184+ }
185+
186+ auto axisSize = to_vector (layout.getOutDimSizes ())[axis];
187+ auto totalAxisBases = warpAxisBases.size () + blockAxisBases.size ();
188+
189+ // First try to place all warp/block axis bases into lane bases that are
190+ // currently zero. If we can do this we will be able to perform the full
191+ // reduction with just one convert_layout
192+ if (zeroLaneBases.size () >= totalAxisBases) {
193+ unsigned laneIdx = 0 ;
194+ for (unsigned idx : warpAxisBases) {
195+ std::swap (laneBases[zeroLaneBases[laneIdx]], warpBases[idx]);
196+ ++laneIdx;
197+ }
198+ for (unsigned idx : blockAxisBases) {
199+ std::swap (laneBases[zeroLaneBases[laneIdx]], blockBases[idx]);
200+ ++laneIdx;
201+ }
202+ return LinearLayout (std::move (regBases),
203+ to_vector (layout.getOutDimNames ()));
204+ }
205+
206+ // If we can fit all the bases inside the lane dimension, we can perform the
207+ // reduction with two convert_layouts
208+ // The first cvt to move the relevant bases to the lane dimension
209+ // The second to move all the bases we moved out of the lane dimension back to
210+ // their original positions
211+ if (warpAxisBases.size () + blockAxisBases.size () <= laneBases.size ()) {
212+ assert (totalAxisBases <= laneBases.size () &&
213+ " unexpected lane base count for axis layout" );
214+ unsigned laneIdx = 0 ;
215+ for (unsigned idx : warpAxisBases) {
216+ std::swap (laneBases[laneIdx], warpBases[idx]);
217+ ++laneIdx;
218+ }
219+ for (unsigned idx : blockAxisBases) {
220+ std::swap (laneBases[laneIdx], blockBases[idx]);
221+ ++laneIdx;
222+ }
223+ return LinearLayout (std::move (regBases),
224+ to_vector (layout.getOutDimNames ()));
225+ }
226+
227+ // Assumptions (easily relaxed if AMD needs it)
228+ // We assume that
229+ // max number of warps * max number of blocks <= (max number of lanes)^2
230+ // We check this in logarithmic space (number of bases)
231+ // This is true in nvidia as the max numbers are warps=64 ctas=16 so that
232+ // 64 * 16 = 1024 = 32 * 32 = laneBases.size() * laneBases.size()
233+ // This implies that, even if we have to perform 3 cvt_layouts, we can perform
234+ // first one that does not cross CTAs, and then two that may cross CTAs
235+ assert (blockBases.size () <= laneBases.size ());
236+ assert (warpBases.size () + blockBases.size () <= 2 * laneBases.size ());
237+
238+ // Otherwise, fit as many warp bases as possible into the lane dimension
239+ unsigned laneIdx = 0 ;
240+ for (unsigned idx : warpAxisBases) {
241+ std::swap (laneBases[laneIdx], warpBases[idx]);
242+ ++laneIdx;
243+ if (laneIdx >= laneBases.size ())
244+ break ;
175245 }
176246
177247 return LinearLayout (std::move (regBases), to_vector (layout.getOutDimNames ()));
@@ -184,9 +254,7 @@ LinearLayout ReduceOpHelper::reducedRegLaneLayout(RankedTensorType srcTy,
184254 auto kLane = StringAttr::get (ctx, " lane" );
185255 auto kWarp = StringAttr::get (ctx, " warp" );
186256
187- auto reduced = triton::gpu::toLinearLayout (srcTy);
188- reduced = reduced.sublayout ({kReg , kLane , kWarp },
189- to_vector (reduced.getOutDimNames ()));
257+ auto reduced = toLinearLayout (srcTy);
190258 reduced = actionRemoveBroadcastedRegs (reduced).apply (reduced);
191259 reduced = makeAxisContiguous (reduced, axis).apply (reduced);
192260 reduced = zeroBasesAlongDimAndReorder (reduced, axis, kReg );
@@ -195,32 +263,6 @@ LinearLayout ReduceOpHelper::reducedRegLaneLayout(RankedTensorType srcTy,
195263 return reduced;
196264}
197265
198- SmallVector<unsigned >
199- ReduceOpHelper::getScratchBytesForCvt (const LinearLayout &srcLayout,
200- const LinearLayout &dstLayout) {
201- SmallVector<unsigned > bytes (srcElementTypes.size (), 0 );
202- auto *ctx = op.getContext ();
203- SmallVector<int64_t > shape;
204- shape.reserve (srcLayout.getNumOutDims ());
205- for (auto dim : srcLayout.getOutDimNames ()) {
206- shape.push_back (srcLayout.getOutDimSize (dim));
207- }
208- auto srcEnc = triton::gpu::LinearEncodingAttr::get (ctx, srcLayout);
209- auto dstEnc = triton::gpu::LinearEncodingAttr::get (ctx, dstLayout);
210- for (unsigned i = 0 ; i < srcElementTypes.size (); ++i) {
211- auto elemTy = srcElementTypes[i];
212- if (elemTy.isIntOrFloat () && elemTy.getIntOrFloatBitWidth () < 8 )
213- elemTy = IntegerType::get (ctx, 8 );
214- auto srcTy = RankedTensorType::get (shape, elemTy, srcEnc);
215- auto dstTy = RankedTensorType::get (shape, elemTy, dstEnc);
216- if (!cvtNeedsSharedMemory (srcTy, dstTy))
217- continue ;
218- auto elems = getNumScratchElemsSwizzledCvt (srcTy, dstTy);
219- bytes[i] = elems * getBitwidth (srcTy) / 8 ;
220- }
221- return bytes;
222- }
223-
224266ScanLoweringHelper::ScanLoweringHelper (triton::ScanOp op) : scanOp(op) {
225267 auto firstTy = cast<RankedTensorType>(op.getOperands ()[0 ].getType ());
226268 srcShape = firstTy.getShape ();
0 commit comments