Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions third_party/hcu/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class HIPOptions:
cluster_dims: tuple = (1, 1, 1)
debug: bool = False
arch: str = None
supported_fp8_dtypes: Tuple[str] = ("fp8e5", )
supported_fp8_dtypes: Tuple[str] = ("fp8e5", "fp8e4nv")
deprecated_fp8_dtypes: Tuple[str] = ()
default_dot_input_precision: str = "ieee"
allowed_dot_input_precisions: Tuple[str] = ("ieee", )
Expand Down Expand Up @@ -261,7 +261,7 @@ def make_llir(src, metadata, options):
## 3. __HIP_FTZ is default to 1 and not exposed as a kernel argument.
## For now it is used as a controller for developers only.
__HIP_FTZ = True
if (options.num_stages >= 2):
if (options.num_stages >= 2) and os.environ.get("TRITON_MOVE_LOAD_TOFRONT_DOT", "0") == "1":
hcu.passes.ttgpuir.add_move_load_tofront_dot(pm)
# hcu.passes.ttgpuir.add_control_fa_fwd_bufferload_cnt(pm, 0)
hcu.passes.ttgpuir.add_to_llvmir(pm, options.arch, __HIP_FTZ)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,8 @@ using namespace mlir::triton;
#define i64_val(...) LLVM::createConstantI64(loc, rewriter, __VA_ARGS__)
#define int_val(width, val) \
LLVM::createLLVMIntegerConstant(rewriter, loc, width, val)
#define i8_val(val) int_val(8, val)
#define i16_val(val) int_val(16, val)
#define tid_val() getThreadId(rewriter, loc)

// Attributes
Expand Down
297 changes: 297 additions & 0 deletions third_party/hcu/lib/TritonHCUGPUToLLVM/ElementwiseOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,183 @@ Fp16_to_Fp8E5M2_RTZ(Location loc, ConversionPatternRewriter &rewriter,
extract_element(i8_ty, a1, i32_val(3))};
}

//===----------------===//
/// FP8E4M3
//===----------------===//
static Value
Fp32_to_Fp8E4M3FN_RTNE_oneValue(Location loc,
ConversionPatternRewriter &rewriter, Value v) {
// == Step 1: Bitcast to i32 and extract components ==
Value vi32 = bitcast(v, i32_ty);
Value sign = trunc(i8_ty, lshr(vi32, i32_val(24))); // get sign << 7
sign = and_(i8_ty, sign, i8_val(0x80)); // keep sign bit

Value abs = and_(i32_ty, vi32, i32_val(0x7FFFFFFF)); // strip sign

// == Step 2: NaN check ==
Value isNaN = and_(
i1_ty,
icmp_eq(and_(i32_ty, vi32, i32_val(0x7F800000)),
i32_val(0x7F800000)), // exp == 0xFF
icmp_ne(and_(i32_ty, vi32, i32_val(0x007FFFFF)), i32_val(0)) // frac != 0
);

// == Step 3: Rounding (RTNE) ==
// bias diff = (127 - 7) << 23 = 120 << 23 = 0x3C000000
// mantissa alignment: keep top 3 mantissa bits => shift right by 23 - 3 = 20
constexpr uint32_t baseRoundingBias =
(1 << 19) - 1; // 0x7FFFF = round-to-even

Value roundBit = lshr(and_(i32_ty, vi32, i32_val(0x00100000)),
i32_val(20)); // bit 20 (for even)
Value roundingBias = add(i32_val(baseRoundingBias), roundBit);
Value vFp8 = add(vi32, roundingBias);

// Keep top 9 bits (sign | exp(4) | mant(3)) ← trunc later
vFp8 = and_(i32_ty, vFp8, i32_val(0xFFFFFF80)); // clear bottom bits

// == Step 4: Clamp to min normal (smallest FP8 normal in FP32) ==
// FP32 representation of 2^-6 = 2.0^(-6) = 0x38800000
vFp8 = select(icmp_ult(vFp8, i32_val(0x38800000)), i32_val(0x38800000), vFp8);

// == Step 5: Adjust exponent bias ==
// Subtract (127 - 7) << 23 = 0x3C000000
vFp8 = sub(vFp8, i32_val(0x3C000000));

// Shift right to extract FP8 bits: (exp + mant) → shift 20 to fit 3-bit mant
vFp8 = trunc(i8_ty, lshr(vFp8, i32_val(20)));

// == Step 6: Clamp to FP8 max value (before inf) ==
// 0x7E is largest finite E4M3FN value (0.1111.10x)
Value isOverflowOrInf =
icmp_ugt(abs, i32_val(0x7F7FFFFF)); // > max finite FP32
vFp8 = select(isOverflowOrInf, i8_val(0x7E), vFp8);

// == Step 7: Handle subnormals via LUT ==
constexpr size_t lutSize = 8;
constexpr uint32_t halfwayPointsLUT[lutSize] = {
0x33800000, 0x37000000, 0x38800000, 0x39400000,
0x3A000000, 0x3A800000, 0x3B000000, 0x3B800000};

for (int i = lutSize - 1; i >= 0; --i) {
Value cmp;
if (i % 2 == 0) {
cmp = icmp_ule(abs, i32_val(halfwayPointsLUT[i]));
} else {
cmp = icmp_ult(abs, i32_val(halfwayPointsLUT[i]));
}
vFp8 = select(cmp, i8_val(i), vFp8);
}

// == Step 8: Handle NaN ==
vFp8 = select(isNaN, i8_val(0x7F), vFp8);

// == Step 9: Set sign bit ==
vFp8 = or_(vFp8, sign);
return vFp8;
}

static SmallVector<Value>
Fp32_to_Fp8E4M3FN_RTNE(Location loc, ConversionPatternRewriter &rewriter,
const SmallVector<Value> &v) {
assert(v.size() == 4 && "Expected 4 values for FP8E4M3FN conversion");
SmallVector<Value> result(4);
for (size_t i = 0; i < 4; ++i) {
result[i] = Fp32_to_Fp8E4M3FN_RTNE_oneValue(loc, rewriter, v[i]);
}
return result;
}

// Cast FP16 to FP8E4M3FN in saturation and round-to-nearest-even mode.
// According to
// https://www.opencompute.org/documents/ocp-8-bit-floating-point-specification-ofp8-revision-1-0-2023-12-01-pdf-1,
// In saturation mode, inf and out-of-range numbers are converted to the largest
// normal number, i.e. ±448. NaNs are converted to NaNs.
static Value
Fp16_to_Fp8E4M3FN_RTNE_oneValue(Location loc,
ConversionPatternRewriter &rewriter, Value v) {
// StringRef funcName = "llvm.is.fpclass";
// Value isNaN = LLVM::createLLVMIntrinsicCallOp(rewriter, loc, funcName,
// i1_ty,
// {v, i32_val(0x3)})
// ->getResult(0);
// Get sign and absolute value
Value vi16 = bitcast(v, i16_ty);
Value isNaN = and_(
icmp_eq(and_(vi16, i16_val(0x7C00)), i16_val(0x7C00)), //; exp == 0x7C00
icmp_ne(and_(vi16, i16_val(0x03FF)), i16_val(0)) //; frac != 0
);

Value sign = trunc(i8_ty, lshr(and_(vi16, i16_val(0x8000)), i16_val(8)));
vi16 = and_(vi16, i16_val(0x7FFF));

// Rounding to nearest even
constexpr uint16_t baseRoundingBias = 0x003F; // 1 << (10 - 3 - 1) - 1

// S.EEEEE.MMMMMMMMMM => 0.00000.00M0000000 => 0.00000.000000000M
Value remainingMantissaLSB = lshr(and_(vi16, i16_val(0x0080)), i16_val(7));
Value roundingBias = add(remainingMantissaLSB, i16_val(baseRoundingBias));
Value vFp8 = add(vi16, roundingBias);

// Reduce mantissa to 3 bits
vFp8 = and_(vFp8, i16_val(0xFF80)); // 0xFF80 == 1.11111.1110000000

// 0x2400 is the FP16 representation of 2^{-6}, which is the smallest normal
// number in FP8E4M3FN. We round numbers smaller than that to 0x2400 to make
// it easier to handle subnormals
vFp8 = umax(vFp8, i16_val(0x2400));

// Adjust exponent bias
vFp8 = sub(vFp8, i16_val(0x2000)); // (15 - 7) << 10

// Shift right and truncate
vFp8 = trunc(i8_ty, lshr(vFp8, i16_val(7))); // 10 - 3

// 0x5F7F == 0.10111.1101111111 is the largest possible normal
// number(including infinity) after rounding in FP8
//
// In saturation mode, numbers larger than the max normal number(including
// infinity) in FP8 after rounding will be replaced with max_E4M3, i.e. 0x7E
// === 0.1111.110
Value isOverflowOrInf = icmp_ugt(vi16, i16_val(0x5F7F));
vFp8 = select(isOverflowOrInf, i8_val(0x7E), vFp8);

// Round subnormals to nearest even. Ref:
// https://github.com/openxla/xla/blob/f20c6fe2/xla/service/elemental_ir_emitter.cc#L272
constexpr size_t lutSize = 8;
constexpr float halfwayPointsLUT[lutSize] = {0x1400, 0x1A00, 0x1D00, 0x1F00,
0x2080, 0x2180, 0x2280, 0x2380};

for (int i = lutSize - 1; i >= 0; i--) {
Value cmp;
if (i % 2 == 0) {
cmp = icmp_ule(vi16, i16_val(halfwayPointsLUT[i]));
} else {
cmp = icmp_ult(vi16, i16_val(halfwayPointsLUT[i]));
}

vFp8 = select(cmp, i8_val(i), vFp8);
}

// NaN remains NaN after conversion
vFp8 = select(isNaN, i8_val(0x7F), vFp8);

// Set sign bit
vFp8 = or_(i8_ty, vFp8, sign);

return vFp8;
}

static SmallVector<Value>
Fp16_to_Fp8E4M3FN_RTNE(Location loc, ConversionPatternRewriter &rewriter,
const SmallVector<Value> &v) {
SmallVector<Value> result(v.size());
for (size_t i = 0; i < v.size(); ++i) {
result[i] = Fp16_to_Fp8E4M3FN_RTNE_oneValue(loc, rewriter, v[i]);
}
return result;
}

static Value cvtFp16ToFp32(Location loc, ConversionPatternRewriter &rewriter,
const Value &v) {
GCNBuilder builder;
Expand Down Expand Up @@ -271,6 +448,120 @@ ConverterT Fp16_to_Fp8E5M2FNUZ(HCU::ISAFamily isaFamily) {
: Fp16_to_Fp8E5M2FNUZ_SW;
}

static Value Fp8E4M3FN_to_Fp32_oneValue(Location loc,
ConversionPatternRewriter &rewriter,
Value v) {
// Bitcast FP8 (i8) value
Value vi8 = bitcast(v, i8_ty);

// Extract sign and absolute value
Value sign = and_(vi8, i8_val(0x80)); // 0b1000'0000
Value vAbs = and_(vi8, i8_val(0x7F)); // 0b0111'1111

// Extract exponent and mantissa
Value exp = and_(vAbs, i8_val(0x78)); // 0b0111'1000
Value mant = and_(vAbs, i8_val(0x07)); // 0b0000'0111

// Right shift exponent to LSB
exp = lshr(exp, i8_val(3));

// Detect special cases
Value isZeroOrDenorm = icmp_ult(vAbs, i8_val(0x08)); // < 0b0000'1000
Value isNaN = icmp_eq(vAbs, i8_val(0x7F)); // == 0b0111'1111

// Default normal conversion
// Compose 32-bit float:
// sign << 31 | (exp + 120) << 23 | (mant << 20)
Value sign32 = shl(zext(i32_ty, sign), i32_val(24)); // move sign to bit 31
Value exp32 = shl(zext(i32_ty, exp), i32_val(23)); // move exp to bit 23
Value expBias = i32_val(120 << 23); // bias diff (127 - 7)
exp32 = add(exp32, expBias);

Value mant32 =
shl(zext(i32_ty, mant), i32_val(20)); // move mant to bits 22-20

Value combined = or_(or_(sign32, exp32), mant32);

// Handle NaN: output canonical FP32 NaN 0x7FC00000
Value nanVal = i32_val(0x7FC00000);
combined = select(isNaN, nanVal, combined);

// Handle denorm/zero via LUT
// For vAbs in [0x00 ~ 0x07] → use LUT
constexpr int lutSize = 8;
static constexpr uint32_t denormsAndZeroLut[lutSize] = {
0x00000000, 0x38800000, 0x39800000, 0x3A000000, 0x3A800000,
0x3B000000, 0x3B400000, 0x3B800000}; // approximated FP32 tiny values

for (int i = 0; i < lutSize; ++i) {
combined = select(icmp_eq(vAbs, i8_val(i)), i32_val(denormsAndZeroLut[i]),
combined);
}

// Bitcast to float
Value result = bitcast(combined, f32_ty);
return result;
}

static SmallVector<Value> Fp8E4M3FN_to_Fp32(Location loc,
ConversionPatternRewriter &rewriter,
const SmallVector<Value> &values) {
SmallVector<Value> results(values.size());
for (size_t i = 0; i < values.size(); i++)
results[i] = Fp8E4M3FN_to_Fp32_oneValue(loc, rewriter, values[i]);
return results;
}

static Value Fp8E4M3FN_to_Fp16_oneValue(Location loc,
ConversionPatternRewriter &rewriter,
Value v) {
auto fp8x2VecTy = vec_ty(i8_ty, 2);
Value a = undef(fp8x2VecTy);
a = insert_element(fp8x2VecTy, a, i8_val(0), i32_val(0));
a = insert_element(fp8x2VecTy, a, v, i32_val(1));
a = bitcast(a, i16_ty);

// Get sign and absolute value
Value sign = and_(a, i16_val(0x8000));
a = and_(a, i16_val(0x7FFF));

// Right shift 1 bit to adjust the positions of exponent and mantissa
a = lshr(a, i16_val(1));

// Adjust exponent, (15 - 7) << 10 === 0x2000
a = add(a, i16_val(0x2000));

// Check NaN
Value vAbs = and_(bitcast(v, i8_ty), i8_val(0x7F));
a = select(icmp_eq(vAbs, i8_val(0x7F)), i16_val(0x7E00), a);

// Check denorms and zero
// Here we use a LUT to map S.0000.000 ~ S.0000.111 to its corresponding fp16
// value
constexpr size_t lutSize = 8;
static constexpr int denormsAndZeroLut[lutSize] = {
0x0000, 0x1800, 0x1C00, 0x1E00, 0x2000, 0x2100, 0x2200, 0x2300};

for (int i = 0; i < lutSize; i++) {
a = select(icmp_eq(vAbs, i8_val(i)), i16_val(denormsAndZeroLut[i]), a);
}

// Set sign
a = or_(a, sign);
a = bitcast(a, f16_ty);

return a;
}

static SmallVector<Value> Fp8E4M3FN_to_Fp16(Location loc,
ConversionPatternRewriter &rewriter,
const SmallVector<Value> &values) {
SmallVector<Value> results(values.size());
for (size_t i = 0; i < values.size(); i++)
results[i] = Fp8E4M3FN_to_Fp16_oneValue(loc, rewriter, values[i]);
return results;
}

static SmallVector<Value> Fp8E5M2_to_Fp16(Location loc,
ConversionPatternRewriter &rewriter,
const SmallVector<Value> &v) {
Expand Down Expand Up @@ -867,10 +1158,13 @@ struct FpToFpOpConversion
// F8 -> F16
{{F8E4M3FNUZTyID, F16TyID, undefRounding},
Fp8E4M3FNUZ_to_Fp16(isaFamily)},
{{F8E4M3FNTyID, F16TyID, undefRounding}, Fp8E4M3FN_to_Fp16},
{{F8E5M2FNUZTyID, F16TyID, undefRounding},
Fp8E5M2FNUZ_to_Fp16(isaFamily)},
{{F8E5M2TyID, F16TyID, undefRounding}, Fp8E5M2_to_Fp16},
// F16 -> F8
{{F16TyID, F8E4M3FNTyID, RoundingMode::RTNE},
Fp16_to_Fp8E4M3FN_RTNE},
{{F16TyID, F8E5M2FNUZTyID, RoundingMode::RTNE},
Fp16_to_Fp8E5M2FNUZ(isaFamily)},
{{F16TyID, F8E4M3FNUZTyID, RoundingMode::RTNE},
Expand All @@ -893,8 +1187,11 @@ struct FpToFpOpConversion
Fp32_to_Fp8E4M3FNUZ},
{{F32TyID, F8E5M2FNUZTyID, RoundingMode::RTNE},
Fp32_to_Fp8E5M2FNUZ},
{{F32TyID, F8E4M3FNTyID, RoundingMode::RTNE},
Fp32_to_Fp8E4M3FN_RTNE},
{{F8E4M3FNUZTyID, F32TyID, undefRounding}, Fp8E4M3FNUZ_to_Fp32},
{{F8E5M2FNUZTyID, F32TyID, undefRounding}, Fp8E5M2FNUZ_to_Fp32},
{{F8E4M3FNTyID, F32TyID, RoundingMode::RTNE}, Fp8E4M3FN_to_Fp32},
};
std::tuple<TypeID, TypeID, RoundingMode> key = {
srcTy.getTypeID(), dstTy.getTypeID(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,8 @@ class TritonHcuFlashAttention : public mlir::RewritePattern {
bool interleave = true; // isSecondDot(dotOp);
unsigned mDim = 16, kDim = 16;
unsigned nDim = oldBType.getShape()[1] < 32 ? 16 : 32;
if (bGobalOrder[0] == 0) {
// if (bGobalOrder[0] == 0) {
if (bGobalOrder[0] == 0 || oldBType.getShape()[1] >= 256) {
mDim = 16, nDim = 16;
}
auto newEnc = ttg::HCUMfmaEncodingAttr::get(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,10 @@ static void sinkSecondLoad(triton::FuncOp funcOp) {
triton::DotOp dotOp;
for (Operation &op : forOp) {
if (auto loadOp = dyn_cast<triton::LoadOp>(&op))
loadOps.insert(loadOp);
if (isa<RankedTensorType>(loadOp.getType())) {
loadOps.insert(loadOp);
}

if (auto curOp = dyn_cast<triton::DotOp>(&op))
dotOp = curOp;
}
Expand Down
2 changes: 1 addition & 1 deletion third_party/hcu/python/triton/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""isort:skip_file"""
__version__ = '3.0.0'
__version__ = '3.1.0'

# ---------------------------------------
# Note: import order is significant here.
Expand Down
Loading
Loading