Skip to content

Commit f95bb15

Browse files
authored
[AMD] Added hw FP upcast conversions for gfx1250 (#9449)
The PR adds missing upcast HW conversion for GFX1250. The PR also fixes some FP truncation for GFX1250.
1 parent df4d375 commit f95bb15

1 file changed

Lines changed: 139 additions & 37 deletions

File tree

third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp

Lines changed: 139 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,62 @@ cvtScalePkUpcastFromFp8(Location loc, ConversionPatternRewriter &rewriter,
269269
return ret;
270270
}
271271

272+
// Convert Ocp Fp8/Bf8 to Fp16/Bf16/Fp32 on CDNA4
273+
template <typename ConvertOp>
274+
static SmallVector<Value>
275+
cvtScalePk8UpcastFromFp8(Location loc, ConversionPatternRewriter &rewriter,
276+
const SmallVector<Value> &v) {
277+
assert(v.size() == 8);
278+
const size_t inSize = v.size();
279+
auto b = TritonLLVMOpBuilder(loc, rewriter);
280+
281+
Type vInTy = nullptr;
282+
Type resTy = nullptr;
283+
Type vResTy = nullptr;
284+
285+
const size_t intVecSize = (inSize * i8_ty.getWidth()) / i32_ty.getWidth();
286+
vInTy = vec_ty(i32_ty, intVecSize);
287+
288+
if constexpr ((std::is_same_v<ConvertOp, ROCDL::CvtPkScalePk8F32Fp8Op>) ||
289+
(std::is_same_v<ConvertOp, ROCDL::CvtPkScalePk8F32Bf8Op>)) {
290+
resTy = f32_ty;
291+
} else if constexpr ((std::is_same_v<ConvertOp,
292+
ROCDL::CvtPkScalePk8F16Bf8Op>) ||
293+
(std::is_same_v<ConvertOp,
294+
ROCDL::CvtPkScalePk8F16Fp8Op>)) {
295+
resTy = f16_ty;
296+
} else if constexpr ((std::is_same_v<ConvertOp,
297+
ROCDL::CvtPkScalePk8Bf16Bf8Op>) ||
298+
(std::is_same_v<ConvertOp,
299+
ROCDL::CvtPkScalePk8Bf16Fp8Op>)) {
300+
resTy = bf16_ty;
301+
}
302+
assert(resTy != nullptr);
303+
304+
vResTy = vec_ty(resTy, inSize);
305+
306+
auto vI8InTy = vec_ty(i8_ty, inSize);
307+
308+
Value vI8In = b.undef(vI8InTy);
309+
SmallVector<Value, 8> idx;
310+
for (size_t i = 0; i < inSize; ++i) {
311+
idx.push_back(b.i32_val(i));
312+
vI8In = b.insert_element(vI8InTy, vI8In, v[i], idx[i]);
313+
}
314+
auto vIn = b.bitcast(vI8In, vInTy);
315+
316+
Value scale = b.i32_val(127);
317+
IntegerAttr opscale = rewriter.getI32IntegerAttr(0b1000);
318+
319+
auto result = ConvertOp::create(rewriter, loc, vResTy, vIn, scale, opscale);
320+
SmallVector<Value> ret(inSize);
321+
for (auto [i, value] : llvm::enumerate(ret)) {
322+
value = b.extract_element(resTy, result, idx[i]);
323+
}
324+
325+
return ret;
326+
}
327+
272328
// Convert Fp16/Bf16/Fp32 to OCP Fp8/Bf8 on CDNA4
273329
template <typename ConvertOp>
274330
static SmallVector<Value>
@@ -734,19 +790,27 @@ static SmallVector<Value> cvtPkFp32ToF8(Location loc,
734790
return ret;
735791
}
736792

737-
// Convert OCP Fp8 to Fp32 on CDNA4
793+
// Convert OCP Fp8 to Fp32 on CDNA4+
738794
static SmallVector<Value> Fp8E4M3FN_to_Fp32(Location loc,
739795
ConversionPatternRewriter &rewriter,
740796
const SmallVector<Value> &v) {
797+
if (v.size() == 8) {
798+
return cvtScalePk8UpcastFromFp8<ROCDL::CvtPkScalePk8F32Fp8Op>(loc, rewriter,
799+
v);
800+
}
741801
assert(v.size() == 4);
742802
return cvtScalePkUpcastFromFp8<ROCDL::CvtScaleF32PkF32Fp8Op>(loc, rewriter,
743803
v);
744804
}
745805

746-
// Convert OCP Bf8 to Fp32 on CDNA4
806+
// Convert OCP Bf8 to Fp32 on CDNA4+
747807
static SmallVector<Value> Fp8E5M2_to_Fp32(Location loc,
748808
ConversionPatternRewriter &rewriter,
749809
const SmallVector<Value> &v) {
810+
if (v.size() == 8) {
811+
return cvtScalePk8UpcastFromFp8<ROCDL::CvtPkScalePk8F32Bf8Op>(loc, rewriter,
812+
v);
813+
}
750814
assert(v.size() == 4);
751815
return cvtScalePkUpcastFromFp8<ROCDL::CvtScaleF32PkF32Bf8Op>(loc, rewriter,
752816
v);
@@ -926,15 +990,15 @@ static ConverterT Fp32_to_Fp8E4M3FNUZ(AMD::ISAFamily isaFamily) {
926990
return isCDNA4(isaFamily) ? Fp32_to_Fp8E4M3FNUZ_SW : Fp32_to_Fp8E4M3FNUZ_HW;
927991
}
928992

929-
// Nanoo Bf8 -> Fp32 on CDNA3
993+
// Nanoo Bf8 -> Fp32 on CDNA3+
930994
static SmallVector<Value>
931995
Fp8E5M2FNUZ_to_Fp32(Location loc, ConversionPatternRewriter &rewriter,
932996
const SmallVector<Value> &v) {
933997
assert(v.size() == 4);
934998
return cvtPkF8ToFp32<ROCDL::CvtPkF32Bf8Op>(loc, rewriter, v);
935999
}
9361000

937-
// Nanoo Fp8 -> Fp32 on CDNA3
1001+
// Nanoo Fp8 -> Fp32 on CDNA3+
9381002
static SmallVector<Value>
9391003
Fp8E4M3FNUZ_to_Fp32(Location loc, ConversionPatternRewriter &rewriter,
9401004
const SmallVector<Value> &v) {
@@ -1023,13 +1087,18 @@ Fp8E4M3FN_to_Fp16_SW(Location loc, ConversionPatternRewriter &rewriter,
10231087
static SmallVector<Value>
10241088
Fp8E4M3FN_to_Fp16_HW(Location loc, ConversionPatternRewriter &rewriter,
10251089
const SmallVector<Value> &v) {
1090+
if (v.size() == 8) {
1091+
return cvtScalePk8UpcastFromFp8<ROCDL::CvtPkScalePk8F16Fp8Op>(loc, rewriter,
1092+
v);
1093+
}
10261094
assert(v.size() == 4);
10271095
return cvtScalePkUpcastFromFp8<ROCDL::CvtScaleF32PkF16Fp8Op>(loc, rewriter,
10281096
v);
10291097
}
10301098

10311099
ConverterT Fp8E4M3FN_to_Fp16(AMD::ISAFamily isaFamily) {
1032-
return isCDNA4(isaFamily) ? Fp8E4M3FN_to_Fp16_HW : Fp8E4M3FN_to_Fp16_SW;
1100+
return isCDNA4OrHigher(isaFamily) ? Fp8E4M3FN_to_Fp16_HW
1101+
: Fp8E4M3FN_to_Fp16_SW;
10331102
}
10341103

10351104
// Ocp Bf8->Fp16
@@ -1064,13 +1133,17 @@ Fp8E5M2_to_Fp16_SW(Location loc, ConversionPatternRewriter &rewriter,
10641133
static SmallVector<Value>
10651134
Fp8E5M2_to_Fp16_HW(Location loc, ConversionPatternRewriter &rewriter,
10661135
const SmallVector<Value> &v) {
1136+
if (v.size() == 8) {
1137+
return cvtScalePk8UpcastFromFp8<ROCDL::CvtPkScalePk8F16Bf8Op>(loc, rewriter,
1138+
v);
1139+
}
10671140
assert(v.size() == 4);
10681141
return cvtScalePkUpcastFromFp8<ROCDL::CvtScaleF32PkF16Bf8Op>(loc, rewriter,
10691142
v);
10701143
}
10711144

10721145
ConverterT Fp8E5M2_to_Fp16(AMD::ISAFamily isaFamily) {
1073-
return isCDNA4(isaFamily) ? Fp8E5M2_to_Fp16_HW : Fp8E5M2_to_Fp16_SW;
1146+
return isCDNA4OrHigher(isaFamily) ? Fp8E5M2_to_Fp16_HW : Fp8E5M2_to_Fp16_SW;
10741147
}
10751148

10761149
static SmallVector<Value>
@@ -1182,7 +1255,7 @@ static SmallVector<Value> Fp32_to_F16_RTNE(Location loc,
11821255
MultipleOperandsRange operands,
11831256
AMD::ISAFamily isaFamily) {
11841257
// For CDNA4 we can potentially use packed v_cvt_pk_[b]f16_f32 instructions.
1185-
if (isCDNA4(isaFamily)) {
1258+
if (isCDNA4OrHigher(isaFamily)) {
11861259
SmallVector<Value> inVals;
11871260
size_t numElem = std::min(size_t(2), operands.size());
11881261
inVals.reserve(numElem);
@@ -1351,13 +1424,17 @@ Fp8E5M2_to_Bf16_SW(Location loc, ConversionPatternRewriter &rewriter,
13511424
static SmallVector<Value>
13521425
Fp8E5M2_to_Bf16_HW(Location loc, ConversionPatternRewriter &rewriter,
13531426
const SmallVector<Value> &v) {
1427+
if (v.size() == 8) {
1428+
return cvtScalePk8UpcastFromFp8<ROCDL::CvtPkScalePk8Bf16Bf8Op>(loc,
1429+
rewriter, v);
1430+
}
13541431
assert(v.size() == 4);
13551432
return cvtScalePkUpcastFromFp8<ROCDL::CvtScaleF32PkBf16Bf8Op>(loc, rewriter,
13561433
v);
13571434
}
13581435

13591436
ConverterT Fp8E5M2_to_Bf16(AMD::ISAFamily isaFamily) {
1360-
return isCDNA4(isaFamily) ? Fp8E5M2_to_Bf16_HW : Fp8E5M2_to_Bf16_SW;
1437+
return isCDNA4OrHigher(isaFamily) ? Fp8E5M2_to_Bf16_HW : Fp8E5M2_to_Bf16_SW;
13611438
}
13621439

13631440
// Bf16 -> OCP Bf8
@@ -1492,13 +1569,18 @@ Fp8E4M3FN_to_Bf16_SW(Location loc, ConversionPatternRewriter &rewriter,
14921569
static SmallVector<Value>
14931570
Fp8E4M3FN_to_Bf16_HW(Location loc, ConversionPatternRewriter &rewriter,
14941571
const SmallVector<Value> &v) {
1572+
if (v.size() == 8) {
1573+
return cvtScalePk8UpcastFromFp8<ROCDL::CvtPkScalePk8Bf16Fp8Op>(loc,
1574+
rewriter, v);
1575+
}
14951576
assert(v.size() == 4);
14961577
return cvtScalePkUpcastFromFp8<ROCDL::CvtScaleF32PkBf16Fp8Op>(loc, rewriter,
14971578
v);
14981579
}
14991580

15001581
ConverterT Fp8E4M3FN_to_Bf16(AMD::ISAFamily isaFamily) {
1501-
return isCDNA4(isaFamily) ? Fp8E4M3FN_to_Bf16_HW : Fp8E4M3FN_to_Bf16_SW;
1582+
return isCDNA4OrHigher(isaFamily) ? Fp8E4M3FN_to_Bf16_HW
1583+
: Fp8E4M3FN_to_Bf16_SW;
15021584
}
15031585

15041586
// fp8e4m3fnuz to bf16
@@ -1837,6 +1919,48 @@ struct FpToFpOpConversion
18371919
return srcMap.lookup(key);
18381920
}
18391921

1922+
int getNumElements(
1923+
Type srcElementType, Type dstElementType,
1924+
std::optional<::mlir::triton::RoundingMode> roundingMode) const {
1925+
const bool isRTZ = roundingMode == RoundingMode::RTZ;
1926+
const bool isRTNE = roundingMode == RoundingMode::RTNE;
1927+
1928+
// numElements = 2 for :
1929+
// fp32 -> fp16 with RTZ
1930+
// fp32/fp16 -> nanoo fp8/bf8 on non-CDNA3
1931+
if ((isa<Float32Type>(srcElementType) && isa<Float16Type>(dstElementType) &&
1932+
isRTZ) ||
1933+
(isa<Float32Type, Float16Type>(srcElementType) &&
1934+
isa<Float8E4M3FNUZType, Float8E5M2FNUZType>(dstElementType) &&
1935+
isaFamily != AMD::ISAFamily::CDNA3)) {
1936+
return 2;
1937+
}
1938+
1939+
// special upcast for CDNA4
1940+
// nanoo fp8 -> bf16 on CDNA4 (numElements = 2)
1941+
if ((isaFamily == AMD::ISAFamily::CDNA4) &&
1942+
isa<Float8E4M3FNUZType>(srcElementType) && dstElementType.isBF16()) {
1943+
return 2;
1944+
}
1945+
1946+
// special downcast cases for GFX1250+
1947+
if ((isaFamily == AMD::ISAFamily::GFX1250) &&
1948+
((isa<Float32Type, Float16Type, BFloat16Type>(srcElementType))) &&
1949+
((isa<Float8E4M3FNType, Float8E5M2Type>(dstElementType))) && isRTNE) {
1950+
return 8;
1951+
}
1952+
1953+
// special upcast cases for GFX1250+
1954+
if ((isaFamily == AMD::ISAFamily::GFX1250) &&
1955+
((isa<Float8E5M2Type, Float8E4M3FNType>(srcElementType))) &&
1956+
((isa<Float16Type, BFloat16Type, Float32Type>(dstElementType)))) {
1957+
return 8;
1958+
}
1959+
1960+
// return default value
1961+
return 4;
1962+
}
1963+
18401964
SmallVector<Value> createDestOps(triton::FpToFpOp op, OpAdaptor adaptor,
18411965
ConversionPatternRewriter &rewriter,
18421966
Type elemTy, MultipleOperandsRange operands,
@@ -1860,20 +1984,8 @@ struct FpToFpOpConversion
18601984
convertFp32ToBf16(loc, rewriter, operands[0][0], RoundingMode::RTZ)};
18611985
}
18621986

1863-
size_t numElements = 4;
1864-
// numElements = 2 for :
1865-
// fp32 -> fp16 with RTZ
1866-
// fp32/fp16 -> nanoo fp8/bf8 on non-CDNA3
1867-
// nanoo fp8 -> bf16 on CDNA4
1868-
if ((llvm::isa<Float32Type>(srcElementType) &&
1869-
llvm::isa<Float16Type>(dstElementType) &&
1870-
roundingMode == RoundingMode::RTZ) ||
1871-
(llvm::isa<Float32Type, Float16Type>(srcElementType) &&
1872-
llvm::isa<Float8E4M3FNUZType, Float8E5M2FNUZType>(dstElementType) &&
1873-
isaFamily != AMD::ISAFamily::CDNA3) ||
1874-
(llvm::isa<Float8E4M3FNUZType>(srcElementType) &&
1875-
dstElementType.isBF16() && isCDNA4(isaFamily)))
1876-
numElements = 2;
1987+
size_t numElements =
1988+
getNumElements(srcElementType, dstElementType, roundingMode);
18771989

18781990
// fp32 -> fp8 with rtne can be done in two steps:
18791991
// - fp32 -> fp16 with rtne and
@@ -1884,39 +1996,29 @@ struct FpToFpOpConversion
18841996
// 3. fp32 -> ocp fp8/bf8 on non-CDNA4: has software support
18851997
bool useFP16IntermediateSrc =
18861998
srcElementType.isF32() && !dstElementType.isF16() &&
1887-
!(isCDNA4(isaFamily) &&
1999+
!(isCDNA4OrHigher(isaFamily) &&
18882000
(llvm::isa<Float8E4M3FNType, Float8E4M3FNUZType, Float8E5M2Type,
18892001
Float8E5M2FNUZType>(dstElementType))) &&
18902002
!(isaFamily == AMD::ISAFamily::CDNA3 &&
18912003
(llvm::isa<Float8E4M3FNUZType, Float8E5M2FNUZType>(
18922004
dstElementType))) &&
1893-
!(!isCDNA4(isaFamily) &&
2005+
!(!isCDNA4OrHigher(isaFamily) &&
18942006
(llvm::isa<Float8E5M2Type, Float8E4M3FNType>(dstElementType)));
18952007

1896-
if ((isaFamily == AMD::ISAFamily::GFX1250) &&
1897-
((llvm::isa<Float32Type>(srcElementType)) ||
1898-
(llvm::isa<Float16Type>(srcElementType)) ||
1899-
(llvm::isa<BFloat16Type>(srcElementType))) &&
1900-
((llvm::isa<Float8E4M3FNType>(dstElementType)) ||
1901-
(llvm::isa<Float8E5M2Type>(dstElementType))) &&
1902-
((roundingMode.has_value()) && (*roundingMode != RoundingMode::RTZ))) {
1903-
numElements = 8;
1904-
useFP16IntermediateSrc = false;
1905-
}
1906-
19072008
// fp8/bf8->f32, if neither nanoo fp8/bf8 on CDNA3 nor ocp fp8/bf8 on CDNA4,
19082009
// is done in two steps: fp8/bf8->fp16 and fp16->fp32
19092010
bool isDstFP32 = dstElementType.isF32();
19102011
bool useFP16IntermediateDst =
19112012
(isDstFP32 &&
1912-
!(isCDNA4(isaFamily) &&
2013+
!(isCDNA4OrHigher(isaFamily) &&
19132014
(llvm::isa<Float8E4M3FNType, Float8E5M2Type>(srcElementType))) &&
19142015
!(isaFamily == AMD::ISAFamily::CDNA3 &&
19152016
(llvm::isa<Float8E4M3FNUZType, Float8E5M2FNUZType>(
19162017
srcElementType))));
19172018

19182019
Type srcType = useFP16IntermediateSrc ? f16_ty : srcElementType;
19192020
Type dstType = useFP16IntermediateDst ? f16_ty : dstElementType;
2021+
19202022
SmallVector<Value> inVals;
19212023
inVals.reserve(std::min(numElements, operands.size()));
19222024
for (unsigned i = 0; i < std::min(numElements, operands.size()); i++) {

0 commit comments

Comments
 (0)