@@ -155,23 +155,97 @@ LinearLayout ReduceOpHelper::getInterLayout(const LinearLayout &layout,
155155 auto *ctx = layout.getOutDimNames ().begin ()->getContext ();
156156 auto kLane = mlir::StringAttr::get (ctx, " lane" );
157157 auto kWarp = mlir::StringAttr::get (ctx, " warp" );
158+ auto kBlock = mlir::StringAttr::get (ctx, " block" );
158159 auto regBases = layout.getBases ();
159- auto linearAttr = triton::gpu::LinearEncodingAttr::get (ctx, layout);
160- int laneBits = layout.getInDimSizeLog2 (kLane );
161- int neededLaneBits = llvm::Log2_32 (linearAttr.getWarpsPerCTA ()[axis]);
162- // TODO move to verifier
163- assert (neededLaneBits <= laneBits && " NYI: more inter-warps than lanes" );
164- // Move the warp axis bases we need to reduce into lane bases, while
165- // keeping non-axis components in their original in-dim.
166- auto &laneBases = regBases[kLane ];
167- auto &warpBases = regBases[kWarp ];
168- int moved = 0 ;
169- for (auto &warpBasis : warpBases) {
170- if (warpBasis[axis] == 0 )
171- continue ;
172- assert (moved < neededLaneBits && " unexpected warp axis bases count" );
173- std::swap (laneBases[moved], warpBasis);
174- moved++;
160+ auto laneIt = regBases.find (kLane );
161+ auto warpIt = regBases.find (kWarp );
162+ auto blockIt = regBases.find (kBlock );
163+ if (laneIt == regBases.end () || warpIt == regBases.end ()) {
164+ return layout;
165+ }
166+
167+ auto &laneBases = laneIt->second ;
168+ auto &warpBases = warpIt->second ;
169+ auto &blockBases = blockIt->second ;
170+
171+ auto collectAxisBases = [&](const std::vector<std::vector<int32_t >> &bases,
172+ SmallVector<unsigned > &out) {
173+ for (unsigned i = 0 ; i < bases.size (); ++i) {
174+ if (bases[i][axis] != 0 )
175+ out.push_back (i);
176+ }
177+ };
178+
179+ SmallVector<unsigned > warpAxisBases;
180+ collectAxisBases (warpBases, warpAxisBases);
181+ SmallVector<unsigned > blockAxisBases;
182+ collectAxisBases (blockBases, blockAxisBases);
183+
184+ SmallVector<unsigned > zeroLaneBases;
185+ for (unsigned i = 0 ; i < laneBases.size (); ++i) {
186+ if (llvm::all_of (laneBases[i], [](int32_t v) { return v == 0 ; }))
187+ zeroLaneBases.push_back (i);
188+ }
189+
190+ auto axisSize = to_vector (layout.getOutDimSizes ())[axis];
191+ auto totalAxisBases = warpAxisBases.size () + blockAxisBases.size ();
192+
193+ // First try to place all warp/block axis bases into lane bases that are
194+ // currently zero. If we can do this we will be able to perform the full
195+ // reduction with just one convert_layout
196+ if (zeroLaneBases.size () >= totalAxisBases) {
197+ unsigned laneIdx = 0 ;
198+ for (unsigned idx : warpAxisBases) {
199+ std::swap (laneBases[zeroLaneBases[laneIdx]], warpBases[idx]);
200+ ++laneIdx;
201+ }
202+ for (unsigned idx : blockAxisBases) {
203+ std::swap (laneBases[zeroLaneBases[laneIdx]], blockBases[idx]);
204+ ++laneIdx;
205+ }
206+ return LinearLayout (std::move (regBases),
207+ to_vector (layout.getOutDimNames ()));
208+ }
209+
210+ // If we can fit all the bases inside the lane dimension, we can perform the
211+ // reduction with two convert_layouts
212+ // The first cvt to move the relevant bases to the lane dimension
213+ // The second to move all the bases we moved out of the lane dimension back to
214+ // their original positions
215+ if (warpAxisBases.size () + blockAxisBases.size () <= laneBases.size ()) {
216+ assert (totalAxisBases <= laneBases.size () &&
217+ " unexpected lane base count for axis layout" );
218+ unsigned laneIdx = 0 ;
219+ for (unsigned idx : warpAxisBases) {
220+ std::swap (laneBases[laneIdx], warpBases[idx]);
221+ ++laneIdx;
222+ }
223+ for (unsigned idx : blockAxisBases) {
224+ std::swap (laneBases[laneIdx], blockBases[idx]);
225+ ++laneIdx;
226+ }
227+ return LinearLayout (std::move (regBases),
228+ to_vector (layout.getOutDimNames ()));
229+ }
230+
231+ // Assumptions (easily relaxed if AMD needs it)
232+ // We assume that
233+ // max number of warps * max number of blocks <= (max number of lanes)^2
234+ // We check this in logarithmic space (number of bases)
235+ // This is true in nvidia as the max numbers are warps=64 ctas=16 so that
236+ // 64 * 16 = 1024 = 32 * 32 = laneBases.size() * laneBases.size()
237+ // This implies that, even if we have to perform 3 cvt_layouts, we can perform
238+ // first one that does not cross CTAs, and then two that may cross CTAs
239+ assert (blockBases.size () <= laneBases.size ());
240+ assert (warpBases.size () + blockBases.size () <= 2 * laneBases.size ());
241+
242+ // Otherwise, fit as many warp bases as possible into the lane dimension
243+ unsigned laneIdx = 0 ;
244+ for (unsigned idx : warpAxisBases) {
245+ std::swap (laneBases[laneIdx], warpBases[idx]);
246+ ++laneIdx;
247+ if (laneIdx >= laneBases.size ())
248+ break ;
175249 }
176250
177251 return LinearLayout (std::move (regBases), to_vector (layout.getOutDimNames ()));
@@ -184,9 +258,7 @@ LinearLayout ReduceOpHelper::reducedRegLaneLayout(RankedTensorType srcTy,
184258 auto kLane = StringAttr::get (ctx, " lane" );
185259 auto kWarp = StringAttr::get (ctx, " warp" );
186260
187- auto reduced = triton::gpu::toLinearLayout (srcTy);
188- reduced = reduced.sublayout ({kReg , kLane , kWarp },
189- to_vector (reduced.getOutDimNames ()));
261+ auto reduced = toLinearLayout (srcTy);
190262 reduced = actionRemoveBroadcastedRegs (reduced).apply (reduced);
191263 reduced = makeAxisContiguous (reduced, axis).apply (reduced);
192264 reduced = zeroBasesAlongDimAndReorder (reduced, axis, kReg );
@@ -195,32 +267,6 @@ LinearLayout ReduceOpHelper::reducedRegLaneLayout(RankedTensorType srcTy,
195267 return reduced;
196268}
197269
198- SmallVector<unsigned >
199- ReduceOpHelper::getScratchBytesForCvt (const LinearLayout &srcLayout,
200- const LinearLayout &dstLayout) {
201- SmallVector<unsigned > bytes (srcElementTypes.size (), 0 );
202- auto *ctx = op.getContext ();
203- SmallVector<int64_t > shape;
204- shape.reserve (srcLayout.getNumOutDims ());
205- for (auto dim : srcLayout.getOutDimNames ()) {
206- shape.push_back (srcLayout.getOutDimSize (dim));
207- }
208- auto srcEnc = triton::gpu::LinearEncodingAttr::get (ctx, srcLayout);
209- auto dstEnc = triton::gpu::LinearEncodingAttr::get (ctx, dstLayout);
210- for (unsigned i = 0 ; i < srcElementTypes.size (); ++i) {
211- auto elemTy = srcElementTypes[i];
212- if (elemTy.isIntOrFloat () && elemTy.getIntOrFloatBitWidth () < 8 )
213- elemTy = IntegerType::get (ctx, 8 );
214- auto srcTy = RankedTensorType::get (shape, elemTy, srcEnc);
215- auto dstTy = RankedTensorType::get (shape, elemTy, dstEnc);
216- if (!cvtNeedsSharedMemory (srcTy, dstTy))
217- continue ;
218- auto elems = getNumScratchElemsSwizzledCvt (srcTy, dstTy);
219- bytes[i] = elems * getBitwidth (srcTy) / 8 ;
220- }
221- return bytes;
222- }
223-
224270ScanLoweringHelper::ScanLoweringHelper (triton::ScanOp op) : scanOp(op) {
225271 auto firstTy = cast<RankedTensorType>(op.getOperands ()[0 ].getType ());
226272 srcShape = firstTy.getShape ();
0 commit comments