@@ -269,6 +269,62 @@ cvtScalePkUpcastFromFp8(Location loc, ConversionPatternRewriter &rewriter,
269269 return ret;
270270}
271271
272+ // Convert Ocp Fp8/Bf8 to Fp16/Bf16/Fp32 on CDNA4
273+ template <typename ConvertOp>
274+ static SmallVector<Value>
275+ cvtScalePk8UpcastFromFp8 (Location loc, ConversionPatternRewriter &rewriter,
276+ const SmallVector<Value> &v) {
277+ assert (v.size () == 8 );
278+ const size_t inSize = v.size ();
279+ auto b = TritonLLVMOpBuilder (loc, rewriter);
280+
281+ Type vInTy = nullptr ;
282+ Type resTy = nullptr ;
283+ Type vResTy = nullptr ;
284+
285+ const size_t intVecSize = (inSize * i8_ty.getWidth ()) / i32_ty.getWidth ();
286+ vInTy = vec_ty (i32_ty, intVecSize);
287+
288+ if constexpr ((std::is_same_v<ConvertOp, ROCDL::CvtPkScalePk8F32Fp8Op>) ||
289+ (std::is_same_v<ConvertOp, ROCDL::CvtPkScalePk8F32Bf8Op>)) {
290+ resTy = f32_ty;
291+ } else if constexpr ((std::is_same_v<ConvertOp,
292+ ROCDL::CvtPkScalePk8F16Bf8Op>) ||
293+ (std::is_same_v<ConvertOp,
294+ ROCDL::CvtPkScalePk8F16Fp8Op>)) {
295+ resTy = f16_ty;
296+ } else if constexpr ((std::is_same_v<ConvertOp,
297+ ROCDL::CvtPkScalePk8Bf16Bf8Op>) ||
298+ (std::is_same_v<ConvertOp,
299+ ROCDL::CvtPkScalePk8Bf16Fp8Op>)) {
300+ resTy = bf16_ty;
301+ }
302+ assert (resTy != nullptr );
303+
304+ vResTy = vec_ty (resTy, inSize);
305+
306+ auto vI8InTy = vec_ty (i8_ty, inSize);
307+
308+ Value vI8In = b.undef (vI8InTy);
309+ SmallVector<Value, 8 > idx;
310+ for (size_t i = 0 ; i < inSize; ++i) {
311+ idx.push_back (b.i32_val (i));
312+ vI8In = b.insert_element (vI8InTy, vI8In, v[i], idx[i]);
313+ }
314+ auto vIn = b.bitcast (vI8In, vInTy);
315+
316+ Value scale = b.i32_val (127 );
317+ IntegerAttr opscale = rewriter.getI32IntegerAttr (0b1000 );
318+
319+ auto result = ConvertOp::create (rewriter, loc, vResTy, vIn, scale, opscale);
320+ SmallVector<Value> ret (inSize);
321+ for (auto [i, value] : llvm::enumerate (ret)) {
322+ value = b.extract_element (resTy, result, idx[i]);
323+ }
324+
325+ return ret;
326+ }
327+
272328// Convert Fp16/Bf16/Fp32 to OCP Fp8/Bf8 on CDNA4
273329template <typename ConvertOp>
274330static SmallVector<Value>
@@ -734,19 +790,27 @@ static SmallVector<Value> cvtPkFp32ToF8(Location loc,
734790 return ret;
735791}
736792
737- // Convert OCP Fp8 to Fp32 on CDNA4
793+ // Convert OCP Fp8 to Fp32 on CDNA4+
738794static SmallVector<Value> Fp8E4M3FN_to_Fp32 (Location loc,
739795 ConversionPatternRewriter &rewriter,
740796 const SmallVector<Value> &v) {
797+ if (v.size () == 8 ) {
798+ return cvtScalePk8UpcastFromFp8<ROCDL::CvtPkScalePk8F32Fp8Op>(loc, rewriter,
799+ v);
800+ }
741801 assert (v.size () == 4 );
742802 return cvtScalePkUpcastFromFp8<ROCDL::CvtScaleF32PkF32Fp8Op>(loc, rewriter,
743803 v);
744804}
745805
746- // Convert OCP Bf8 to Fp32 on CDNA4
806+ // Convert OCP Bf8 to Fp32 on CDNA4+
747807static SmallVector<Value> Fp8E5M2_to_Fp32 (Location loc,
748808 ConversionPatternRewriter &rewriter,
749809 const SmallVector<Value> &v) {
810+ if (v.size () == 8 ) {
811+ return cvtScalePk8UpcastFromFp8<ROCDL::CvtPkScalePk8F32Bf8Op>(loc, rewriter,
812+ v);
813+ }
750814 assert (v.size () == 4 );
751815 return cvtScalePkUpcastFromFp8<ROCDL::CvtScaleF32PkF32Bf8Op>(loc, rewriter,
752816 v);
@@ -926,15 +990,15 @@ static ConverterT Fp32_to_Fp8E4M3FNUZ(AMD::ISAFamily isaFamily) {
926990 return isCDNA4 (isaFamily) ? Fp32_to_Fp8E4M3FNUZ_SW : Fp32_to_Fp8E4M3FNUZ_HW;
927991}
928992
929- // Nanoo Bf8 -> Fp32 on CDNA3
993+ // Nanoo Bf8 -> Fp32 on CDNA3+
930994static SmallVector<Value>
931995Fp8E5M2FNUZ_to_Fp32 (Location loc, ConversionPatternRewriter &rewriter,
932996 const SmallVector<Value> &v) {
933997 assert (v.size () == 4 );
934998 return cvtPkF8ToFp32<ROCDL::CvtPkF32Bf8Op>(loc, rewriter, v);
935999}
9361000
937- // Nanoo Fp8 -> Fp32 on CDNA3
1001+ // Nanoo Fp8 -> Fp32 on CDNA3+
9381002static SmallVector<Value>
9391003Fp8E4M3FNUZ_to_Fp32 (Location loc, ConversionPatternRewriter &rewriter,
9401004 const SmallVector<Value> &v) {
@@ -1023,13 +1087,18 @@ Fp8E4M3FN_to_Fp16_SW(Location loc, ConversionPatternRewriter &rewriter,
10231087static SmallVector<Value>
10241088Fp8E4M3FN_to_Fp16_HW (Location loc, ConversionPatternRewriter &rewriter,
10251089 const SmallVector<Value> &v) {
1090+ if (v.size () == 8 ) {
1091+ return cvtScalePk8UpcastFromFp8<ROCDL::CvtPkScalePk8F16Fp8Op>(loc, rewriter,
1092+ v);
1093+ }
10261094 assert (v.size () == 4 );
10271095 return cvtScalePkUpcastFromFp8<ROCDL::CvtScaleF32PkF16Fp8Op>(loc, rewriter,
10281096 v);
10291097}
10301098
10311099ConverterT Fp8E4M3FN_to_Fp16 (AMD::ISAFamily isaFamily) {
1032- return isCDNA4 (isaFamily) ? Fp8E4M3FN_to_Fp16_HW : Fp8E4M3FN_to_Fp16_SW;
1100+ return isCDNA4OrHigher (isaFamily) ? Fp8E4M3FN_to_Fp16_HW
1101+ : Fp8E4M3FN_to_Fp16_SW;
10331102}
10341103
10351104// Ocp Bf8->Fp16
@@ -1064,13 +1133,17 @@ Fp8E5M2_to_Fp16_SW(Location loc, ConversionPatternRewriter &rewriter,
10641133static SmallVector<Value>
10651134Fp8E5M2_to_Fp16_HW (Location loc, ConversionPatternRewriter &rewriter,
10661135 const SmallVector<Value> &v) {
1136+ if (v.size () == 8 ) {
1137+ return cvtScalePk8UpcastFromFp8<ROCDL::CvtPkScalePk8F16Bf8Op>(loc, rewriter,
1138+ v);
1139+ }
10671140 assert (v.size () == 4 );
10681141 return cvtScalePkUpcastFromFp8<ROCDL::CvtScaleF32PkF16Bf8Op>(loc, rewriter,
10691142 v);
10701143}
10711144
10721145ConverterT Fp8E5M2_to_Fp16 (AMD::ISAFamily isaFamily) {
1073- return isCDNA4 (isaFamily) ? Fp8E5M2_to_Fp16_HW : Fp8E5M2_to_Fp16_SW;
1146+ return isCDNA4OrHigher (isaFamily) ? Fp8E5M2_to_Fp16_HW : Fp8E5M2_to_Fp16_SW;
10741147}
10751148
10761149static SmallVector<Value>
@@ -1182,7 +1255,7 @@ static SmallVector<Value> Fp32_to_F16_RTNE(Location loc,
11821255 MultipleOperandsRange operands,
11831256 AMD::ISAFamily isaFamily) {
11841257 // For CDNA4 we can potentially use packed v_cvt_pk_[b]f16_f32 instructions.
1185- if (isCDNA4 (isaFamily)) {
1258+ if (isCDNA4OrHigher (isaFamily)) {
11861259 SmallVector<Value> inVals;
11871260 size_t numElem = std::min (size_t (2 ), operands.size ());
11881261 inVals.reserve (numElem);
@@ -1351,13 +1424,17 @@ Fp8E5M2_to_Bf16_SW(Location loc, ConversionPatternRewriter &rewriter,
13511424static SmallVector<Value>
13521425Fp8E5M2_to_Bf16_HW (Location loc, ConversionPatternRewriter &rewriter,
13531426 const SmallVector<Value> &v) {
1427+ if (v.size () == 8 ) {
1428+ return cvtScalePk8UpcastFromFp8<ROCDL::CvtPkScalePk8Bf16Bf8Op>(loc,
1429+ rewriter, v);
1430+ }
13541431 assert (v.size () == 4 );
13551432 return cvtScalePkUpcastFromFp8<ROCDL::CvtScaleF32PkBf16Bf8Op>(loc, rewriter,
13561433 v);
13571434}
13581435
13591436ConverterT Fp8E5M2_to_Bf16 (AMD::ISAFamily isaFamily) {
1360- return isCDNA4 (isaFamily) ? Fp8E5M2_to_Bf16_HW : Fp8E5M2_to_Bf16_SW;
1437+ return isCDNA4OrHigher (isaFamily) ? Fp8E5M2_to_Bf16_HW : Fp8E5M2_to_Bf16_SW;
13611438}
13621439
13631440// Bf16 -> OCP Bf8
@@ -1492,13 +1569,18 @@ Fp8E4M3FN_to_Bf16_SW(Location loc, ConversionPatternRewriter &rewriter,
14921569static SmallVector<Value>
14931570Fp8E4M3FN_to_Bf16_HW (Location loc, ConversionPatternRewriter &rewriter,
14941571 const SmallVector<Value> &v) {
1572+ if (v.size () == 8 ) {
1573+ return cvtScalePk8UpcastFromFp8<ROCDL::CvtPkScalePk8Bf16Fp8Op>(loc,
1574+ rewriter, v);
1575+ }
14951576 assert (v.size () == 4 );
14961577 return cvtScalePkUpcastFromFp8<ROCDL::CvtScaleF32PkBf16Fp8Op>(loc, rewriter,
14971578 v);
14981579}
14991580
15001581ConverterT Fp8E4M3FN_to_Bf16 (AMD::ISAFamily isaFamily) {
1501- return isCDNA4 (isaFamily) ? Fp8E4M3FN_to_Bf16_HW : Fp8E4M3FN_to_Bf16_SW;
1582+ return isCDNA4OrHigher (isaFamily) ? Fp8E4M3FN_to_Bf16_HW
1583+ : Fp8E4M3FN_to_Bf16_SW;
15021584}
15031585
15041586// fp8e4m3fnuz to bf16
@@ -1837,6 +1919,48 @@ struct FpToFpOpConversion
18371919 return srcMap.lookup (key);
18381920 }
18391921
1922+ int getNumElements (
1923+ Type srcElementType, Type dstElementType,
1924+ std::optional<::mlir::triton::RoundingMode> roundingMode) const {
1925+ const bool isRTZ = roundingMode == RoundingMode::RTZ;
1926+ const bool isRTNE = roundingMode == RoundingMode::RTNE;
1927+
1928+ // numElements = 2 for :
1929+ // fp32 -> fp16 with RTZ
1930+ // fp32/fp16 -> nanoo fp8/bf8 on non-CDNA3
1931+ if ((isa<Float32Type>(srcElementType) && isa<Float16Type>(dstElementType) &&
1932+ isRTZ) ||
1933+ (isa<Float32Type, Float16Type>(srcElementType) &&
1934+ isa<Float8E4M3FNUZType, Float8E5M2FNUZType>(dstElementType) &&
1935+ isaFamily != AMD::ISAFamily::CDNA3)) {
1936+ return 2 ;
1937+ }
1938+
1939+ // special upcast for CDNA4
1940+ // nanoo fp8 -> bf16 on CDNA4 (numElements = 2)
1941+ if ((isaFamily == AMD::ISAFamily::CDNA4) &&
1942+ isa<Float8E4M3FNUZType>(srcElementType) && dstElementType.isBF16 ()) {
1943+ return 2 ;
1944+ }
1945+
1946+ // special downcast cases for GFX1250+
1947+ if ((isaFamily == AMD::ISAFamily::GFX1250) &&
1948+ ((isa<Float32Type, Float16Type, BFloat16Type>(srcElementType))) &&
1949+ ((isa<Float8E4M3FNType, Float8E5M2Type>(dstElementType))) && isRTNE) {
1950+ return 8 ;
1951+ }
1952+
1953+ // special upcast cases for GFX1250+
1954+ if ((isaFamily == AMD::ISAFamily::GFX1250) &&
1955+ ((isa<Float8E5M2Type, Float8E4M3FNType>(srcElementType))) &&
1956+ ((isa<Float16Type, BFloat16Type, Float32Type>(dstElementType)))) {
1957+ return 8 ;
1958+ }
1959+
1960+ // return default value
1961+ return 4 ;
1962+ }
1963+
18401964 SmallVector<Value> createDestOps (triton::FpToFpOp op, OpAdaptor adaptor,
18411965 ConversionPatternRewriter &rewriter,
18421966 Type elemTy, MultipleOperandsRange operands,
@@ -1860,20 +1984,8 @@ struct FpToFpOpConversion
18601984 convertFp32ToBf16 (loc, rewriter, operands[0 ][0 ], RoundingMode::RTZ)};
18611985 }
18621986
1863- size_t numElements = 4 ;
1864- // numElements = 2 for :
1865- // fp32 -> fp16 with RTZ
1866- // fp32/fp16 -> nanoo fp8/bf8 on non-CDNA3
1867- // nanoo fp8 -> bf16 on CDNA4
1868- if ((llvm::isa<Float32Type>(srcElementType) &&
1869- llvm::isa<Float16Type>(dstElementType) &&
1870- roundingMode == RoundingMode::RTZ) ||
1871- (llvm::isa<Float32Type, Float16Type>(srcElementType) &&
1872- llvm::isa<Float8E4M3FNUZType, Float8E5M2FNUZType>(dstElementType) &&
1873- isaFamily != AMD::ISAFamily::CDNA3) ||
1874- (llvm::isa<Float8E4M3FNUZType>(srcElementType) &&
1875- dstElementType.isBF16 () && isCDNA4 (isaFamily)))
1876- numElements = 2 ;
1987+ size_t numElements =
1988+ getNumElements (srcElementType, dstElementType, roundingMode);
18771989
18781990 // fp32 -> fp8 with rtne can be done in two steps:
18791991 // - fp32 -> fp16 with rtne and
@@ -1884,39 +1996,29 @@ struct FpToFpOpConversion
18841996 // 3. fp32 -> ocp fp8/bf8 on non-CDNA4: has software support
18851997 bool useFP16IntermediateSrc =
18861998 srcElementType.isF32 () && !dstElementType.isF16 () &&
1887- !(isCDNA4 (isaFamily) &&
1999+ !(isCDNA4OrHigher (isaFamily) &&
18882000 (llvm::isa<Float8E4M3FNType, Float8E4M3FNUZType, Float8E5M2Type,
18892001 Float8E5M2FNUZType>(dstElementType))) &&
18902002 !(isaFamily == AMD::ISAFamily::CDNA3 &&
18912003 (llvm::isa<Float8E4M3FNUZType, Float8E5M2FNUZType>(
18922004 dstElementType))) &&
1893- !(!isCDNA4 (isaFamily) &&
2005+ !(!isCDNA4OrHigher (isaFamily) &&
18942006 (llvm::isa<Float8E5M2Type, Float8E4M3FNType>(dstElementType)));
18952007
1896- if ((isaFamily == AMD::ISAFamily::GFX1250) &&
1897- ((llvm::isa<Float32Type>(srcElementType)) ||
1898- (llvm::isa<Float16Type>(srcElementType)) ||
1899- (llvm::isa<BFloat16Type>(srcElementType))) &&
1900- ((llvm::isa<Float8E4M3FNType>(dstElementType)) ||
1901- (llvm::isa<Float8E5M2Type>(dstElementType))) &&
1902- ((roundingMode.has_value ()) && (*roundingMode != RoundingMode::RTZ))) {
1903- numElements = 8 ;
1904- useFP16IntermediateSrc = false ;
1905- }
1906-
19072008 // fp8/bf8->f32, if neither nanoo fp8/bf8 on CDNA3 nor ocp fp8/bf8 on CDNA4,
19082009 // is done in two steps: fp8/bf8->fp16 and fp16->fp32
19092010 bool isDstFP32 = dstElementType.isF32 ();
19102011 bool useFP16IntermediateDst =
19112012 (isDstFP32 &&
1912- !(isCDNA4 (isaFamily) &&
2013+ !(isCDNA4OrHigher (isaFamily) &&
19132014 (llvm::isa<Float8E4M3FNType, Float8E5M2Type>(srcElementType))) &&
19142015 !(isaFamily == AMD::ISAFamily::CDNA3 &&
19152016 (llvm::isa<Float8E4M3FNUZType, Float8E5M2FNUZType>(
19162017 srcElementType))));
19172018
19182019 Type srcType = useFP16IntermediateSrc ? f16_ty : srcElementType;
19192020 Type dstType = useFP16IntermediateDst ? f16_ty : dstElementType;
2021+
19202022 SmallVector<Value> inVals;
19212023 inVals.reserve (std::min (numElements, operands.size ()));
19222024 for (unsigned i = 0 ; i < std::min (numElements, operands.size ()); i++) {
0 commit comments