@@ -78,10 +78,6 @@ unsigned ReduceOpHelper::getIntraWarpSizeWithUniqueData() {
7878 return getThreadsPerWarp (srcEncoding, srcShape)[axis];
7979}
8080
81- bool ReduceOpHelper::isWarpSynchronous () {
82- return getWarpsPerCTA (srcEncoding, srcShape)[axis] == 1 ;
83- }
84-
8581bool ReduceOpHelper::isReduceWithinCTA () {
8682 // TODO: Support reduce across CTAS
8783 // Layout optimization passes such as PlanCTAPass and
@@ -109,6 +105,35 @@ bool ReduceOpHelper::isAssociative() {
109105 return !hasNoAssociativeOp;
110106}
111107
108+ unsigned ReduceOpHelper::getScratchSizeInBytes () {
109+ auto kLane = StringAttr::get (op.getContext (), " lane" );
110+
111+ auto isReduced = [axis = axis](const LinearLayout &layout) {
112+ return layout.getOutDimSizes ().begin ()[axis] == 1 ;
113+ };
114+ auto regLl = reducedRegLaneLayout (srcTy, axis);
115+
116+ // All the inputs have the same layout so, since we order them from largest
117+ // bitsize to smallest, and the first one is aligned, by induction, they are
118+ // all aligned, so we don't need to align the byte numbers returned here.
119+ unsigned bytesRegToTmp = 0 ;
120+ while (!isReduced (regLl)) {
121+ auto tmpLl = getInterLayout (regLl, axis);
122+ // We take the maximum of the elements and multiply by the total bitwidth.
123+ // We do this as otherwise it's quite tricky to find the correct
124+ // BaseOffsets in the lowering.
125+ int bytes = 0 ;
126+ for (auto inputTy : op.getInputTypes ()) {
127+ auto nelem =
128+ getNumScratchElemsSwizzledCvt (regLl, tmpLl, getBitwidth (inputTy));
129+ bytes += nelem * (getBitwidth (inputTy) / 8 );
130+ }
131+ bytesRegToTmp = std::max<unsigned >(bytesRegToTmp, bytes);
132+ regLl = zeroBasesAlongDimAndReorder (tmpLl, axis, kLane );
133+ }
134+ return bytesRegToTmp;
135+ }
136+
112137ReduceOpHelper::InThreadVectorizeOpKind
113138ReduceOpHelper::getInThreadVectorizeOpKind (unsigned axisPack) {
114139 Operation *reduceOperation = op.getOperation ();
@@ -291,26 +316,90 @@ LinearLayout ReduceOpHelper::getInterLayout(const LinearLayout &layout,
291316 auto *ctx = layout.getOutDimNames ().begin ()->getContext ();
292317 auto kLane = mlir::StringAttr::get (ctx, " lane" );
293318 auto kWarp = mlir::StringAttr::get (ctx, " warp" );
294- auto regBases = layout.getBases ();
295- auto linearAttr = triton::gpu::LinearEncodingAttr::get (ctx, layout);
296- int laneBits = layout.getInDimSizeLog2 (kLane );
297- int neededLaneBits = llvm::Log2_32 (linearAttr.getWarpsPerCTA ()[axis]);
298- // TODO move to verifier
299- assert (neededLaneBits <= laneBits && " NYI: more inter-warps than lanes" );
300- // Move the warp axis bases we need to reduce into lane bases, while
301- // keeping non-axis components in their original in-dim.
302- auto &laneBases = regBases[kLane ];
303- auto &warpBases = regBases[kWarp ];
304- int moved = 0 ;
305- for (auto &warpBasis : warpBases) {
306- if (warpBasis[axis] == 0 )
307- continue ;
308- assert (moved < neededLaneBits && " unexpected warp axis bases count" );
309- std::swap (laneBases[moved], warpBasis);
310- moved++;
319+ auto kBlock = mlir::StringAttr::get (ctx, " block" );
320+ auto bases = layout.getBases ();
321+ auto &laneBases = bases[kLane ];
322+ auto &warpBases = bases[kWarp ];
323+ auto &blockBases = bases[kBlock ];
324+
325+ auto collectAxisBases = [&](ArrayRef<std::vector<int32_t >> bases) {
326+ SmallVector<unsigned > out;
327+ for (unsigned i = 0 ; i < bases.size (); ++i) {
328+ if (bases[i][axis] != 0 )
329+ out.push_back (i);
330+ }
331+ return out;
332+ };
333+
334+ SmallVector<unsigned > warpAxisBases = collectAxisBases (warpBases);
335+ SmallVector<unsigned > blockAxisBases = collectAxisBases (blockBases);
336+
337+ SmallVector<unsigned > zeroLaneBases;
338+ for (unsigned i = 0 ; i < laneBases.size (); ++i) {
339+ if (llvm::all_of (laneBases[i], [](int32_t v) { return v == 0 ; }))
340+ zeroLaneBases.push_back (i);
311341 }
312342
313- return LinearLayout (std::move (regBases), to_vector (layout.getOutDimNames ()));
343+ auto axisSize = to_vector (layout.getOutDimSizes ())[axis];
344+ auto totalAxisBases = warpAxisBases.size () + blockAxisBases.size ();
345+
346+ // First try to place all warp/block axis bases into lane bases that are
347+ // currently zero. If we can do this we will be able to perform the full
348+ // reduction with just one convert_layout
349+ if (zeroLaneBases.size () >= totalAxisBases) {
350+ unsigned laneIdx = 0 ;
351+ for (unsigned idx : warpAxisBases) {
352+ std::swap (laneBases[zeroLaneBases[laneIdx]], warpBases[idx]);
353+ ++laneIdx;
354+ }
355+ for (unsigned idx : blockAxisBases) {
356+ std::swap (laneBases[zeroLaneBases[laneIdx]], blockBases[idx]);
357+ ++laneIdx;
358+ }
359+ return LinearLayout (std::move (bases), to_vector (layout.getOutDimNames ()));
360+ }
361+
362+ // If we can fit all the bases inside the lane dimension, we can perform the
363+ // reduction with two convert_layouts
364+ // The first cvt to move the relevant bases to the lane dimension
365+ // The second to move all the bases we moved out of the lane dimension back to
366+ // their original positions
367+ if (warpAxisBases.size () + blockAxisBases.size () <= laneBases.size ()) {
368+ assert (totalAxisBases <= laneBases.size () &&
369+ " unexpected lane base count for axis layout" );
370+ unsigned laneIdx = 0 ;
371+ for (unsigned idx : warpAxisBases) {
372+ std::swap (laneBases[laneIdx], warpBases[idx]);
373+ ++laneIdx;
374+ }
375+ for (unsigned idx : blockAxisBases) {
376+ std::swap (laneBases[laneIdx], blockBases[idx]);
377+ ++laneIdx;
378+ }
379+ return LinearLayout (std::move (bases), to_vector (layout.getOutDimNames ()));
380+ }
381+
382+ // Assumptions (easily relaxed if AMD needs it)
383+ // We assume that
384+ // max number of warps * max number of blocks <= (max number of lanes)^2
385+ // We check this in logarithmic space (number of bases)
386+ // This is true in nvidia as the max numbers are warps=64 ctas=16 so that
387+ // 64 * 16 = 1024 = 32 * 32 = laneBases.size() * laneBases.size()
388+ // This implies that, even if we have to perform 3 cvt_layouts, we can perform
389+ // first one that does not cross CTAs, and then two that may cross CTAs
390+ assert (blockBases.size () <= laneBases.size ());
391+ assert (warpBases.size () + blockBases.size () <= 2 * laneBases.size ());
392+
393+ // Otherwise, fit as many warp bases as possible into the lane dimension
394+ unsigned laneIdx = 0 ;
395+ for (unsigned idx : warpAxisBases) {
396+ std::swap (laneBases[laneIdx], warpBases[idx]);
397+ ++laneIdx;
398+ if (laneIdx >= laneBases.size ())
399+ break ;
400+ }
401+
402+ return LinearLayout (std::move (bases), to_vector (layout.getOutDimNames ()));
314403}
315404
316405LinearLayout ReduceOpHelper::reducedRegLaneLayout (RankedTensorType srcTy,
@@ -320,9 +409,7 @@ LinearLayout ReduceOpHelper::reducedRegLaneLayout(RankedTensorType srcTy,
320409 auto kLane = StringAttr::get (ctx, " lane" );
321410 auto kWarp = StringAttr::get (ctx, " warp" );
322411
323- auto reduced = triton::gpu::toLinearLayout (srcTy);
324- reduced = reduced.sublayout ({kReg , kLane , kWarp },
325- to_vector (reduced.getOutDimNames ()));
412+ auto reduced = toLinearLayout (srcTy);
326413 reduced = actionRemoveBroadcastedRegs (reduced).apply (reduced);
327414
328415 reduced = moveAxisBasesToFront (reduced, axis).apply (reduced);
@@ -332,32 +419,6 @@ LinearLayout ReduceOpHelper::reducedRegLaneLayout(RankedTensorType srcTy,
332419 return reduced;
333420}
334421
335- SmallVector<unsigned >
336- ReduceOpHelper::getScratchBytesForCvt (const LinearLayout &srcLayout,
337- const LinearLayout &dstLayout) {
338- SmallVector<unsigned > bytes (srcElementTypes.size (), 0 );
339- auto *ctx = op.getContext ();
340- SmallVector<int64_t > shape;
341- shape.reserve (srcLayout.getNumOutDims ());
342- for (auto dim : srcLayout.getOutDimNames ()) {
343- shape.push_back (srcLayout.getOutDimSize (dim));
344- }
345- auto srcEnc = triton::gpu::LinearEncodingAttr::get (ctx, srcLayout);
346- auto dstEnc = triton::gpu::LinearEncodingAttr::get (ctx, dstLayout);
347- for (unsigned i = 0 ; i < srcElementTypes.size (); ++i) {
348- auto elemTy = srcElementTypes[i];
349- if (elemTy.isIntOrFloat () && elemTy.getIntOrFloatBitWidth () < 8 )
350- elemTy = IntegerType::get (ctx, 8 );
351- auto srcTy = RankedTensorType::get (shape, elemTy, srcEnc);
352- auto dstTy = RankedTensorType::get (shape, elemTy, dstEnc);
353- if (!cvtNeedsSharedMemory (srcTy, dstTy))
354- continue ;
355- auto elems = getNumScratchElemsSwizzledCvt (srcTy, dstTy);
356- bytes[i] = elems * getBitwidth (srcTy) / 8 ;
357- }
358- return bytes;
359- }
360-
361422ScanLoweringHelper::ScanLoweringHelper (triton::ScanOp op) : scanOp(op) {
362423 auto firstTy = cast<RankedTensorType>(op.getOperands ()[0 ].getType ());
363424 srcShape = firstTy.getShape ();
0 commit comments