-
Notifications
You must be signed in to change notification settings - Fork 2.8k
[BACKEND] Improve and simplify ReduceOp's lowering #9219
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,33 +23,6 @@ namespace mlir { | |
| using namespace triton; | ||
| using namespace triton::gpu; | ||
|
|
||
| SmallVector<unsigned> ReduceOpHelper::getOrderWithAxisAtBeginning() { | ||
| auto order = toLinearEncoding(srcTy).getOrder(); | ||
| auto it = std::find(order.begin(), order.end(), axis); | ||
| // delete the axis from order | ||
| order.erase(it); | ||
| // insert axis at the beginning of order | ||
| order.insert(order.begin(), axis); | ||
| return order; | ||
| } | ||
|
|
||
| // Thread offset is the thread index offset of two adjacent threads on the | ||
| // reduction axis within the warp. | ||
| unsigned ReduceOpHelper::getThreadOffsetOnReductionAxis() { | ||
| auto *ctx = srcEncoding.getContext(); | ||
| auto linearLayout = toLinearLayout(srcTy); | ||
| auto kLane = mlir::StringAttr::get(ctx, "lane"); | ||
| const auto &bases = linearLayout.getBases(); | ||
| const auto &lanes = bases.find(kLane)->second; | ||
| auto offset = 1; | ||
| for (const auto &lane : lanes) { | ||
| if (lane[axis] != 0) | ||
| break; | ||
| offset *= 2; | ||
| } | ||
| return offset; | ||
| } | ||
|
|
||
| // Cases where distributed shared memory is not required in ConvertLayout: | ||
| // (1) numCTAs == 1 | ||
| // (2) numCTAs > 1 but srcCGALayout == dstCGALayout | ||
|
|
@@ -107,29 +80,6 @@ bool ReduceOpHelper::isWarpSynchronous() { | |
| return getWarpsPerCTA(srcEncoding, srcShape)[axis] == 1; | ||
| } | ||
|
|
||
| SmallVector<unsigned> ReduceOpHelper::getScratchRepShape() { | ||
| SmallVector<unsigned> smemShape; | ||
| // This case doesn't need inter-warp communication | ||
| if (isWarpSynchronous()) | ||
| return {0, 0}; | ||
|
|
||
| smemShape = convertType<unsigned>(srcShape); | ||
| smemShape[axis] = getInterWarpSizeWithUniqueData(); | ||
|
|
||
| return smemShape; | ||
| } | ||
|
|
||
| unsigned ReduceOpHelper::getScratchSizeInBytes() { | ||
| auto smemShape = getScratchRepShape(); | ||
| auto elems = product<unsigned>(smemShape); | ||
|
|
||
| unsigned bytesPerElem = 0; | ||
| for (const auto &ty : srcElementTypes) { | ||
| bytesPerElem += ceil<unsigned>(ty.getIntOrFloatBitWidth(), 8); | ||
| } | ||
| return bytesPerElem * elems; | ||
| } | ||
|
|
||
| bool ReduceOpHelper::isReduceWithinCTA() { | ||
| // TODO: Support reduce across CTAS | ||
| // Layout optimization passes such as PlanCTAPass and | ||
|
|
@@ -157,6 +107,125 @@ bool ReduceOpHelper::isAssociative() { | |
| return !hasNoAssociativeOp; | ||
| } | ||
|
|
||
| ColumnAction ReduceOpHelper::moveAxisBasesToFront(const LinearLayout &layout, | ||
| int axis) { | ||
| auto *ctx = layout.getOutDimNames().begin()->getContext(); | ||
| auto kReg = StringAttr::get(ctx, "register"); | ||
| const auto &bases = layout.getBases().lookup(kReg); | ||
| SmallVector<size_t> perm; | ||
| SmallVector<size_t> back; | ||
| for (size_t i = 0; i < bases.size(); ++i) { | ||
| if (bases[i][axis] != 0) | ||
| perm.push_back(i); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. NIT: This doesn't make it contiguous, i.e. strictly linearly ordered. I would say it's more like "group elements by axis".
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. renamed to |
||
| else | ||
| back.push_back(i); | ||
| } | ||
| perm.append(back.begin(), back.end()); | ||
| return ColumnAction(perm, kReg, bases.size()); | ||
| } | ||
|
|
||
| LinearLayout | ||
| ReduceOpHelper::zeroBasesAlongDimAndReorder(const LinearLayout &layout, | ||
| unsigned axis, StringAttr dim) { | ||
| // Zeros out the basis along the specified axis in the given hardware | ||
| // dimension, and reindexes the remaining bases along axis so that each | ||
| // element is in linearly increasing order from the hardware's perspective. | ||
| // Note that for this reordering we need the operator to be commutative, but | ||
| // it's the only way to have a performant lowering. | ||
| LinearLayout::BasesT newBases; | ||
| for (auto [inDim, bases] : layout.getBases()) { | ||
| std::vector<std::vector<int32_t>> newInBases = bases; | ||
| if (inDim == dim) { | ||
| for (auto &basis : newInBases) | ||
| basis[axis] = 0; | ||
| } | ||
| newBases[inDim] = std::move(newInBases); | ||
| } | ||
|
|
||
| int32_t nextAxisBase = 1; | ||
| for (auto &[inDim, inDimBases] : newBases) { | ||
| for (auto &basis : inDimBases) { | ||
| if (basis[axis] == 0) | ||
| continue; | ||
| basis[axis] = nextAxisBase; | ||
| nextAxisBase *= 2; | ||
| } | ||
| } | ||
|
|
||
| return LinearLayout(std::move(newBases), to_vector(layout.getOutDimNames())); | ||
|
peterbell10 marked this conversation as resolved.
|
||
| } | ||
|
|
||
| LinearLayout ReduceOpHelper::getInterLayout(const LinearLayout &layout, | ||
| unsigned axis) { | ||
| auto *ctx = layout.getOutDimNames().begin()->getContext(); | ||
| auto kLane = mlir::StringAttr::get(ctx, "lane"); | ||
| auto kWarp = mlir::StringAttr::get(ctx, "warp"); | ||
| auto regBases = layout.getBases(); | ||
| auto linearAttr = triton::gpu::LinearEncodingAttr::get(ctx, layout); | ||
| int laneBits = layout.getInDimSizeLog2(kLane); | ||
| int neededLaneBits = llvm::Log2_32(linearAttr.getWarpsPerCTA()[axis]); | ||
| // TODO move to verifier | ||
| assert(neededLaneBits <= laneBits && "NYI: more inter-warps than lanes"); | ||
|
peterbell10 marked this conversation as resolved.
|
||
| // Move the warp axis bases we need to reduce into lane bases, while | ||
| // keeping non-axis components in their original in-dim. | ||
| auto &laneBases = regBases[kLane]; | ||
| auto &warpBases = regBases[kWarp]; | ||
| int moved = 0; | ||
| for (auto &warpBasis : warpBases) { | ||
| if (warpBasis[axis] == 0) | ||
| continue; | ||
| assert(moved < neededLaneBits && "unexpected warp axis bases count"); | ||
| std::swap(laneBases[moved], warpBasis); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. NIT: the swap is a bit confusing. I think this is more explicit: laneBases[moved][axis] = warpBasis[axis];
warpBasis[axis] = 0;
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This code will be rewritten in the 3rd PR of this stack as well |
||
| moved++; | ||
|
lezcano marked this conversation as resolved.
|
||
| } | ||
|
|
||
| return LinearLayout(std::move(regBases), to_vector(layout.getOutDimNames())); | ||
| } | ||
|
|
||
| LinearLayout ReduceOpHelper::reducedRegLaneLayout(RankedTensorType srcTy, | ||
| unsigned axis) { | ||
| auto *ctx = srcTy.getContext(); | ||
| auto kReg = StringAttr::get(ctx, "register"); | ||
| auto kLane = StringAttr::get(ctx, "lane"); | ||
| auto kWarp = StringAttr::get(ctx, "warp"); | ||
|
|
||
| auto reduced = triton::gpu::toLinearLayout(srcTy); | ||
| reduced = reduced.sublayout({kReg, kLane, kWarp}, | ||
| to_vector(reduced.getOutDimNames())); | ||
| reduced = actionRemoveBroadcastedRegs(reduced).apply(reduced); | ||
| reduced = moveAxisBasesToFront(reduced, axis).apply(reduced); | ||
| reduced = zeroBasesAlongDimAndReorder(reduced, axis, kReg); | ||
| reduced = actionRemoveBroadcastedRegs(reduced).apply(reduced); | ||
| reduced = zeroBasesAlongDimAndReorder(reduced, axis, kLane); | ||
| return reduced; | ||
| } | ||
|
|
||
| SmallVector<unsigned> | ||
| ReduceOpHelper::getScratchBytesForCvt(const LinearLayout &srcLayout, | ||
| const LinearLayout &dstLayout) { | ||
| SmallVector<unsigned> bytes(srcElementTypes.size(), 0); | ||
| auto *ctx = op.getContext(); | ||
| SmallVector<int64_t> shape; | ||
| shape.reserve(srcLayout.getNumOutDims()); | ||
| for (auto dim : srcLayout.getOutDimNames()) { | ||
| shape.push_back(srcLayout.getOutDimSize(dim)); | ||
| } | ||
| auto srcEnc = triton::gpu::LinearEncodingAttr::get(ctx, srcLayout); | ||
| auto dstEnc = triton::gpu::LinearEncodingAttr::get(ctx, dstLayout); | ||
| for (unsigned i = 0; i < srcElementTypes.size(); ++i) { | ||
| auto elemTy = srcElementTypes[i]; | ||
| if (elemTy.isIntOrFloat() && elemTy.getIntOrFloatBitWidth() < 8) | ||
| elemTy = IntegerType::get(ctx, 8); | ||
| auto srcTy = RankedTensorType::get(shape, elemTy, srcEnc); | ||
| auto dstTy = RankedTensorType::get(shape, elemTy, dstEnc); | ||
| if (!cvtNeedsSharedMemory(srcTy, dstTy)) | ||
| continue; | ||
| auto elems = getNumScratchElemsSwizzledCvt(srcTy, dstTy); | ||
| bytes[i] = elems * getBitwidth(srcTy) / 8; | ||
| } | ||
| return bytes; | ||
| } | ||
|
|
||
| ScanLoweringHelper::ScanLoweringHelper(triton::ScanOp op) : scanOp(op) { | ||
| auto firstTy = cast<RankedTensorType>(op.getOperands()[0].getType()); | ||
| srcShape = firstTy.getShape(); | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.