Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions include/triton/Dialect/Triton/IR/Utility.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#ifndef TRITON_IR_UTILITY_H_
#define TRITON_IR_UTILITY_H_

#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/BuiltinTypes.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include <algorithm>
#include <numeric>
Expand All @@ -10,6 +12,14 @@ namespace mlir {
// Bitwidth of pointers
constexpr int kPtrBitWidth = 64;

// Returns the bit width of a type, treating pointer-like types as 64-bit.
// This handles LLVM dialect pointer types.
inline int getIntOrFloatOrPtrBitWidth(Type type) {
if (isa<LLVM::LLVMPointerType, triton::PointerType>(type))
return kPtrBitWidth;
return type.getIntOrFloatBitWidth();
}

template <typename T, typename U> SmallVector<T> convertType(ArrayRef<U> in) {
SmallVector<T> out;
for (const auto &i : in)
Expand Down
7 changes: 3 additions & 4 deletions include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,9 @@ LinearLayout chooseScaledMfmaScaleLayout(MLIRContext *ctx, int dotOperandIdx,
ArrayRef<unsigned> tilesPerWarp,
ArrayRef<unsigned> warpsPerCTA);

LinearLayout chooseScaledWmmaScaleLayout(
MLIRContext *ctx, int dotOperandIdx,
const std::vector<std::vector<int32_t>> &dotOperandWarpBasis,
ArrayRef<int64_t> dotOperandShape);
LinearLayout chooseScaledWmmaScaleLayout(MLIRContext *ctx, int dotOperandIdx,
ArrayRef<unsigned> warpsPerCTA,
ArrayRef<int64_t> dotOperandShape);

LinearLayout getSM120DotScaledScaleLayout(MLIRContext *ctx,
ArrayRef<int64_t> shape, int opIdx,
Expand Down
3 changes: 2 additions & 1 deletion lib/Analysis/Allocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,8 @@ class AllocationAnalysis {
auto shapePerCTA = gpu::getAllocationShapePerCTA(allocType);
numElems = product<int64_t>(shapePerCTA);
}
int64_t bytes = numElems * allocType.getElementTypeBitWidth() / 8;
int64_t bytes =
numElems * getIntOrFloatOrPtrBitWidth(allocType.getElementType()) / 8;

auto alignment = alloc.getAlignmentOrDefault();
allocation->addBuffer<BufferT::BufferKind::Explicit>(alloc, bytes,
Expand Down
132 changes: 74 additions & 58 deletions lib/Analysis/AxisInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,23 +91,26 @@ class BinaryOpVisitorImpl : public AxisInfoVisitorImpl<OpTy> {
auto lhsInfo = operands[0]->getValue();
auto rhsInfo = operands[1]->getValue();
auto rank = lhsInfo.getRank();
assert(isa<RankedTensorType>(op.getType()) ||
rank == 1 && "Expected ranked tensor or scalar");
assert(operands.size() == 2 && "Expected two operands");
auto constantValue = getConstantValue(op, lhsInfo, rhsInfo);
if (constantValue.has_value()) {
auto resTy = dyn_cast<RankedTensorType>(op.getType());
AxisInfo::DimVectorT constancy =
resTy ? to_vector(resTy.getShape()) : AxisInfo::DimVectorT(rank, 1);
AxisInfo::DimVectorT contiguity(rank, 1);
AxisInfo::DimVectorT divisibility(
rank, highestPowOf2Divisor<int64_t>(constantValue.value()));
return AxisInfo(contiguity, divisibility, constancy, constantValue);
}
AxisInfo::DimVectorT contiguity;
AxisInfo::DimVectorT divisibility;
AxisInfo::DimVectorT constancy;
auto constantValue = getConstantValue(op, lhsInfo, rhsInfo);
for (auto d = 0; d < rank; ++d) {
if (constantValue.has_value()) {
contiguity.push_back(1);
constancy.push_back(
std::max(lhsInfo.getConstancy(d), rhsInfo.getConstancy(d)));
divisibility.push_back(
highestPowOf2Divisor<int64_t>(constantValue.value()));
} else {
contiguity.push_back(getContiguity(op, lhsInfo, rhsInfo, d));
constancy.push_back(getConstancy(op, lhsInfo, rhsInfo, d));
divisibility.push_back(getDivisibility(op, lhsInfo, rhsInfo, d));
}
contiguity.push_back(getContiguity(op, lhsInfo, rhsInfo, d));
constancy.push_back(getConstancy(op, lhsInfo, rhsInfo, d));
divisibility.push_back(getDivisibility(op, lhsInfo, rhsInfo, d));
}
return AxisInfo(contiguity, divisibility, constancy, constantValue);
}
Expand All @@ -125,9 +128,8 @@ class BinaryOpVisitorImpl : public AxisInfoVisitorImpl<OpTy> {

virtual int64_t getConstancy(OpTy op, const AxisInfo &lhs,
const AxisInfo &rhs, int dim) {
return 1;
return gcd(lhs.getConstancy(dim), rhs.getConstancy(dim));
}

virtual std::optional<int64_t> getConstantValue(OpTy op, const AxisInfo &lhs,
const AxisInfo &rhs) {
return {};
Expand Down Expand Up @@ -192,6 +194,26 @@ class CastOpAxisInfoVisitor final : public AxisInfoVisitorImpl<OpTy> {
}
};

class UnrealizedConversionCastOpAxisInfoVisitor final
: public AxisInfoVisitorImpl<mlir::UnrealizedConversionCastOp> {
public:
using AxisInfoVisitorImpl<
mlir::UnrealizedConversionCastOp>::AxisInfoVisitorImpl;

AxisInfo
getAxisInfo(mlir::UnrealizedConversionCastOp op,
ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) override {
auto tensorType = dyn_cast<RankedTensorType>(op.getResultTypes()[0]);
if (tensorType &&
tensorType.getRank() != operands[0]->getValue().getRank()) {
// Do not propagate AxisInfo with incorrect rank. This can cause a crash
// in future visitor applications.
return AxisInfo::getPessimisticValueState(op->getResult(0));
}
return operands[0]->getValue();
}
};

class MakeRangeOpAxisInfoVisitor final
: public AxisInfoVisitorImpl<triton::MakeRangeOp> {
public:
Expand Down Expand Up @@ -308,11 +330,6 @@ class AddSubOpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {
return gcd(lhs.getDivisibility(dim), rhsDivisibility);
}

int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
int dim) override {
return gcd(lhs.getConstancy(dim), rhs.getConstancy(dim));
}

std::optional<int64_t> getConstantValue(OpTy op, const AxisInfo &lhs,
const AxisInfo &rhs) override {
if (lhs.getConstantValue().has_value() &&
Expand Down Expand Up @@ -355,11 +372,6 @@ class MulIOpAxisInfoVisitor final : public BinaryOpVisitorImpl<arith::MulIOp> {
return std::max(lhsContiguity, rhsContiguity);
}

int64_t getConstancy(arith::MulIOp op, const AxisInfo &lhs,
const AxisInfo &rhs, int dim) override {
return gcd(lhs.getConstancy(dim), rhs.getConstancy(dim));
}

int64_t getDivisibility(arith::MulIOp op, const AxisInfo &lhs,
const AxisInfo &rhs, int dim) override {
auto lhsDivisibility = lhs.getDivisibility(dim);
Expand All @@ -379,9 +391,13 @@ class MulIOpAxisInfoVisitor final : public BinaryOpVisitorImpl<arith::MulIOp> {

std::optional<int64_t> getConstantValue(arith::MulIOp op, const AxisInfo &lhs,
const AxisInfo &rhs) override {
if (lhs.getConstantValue().has_value() &&
rhs.getConstantValue().has_value())
return {lhs.getConstantValue().value() * rhs.getConstantValue().value()};
auto lhsConst = lhs.getConstantValue();
auto rhsConst = rhs.getConstantValue();
if (lhsConst.has_value() && rhsConst.has_value())
return {lhsConst.value() * rhsConst.value()};
if ((lhsConst.has_value() && lhsConst.value() == 0) ||
(rhsConst.has_value() && rhsConst.value() == 0))
return 0;
return {};
}
};
Expand All @@ -404,12 +420,11 @@ class DivOpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {
int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
int dim) override {
auto resTy = dyn_cast<RankedTensorType>(op.getType());
auto constancy = BinaryOpVisitorImpl<OpTy>::getConstancy(op, lhs, rhs, dim);
if (!resTy)
return BinaryOpVisitorImpl<OpTy>::getConstancy(op, lhs, rhs, dim);
return constancy;
auto shape = resTy.getShape();
// Case 1: both lhs and rhs are constants.
auto constancy = gcd(lhs.getConstancy(dim), rhs.getConstancy(dim));
// Case 2: lhs contiguous, rhs constant.
// Case: lhs contiguous, rhs constant.
// lhs: d_lhs * k, d_lhs * k + 1, ..., d_lhs * k + n
// rhs: d_rhs * p, d_rhs * p, ..., d_rhs * p
// lhs / rhs = d_lhs * k / (d_rhs * p), (d_lhs * k + 1) / (d_rhs * p),
Expand Down Expand Up @@ -506,15 +521,15 @@ class RemOpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {

int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
int dim) override {
auto constancy = BinaryOpVisitorImpl<OpTy>::getConstancy(op, lhs, rhs, dim);
auto resTy = dyn_cast<RankedTensorType>(op.getType());
if (!resTy)
return BinaryOpVisitorImpl<OpTy>::getConstancy(op, lhs, rhs, dim);
auto shape = resTy.getShape();
// lhs % 1 = 0
return rhs.getConstantValue().has_value() &&
rhs.getConstantValue().value() == 1
? shape[dim]
: gcd(lhs.getConstancy(dim), rhs.getConstancy(dim));
return constancy;
// Case: lhs % 1 = 0
if (rhs.getConstantValue().has_value() &&
rhs.getConstantValue().value() == 1)
return resTy.getDimSize(dim);
return constancy;
}

std::optional<int64_t> getConstantValue(OpTy op, const AxisInfo &lhs,
Expand Down Expand Up @@ -669,7 +684,7 @@ class CmpOpAxisInfoVisitor final : public AxisInfoVisitorImpl<OpTy> {
int64_t constHint = 1;
if (lhsInfo.getConstantValue().has_value() &&
rhsInfo.getConstantValue().has_value()) {
constHint = lhsInfo.getConstancy(d);
constHint = shape[d];
constantValue =
compare(getPredicate(op), lhsInfo.getConstantValue().value(),
rhsInfo.getConstantValue().value())
Expand Down Expand Up @@ -828,6 +843,13 @@ class SelectOpAxisInfoVisitor final : public AxisInfoVisitorImpl<OpTy> {
rhsInfo.getConstantValue().has_value() &&
lhsInfo.getConstantValue() == rhsInfo.getConstantValue())
constantValue = lhsInfo.getConstantValue();

if (constantValue.has_value()) {
auto resTy = dyn_cast<RankedTensorType>(op.getType());
assert(resTy || rank == 1);
constancy =
resTy ? to_vector(resTy.getShape()) : AxisInfo::DimVectorT(rank, 1);
}
}

return AxisInfo(contiguity, divisibility, constancy, constantValue);
Expand All @@ -840,11 +862,6 @@ class LogicalOpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {
using BinaryOpVisitorImpl<OpTy>::BinaryOpVisitorImpl;

private:
int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
int dim) override {
return gcd(lhs.getConstancy(dim), rhs.getConstancy(dim));
}

std::optional<int64_t> getConstantValue(OpTy op, const AxisInfo &lhs,
const AxisInfo &rhs) override {
if (lhs.getConstantValue().has_value() &&
Expand Down Expand Up @@ -890,11 +907,6 @@ class ShLIOpAxisInfoVisitor final : public BinaryOpVisitorImpl<arith::ShLIOp> {
return multiplyDivisor(lhsDivisibility, 1ll << shift);
}

int64_t getConstancy(arith::ShLIOp op, const AxisInfo &lhs,
const AxisInfo &rhs, int dim) override {
return gcd(lhs.getConstancy(dim), rhs.getConstancy(dim));
}

std::optional<int64_t> getConstantValue(arith::ShLIOp op, const AxisInfo &lhs,
const AxisInfo &rhs) override {
if (lhs.getConstantValue().has_value() &&
Expand Down Expand Up @@ -932,11 +944,6 @@ class ShROpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {
return std::max<int64_t>(1, lhsDivisibility / (int64_t(1) << shift));
}

int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
int dim) override {
return gcd(lhs.getConstancy(dim), rhs.getConstancy(dim));
}

std::optional<int64_t> getConstantValue(OpTy op, const AxisInfo &lhs,
const AxisInfo &rhs) override {
if (lhs.getConstantValue().has_value() &&
Expand Down Expand Up @@ -969,9 +976,15 @@ class MaxMinOpAxisInfoVisitor final : public AxisInfoVisitorImpl<OpTy> {
constantValue = {std::min(lhsInfo.getConstantValue().value(),
rhsInfo.getConstantValue().value())};
}
auto resTy = dyn_cast<RankedTensorType>(op.getType());
assert(resTy || rank == 1);
AxisInfo::DimVectorT constancy =
resTy ? to_vector(resTy.getShape()) : AxisInfo::DimVectorT(rank, 1);
AxisInfo::DimVectorT divisibility(
rank, highestPowOf2Divisor<int64_t>(constantValue.value()));
return AxisInfo(/*knownContiguity=*/AxisInfo::DimVectorT(rank, 1),
/*knownDivisibility=*/AxisInfo::DimVectorT(rank, 1),
/*knownConstancy=*/AxisInfo::DimVectorT(rank, 1),
/*knownDivisibility=*/divisibility,
/*knownConstancy=*/constancy,
/*constantValue=*/constantValue);
} else {
AxisInfo::DimVectorT contiguity, divisibility, constancy;
Expand Down Expand Up @@ -1029,11 +1042,11 @@ AxisInfoAnalysis::AxisInfoAnalysis(DataFlowSolver &solver,
// This is needed by TritonGPUToLLVM, to get AxisInfo when the graph is
// in the process of a PartialConversion, where UnrealizedConversionCast
// may exist
visitors.append<UnrealizedConversionCastOpAxisInfoVisitor>();
visitors.append<CastOpAxisInfoVisitor<arith::ExtSIOp>,
CastOpAxisInfoVisitor<arith::ExtUIOp>,
CastOpAxisInfoVisitor<arith::TruncIOp>,
CastOpAxisInfoVisitor<triton::gpu::ConvertLayoutOp>,
CastOpAxisInfoVisitor<mlir::UnrealizedConversionCastOp>,
CastOpAxisInfoVisitor<triton::BitcastOp>>();
visitors.append<MakeRangeOpAxisInfoVisitor>();
visitors.append<PoisonOpAxisInfoVisitor>();
Expand Down Expand Up @@ -1384,7 +1397,10 @@ void ModuleAxisInfoAnalysis::update(CallOpInterface callOp,
callee.setArgAttr(index, attrName, attr);
};
auto axisInfo = axisInfoMap->lookup(value);
assert(axisInfo.getRank() == 1 && "only scalar arguments are supported");
// Only scalar arguments are supported. Do not forward multi-dimensional
// AxisInfo to the callee.
if (axisInfo.getRank() != 1)
continue;
setAttrFn("tt.contiguity", axisInfo.getContiguity(0));
setAttrFn("tt.divisibility", axisInfo.getDivisibility(0));
setAttrFn("tt.constancy", axisInfo.getConstancy(0));
Expand Down
5 changes: 1 addition & 4 deletions lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@ struct ConvertLayoutOpConversion
: public ConvertOpToLLVMPattern<ConvertLayoutOp> {
const TargetInfoBase &targetInfo;

// Set benefit to 2 so that this pattern applies before other convert-layout
// conversions. TODO(jlebar): Eventually we want this to be the only pattern.
explicit ConvertLayoutOpConversion(LLVMTypeConverter &typeConverter,
const TargetInfoBase &targetInfo,
PatternBenefit benefit = 1)
Expand Down Expand Up @@ -277,8 +275,7 @@ struct ConvertLayoutOpConversion
StringAttr kReg = str_attr("register");
StringAttr kLane = str_attr("lane");
auto elemTy = getTypeConverter()->convertType(srcTy.getElementType());
int bitwidth =
elemTy.isIntOrFloat() ? elemTy.getIntOrFloatBitWidth() : kPtrBitWidth;
int bitwidth = getIntOrFloatOrPtrBitWidth(elemTy);

auto factors = getWarpLayoutConvertDecomposition(srcTy, dstTy, bitwidth);
auto &[pReg, pLane, mixedTranspositions, nPack] = factors;
Expand Down
2 changes: 1 addition & 1 deletion lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ struct ElementwiseInlineAsmOpConversion
auto ty = getTypeConverter()->convertType(getElementType(result));

// Pack return elements into 32-bits.
unsigned bitWidth = ty.isIntOrFloat() ? ty.getIntOrFloatBitWidth() : 64;
unsigned bitWidth = getIntOrFloatOrPtrBitWidth(ty);
unsigned numElemsPerReg =
std::min(std::max(32 / bitWidth, 1u), op.getPackedElement());
assert(op.getPackedElement() % numElemsPerReg == 0);
Expand Down
4 changes: 2 additions & 2 deletions lib/Conversion/TritonGPUToLLVM/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -540,7 +540,7 @@ SmallVector<Value> lowerLdSt(
auto kLane = str_attr("lane");
auto kWarp = str_attr("warp");
auto kOffset = str_attr("offset");
auto bitwidth = llvmElemTy.getIntOrFloatBitWidth();
auto bitwidth = getIntOrFloatOrPtrBitWidth(llvmElemTy);

auto [elemsPerVec, permutation] =
largestVectorisation(ctx, cvt, bitwidth, maybeMaxVecElems);
Expand Down Expand Up @@ -625,7 +625,7 @@ lowerLocalLdSt(Location loc, MLIRContext *ctx,
assert(*cvt.getOutDimNames().begin() == str_attr("offset"));
auto calcPaddedOffset = [&](Value smemOffset) {
TritonLLVMOpBuilder b(loc, rewriter);
auto bitwidth = llvmElemTy.getIntOrFloatBitWidth();
auto bitwidth = getIntOrFloatOrPtrBitWidth(llvmElemTy);
if (auto paddedEnc = dyn_cast<triton::gpu::PaddedSharedEncodingAttr>(
srcTy.getEncoding())) {
// Apply the offset needed for padding.
Expand Down
14 changes: 12 additions & 2 deletions lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,16 @@ using namespace mlir::triton;
using namespace mlir::triton::gpu;
using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
namespace {

Value bitOrPtrCast(Value val, Type type, TritonLLVMOpBuilder &b) {
if (isa<LLVM::LLVMPointerType>(val.getType()) &&
!isa<LLVM::LLVMPointerType>(type)) {
return b.ptrtoint(type, val);
} else {
return b.bitcast(val, type);
}
}

struct SplatOpConversion : public ConvertOpToLLVMPattern<triton::SplatOp> {
using ConvertOpToLLVMPattern<triton::SplatOp>::ConvertOpToLLVMPattern;
// Convert SplatOp or arith::ConstantOp with SplatElementsAttr to a
Expand Down Expand Up @@ -39,13 +49,13 @@ struct SplatOpConversion : public ConvertOpToLLVMPattern<triton::SplatOp> {
unsigned ratio = srcBitWidth / cstBitWidth;
Type intTy = IntegerType::get(elemType.getContext(), cstBitWidth);
VectorType vecType = VectorType::get(ratio, intTy);
Value intCst = b.bitcast(constVal, intTy);
Value intCst = bitOrPtrCast(constVal, intTy, b);
Value vec = b.undef(vecType);
for (unsigned i = 0; i < ratio; ++i)
vec = b.insert_element(vecType, vec, intCst, b.int_val(32, i));
constVal = vec;
}
auto llSrc = b.bitcast(constVal, srcType);
Value llSrc = bitOrPtrCast(constVal, srcType, b);
size_t elemsPerThread = getTotalElemsPerThread(tensorTy);
llvm::SmallVector<Value> elems(elemsPerThread, llSrc);
return packLLElements(loc, typeConverter, elems, rewriter, resType);
Expand Down
Loading
Loading