Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 139 additions & 37 deletions third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,62 @@ cvtScalePkUpcastFromFp8(Location loc, ConversionPatternRewriter &rewriter,
return ret;
}

// Convert Ocp Fp8/Bf8 to Fp16/Bf16/Fp32 on CDNA4
template <typename ConvertOp>
static SmallVector<Value>
cvtScalePk8UpcastFromFp8(Location loc, ConversionPatternRewriter &rewriter,
const SmallVector<Value> &v) {
assert(v.size() == 8);
const size_t inSize = v.size();
auto b = TritonLLVMOpBuilder(loc, rewriter);

Type vInTy = nullptr;
Type resTy = nullptr;
Type vResTy = nullptr;

const size_t intVecSize = (inSize * i8_ty.getWidth()) / i32_ty.getWidth();
vInTy = vec_ty(i32_ty, intVecSize);

if constexpr ((std::is_same_v<ConvertOp, ROCDL::CvtPkScalePk8F32Fp8Op>) ||
(std::is_same_v<ConvertOp, ROCDL::CvtPkScalePk8F32Bf8Op>)) {
resTy = f32_ty;
} else if constexpr ((std::is_same_v<ConvertOp,
ROCDL::CvtPkScalePk8F16Bf8Op>) ||
(std::is_same_v<ConvertOp,
ROCDL::CvtPkScalePk8F16Fp8Op>)) {
resTy = f16_ty;
} else if constexpr ((std::is_same_v<ConvertOp,
ROCDL::CvtPkScalePk8Bf16Bf8Op>) ||
(std::is_same_v<ConvertOp,
ROCDL::CvtPkScalePk8Bf16Fp8Op>)) {
resTy = bf16_ty;
}
assert(resTy != nullptr);

vResTy = vec_ty(resTy, inSize);

auto vI8InTy = vec_ty(i8_ty, inSize);

Value vI8In = b.undef(vI8InTy);
SmallVector<Value, 8> idx;
for (size_t i = 0; i < inSize; ++i) {
idx.push_back(b.i32_val(i));
vI8In = b.insert_element(vI8InTy, vI8In, v[i], idx[i]);
}
auto vIn = b.bitcast(vI8In, vInTy);

Value scale = b.i32_val(127);
IntegerAttr opscale = rewriter.getI32IntegerAttr(0b1000);

auto result = ConvertOp::create(rewriter, loc, vResTy, vIn, scale, opscale);
SmallVector<Value> ret(inSize);
for (auto [i, value] : llvm::enumerate(ret)) {
value = b.extract_element(resTy, result, idx[i]);
}

return ret;
}

// Convert Fp16/Bf16/Fp32 to OCP Fp8/Bf8 on CDNA4
template <typename ConvertOp>
static SmallVector<Value>
Expand Down Expand Up @@ -734,19 +790,27 @@ static SmallVector<Value> cvtPkFp32ToF8(Location loc,
return ret;
}

// Convert OCP Fp8 to Fp32 on CDNA4
// Convert OCP Fp8 to Fp32 on CDNA4+
static SmallVector<Value> Fp8E4M3FN_to_Fp32(Location loc,
ConversionPatternRewriter &rewriter,
const SmallVector<Value> &v) {
if (v.size() == 8) {
return cvtScalePk8UpcastFromFp8<ROCDL::CvtPkScalePk8F32Fp8Op>(loc, rewriter,
v);
}
assert(v.size() == 4);
return cvtScalePkUpcastFromFp8<ROCDL::CvtScaleF32PkF32Fp8Op>(loc, rewriter,
v);
}

// Convert OCP Bf8 to Fp32 on CDNA4
// Convert OCP Bf8 to Fp32 on CDNA4+
static SmallVector<Value> Fp8E5M2_to_Fp32(Location loc,
ConversionPatternRewriter &rewriter,
const SmallVector<Value> &v) {
if (v.size() == 8) {
return cvtScalePk8UpcastFromFp8<ROCDL::CvtPkScalePk8F32Bf8Op>(loc, rewriter,
v);
}
assert(v.size() == 4);
return cvtScalePkUpcastFromFp8<ROCDL::CvtScaleF32PkF32Bf8Op>(loc, rewriter,
v);
Expand Down Expand Up @@ -926,15 +990,15 @@ static ConverterT Fp32_to_Fp8E4M3FNUZ(AMD::ISAFamily isaFamily) {
return isCDNA4(isaFamily) ? Fp32_to_Fp8E4M3FNUZ_SW : Fp32_to_Fp8E4M3FNUZ_HW;
}

// Nanoo Bf8 -> Fp32 on CDNA3
// Nanoo Bf8 -> Fp32 on CDNA3+
static SmallVector<Value>
Fp8E5M2FNUZ_to_Fp32(Location loc, ConversionPatternRewriter &rewriter,
const SmallVector<Value> &v) {
assert(v.size() == 4);
return cvtPkF8ToFp32<ROCDL::CvtPkF32Bf8Op>(loc, rewriter, v);
}

// Nanoo Fp8 -> Fp32 on CDNA3
// Nanoo Fp8 -> Fp32 on CDNA3+
static SmallVector<Value>
Fp8E4M3FNUZ_to_Fp32(Location loc, ConversionPatternRewriter &rewriter,
const SmallVector<Value> &v) {
Expand Down Expand Up @@ -1023,13 +1087,18 @@ Fp8E4M3FN_to_Fp16_SW(Location loc, ConversionPatternRewriter &rewriter,
static SmallVector<Value>
Fp8E4M3FN_to_Fp16_HW(Location loc, ConversionPatternRewriter &rewriter,
const SmallVector<Value> &v) {
if (v.size() == 8) {
return cvtScalePk8UpcastFromFp8<ROCDL::CvtPkScalePk8F16Fp8Op>(loc, rewriter,
v);
}
assert(v.size() == 4);
return cvtScalePkUpcastFromFp8<ROCDL::CvtScaleF32PkF16Fp8Op>(loc, rewriter,
v);
}

ConverterT Fp8E4M3FN_to_Fp16(AMD::ISAFamily isaFamily) {
return isCDNA4(isaFamily) ? Fp8E4M3FN_to_Fp16_HW : Fp8E4M3FN_to_Fp16_SW;
return isCDNA4OrHigher(isaFamily) ? Fp8E4M3FN_to_Fp16_HW
: Fp8E4M3FN_to_Fp16_SW;
}

// Ocp Bf8->Fp16
Expand Down Expand Up @@ -1064,13 +1133,17 @@ Fp8E5M2_to_Fp16_SW(Location loc, ConversionPatternRewriter &rewriter,
static SmallVector<Value>
Fp8E5M2_to_Fp16_HW(Location loc, ConversionPatternRewriter &rewriter,
const SmallVector<Value> &v) {
if (v.size() == 8) {
return cvtScalePk8UpcastFromFp8<ROCDL::CvtPkScalePk8F16Bf8Op>(loc, rewriter,
v);
}
assert(v.size() == 4);
return cvtScalePkUpcastFromFp8<ROCDL::CvtScaleF32PkF16Bf8Op>(loc, rewriter,
v);
}

ConverterT Fp8E5M2_to_Fp16(AMD::ISAFamily isaFamily) {
return isCDNA4(isaFamily) ? Fp8E5M2_to_Fp16_HW : Fp8E5M2_to_Fp16_SW;
return isCDNA4OrHigher(isaFamily) ? Fp8E5M2_to_Fp16_HW : Fp8E5M2_to_Fp16_SW;
}

static SmallVector<Value>
Expand Down Expand Up @@ -1182,7 +1255,7 @@ static SmallVector<Value> Fp32_to_F16_RTNE(Location loc,
MultipleOperandsRange operands,
AMD::ISAFamily isaFamily) {
// For CDNA4 we can potentially use packed v_cvt_pk_[b]f16_f32 instructions.
if (isCDNA4(isaFamily)) {
if (isCDNA4OrHigher(isaFamily)) {
SmallVector<Value> inVals;
size_t numElem = std::min(size_t(2), operands.size());
inVals.reserve(numElem);
Expand Down Expand Up @@ -1351,13 +1424,17 @@ Fp8E5M2_to_Bf16_SW(Location loc, ConversionPatternRewriter &rewriter,
static SmallVector<Value>
Fp8E5M2_to_Bf16_HW(Location loc, ConversionPatternRewriter &rewriter,
const SmallVector<Value> &v) {
if (v.size() == 8) {
return cvtScalePk8UpcastFromFp8<ROCDL::CvtPkScalePk8Bf16Bf8Op>(loc,
rewriter, v);
}
assert(v.size() == 4);
return cvtScalePkUpcastFromFp8<ROCDL::CvtScaleF32PkBf16Bf8Op>(loc, rewriter,
v);
}

ConverterT Fp8E5M2_to_Bf16(AMD::ISAFamily isaFamily) {
return isCDNA4(isaFamily) ? Fp8E5M2_to_Bf16_HW : Fp8E5M2_to_Bf16_SW;
return isCDNA4OrHigher(isaFamily) ? Fp8E5M2_to_Bf16_HW : Fp8E5M2_to_Bf16_SW;
}

// Bf16 -> OCP Bf8
Expand Down Expand Up @@ -1492,13 +1569,18 @@ Fp8E4M3FN_to_Bf16_SW(Location loc, ConversionPatternRewriter &rewriter,
static SmallVector<Value>
Fp8E4M3FN_to_Bf16_HW(Location loc, ConversionPatternRewriter &rewriter,
const SmallVector<Value> &v) {
if (v.size() == 8) {
return cvtScalePk8UpcastFromFp8<ROCDL::CvtPkScalePk8Bf16Fp8Op>(loc,
rewriter, v);
}
assert(v.size() == 4);
return cvtScalePkUpcastFromFp8<ROCDL::CvtScaleF32PkBf16Fp8Op>(loc, rewriter,
v);
}

ConverterT Fp8E4M3FN_to_Bf16(AMD::ISAFamily isaFamily) {
return isCDNA4(isaFamily) ? Fp8E4M3FN_to_Bf16_HW : Fp8E4M3FN_to_Bf16_SW;
return isCDNA4OrHigher(isaFamily) ? Fp8E4M3FN_to_Bf16_HW
: Fp8E4M3FN_to_Bf16_SW;
}

// fp8e4m3fnuz to bf16
Expand Down Expand Up @@ -1837,6 +1919,48 @@ struct FpToFpOpConversion
return srcMap.lookup(key);
}

int getNumElements(
Type srcElementType, Type dstElementType,
std::optional<::mlir::triton::RoundingMode> roundingMode) const {
const bool isRTZ = roundingMode == RoundingMode::RTZ;
const bool isRTNE = roundingMode == RoundingMode::RTNE;

// numElements = 2 for :
// fp32 -> fp16 with RTZ
// fp32/fp16 -> nanoo fp8/bf8 on non-CDNA3
if ((isa<Float32Type>(srcElementType) && isa<Float16Type>(dstElementType) &&
isRTZ) ||
(isa<Float32Type, Float16Type>(srcElementType) &&
isa<Float8E4M3FNUZType, Float8E5M2FNUZType>(dstElementType) &&
isaFamily != AMD::ISAFamily::CDNA3)) {
return 2;
}

// special upcast for CDNA4
// nanoo fp8 -> bf16 on CDNA4 (numElements = 2)
if ((isaFamily == AMD::ISAFamily::CDNA4) &&
isa<Float8E4M3FNUZType>(srcElementType) && dstElementType.isBF16()) {
return 2;
}

// special downcast cases for GFX1250+
if ((isaFamily == AMD::ISAFamily::GFX1250) &&
((isa<Float32Type, Float16Type, BFloat16Type>(srcElementType))) &&
((isa<Float8E4M3FNType, Float8E5M2Type>(dstElementType))) && isRTNE) {
return 8;
}

// special upcast cases for GFX1250+
if ((isaFamily == AMD::ISAFamily::GFX1250) &&
((isa<Float8E5M2Type, Float8E4M3FNType>(srcElementType))) &&
((isa<Float16Type, BFloat16Type, Float32Type>(dstElementType)))) {
return 8;
}

// return default value
return 4;
}

SmallVector<Value> createDestOps(triton::FpToFpOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter,
Type elemTy, MultipleOperandsRange operands,
Expand All @@ -1860,20 +1984,8 @@ struct FpToFpOpConversion
convertFp32ToBf16(loc, rewriter, operands[0][0], RoundingMode::RTZ)};
}

size_t numElements = 4;
// numElements = 2 for :
// fp32 -> fp16 with RTZ
// fp32/fp16 -> nanoo fp8/bf8 on non-CDNA3
// nanoo fp8 -> bf16 on CDNA4
if ((llvm::isa<Float32Type>(srcElementType) &&
llvm::isa<Float16Type>(dstElementType) &&
roundingMode == RoundingMode::RTZ) ||
(llvm::isa<Float32Type, Float16Type>(srcElementType) &&
llvm::isa<Float8E4M3FNUZType, Float8E5M2FNUZType>(dstElementType) &&
isaFamily != AMD::ISAFamily::CDNA3) ||
(llvm::isa<Float8E4M3FNUZType>(srcElementType) &&
dstElementType.isBF16() && isCDNA4(isaFamily)))
numElements = 2;
size_t numElements =
getNumElements(srcElementType, dstElementType, roundingMode);

// fp32 -> fp8 with rtne can be done in two steps:
// - fp32 -> fp16 with rtne and
Expand All @@ -1884,39 +1996,29 @@ struct FpToFpOpConversion
// 3. fp32 -> ocp fp8/bf8 on non-CDNA4: has software support
bool useFP16IntermediateSrc =
srcElementType.isF32() && !dstElementType.isF16() &&
!(isCDNA4(isaFamily) &&
!(isCDNA4OrHigher(isaFamily) &&
(llvm::isa<Float8E4M3FNType, Float8E4M3FNUZType, Float8E5M2Type,
Float8E5M2FNUZType>(dstElementType))) &&
!(isaFamily == AMD::ISAFamily::CDNA3 &&
(llvm::isa<Float8E4M3FNUZType, Float8E5M2FNUZType>(
dstElementType))) &&
!(!isCDNA4(isaFamily) &&
!(!isCDNA4OrHigher(isaFamily) &&
(llvm::isa<Float8E5M2Type, Float8E4M3FNType>(dstElementType)));

if ((isaFamily == AMD::ISAFamily::GFX1250) &&
((llvm::isa<Float32Type>(srcElementType)) ||
(llvm::isa<Float16Type>(srcElementType)) ||
(llvm::isa<BFloat16Type>(srcElementType))) &&
((llvm::isa<Float8E4M3FNType>(dstElementType)) ||
(llvm::isa<Float8E5M2Type>(dstElementType))) &&
((roundingMode.has_value()) && (*roundingMode != RoundingMode::RTZ))) {
numElements = 8;
useFP16IntermediateSrc = false;
}

// fp8/bf8->f32, if neither nanoo fp8/bf8 on CDNA3 nor ocp fp8/bf8 on CDNA4,
// is done in two steps: fp8/bf8->fp16 and fp16->fp32
bool isDstFP32 = dstElementType.isF32();
bool useFP16IntermediateDst =
(isDstFP32 &&
!(isCDNA4(isaFamily) &&
!(isCDNA4OrHigher(isaFamily) &&
(llvm::isa<Float8E4M3FNType, Float8E5M2Type>(srcElementType))) &&
!(isaFamily == AMD::ISAFamily::CDNA3 &&
(llvm::isa<Float8E4M3FNUZType, Float8E5M2FNUZType>(
srcElementType))));

Type srcType = useFP16IntermediateSrc ? f16_ty : srcElementType;
Type dstType = useFP16IntermediateDst ? f16_ty : dstElementType;

SmallVector<Value> inVals;
inVals.reserve(std::min(numElements, operands.size()));
for (unsigned i = 0; i < std::min(numElements, operands.size()); i++) {
Expand Down
Loading