diff --git a/third_party/hcu/backend/compiler.py b/third_party/hcu/backend/compiler.py index b0dc5b746c..3a542a35e5 100644 --- a/third_party/hcu/backend/compiler.py +++ b/third_party/hcu/backend/compiler.py @@ -38,7 +38,7 @@ class HIPOptions: cluster_dims: tuple = (1, 1, 1) debug: bool = False arch: str = None - supported_fp8_dtypes: Tuple[str] = ("fp8e5", ) + supported_fp8_dtypes: Tuple[str] = ("fp8e5", "fp8e4nv") deprecated_fp8_dtypes: Tuple[str] = () default_dot_input_precision: str = "ieee" allowed_dot_input_precisions: Tuple[str] = ("ieee", ) @@ -261,7 +261,7 @@ def make_llir(src, metadata, options): ## 3. __HIP_FTZ is default to 1 and not exposed as a kernel argument. ## For now it is used as a controller for developers only. __HIP_FTZ = True - if (options.num_stages >= 2): + if (options.num_stages >= 2) and os.environ.get("TRITON_MOVE_LOAD_TOFRONT_DOT", "0") == "1": hcu.passes.ttgpuir.add_move_load_tofront_dot(pm) # hcu.passes.ttgpuir.add_control_fa_fwd_bufferload_cnt(pm, 0) hcu.passes.ttgpuir.add_to_llvmir(pm, options.arch, __HIP_FTZ) diff --git a/third_party/hcu/include/triton/Conversion/TritonGPUToLLVM/Utility.h b/third_party/hcu/include/triton/Conversion/TritonGPUToLLVM/Utility.h index ca29c37d8d..0bd4205fe6 100644 --- a/third_party/hcu/include/triton/Conversion/TritonGPUToLLVM/Utility.h +++ b/third_party/hcu/include/triton/Conversion/TritonGPUToLLVM/Utility.h @@ -134,6 +134,8 @@ using namespace mlir::triton; #define i64_val(...) LLVM::createConstantI64(loc, rewriter, __VA_ARGS__) #define int_val(width, val) \ LLVM::createLLVMIntegerConstant(rewriter, loc, width, val) +#define i8_val(val) int_val(8, val) +#define i16_val(val) int_val(16, val) #define tid_val() getThreadId(rewriter, loc) // Attributes diff --git a/third_party/hcu/lib/TritonHCUGPUToLLVM/ElementwiseOpToLLVM.cpp b/third_party/hcu/lib/TritonHCUGPUToLLVM/ElementwiseOpToLLVM.cpp index 6642f6990e..52fee8a9c8 100644 --- a/third_party/hcu/lib/TritonHCUGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/third_party/hcu/lib/TritonHCUGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -83,6 +83,183 @@ Fp16_to_Fp8E5M2_RTZ(Location loc, ConversionPatternRewriter &rewriter, extract_element(i8_ty, a1, i32_val(3))}; } +//===----------------===// +/// FP8E4M3 +//===----------------===// +static Value +Fp32_to_Fp8E4M3FN_RTNE_oneValue(Location loc, + ConversionPatternRewriter &rewriter, Value v) { + // == Step 1: Bitcast to i32 and extract components == + Value vi32 = bitcast(v, i32_ty); + Value sign = trunc(i8_ty, lshr(vi32, i32_val(24))); // get sign << 7 + sign = and_(i8_ty, sign, i8_val(0x80)); // keep sign bit + + Value abs = and_(i32_ty, vi32, i32_val(0x7FFFFFFF)); // strip sign + + // == Step 2: NaN check == + Value isNaN = and_( + i1_ty, + icmp_eq(and_(i32_ty, vi32, i32_val(0x7F800000)), + i32_val(0x7F800000)), // exp == 0xFF + icmp_ne(and_(i32_ty, vi32, i32_val(0x007FFFFF)), i32_val(0)) // frac != 0 + ); + + // == Step 3: Rounding (RTNE) == + // bias diff = (127 - 7) << 23 = 120 << 23 = 0x3C000000 + // mantissa alignment: keep top 3 mantissa bits => shift right by 23 - 3 = 20 + constexpr uint32_t baseRoundingBias = + (1 << 19) - 1; // 0x7FFFF = round-to-even + + Value roundBit = lshr(and_(i32_ty, vi32, i32_val(0x00100000)), + i32_val(20)); // bit 20 (for even) + Value roundingBias = add(i32_val(baseRoundingBias), roundBit); + Value vFp8 = add(vi32, roundingBias); + + // Keep top 9 bits (sign | exp(4) | mant(3)) ← trunc later + vFp8 = and_(i32_ty, vFp8, i32_val(0xFFFFFF80)); // clear bottom bits + + // == Step 4: Clamp to min normal (smallest FP8 normal in FP32) == + // FP32 representation of 2^-6 = 2.0^(-6) = 0x38800000 + vFp8 = select(icmp_ult(vFp8, i32_val(0x38800000)), i32_val(0x38800000), vFp8); + + // == Step 5: Adjust exponent bias == + // Subtract (127 - 7) << 23 = 0x3C000000 + vFp8 = sub(vFp8, i32_val(0x3C000000)); + + // Shift right to extract FP8 bits: (exp + mant) → shift 20 to fit 3-bit mant + vFp8 = trunc(i8_ty, lshr(vFp8, i32_val(20))); + + // == Step 6: Clamp to FP8 max value (before inf) == + // 0x7E is largest finite E4M3FN value (0.1111.10x) + Value isOverflowOrInf = + icmp_ugt(abs, i32_val(0x7F7FFFFF)); // > max finite FP32 + vFp8 = select(isOverflowOrInf, i8_val(0x7E), vFp8); + + // == Step 7: Handle subnormals via LUT == + constexpr size_t lutSize = 8; + constexpr uint32_t halfwayPointsLUT[lutSize] = { + 0x33800000, 0x37000000, 0x38800000, 0x39400000, + 0x3A000000, 0x3A800000, 0x3B000000, 0x3B800000}; + + for (int i = lutSize - 1; i >= 0; --i) { + Value cmp; + if (i % 2 == 0) { + cmp = icmp_ule(abs, i32_val(halfwayPointsLUT[i])); + } else { + cmp = icmp_ult(abs, i32_val(halfwayPointsLUT[i])); + } + vFp8 = select(cmp, i8_val(i), vFp8); + } + + // == Step 8: Handle NaN == + vFp8 = select(isNaN, i8_val(0x7F), vFp8); + + // == Step 9: Set sign bit == + vFp8 = or_(vFp8, sign); + return vFp8; +} + +static SmallVector +Fp32_to_Fp8E4M3FN_RTNE(Location loc, ConversionPatternRewriter &rewriter, + const SmallVector &v) { + assert(v.size() == 4 && "Expected 4 values for FP8E4M3FN conversion"); + SmallVector result(4); + for (size_t i = 0; i < 4; ++i) { + result[i] = Fp32_to_Fp8E4M3FN_RTNE_oneValue(loc, rewriter, v[i]); + } + return result; +} + +// Cast FP16 to FP8E4M3FN in saturation and round-to-nearest-even mode. +// According to +// https://www.opencompute.org/documents/ocp-8-bit-floating-point-specification-ofp8-revision-1-0-2023-12-01-pdf-1, +// In saturation mode, inf and out-of-range numbers are converted to the largest +// normal number, i.e. ±448. NaNs are converted to NaNs. +static Value +Fp16_to_Fp8E4M3FN_RTNE_oneValue(Location loc, + ConversionPatternRewriter &rewriter, Value v) { + // StringRef funcName = "llvm.is.fpclass"; + // Value isNaN = LLVM::createLLVMIntrinsicCallOp(rewriter, loc, funcName, + // i1_ty, + // {v, i32_val(0x3)}) + // ->getResult(0); + // Get sign and absolute value + Value vi16 = bitcast(v, i16_ty); + Value isNaN = and_( + icmp_eq(and_(vi16, i16_val(0x7C00)), i16_val(0x7C00)), //; exp == 0x7C00 + icmp_ne(and_(vi16, i16_val(0x03FF)), i16_val(0)) //; frac != 0 + ); + + Value sign = trunc(i8_ty, lshr(and_(vi16, i16_val(0x8000)), i16_val(8))); + vi16 = and_(vi16, i16_val(0x7FFF)); + + // Rounding to nearest even + constexpr uint16_t baseRoundingBias = 0x003F; // 1 << (10 - 3 - 1) - 1 + + // S.EEEEE.MMMMMMMMMM => 0.00000.00M0000000 => 0.00000.000000000M + Value remainingMantissaLSB = lshr(and_(vi16, i16_val(0x0080)), i16_val(7)); + Value roundingBias = add(remainingMantissaLSB, i16_val(baseRoundingBias)); + Value vFp8 = add(vi16, roundingBias); + + // Reduce mantissa to 3 bits + vFp8 = and_(vFp8, i16_val(0xFF80)); // 0xFF80 == 1.11111.1110000000 + + // 0x2400 is the FP16 representation of 2^{-6}, which is the smallest normal + // number in FP8E4M3FN. We round numbers smaller than that to 0x2400 to make + // it easier to handle subnormals + vFp8 = umax(vFp8, i16_val(0x2400)); + + // Adjust exponent bias + vFp8 = sub(vFp8, i16_val(0x2000)); // (15 - 7) << 10 + + // Shift right and truncate + vFp8 = trunc(i8_ty, lshr(vFp8, i16_val(7))); // 10 - 3 + + // 0x5F7F == 0.10111.1101111111 is the largest possible normal + // number(including infinity) after rounding in FP8 + // + // In saturation mode, numbers larger than the max normal number(including + // infinity) in FP8 after rounding will be replaced with max_E4M3, i.e. 0x7E + // === 0.1111.110 + Value isOverflowOrInf = icmp_ugt(vi16, i16_val(0x5F7F)); + vFp8 = select(isOverflowOrInf, i8_val(0x7E), vFp8); + + // Round subnormals to nearest even. Ref: + // https://github.com/openxla/xla/blob/f20c6fe2/xla/service/elemental_ir_emitter.cc#L272 + constexpr size_t lutSize = 8; + constexpr float halfwayPointsLUT[lutSize] = {0x1400, 0x1A00, 0x1D00, 0x1F00, + 0x2080, 0x2180, 0x2280, 0x2380}; + + for (int i = lutSize - 1; i >= 0; i--) { + Value cmp; + if (i % 2 == 0) { + cmp = icmp_ule(vi16, i16_val(halfwayPointsLUT[i])); + } else { + cmp = icmp_ult(vi16, i16_val(halfwayPointsLUT[i])); + } + + vFp8 = select(cmp, i8_val(i), vFp8); + } + + // NaN remains NaN after conversion + vFp8 = select(isNaN, i8_val(0x7F), vFp8); + + // Set sign bit + vFp8 = or_(i8_ty, vFp8, sign); + + return vFp8; +} + +static SmallVector +Fp16_to_Fp8E4M3FN_RTNE(Location loc, ConversionPatternRewriter &rewriter, + const SmallVector &v) { + SmallVector result(v.size()); + for (size_t i = 0; i < v.size(); ++i) { + result[i] = Fp16_to_Fp8E4M3FN_RTNE_oneValue(loc, rewriter, v[i]); + } + return result; +} + static Value cvtFp16ToFp32(Location loc, ConversionPatternRewriter &rewriter, const Value &v) { GCNBuilder builder; @@ -271,6 +448,120 @@ ConverterT Fp16_to_Fp8E5M2FNUZ(HCU::ISAFamily isaFamily) { : Fp16_to_Fp8E5M2FNUZ_SW; } +static Value Fp8E4M3FN_to_Fp32_oneValue(Location loc, + ConversionPatternRewriter &rewriter, + Value v) { + // Bitcast FP8 (i8) value + Value vi8 = bitcast(v, i8_ty); + + // Extract sign and absolute value + Value sign = and_(vi8, i8_val(0x80)); // 0b1000'0000 + Value vAbs = and_(vi8, i8_val(0x7F)); // 0b0111'1111 + + // Extract exponent and mantissa + Value exp = and_(vAbs, i8_val(0x78)); // 0b0111'1000 + Value mant = and_(vAbs, i8_val(0x07)); // 0b0000'0111 + + // Right shift exponent to LSB + exp = lshr(exp, i8_val(3)); + + // Detect special cases + Value isZeroOrDenorm = icmp_ult(vAbs, i8_val(0x08)); // < 0b0000'1000 + Value isNaN = icmp_eq(vAbs, i8_val(0x7F)); // == 0b0111'1111 + + // Default normal conversion + // Compose 32-bit float: + // sign << 31 | (exp + 120) << 23 | (mant << 20) + Value sign32 = shl(zext(i32_ty, sign), i32_val(24)); // move sign to bit 31 + Value exp32 = shl(zext(i32_ty, exp), i32_val(23)); // move exp to bit 23 + Value expBias = i32_val(120 << 23); // bias diff (127 - 7) + exp32 = add(exp32, expBias); + + Value mant32 = + shl(zext(i32_ty, mant), i32_val(20)); // move mant to bits 22-20 + + Value combined = or_(or_(sign32, exp32), mant32); + + // Handle NaN: output canonical FP32 NaN 0x7FC00000 + Value nanVal = i32_val(0x7FC00000); + combined = select(isNaN, nanVal, combined); + + // Handle denorm/zero via LUT + // For vAbs in [0x00 ~ 0x07] → use LUT + constexpr int lutSize = 8; + static constexpr uint32_t denormsAndZeroLut[lutSize] = { + 0x00000000, 0x38800000, 0x39800000, 0x3A000000, 0x3A800000, + 0x3B000000, 0x3B400000, 0x3B800000}; // approximated FP32 tiny values + + for (int i = 0; i < lutSize; ++i) { + combined = select(icmp_eq(vAbs, i8_val(i)), i32_val(denormsAndZeroLut[i]), + combined); + } + + // Bitcast to float + Value result = bitcast(combined, f32_ty); + return result; +} + +static SmallVector Fp8E4M3FN_to_Fp32(Location loc, + ConversionPatternRewriter &rewriter, + const SmallVector &values) { + SmallVector results(values.size()); + for (size_t i = 0; i < values.size(); i++) + results[i] = Fp8E4M3FN_to_Fp32_oneValue(loc, rewriter, values[i]); + return results; +} + +static Value Fp8E4M3FN_to_Fp16_oneValue(Location loc, + ConversionPatternRewriter &rewriter, + Value v) { + auto fp8x2VecTy = vec_ty(i8_ty, 2); + Value a = undef(fp8x2VecTy); + a = insert_element(fp8x2VecTy, a, i8_val(0), i32_val(0)); + a = insert_element(fp8x2VecTy, a, v, i32_val(1)); + a = bitcast(a, i16_ty); + + // Get sign and absolute value + Value sign = and_(a, i16_val(0x8000)); + a = and_(a, i16_val(0x7FFF)); + + // Right shift 1 bit to adjust the positions of exponent and mantissa + a = lshr(a, i16_val(1)); + + // Adjust exponent, (15 - 7) << 10 === 0x2000 + a = add(a, i16_val(0x2000)); + + // Check NaN + Value vAbs = and_(bitcast(v, i8_ty), i8_val(0x7F)); + a = select(icmp_eq(vAbs, i8_val(0x7F)), i16_val(0x7E00), a); + + // Check denorms and zero + // Here we use a LUT to map S.0000.000 ~ S.0000.111 to its corresponding fp16 + // value + constexpr size_t lutSize = 8; + static constexpr int denormsAndZeroLut[lutSize] = { + 0x0000, 0x1800, 0x1C00, 0x1E00, 0x2000, 0x2100, 0x2200, 0x2300}; + + for (int i = 0; i < lutSize; i++) { + a = select(icmp_eq(vAbs, i8_val(i)), i16_val(denormsAndZeroLut[i]), a); + } + + // Set sign + a = or_(a, sign); + a = bitcast(a, f16_ty); + + return a; +} + +static SmallVector Fp8E4M3FN_to_Fp16(Location loc, + ConversionPatternRewriter &rewriter, + const SmallVector &values) { + SmallVector results(values.size()); + for (size_t i = 0; i < values.size(); i++) + results[i] = Fp8E4M3FN_to_Fp16_oneValue(loc, rewriter, values[i]); + return results; +} + static SmallVector Fp8E5M2_to_Fp16(Location loc, ConversionPatternRewriter &rewriter, const SmallVector &v) { @@ -867,10 +1158,13 @@ struct FpToFpOpConversion // F8 -> F16 {{F8E4M3FNUZTyID, F16TyID, undefRounding}, Fp8E4M3FNUZ_to_Fp16(isaFamily)}, + {{F8E4M3FNTyID, F16TyID, undefRounding}, Fp8E4M3FN_to_Fp16}, {{F8E5M2FNUZTyID, F16TyID, undefRounding}, Fp8E5M2FNUZ_to_Fp16(isaFamily)}, {{F8E5M2TyID, F16TyID, undefRounding}, Fp8E5M2_to_Fp16}, // F16 -> F8 + {{F16TyID, F8E4M3FNTyID, RoundingMode::RTNE}, + Fp16_to_Fp8E4M3FN_RTNE}, {{F16TyID, F8E5M2FNUZTyID, RoundingMode::RTNE}, Fp16_to_Fp8E5M2FNUZ(isaFamily)}, {{F16TyID, F8E4M3FNUZTyID, RoundingMode::RTNE}, @@ -893,8 +1187,11 @@ struct FpToFpOpConversion Fp32_to_Fp8E4M3FNUZ}, {{F32TyID, F8E5M2FNUZTyID, RoundingMode::RTNE}, Fp32_to_Fp8E5M2FNUZ}, + {{F32TyID, F8E4M3FNTyID, RoundingMode::RTNE}, + Fp32_to_Fp8E4M3FN_RTNE}, {{F8E4M3FNUZTyID, F32TyID, undefRounding}, Fp8E4M3FNUZ_to_Fp32}, {{F8E5M2FNUZTyID, F32TyID, undefRounding}, Fp8E5M2FNUZ_to_Fp32}, + {{F8E4M3FNTyID, F32TyID, RoundingMode::RTNE}, Fp8E4M3FN_to_Fp32}, }; std::tuple key = { srcTy.getTypeID(), dstTy.getTypeID(), diff --git a/third_party/hcu/lib/TritonHCUTransforms/AccelerateHcuFlashAttention.cpp b/third_party/hcu/lib/TritonHCUTransforms/AccelerateHcuFlashAttention.cpp index 39d2267f8b..183e1546fc 100644 --- a/third_party/hcu/lib/TritonHCUTransforms/AccelerateHcuFlashAttention.cpp +++ b/third_party/hcu/lib/TritonHCUTransforms/AccelerateHcuFlashAttention.cpp @@ -114,7 +114,8 @@ class TritonHcuFlashAttention : public mlir::RewritePattern { bool interleave = true; // isSecondDot(dotOp); unsigned mDim = 16, kDim = 16; unsigned nDim = oldBType.getShape()[1] < 32 ? 16 : 32; - if (bGobalOrder[0] == 0) { + // if (bGobalOrder[0] == 0) { + if (bGobalOrder[0] == 0 || oldBType.getShape()[1] >= 256) { mDim = 16, nDim = 16; } auto newEnc = ttg::HCUMfmaEncodingAttr::get( diff --git a/third_party/hcu/lib/TritonHCUTransforms/ReorderInstructions.cpp b/third_party/hcu/lib/TritonHCUTransforms/ReorderInstructions.cpp index 1a85e3b837..6f0388a6a4 100644 --- a/third_party/hcu/lib/TritonHCUTransforms/ReorderInstructions.cpp +++ b/third_party/hcu/lib/TritonHCUTransforms/ReorderInstructions.cpp @@ -328,7 +328,10 @@ static void sinkSecondLoad(triton::FuncOp funcOp) { triton::DotOp dotOp; for (Operation &op : forOp) { if (auto loadOp = dyn_cast(&op)) - loadOps.insert(loadOp); + if (isa(loadOp.getType())) { + loadOps.insert(loadOp); + } + if (auto curOp = dyn_cast(&op)) dotOp = curOp; } diff --git a/third_party/hcu/python/triton/__init__.py b/third_party/hcu/python/triton/__init__.py index 031c58fb16..a5f77f91e6 100644 --- a/third_party/hcu/python/triton/__init__.py +++ b/third_party/hcu/python/triton/__init__.py @@ -1,5 +1,5 @@ """isort:skip_file""" -__version__ = '3.0.0' +__version__ = '3.1.0' # --------------------------------------- # Note: import order is significant here. diff --git a/third_party/hcu/python/triton/language/semantic.py b/third_party/hcu/python/triton/language/semantic.py index 27c64337a8..b56909e6c7 100644 --- a/third_party/hcu/python/triton/language/semantic.py +++ b/third_party/hcu/python/triton/language/semantic.py @@ -1529,10 +1529,10 @@ def dot(lhs: tl.tensor, rhs: tl.tensor, acc: tl.tensor, input_precision: Optiona def where(condition: tl.tensor, x: tl.tensor, y: tl.tensor, builder: ir.builder) -> tl.tensor: - if condition.dtype != tl.int1: - warnings.warn( - f"tl.where with a non-boolean condition is deprecated and will error out in a future triton release. Got {condition.dtype}" - ) + # if condition.dtype != tl.int1: + # warnings.warn( + # f"tl.where with a non-boolean condition is deprecated and will error out in a future triton release. Got {condition.dtype}" + # ) condition = cast(condition, tl.int1, builder) x, y = binary_op_type_checking_impl(x, y, builder, True, True) # x, y are broadcasted