@@ -78,10 +78,6 @@ unsigned ReduceOpHelper::getIntraWarpSizeWithUniqueData() {
7878 return getThreadsPerWarp (srcEncoding, srcShape)[axis];
7979}
8080
81- bool ReduceOpHelper::isWarpSynchronous () {
82- return getWarpsPerCTA (srcEncoding, srcShape)[axis] == 1 ;
83- }
84-
8581bool ReduceOpHelper::isReduceWithinCTA () {
8682 // TODO: Support reduce across CTAS
8783 // Layout optimization passes such as PlanCTAPass and
@@ -109,6 +105,35 @@ bool ReduceOpHelper::isAssociative() {
109105 return !hasNoAssociativeOp;
110106}
111107
108+ unsigned ReduceOpHelper::getScratchSizeInBytes () {
109+ auto kLane = StringAttr::get (op.getContext (), " lane" );
110+
111+ auto isReduced = [axis = axis](const LinearLayout &layout) {
112+ return layout.getOutDimSizes ().begin ()[axis] == 1 ;
113+ };
114+ auto regLl = reducedRegLaneLayout (srcTy, axis);
115+
116+ // All the inputs have the same layout so, since we order them from largest
117+ // bitsize to smallest, and the first one is aligned, by induction, they are
118+ // all aligned, so we don't need to align the byte numbers returned here.
119+ unsigned bytesRegToTmp = 0 ;
120+ while (!isReduced (regLl)) {
121+ auto tmpLl = getInterLayout (regLl, axis);
122+ // We take the maximum of the elements and multiply by the total bitwidth.
123+ // We do this as otherwise it's quite tricky to find the correct
124+ // BaseOffsets in the lowering.
125+ int bytes = 0 ;
126+ for (auto inputTy : op.getInputTypes ()) {
127+ auto nelem =
128+ getNumScratchElemsSwizzledCvt (regLl, tmpLl, getBitwidth (inputTy));
129+ bytes += nelem * (getBitwidth (inputTy) / 8 );
130+ }
131+ bytesRegToTmp = std::max<unsigned >(bytesRegToTmp, bytes);
132+ regLl = zeroBasesAlongDimAndReorder (tmpLl, axis, kLane );
133+ }
134+ return bytesRegToTmp;
135+ }
136+
112137ReduceOpHelper::InThreadVectorizeOpKind
113138ReduceOpHelper::getInThreadVectorizeOpKind (unsigned axisPack,
114139 bool supportBitwidth16Elementwise,
@@ -298,26 +323,90 @@ LinearLayout ReduceOpHelper::getInterLayout(const LinearLayout &layout,
298323 auto *ctx = layout.getOutDimNames ().begin ()->getContext ();
299324 auto kLane = mlir::StringAttr::get (ctx, " lane" );
300325 auto kWarp = mlir::StringAttr::get (ctx, " warp" );
301- auto regBases = layout.getBases ();
302- auto linearAttr = triton::gpu::LinearEncodingAttr::get (ctx, layout);
303- int laneBits = layout.getInDimSizeLog2 (kLane );
304- int neededLaneBits = llvm::Log2_32 (linearAttr.getWarpsPerCTA ()[axis]);
305- // TODO move to verifier
306- assert (neededLaneBits <= laneBits && " NYI: more inter-warps than lanes" );
307- // Move the warp axis bases we need to reduce into lane bases, while
308- // keeping non-axis components in their original in-dim.
309- auto &laneBases = regBases[kLane ];
310- auto &warpBases = regBases[kWarp ];
311- int moved = 0 ;
312- for (auto &warpBasis : warpBases) {
313- if (warpBasis[axis] == 0 )
314- continue ;
315- assert (moved < neededLaneBits && " unexpected warp axis bases count" );
316- std::swap (laneBases[moved], warpBasis);
317- moved++;
326+ auto kBlock = mlir::StringAttr::get (ctx, " block" );
327+ auto bases = layout.getBases ();
328+ auto &laneBases = bases[kLane ];
329+ auto &warpBases = bases[kWarp ];
330+ auto &blockBases = bases[kBlock ];
331+
332+ auto collectAxisBases = [&](ArrayRef<std::vector<int32_t >> bases) {
333+ SmallVector<unsigned > out;
334+ for (unsigned i = 0 ; i < bases.size (); ++i) {
335+ if (bases[i][axis] != 0 )
336+ out.push_back (i);
337+ }
338+ return out;
339+ };
340+
341+ SmallVector<unsigned > warpAxisBases = collectAxisBases (warpBases);
342+ SmallVector<unsigned > blockAxisBases = collectAxisBases (blockBases);
343+
344+ SmallVector<unsigned > zeroLaneBases;
345+ for (unsigned i = 0 ; i < laneBases.size (); ++i) {
346+ if (llvm::all_of (laneBases[i], [](int32_t v) { return v == 0 ; }))
347+ zeroLaneBases.push_back (i);
318348 }
319349
320- return LinearLayout (std::move (regBases), to_vector (layout.getOutDimNames ()));
350+ auto axisSize = to_vector (layout.getOutDimSizes ())[axis];
351+ auto totalAxisBases = warpAxisBases.size () + blockAxisBases.size ();
352+
353+ // First try to place all warp/block axis bases into lane bases that are
354+ // currently zero. If we can do this we will be able to perform the full
355+ // reduction with just one convert_layout
356+ if (zeroLaneBases.size () >= totalAxisBases) {
357+ unsigned laneIdx = 0 ;
358+ for (unsigned idx : warpAxisBases) {
359+ std::swap (laneBases[zeroLaneBases[laneIdx]], warpBases[idx]);
360+ ++laneIdx;
361+ }
362+ for (unsigned idx : blockAxisBases) {
363+ std::swap (laneBases[zeroLaneBases[laneIdx]], blockBases[idx]);
364+ ++laneIdx;
365+ }
366+ return LinearLayout (std::move (bases), to_vector (layout.getOutDimNames ()));
367+ }
368+
369+ // If we can fit all the bases inside the lane dimension, we can perform the
370+ // reduction with two convert_layouts
371+ // The first cvt to move the relevant bases to the lane dimension
372+ // The second to move all the bases we moved out of the lane dimension back to
373+ // their original positions
374+ if (warpAxisBases.size () + blockAxisBases.size () <= laneBases.size ()) {
375+ assert (totalAxisBases <= laneBases.size () &&
376+ " unexpected lane base count for axis layout" );
377+ unsigned laneIdx = 0 ;
378+ for (unsigned idx : warpAxisBases) {
379+ std::swap (laneBases[laneIdx], warpBases[idx]);
380+ ++laneIdx;
381+ }
382+ for (unsigned idx : blockAxisBases) {
383+ std::swap (laneBases[laneIdx], blockBases[idx]);
384+ ++laneIdx;
385+ }
386+ return LinearLayout (std::move (bases), to_vector (layout.getOutDimNames ()));
387+ }
388+
389+ // Assumptions (easily relaxed if AMD needs it)
390+ // We assume that
391+ // max number of warps * max number of blocks <= (max number of lanes)^2
392+ // We check this in logarithmic space (number of bases)
393+ // This is true in nvidia as the max numbers are warps=64 ctas=16 so that
394+ // 64 * 16 = 1024 = 32 * 32 = laneBases.size() * laneBases.size()
395+ // This implies that, even if we have to perform 3 cvt_layouts, we can perform
396+ // first one that does not cross CTAs, and then two that may cross CTAs
397+ assert (blockBases.size () <= laneBases.size ());
398+ assert (warpBases.size () + blockBases.size () <= 2 * laneBases.size ());
399+
400+ // Otherwise, fit as many warp bases as possible into the lane dimension
401+ unsigned laneIdx = 0 ;
402+ for (unsigned idx : warpAxisBases) {
403+ std::swap (laneBases[laneIdx], warpBases[idx]);
404+ ++laneIdx;
405+ if (laneIdx >= laneBases.size ())
406+ break ;
407+ }
408+
409+ return LinearLayout (std::move (bases), to_vector (layout.getOutDimNames ()));
321410}
322411
323412LinearLayout ReduceOpHelper::reducedRegLaneLayout (RankedTensorType srcTy,
@@ -327,9 +416,7 @@ LinearLayout ReduceOpHelper::reducedRegLaneLayout(RankedTensorType srcTy,
327416 auto kLane = StringAttr::get (ctx, " lane" );
328417 auto kWarp = StringAttr::get (ctx, " warp" );
329418
330- auto reduced = triton::gpu::toLinearLayout (srcTy);
331- reduced = reduced.sublayout ({kReg , kLane , kWarp },
332- to_vector (reduced.getOutDimNames ()));
419+ auto reduced = toLinearLayout (srcTy);
333420 reduced = actionRemoveBroadcastedRegs (reduced).apply (reduced);
334421
335422 reduced = moveAxisBasesToFront (reduced, axis).apply (reduced);
@@ -339,32 +426,6 @@ LinearLayout ReduceOpHelper::reducedRegLaneLayout(RankedTensorType srcTy,
339426 return reduced;
340427}
341428
342- SmallVector<unsigned >
343- ReduceOpHelper::getScratchBytesForCvt (const LinearLayout &srcLayout,
344- const LinearLayout &dstLayout) {
345- SmallVector<unsigned > bytes (srcElementTypes.size (), 0 );
346- auto *ctx = op.getContext ();
347- SmallVector<int64_t > shape;
348- shape.reserve (srcLayout.getNumOutDims ());
349- for (auto dim : srcLayout.getOutDimNames ()) {
350- shape.push_back (srcLayout.getOutDimSize (dim));
351- }
352- auto srcEnc = triton::gpu::LinearEncodingAttr::get (ctx, srcLayout);
353- auto dstEnc = triton::gpu::LinearEncodingAttr::get (ctx, dstLayout);
354- for (unsigned i = 0 ; i < srcElementTypes.size (); ++i) {
355- auto elemTy = srcElementTypes[i];
356- if (elemTy.isIntOrFloat () && elemTy.getIntOrFloatBitWidth () < 8 )
357- elemTy = IntegerType::get (ctx, 8 );
358- auto srcTy = RankedTensorType::get (shape, elemTy, srcEnc);
359- auto dstTy = RankedTensorType::get (shape, elemTy, dstEnc);
360- if (!cvtNeedsSharedMemory (srcTy, dstTy))
361- continue ;
362- auto elems = getNumScratchElemsSwizzledCvt (srcTy, dstTy);
363- bytes[i] = elems * getBitwidth (srcTy) / 8 ;
364- }
365- return bytes;
366- }
367-
368429ScanLoweringHelper::ScanLoweringHelper (triton::ScanOp op) : scanOp(op) {
369430 auto firstTy = cast<RankedTensorType>(op.getOperands ()[0 ].getType ());
370431 srcShape = firstTy.getShape ();
0 commit comments