Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 17 additions & 13 deletions include/triton/Analysis/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,30 +42,34 @@ class ReduceOpHelper {
}
}

ArrayRef<int64_t> getSrcShape() { return srcShape; }

Attribute getSrcLayout() { return srcEncoding; }

triton::ReduceOp getOperation() { return op; }

unsigned getThreadOffsetOnReductionAxis();
RankedTensorType getSrcTy() { return srcTy; }

bool isWarpSynchronous();

unsigned getInterWarpSizeWithUniqueData();

unsigned getIntraWarpSizeWithUniqueData();

// The shape of the shared memory space needed for the reduction.
SmallVector<unsigned> getScratchRepShape();
bool isReduceWithinCTA();

bool isAssociative();

SmallVector<unsigned> getOrderWithAxisAtBeginning();
static triton::ColumnAction
moveAxisBasesToFront(const triton::LinearLayout &layout, int axis);

unsigned getScratchSizeInBytes();
static triton::LinearLayout
zeroBasesAlongDimAndReorder(const triton::LinearLayout &layout, unsigned axis,
mlir::StringAttr dim);

bool isReduceWithinCTA();
static triton::LinearLayout getInterLayout(const triton::LinearLayout &layout,
unsigned axis);

bool isAssociative();
static triton::LinearLayout reducedRegLaneLayout(RankedTensorType srcTy,
unsigned axis);

SmallVector<unsigned>
getScratchBytesForCvt(const triton::LinearLayout &srcLayout,
const triton::LinearLayout &dstLayout);

private:
triton::ReduceOp op;
Expand Down
4 changes: 2 additions & 2 deletions include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define TRITON_CONVERSION_TRITONGPU_TO_LLVM_TARGETINFOBASE_H

#include "triton/Conversion/MLIRTypes.h"
#include "llvm/ADT/ArrayRef.h"

namespace mlir::triton {
enum class ProgramIDDim : uint32_t;
Expand Down Expand Up @@ -66,8 +67,7 @@ class TargetInfoBase {

virtual bool warpReduce(RewriterBase &rewriter, Location loc,
SmallVector<Value> &acc, triton::ReduceOp op,
unsigned numLaneToReduce,
unsigned interleave) const = 0;
unsigned reduceLaneIdMask) const = 0;

virtual std::string getMulhiFuncName(Type resultElementTy) const = 0;
// Emits LLVM code with |rewriter| to print a message following the given
Expand Down
41 changes: 40 additions & 1 deletion lib/Analysis/Allocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <algorithm>
#include <limits>
#include <numeric>

#include "mlir/Analysis/Liveness.h"
#include "mlir/Support/LLVM.h"
Expand All @@ -14,6 +15,7 @@
#include "triton/Tools/LayoutUtils.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/MathExtras.h"
#include "llvm/Support/raw_ostream.h"

#define DEBUG_TYPE "allocation-shared-memory"
Expand Down Expand Up @@ -49,6 +51,36 @@ unsigned getNumScratchElemsSwizzledCvt(RankedTensorType srcTy,
return smem.getTotalOutDimSize() / reps;
}

namespace {
constexpr int64_t kReduceScratchAlign = 16;

Type getReduceMemElemTy(Type elemTy, MLIRContext *ctx) {
if (elemTy.isIntOrFloat() && elemTy.getIntOrFloatBitWidth() < 8)
return IntegerType::get(ctx, 8);
return elemTy;
}

int64_t getReduceScratchSizeBytes(triton::ReduceOp op,
ArrayRef<unsigned> bytesPerOperand) {
std::vector<unsigned> indices(op.getNumOperands());
std::iota(indices.begin(), indices.end(), 0);
auto *ctx = op.getContext();
std::sort(indices.begin(), indices.end(), [&](unsigned i, unsigned j) {
auto lhsTy = getReduceMemElemTy(op.getElementTypes()[i], ctx);
auto rhsTy = getReduceMemElemTy(op.getElementTypes()[j], ctx);
return getIntOrFloatOrPtrBitWidth(lhsTy) >
getIntOrFloatOrPtrBitWidth(rhsTy);
});
// Aling to 16 bytes to allow for vectorisation
int64_t offset = 0;
for (unsigned idx : indices) {
offset += llvm::alignTo(bytesPerOperand[idx], kReduceScratchAlign);
}
return offset;
}

} // namespace

// Both `atomic_cas` and `atomic_rmw` may need scratch memory to store values
// because Triton's block-based programming model ensures that
// all threads sharing the same partition of the tensor see the same values,
Expand Down Expand Up @@ -76,7 +108,14 @@ static SmallVector<unsigned> getRepShapeForAtomic(Value result) {
unsigned defaultAllocationAnalysisScratchSizeFn(Operation *op) {
if (auto reduceOp = dyn_cast<ReduceOp>(op)) {
ReduceOpHelper helper(reduceOp);
return helper.getScratchSizeInBytes();
if (helper.isWarpSynchronous())
return 0;

auto regLl = ReduceOpHelper::reducedRegLaneLayout(helper.getSrcTy(),
reduceOp.getAxis());
auto tmpLl = ReduceOpHelper::getInterLayout(regLl, reduceOp.getAxis());
auto bytesRegToTmp = helper.getScratchBytesForCvt(regLl, tmpLl);
return getReduceScratchSizeBytes(reduceOp, bytesRegToTmp);
Comment thread
lezcano marked this conversation as resolved.
Comment thread
lezcano marked this conversation as resolved.
Comment thread
lezcano marked this conversation as resolved.
}
if (auto scanOp = dyn_cast<ScanOp>(op)) {
ScanLoweringHelper helper(scanOp);
Expand Down
169 changes: 119 additions & 50 deletions lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,33 +23,6 @@ namespace mlir {
using namespace triton;
using namespace triton::gpu;

SmallVector<unsigned> ReduceOpHelper::getOrderWithAxisAtBeginning() {
auto order = toLinearEncoding(srcTy).getOrder();
auto it = std::find(order.begin(), order.end(), axis);
// delete the axis from order
order.erase(it);
// insert axis at the beginning of order
order.insert(order.begin(), axis);
return order;
}

// Thread offset is the thread index offset of two adjacent threads on the
// reduction axis within the warp.
unsigned ReduceOpHelper::getThreadOffsetOnReductionAxis() {
auto *ctx = srcEncoding.getContext();
auto linearLayout = toLinearLayout(srcTy);
auto kLane = mlir::StringAttr::get(ctx, "lane");
const auto &bases = linearLayout.getBases();
const auto &lanes = bases.find(kLane)->second;
auto offset = 1;
for (const auto &lane : lanes) {
if (lane[axis] != 0)
break;
offset *= 2;
}
return offset;
}

// Cases where distributed shared memory is not required in ConvertLayout:
// (1) numCTAs == 1
// (2) numCTAs > 1 but srcCGALayout == dstCGALayout
Expand Down Expand Up @@ -107,29 +80,6 @@ bool ReduceOpHelper::isWarpSynchronous() {
return getWarpsPerCTA(srcEncoding, srcShape)[axis] == 1;
}

SmallVector<unsigned> ReduceOpHelper::getScratchRepShape() {
SmallVector<unsigned> smemShape;
// This case doesn't need inter-warp communication
if (isWarpSynchronous())
return {0, 0};

smemShape = convertType<unsigned>(srcShape);
smemShape[axis] = getInterWarpSizeWithUniqueData();

return smemShape;
}

unsigned ReduceOpHelper::getScratchSizeInBytes() {
auto smemShape = getScratchRepShape();
auto elems = product<unsigned>(smemShape);

unsigned bytesPerElem = 0;
for (const auto &ty : srcElementTypes) {
bytesPerElem += ceil<unsigned>(ty.getIntOrFloatBitWidth(), 8);
}
return bytesPerElem * elems;
}

bool ReduceOpHelper::isReduceWithinCTA() {
// TODO: Support reduce across CTAS
// Layout optimization passes such as PlanCTAPass and
Expand Down Expand Up @@ -157,6 +107,125 @@ bool ReduceOpHelper::isAssociative() {
return !hasNoAssociativeOp;
}

ColumnAction ReduceOpHelper::moveAxisBasesToFront(const LinearLayout &layout,
int axis) {
auto *ctx = layout.getOutDimNames().begin()->getContext();
auto kReg = StringAttr::get(ctx, "register");
const auto &bases = layout.getBases().lookup(kReg);
SmallVector<size_t> perm;
SmallVector<size_t> back;
for (size_t i = 0; i < bases.size(); ++i) {
if (bases[i][axis] != 0)
perm.push_back(i);
Copy link
Copy Markdown
Contributor

@peterbell10 peterbell10 Jan 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: This doesn't make it contiguous, i.e. strictly linearly ordered. I would say it's more like "group elements by axis".

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

renamed to moveAxisBasesToFront.

else
back.push_back(i);
}
perm.append(back.begin(), back.end());
return ColumnAction(perm, kReg, bases.size());
}

LinearLayout
ReduceOpHelper::zeroBasesAlongDimAndReorder(const LinearLayout &layout,
unsigned axis, StringAttr dim) {
// Zeros out the basis along the specified axis in the given hardware
// dimension, and reindexes the remaining bases along axis so that each
// element is in linearly increasing order from the hardware's perspective.
// Note that for this reordering we need the operator to be commutative, but
// it's the only way to have a performant lowering.
LinearLayout::BasesT newBases;
for (auto [inDim, bases] : layout.getBases()) {
std::vector<std::vector<int32_t>> newInBases = bases;
if (inDim == dim) {
for (auto &basis : newInBases)
basis[axis] = 0;
}
newBases[inDim] = std::move(newInBases);
}

int32_t nextAxisBase = 1;
for (auto &[inDim, inDimBases] : newBases) {
for (auto &basis : inDimBases) {
if (basis[axis] == 0)
continue;
basis[axis] = nextAxisBase;
nextAxisBase *= 2;
}
}

return LinearLayout(std::move(newBases), to_vector(layout.getOutDimNames()));
Comment thread
peterbell10 marked this conversation as resolved.
}

LinearLayout ReduceOpHelper::getInterLayout(const LinearLayout &layout,
unsigned axis) {
auto *ctx = layout.getOutDimNames().begin()->getContext();
auto kLane = mlir::StringAttr::get(ctx, "lane");
auto kWarp = mlir::StringAttr::get(ctx, "warp");
auto regBases = layout.getBases();
auto linearAttr = triton::gpu::LinearEncodingAttr::get(ctx, layout);
int laneBits = layout.getInDimSizeLog2(kLane);
int neededLaneBits = llvm::Log2_32(linearAttr.getWarpsPerCTA()[axis]);
// TODO move to verifier
assert(neededLaneBits <= laneBits && "NYI: more inter-warps than lanes");
Comment thread
peterbell10 marked this conversation as resolved.
// Move the warp axis bases we need to reduce into lane bases, while
// keeping non-axis components in their original in-dim.
auto &laneBases = regBases[kLane];
auto &warpBases = regBases[kWarp];
int moved = 0;
for (auto &warpBasis : warpBases) {
if (warpBasis[axis] == 0)
continue;
assert(moved < neededLaneBits && "unexpected warp axis bases count");
std::swap(laneBases[moved], warpBasis);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: the swap is a bit confusing. I think this is more explicit:

laneBases[moved][axis] = warpBasis[axis];
warpBasis[axis] = 0;

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code will be rewritten in the 3rd PR of this stack as well

moved++;
Comment thread
lezcano marked this conversation as resolved.
}

return LinearLayout(std::move(regBases), to_vector(layout.getOutDimNames()));
}

LinearLayout ReduceOpHelper::reducedRegLaneLayout(RankedTensorType srcTy,
unsigned axis) {
auto *ctx = srcTy.getContext();
auto kReg = StringAttr::get(ctx, "register");
auto kLane = StringAttr::get(ctx, "lane");
auto kWarp = StringAttr::get(ctx, "warp");

auto reduced = triton::gpu::toLinearLayout(srcTy);
reduced = reduced.sublayout({kReg, kLane, kWarp},
to_vector(reduced.getOutDimNames()));
reduced = actionRemoveBroadcastedRegs(reduced).apply(reduced);
reduced = moveAxisBasesToFront(reduced, axis).apply(reduced);
reduced = zeroBasesAlongDimAndReorder(reduced, axis, kReg);
reduced = actionRemoveBroadcastedRegs(reduced).apply(reduced);
reduced = zeroBasesAlongDimAndReorder(reduced, axis, kLane);
return reduced;
}

SmallVector<unsigned>
ReduceOpHelper::getScratchBytesForCvt(const LinearLayout &srcLayout,
const LinearLayout &dstLayout) {
SmallVector<unsigned> bytes(srcElementTypes.size(), 0);
auto *ctx = op.getContext();
SmallVector<int64_t> shape;
shape.reserve(srcLayout.getNumOutDims());
for (auto dim : srcLayout.getOutDimNames()) {
shape.push_back(srcLayout.getOutDimSize(dim));
}
auto srcEnc = triton::gpu::LinearEncodingAttr::get(ctx, srcLayout);
auto dstEnc = triton::gpu::LinearEncodingAttr::get(ctx, dstLayout);
for (unsigned i = 0; i < srcElementTypes.size(); ++i) {
auto elemTy = srcElementTypes[i];
if (elemTy.isIntOrFloat() && elemTy.getIntOrFloatBitWidth() < 8)
elemTy = IntegerType::get(ctx, 8);
auto srcTy = RankedTensorType::get(shape, elemTy, srcEnc);
auto dstTy = RankedTensorType::get(shape, elemTy, dstEnc);
if (!cvtNeedsSharedMemory(srcTy, dstTy))
continue;
auto elems = getNumScratchElemsSwizzledCvt(srcTy, dstTy);
bytes[i] = elems * getBitwidth(srcTy) / 8;
}
return bytes;
}

ScanLoweringHelper::ScanLoweringHelper(triton::ScanOp op) : scanOp(op) {
auto firstTy = cast<RankedTensorType>(op.getOperands()[0].getType());
srcShape = firstTy.getShape();
Expand Down
Loading
Loading