diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp index 79116ac799ad..ac0f7cdad8fc 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -269,6 +269,62 @@ cvtScalePkUpcastFromFp8(Location loc, ConversionPatternRewriter &rewriter, return ret; } +// Convert Ocp Fp8/Bf8 to Fp16/Bf16/Fp32 on CDNA4 +template +static SmallVector +cvtScalePk8UpcastFromFp8(Location loc, ConversionPatternRewriter &rewriter, + const SmallVector &v) { + assert(v.size() == 8); + const size_t inSize = v.size(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + + Type vInTy = nullptr; + Type resTy = nullptr; + Type vResTy = nullptr; + + const size_t intVecSize = (inSize * i8_ty.getWidth()) / i32_ty.getWidth(); + vInTy = vec_ty(i32_ty, intVecSize); + + if constexpr ((std::is_same_v) || + (std::is_same_v)) { + resTy = f32_ty; + } else if constexpr ((std::is_same_v) || + (std::is_same_v)) { + resTy = f16_ty; + } else if constexpr ((std::is_same_v) || + (std::is_same_v)) { + resTy = bf16_ty; + } + assert(resTy != nullptr); + + vResTy = vec_ty(resTy, inSize); + + auto vI8InTy = vec_ty(i8_ty, inSize); + + Value vI8In = b.undef(vI8InTy); + SmallVector idx; + for (size_t i = 0; i < inSize; ++i) { + idx.push_back(b.i32_val(i)); + vI8In = b.insert_element(vI8InTy, vI8In, v[i], idx[i]); + } + auto vIn = b.bitcast(vI8In, vInTy); + + Value scale = b.i32_val(127); + IntegerAttr opscale = rewriter.getI32IntegerAttr(0b1000); + + auto result = ConvertOp::create(rewriter, loc, vResTy, vIn, scale, opscale); + SmallVector ret(inSize); + for (auto [i, value] : llvm::enumerate(ret)) { + value = b.extract_element(resTy, result, idx[i]); + } + + return ret; +} + // Convert Fp16/Bf16/Fp32 to OCP Fp8/Bf8 on CDNA4 template static SmallVector @@ -734,19 +790,27 @@ static SmallVector cvtPkFp32ToF8(Location loc, return ret; } -// Convert OCP Fp8 to Fp32 on CDNA4 +// Convert OCP Fp8 to Fp32 on CDNA4+ static SmallVector Fp8E4M3FN_to_Fp32(Location loc, ConversionPatternRewriter &rewriter, const SmallVector &v) { + if (v.size() == 8) { + return cvtScalePk8UpcastFromFp8(loc, rewriter, + v); + } assert(v.size() == 4); return cvtScalePkUpcastFromFp8(loc, rewriter, v); } -// Convert OCP Bf8 to Fp32 on CDNA4 +// Convert OCP Bf8 to Fp32 on CDNA4+ static SmallVector Fp8E5M2_to_Fp32(Location loc, ConversionPatternRewriter &rewriter, const SmallVector &v) { + if (v.size() == 8) { + return cvtScalePk8UpcastFromFp8(loc, rewriter, + v); + } assert(v.size() == 4); return cvtScalePkUpcastFromFp8(loc, rewriter, v); @@ -926,7 +990,7 @@ static ConverterT Fp32_to_Fp8E4M3FNUZ(AMD::ISAFamily isaFamily) { return isCDNA4(isaFamily) ? Fp32_to_Fp8E4M3FNUZ_SW : Fp32_to_Fp8E4M3FNUZ_HW; } -// Nanoo Bf8 -> Fp32 on CDNA3 +// Nanoo Bf8 -> Fp32 on CDNA3+ static SmallVector Fp8E5M2FNUZ_to_Fp32(Location loc, ConversionPatternRewriter &rewriter, const SmallVector &v) { @@ -934,7 +998,7 @@ Fp8E5M2FNUZ_to_Fp32(Location loc, ConversionPatternRewriter &rewriter, return cvtPkF8ToFp32(loc, rewriter, v); } -// Nanoo Fp8 -> Fp32 on CDNA3 +// Nanoo Fp8 -> Fp32 on CDNA3+ static SmallVector Fp8E4M3FNUZ_to_Fp32(Location loc, ConversionPatternRewriter &rewriter, const SmallVector &v) { @@ -1023,13 +1087,18 @@ Fp8E4M3FN_to_Fp16_SW(Location loc, ConversionPatternRewriter &rewriter, static SmallVector Fp8E4M3FN_to_Fp16_HW(Location loc, ConversionPatternRewriter &rewriter, const SmallVector &v) { + if (v.size() == 8) { + return cvtScalePk8UpcastFromFp8(loc, rewriter, + v); + } assert(v.size() == 4); return cvtScalePkUpcastFromFp8(loc, rewriter, v); } ConverterT Fp8E4M3FN_to_Fp16(AMD::ISAFamily isaFamily) { - return isCDNA4(isaFamily) ? Fp8E4M3FN_to_Fp16_HW : Fp8E4M3FN_to_Fp16_SW; + return isCDNA4OrHigher(isaFamily) ? Fp8E4M3FN_to_Fp16_HW + : Fp8E4M3FN_to_Fp16_SW; } // Ocp Bf8->Fp16 @@ -1064,13 +1133,17 @@ Fp8E5M2_to_Fp16_SW(Location loc, ConversionPatternRewriter &rewriter, static SmallVector Fp8E5M2_to_Fp16_HW(Location loc, ConversionPatternRewriter &rewriter, const SmallVector &v) { + if (v.size() == 8) { + return cvtScalePk8UpcastFromFp8(loc, rewriter, + v); + } assert(v.size() == 4); return cvtScalePkUpcastFromFp8(loc, rewriter, v); } ConverterT Fp8E5M2_to_Fp16(AMD::ISAFamily isaFamily) { - return isCDNA4(isaFamily) ? Fp8E5M2_to_Fp16_HW : Fp8E5M2_to_Fp16_SW; + return isCDNA4OrHigher(isaFamily) ? Fp8E5M2_to_Fp16_HW : Fp8E5M2_to_Fp16_SW; } static SmallVector @@ -1182,7 +1255,7 @@ static SmallVector Fp32_to_F16_RTNE(Location loc, MultipleOperandsRange operands, AMD::ISAFamily isaFamily) { // For CDNA4 we can potentially use packed v_cvt_pk_[b]f16_f32 instructions. - if (isCDNA4(isaFamily)) { + if (isCDNA4OrHigher(isaFamily)) { SmallVector inVals; size_t numElem = std::min(size_t(2), operands.size()); inVals.reserve(numElem); @@ -1351,13 +1424,17 @@ Fp8E5M2_to_Bf16_SW(Location loc, ConversionPatternRewriter &rewriter, static SmallVector Fp8E5M2_to_Bf16_HW(Location loc, ConversionPatternRewriter &rewriter, const SmallVector &v) { + if (v.size() == 8) { + return cvtScalePk8UpcastFromFp8(loc, + rewriter, v); + } assert(v.size() == 4); return cvtScalePkUpcastFromFp8(loc, rewriter, v); } ConverterT Fp8E5M2_to_Bf16(AMD::ISAFamily isaFamily) { - return isCDNA4(isaFamily) ? Fp8E5M2_to_Bf16_HW : Fp8E5M2_to_Bf16_SW; + return isCDNA4OrHigher(isaFamily) ? Fp8E5M2_to_Bf16_HW : Fp8E5M2_to_Bf16_SW; } // Bf16 -> OCP Bf8 @@ -1492,13 +1569,18 @@ Fp8E4M3FN_to_Bf16_SW(Location loc, ConversionPatternRewriter &rewriter, static SmallVector Fp8E4M3FN_to_Bf16_HW(Location loc, ConversionPatternRewriter &rewriter, const SmallVector &v) { + if (v.size() == 8) { + return cvtScalePk8UpcastFromFp8(loc, + rewriter, v); + } assert(v.size() == 4); return cvtScalePkUpcastFromFp8(loc, rewriter, v); } ConverterT Fp8E4M3FN_to_Bf16(AMD::ISAFamily isaFamily) { - return isCDNA4(isaFamily) ? Fp8E4M3FN_to_Bf16_HW : Fp8E4M3FN_to_Bf16_SW; + return isCDNA4OrHigher(isaFamily) ? Fp8E4M3FN_to_Bf16_HW + : Fp8E4M3FN_to_Bf16_SW; } // fp8e4m3fnuz to bf16 @@ -1837,6 +1919,48 @@ struct FpToFpOpConversion return srcMap.lookup(key); } + int getNumElements( + Type srcElementType, Type dstElementType, + std::optional<::mlir::triton::RoundingMode> roundingMode) const { + const bool isRTZ = roundingMode == RoundingMode::RTZ; + const bool isRTNE = roundingMode == RoundingMode::RTNE; + + // numElements = 2 for : + // fp32 -> fp16 with RTZ + // fp32/fp16 -> nanoo fp8/bf8 on non-CDNA3 + if ((isa(srcElementType) && isa(dstElementType) && + isRTZ) || + (isa(srcElementType) && + isa(dstElementType) && + isaFamily != AMD::ISAFamily::CDNA3)) { + return 2; + } + + // special upcast for CDNA4 + // nanoo fp8 -> bf16 on CDNA4 (numElements = 2) + if ((isaFamily == AMD::ISAFamily::CDNA4) && + isa(srcElementType) && dstElementType.isBF16()) { + return 2; + } + + // special downcast cases for GFX1250+ + if ((isaFamily == AMD::ISAFamily::GFX1250) && + ((isa(srcElementType))) && + ((isa(dstElementType))) && isRTNE) { + return 8; + } + + // special upcast cases for GFX1250+ + if ((isaFamily == AMD::ISAFamily::GFX1250) && + ((isa(srcElementType))) && + ((isa(dstElementType)))) { + return 8; + } + + // return default value + return 4; + } + SmallVector createDestOps(triton::FpToFpOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, Type elemTy, MultipleOperandsRange operands, @@ -1860,20 +1984,8 @@ struct FpToFpOpConversion convertFp32ToBf16(loc, rewriter, operands[0][0], RoundingMode::RTZ)}; } - size_t numElements = 4; - // numElements = 2 for : - // fp32 -> fp16 with RTZ - // fp32/fp16 -> nanoo fp8/bf8 on non-CDNA3 - // nanoo fp8 -> bf16 on CDNA4 - if ((llvm::isa(srcElementType) && - llvm::isa(dstElementType) && - roundingMode == RoundingMode::RTZ) || - (llvm::isa(srcElementType) && - llvm::isa(dstElementType) && - isaFamily != AMD::ISAFamily::CDNA3) || - (llvm::isa(srcElementType) && - dstElementType.isBF16() && isCDNA4(isaFamily))) - numElements = 2; + size_t numElements = + getNumElements(srcElementType, dstElementType, roundingMode); // fp32 -> fp8 with rtne can be done in two steps: // - fp32 -> fp16 with rtne and @@ -1884,32 +1996,21 @@ struct FpToFpOpConversion // 3. fp32 -> ocp fp8/bf8 on non-CDNA4: has software support bool useFP16IntermediateSrc = srcElementType.isF32() && !dstElementType.isF16() && - !(isCDNA4(isaFamily) && + !(isCDNA4OrHigher(isaFamily) && (llvm::isa(dstElementType))) && !(isaFamily == AMD::ISAFamily::CDNA3 && (llvm::isa( dstElementType))) && - !(!isCDNA4(isaFamily) && + !(!isCDNA4OrHigher(isaFamily) && (llvm::isa(dstElementType))); - if ((isaFamily == AMD::ISAFamily::GFX1250) && - ((llvm::isa(srcElementType)) || - (llvm::isa(srcElementType)) || - (llvm::isa(srcElementType))) && - ((llvm::isa(dstElementType)) || - (llvm::isa(dstElementType))) && - ((roundingMode.has_value()) && (*roundingMode != RoundingMode::RTZ))) { - numElements = 8; - useFP16IntermediateSrc = false; - } - // fp8/bf8->f32, if neither nanoo fp8/bf8 on CDNA3 nor ocp fp8/bf8 on CDNA4, // is done in two steps: fp8/bf8->fp16 and fp16->fp32 bool isDstFP32 = dstElementType.isF32(); bool useFP16IntermediateDst = (isDstFP32 && - !(isCDNA4(isaFamily) && + !(isCDNA4OrHigher(isaFamily) && (llvm::isa(srcElementType))) && !(isaFamily == AMD::ISAFamily::CDNA3 && (llvm::isa( @@ -1917,6 +2018,7 @@ struct FpToFpOpConversion Type srcType = useFP16IntermediateSrc ? f16_ty : srcElementType; Type dstType = useFP16IntermediateDst ? f16_ty : dstElementType; + SmallVector inVals; inVals.reserve(std::min(numElements, operands.size())); for (unsigned i = 0; i < std::min(numElements, operands.size()); i++) {