Skip to content

Commit ce6e70b

Browse files
authored
Merge branch 'main' into main_add_check-backend-changed
2 parents c869174 + e977827 commit ce6e70b

File tree

7 files changed

+312
-9
lines changed

7 files changed

+312
-9
lines changed

third_party/hcu/backend/compiler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ class HIPOptions:
3838
cluster_dims: tuple = (1, 1, 1)
3939
debug: bool = False
4040
arch: str = None
41-
supported_fp8_dtypes: Tuple[str] = ("fp8e5", )
41+
supported_fp8_dtypes: Tuple[str] = ("fp8e5", "fp8e4nv")
4242
deprecated_fp8_dtypes: Tuple[str] = ()
4343
default_dot_input_precision: str = "ieee"
4444
allowed_dot_input_precisions: Tuple[str] = ("ieee", )
@@ -261,7 +261,7 @@ def make_llir(src, metadata, options):
261261
## 3. __HIP_FTZ is default to 1 and not exposed as a kernel argument.
262262
## For now it is used as a controller for developers only.
263263
__HIP_FTZ = True
264-
if (options.num_stages >= 2):
264+
if (options.num_stages >= 2) and os.environ.get("TRITON_MOVE_LOAD_TOFRONT_DOT", "0") == "1":
265265
hcu.passes.ttgpuir.add_move_load_tofront_dot(pm)
266266
# hcu.passes.ttgpuir.add_control_fa_fwd_bufferload_cnt(pm, 0)
267267
hcu.passes.ttgpuir.add_to_llvmir(pm, options.arch, __HIP_FTZ)

third_party/hcu/include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,8 @@ using namespace mlir::triton;
134134
#define i64_val(...) LLVM::createConstantI64(loc, rewriter, __VA_ARGS__)
135135
#define int_val(width, val) \
136136
LLVM::createLLVMIntegerConstant(rewriter, loc, width, val)
137+
#define i8_val(val) int_val(8, val)
138+
#define i16_val(val) int_val(16, val)
137139
#define tid_val() getThreadId(rewriter, loc)
138140

139141
// Attributes

third_party/hcu/lib/TritonHCUGPUToLLVM/ElementwiseOpToLLVM.cpp

Lines changed: 297 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,183 @@ Fp16_to_Fp8E5M2_RTZ(Location loc, ConversionPatternRewriter &rewriter,
8383
extract_element(i8_ty, a1, i32_val(3))};
8484
}
8585

86+
//===----------------===//
87+
/// FP8E4M3
88+
//===----------------===//
89+
static Value
90+
Fp32_to_Fp8E4M3FN_RTNE_oneValue(Location loc,
91+
ConversionPatternRewriter &rewriter, Value v) {
92+
// == Step 1: Bitcast to i32 and extract components ==
93+
Value vi32 = bitcast(v, i32_ty);
94+
Value sign = trunc(i8_ty, lshr(vi32, i32_val(24))); // get sign << 7
95+
sign = and_(i8_ty, sign, i8_val(0x80)); // keep sign bit
96+
97+
Value abs = and_(i32_ty, vi32, i32_val(0x7FFFFFFF)); // strip sign
98+
99+
// == Step 2: NaN check ==
100+
Value isNaN = and_(
101+
i1_ty,
102+
icmp_eq(and_(i32_ty, vi32, i32_val(0x7F800000)),
103+
i32_val(0x7F800000)), // exp == 0xFF
104+
icmp_ne(and_(i32_ty, vi32, i32_val(0x007FFFFF)), i32_val(0)) // frac != 0
105+
);
106+
107+
// == Step 3: Rounding (RTNE) ==
108+
// bias diff = (127 - 7) << 23 = 120 << 23 = 0x3C000000
109+
// mantissa alignment: keep top 3 mantissa bits => shift right by 23 - 3 = 20
110+
constexpr uint32_t baseRoundingBias =
111+
(1 << 19) - 1; // 0x7FFFF = round-to-even
112+
113+
Value roundBit = lshr(and_(i32_ty, vi32, i32_val(0x00100000)),
114+
i32_val(20)); // bit 20 (for even)
115+
Value roundingBias = add(i32_val(baseRoundingBias), roundBit);
116+
Value vFp8 = add(vi32, roundingBias);
117+
118+
// Keep top 9 bits (sign | exp(4) | mant(3)) ← trunc later
119+
vFp8 = and_(i32_ty, vFp8, i32_val(0xFFFFFF80)); // clear bottom bits
120+
121+
// == Step 4: Clamp to min normal (smallest FP8 normal in FP32) ==
122+
// FP32 representation of 2^-6 = 2.0^(-6) = 0x38800000
123+
vFp8 = select(icmp_ult(vFp8, i32_val(0x38800000)), i32_val(0x38800000), vFp8);
124+
125+
// == Step 5: Adjust exponent bias ==
126+
// Subtract (127 - 7) << 23 = 0x3C000000
127+
vFp8 = sub(vFp8, i32_val(0x3C000000));
128+
129+
// Shift right to extract FP8 bits: (exp + mant) → shift 20 to fit 3-bit mant
130+
vFp8 = trunc(i8_ty, lshr(vFp8, i32_val(20)));
131+
132+
// == Step 6: Clamp to FP8 max value (before inf) ==
133+
// 0x7E is largest finite E4M3FN value (0.1111.10x)
134+
Value isOverflowOrInf =
135+
icmp_ugt(abs, i32_val(0x7F7FFFFF)); // > max finite FP32
136+
vFp8 = select(isOverflowOrInf, i8_val(0x7E), vFp8);
137+
138+
// == Step 7: Handle subnormals via LUT ==
139+
constexpr size_t lutSize = 8;
140+
constexpr uint32_t halfwayPointsLUT[lutSize] = {
141+
0x33800000, 0x37000000, 0x38800000, 0x39400000,
142+
0x3A000000, 0x3A800000, 0x3B000000, 0x3B800000};
143+
144+
for (int i = lutSize - 1; i >= 0; --i) {
145+
Value cmp;
146+
if (i % 2 == 0) {
147+
cmp = icmp_ule(abs, i32_val(halfwayPointsLUT[i]));
148+
} else {
149+
cmp = icmp_ult(abs, i32_val(halfwayPointsLUT[i]));
150+
}
151+
vFp8 = select(cmp, i8_val(i), vFp8);
152+
}
153+
154+
// == Step 8: Handle NaN ==
155+
vFp8 = select(isNaN, i8_val(0x7F), vFp8);
156+
157+
// == Step 9: Set sign bit ==
158+
vFp8 = or_(vFp8, sign);
159+
return vFp8;
160+
}
161+
162+
static SmallVector<Value>
163+
Fp32_to_Fp8E4M3FN_RTNE(Location loc, ConversionPatternRewriter &rewriter,
164+
const SmallVector<Value> &v) {
165+
assert(v.size() == 4 && "Expected 4 values for FP8E4M3FN conversion");
166+
SmallVector<Value> result(4);
167+
for (size_t i = 0; i < 4; ++i) {
168+
result[i] = Fp32_to_Fp8E4M3FN_RTNE_oneValue(loc, rewriter, v[i]);
169+
}
170+
return result;
171+
}
172+
173+
// Cast FP16 to FP8E4M3FN in saturation and round-to-nearest-even mode.
174+
// According to
175+
// https://www.opencompute.org/documents/ocp-8-bit-floating-point-specification-ofp8-revision-1-0-2023-12-01-pdf-1,
176+
// In saturation mode, inf and out-of-range numbers are converted to the largest
177+
// normal number, i.e. ±448. NaNs are converted to NaNs.
178+
static Value
179+
Fp16_to_Fp8E4M3FN_RTNE_oneValue(Location loc,
180+
ConversionPatternRewriter &rewriter, Value v) {
181+
// StringRef funcName = "llvm.is.fpclass";
182+
// Value isNaN = LLVM::createLLVMIntrinsicCallOp(rewriter, loc, funcName,
183+
// i1_ty,
184+
// {v, i32_val(0x3)})
185+
// ->getResult(0);
186+
// Get sign and absolute value
187+
Value vi16 = bitcast(v, i16_ty);
188+
Value isNaN = and_(
189+
icmp_eq(and_(vi16, i16_val(0x7C00)), i16_val(0x7C00)), //; exp == 0x7C00
190+
icmp_ne(and_(vi16, i16_val(0x03FF)), i16_val(0)) //; frac != 0
191+
);
192+
193+
Value sign = trunc(i8_ty, lshr(and_(vi16, i16_val(0x8000)), i16_val(8)));
194+
vi16 = and_(vi16, i16_val(0x7FFF));
195+
196+
// Rounding to nearest even
197+
constexpr uint16_t baseRoundingBias = 0x003F; // 1 << (10 - 3 - 1) - 1
198+
199+
// S.EEEEE.MMMMMMMMMM => 0.00000.00M0000000 => 0.00000.000000000M
200+
Value remainingMantissaLSB = lshr(and_(vi16, i16_val(0x0080)), i16_val(7));
201+
Value roundingBias = add(remainingMantissaLSB, i16_val(baseRoundingBias));
202+
Value vFp8 = add(vi16, roundingBias);
203+
204+
// Reduce mantissa to 3 bits
205+
vFp8 = and_(vFp8, i16_val(0xFF80)); // 0xFF80 == 1.11111.1110000000
206+
207+
// 0x2400 is the FP16 representation of 2^{-6}, which is the smallest normal
208+
// number in FP8E4M3FN. We round numbers smaller than that to 0x2400 to make
209+
// it easier to handle subnormals
210+
vFp8 = umax(vFp8, i16_val(0x2400));
211+
212+
// Adjust exponent bias
213+
vFp8 = sub(vFp8, i16_val(0x2000)); // (15 - 7) << 10
214+
215+
// Shift right and truncate
216+
vFp8 = trunc(i8_ty, lshr(vFp8, i16_val(7))); // 10 - 3
217+
218+
// 0x5F7F == 0.10111.1101111111 is the largest possible normal
219+
// number(including infinity) after rounding in FP8
220+
//
221+
// In saturation mode, numbers larger than the max normal number(including
222+
// infinity) in FP8 after rounding will be replaced with max_E4M3, i.e. 0x7E
223+
// === 0.1111.110
224+
Value isOverflowOrInf = icmp_ugt(vi16, i16_val(0x5F7F));
225+
vFp8 = select(isOverflowOrInf, i8_val(0x7E), vFp8);
226+
227+
// Round subnormals to nearest even. Ref:
228+
// https://github.com/openxla/xla/blob/f20c6fe2/xla/service/elemental_ir_emitter.cc#L272
229+
constexpr size_t lutSize = 8;
230+
constexpr float halfwayPointsLUT[lutSize] = {0x1400, 0x1A00, 0x1D00, 0x1F00,
231+
0x2080, 0x2180, 0x2280, 0x2380};
232+
233+
for (int i = lutSize - 1; i >= 0; i--) {
234+
Value cmp;
235+
if (i % 2 == 0) {
236+
cmp = icmp_ule(vi16, i16_val(halfwayPointsLUT[i]));
237+
} else {
238+
cmp = icmp_ult(vi16, i16_val(halfwayPointsLUT[i]));
239+
}
240+
241+
vFp8 = select(cmp, i8_val(i), vFp8);
242+
}
243+
244+
// NaN remains NaN after conversion
245+
vFp8 = select(isNaN, i8_val(0x7F), vFp8);
246+
247+
// Set sign bit
248+
vFp8 = or_(i8_ty, vFp8, sign);
249+
250+
return vFp8;
251+
}
252+
253+
static SmallVector<Value>
254+
Fp16_to_Fp8E4M3FN_RTNE(Location loc, ConversionPatternRewriter &rewriter,
255+
const SmallVector<Value> &v) {
256+
SmallVector<Value> result(v.size());
257+
for (size_t i = 0; i < v.size(); ++i) {
258+
result[i] = Fp16_to_Fp8E4M3FN_RTNE_oneValue(loc, rewriter, v[i]);
259+
}
260+
return result;
261+
}
262+
86263
static Value cvtFp16ToFp32(Location loc, ConversionPatternRewriter &rewriter,
87264
const Value &v) {
88265
GCNBuilder builder;
@@ -271,6 +448,120 @@ ConverterT Fp16_to_Fp8E5M2FNUZ(HCU::ISAFamily isaFamily) {
271448
: Fp16_to_Fp8E5M2FNUZ_SW;
272449
}
273450

451+
static Value Fp8E4M3FN_to_Fp32_oneValue(Location loc,
452+
ConversionPatternRewriter &rewriter,
453+
Value v) {
454+
// Bitcast FP8 (i8) value
455+
Value vi8 = bitcast(v, i8_ty);
456+
457+
// Extract sign and absolute value
458+
Value sign = and_(vi8, i8_val(0x80)); // 0b1000'0000
459+
Value vAbs = and_(vi8, i8_val(0x7F)); // 0b0111'1111
460+
461+
// Extract exponent and mantissa
462+
Value exp = and_(vAbs, i8_val(0x78)); // 0b0111'1000
463+
Value mant = and_(vAbs, i8_val(0x07)); // 0b0000'0111
464+
465+
// Right shift exponent to LSB
466+
exp = lshr(exp, i8_val(3));
467+
468+
// Detect special cases
469+
Value isZeroOrDenorm = icmp_ult(vAbs, i8_val(0x08)); // < 0b0000'1000
470+
Value isNaN = icmp_eq(vAbs, i8_val(0x7F)); // == 0b0111'1111
471+
472+
// Default normal conversion
473+
// Compose 32-bit float:
474+
// sign << 31 | (exp + 120) << 23 | (mant << 20)
475+
Value sign32 = shl(zext(i32_ty, sign), i32_val(24)); // move sign to bit 31
476+
Value exp32 = shl(zext(i32_ty, exp), i32_val(23)); // move exp to bit 23
477+
Value expBias = i32_val(120 << 23); // bias diff (127 - 7)
478+
exp32 = add(exp32, expBias);
479+
480+
Value mant32 =
481+
shl(zext(i32_ty, mant), i32_val(20)); // move mant to bits 22-20
482+
483+
Value combined = or_(or_(sign32, exp32), mant32);
484+
485+
// Handle NaN: output canonical FP32 NaN 0x7FC00000
486+
Value nanVal = i32_val(0x7FC00000);
487+
combined = select(isNaN, nanVal, combined);
488+
489+
// Handle denorm/zero via LUT
490+
// For vAbs in [0x00 ~ 0x07] → use LUT
491+
constexpr int lutSize = 8;
492+
static constexpr uint32_t denormsAndZeroLut[lutSize] = {
493+
0x00000000, 0x38800000, 0x39800000, 0x3A000000, 0x3A800000,
494+
0x3B000000, 0x3B400000, 0x3B800000}; // approximated FP32 tiny values
495+
496+
for (int i = 0; i < lutSize; ++i) {
497+
combined = select(icmp_eq(vAbs, i8_val(i)), i32_val(denormsAndZeroLut[i]),
498+
combined);
499+
}
500+
501+
// Bitcast to float
502+
Value result = bitcast(combined, f32_ty);
503+
return result;
504+
}
505+
506+
static SmallVector<Value> Fp8E4M3FN_to_Fp32(Location loc,
507+
ConversionPatternRewriter &rewriter,
508+
const SmallVector<Value> &values) {
509+
SmallVector<Value> results(values.size());
510+
for (size_t i = 0; i < values.size(); i++)
511+
results[i] = Fp8E4M3FN_to_Fp32_oneValue(loc, rewriter, values[i]);
512+
return results;
513+
}
514+
515+
static Value Fp8E4M3FN_to_Fp16_oneValue(Location loc,
516+
ConversionPatternRewriter &rewriter,
517+
Value v) {
518+
auto fp8x2VecTy = vec_ty(i8_ty, 2);
519+
Value a = undef(fp8x2VecTy);
520+
a = insert_element(fp8x2VecTy, a, i8_val(0), i32_val(0));
521+
a = insert_element(fp8x2VecTy, a, v, i32_val(1));
522+
a = bitcast(a, i16_ty);
523+
524+
// Get sign and absolute value
525+
Value sign = and_(a, i16_val(0x8000));
526+
a = and_(a, i16_val(0x7FFF));
527+
528+
// Right shift 1 bit to adjust the positions of exponent and mantissa
529+
a = lshr(a, i16_val(1));
530+
531+
// Adjust exponent, (15 - 7) << 10 === 0x2000
532+
a = add(a, i16_val(0x2000));
533+
534+
// Check NaN
535+
Value vAbs = and_(bitcast(v, i8_ty), i8_val(0x7F));
536+
a = select(icmp_eq(vAbs, i8_val(0x7F)), i16_val(0x7E00), a);
537+
538+
// Check denorms and zero
539+
// Here we use a LUT to map S.0000.000 ~ S.0000.111 to its corresponding fp16
540+
// value
541+
constexpr size_t lutSize = 8;
542+
static constexpr int denormsAndZeroLut[lutSize] = {
543+
0x0000, 0x1800, 0x1C00, 0x1E00, 0x2000, 0x2100, 0x2200, 0x2300};
544+
545+
for (int i = 0; i < lutSize; i++) {
546+
a = select(icmp_eq(vAbs, i8_val(i)), i16_val(denormsAndZeroLut[i]), a);
547+
}
548+
549+
// Set sign
550+
a = or_(a, sign);
551+
a = bitcast(a, f16_ty);
552+
553+
return a;
554+
}
555+
556+
static SmallVector<Value> Fp8E4M3FN_to_Fp16(Location loc,
557+
ConversionPatternRewriter &rewriter,
558+
const SmallVector<Value> &values) {
559+
SmallVector<Value> results(values.size());
560+
for (size_t i = 0; i < values.size(); i++)
561+
results[i] = Fp8E4M3FN_to_Fp16_oneValue(loc, rewriter, values[i]);
562+
return results;
563+
}
564+
274565
static SmallVector<Value> Fp8E5M2_to_Fp16(Location loc,
275566
ConversionPatternRewriter &rewriter,
276567
const SmallVector<Value> &v) {
@@ -867,10 +1158,13 @@ struct FpToFpOpConversion
8671158
// F8 -> F16
8681159
{{F8E4M3FNUZTyID, F16TyID, undefRounding},
8691160
Fp8E4M3FNUZ_to_Fp16(isaFamily)},
1161+
{{F8E4M3FNTyID, F16TyID, undefRounding}, Fp8E4M3FN_to_Fp16},
8701162
{{F8E5M2FNUZTyID, F16TyID, undefRounding},
8711163
Fp8E5M2FNUZ_to_Fp16(isaFamily)},
8721164
{{F8E5M2TyID, F16TyID, undefRounding}, Fp8E5M2_to_Fp16},
8731165
// F16 -> F8
1166+
{{F16TyID, F8E4M3FNTyID, RoundingMode::RTNE},
1167+
Fp16_to_Fp8E4M3FN_RTNE},
8741168
{{F16TyID, F8E5M2FNUZTyID, RoundingMode::RTNE},
8751169
Fp16_to_Fp8E5M2FNUZ(isaFamily)},
8761170
{{F16TyID, F8E4M3FNUZTyID, RoundingMode::RTNE},
@@ -893,8 +1187,11 @@ struct FpToFpOpConversion
8931187
Fp32_to_Fp8E4M3FNUZ},
8941188
{{F32TyID, F8E5M2FNUZTyID, RoundingMode::RTNE},
8951189
Fp32_to_Fp8E5M2FNUZ},
1190+
{{F32TyID, F8E4M3FNTyID, RoundingMode::RTNE},
1191+
Fp32_to_Fp8E4M3FN_RTNE},
8961192
{{F8E4M3FNUZTyID, F32TyID, undefRounding}, Fp8E4M3FNUZ_to_Fp32},
8971193
{{F8E5M2FNUZTyID, F32TyID, undefRounding}, Fp8E5M2FNUZ_to_Fp32},
1194+
{{F8E4M3FNTyID, F32TyID, RoundingMode::RTNE}, Fp8E4M3FN_to_Fp32},
8981195
};
8991196
std::tuple<TypeID, TypeID, RoundingMode> key = {
9001197
srcTy.getTypeID(), dstTy.getTypeID(),

third_party/hcu/lib/TritonHCUTransforms/AccelerateHcuFlashAttention.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,8 @@ class TritonHcuFlashAttention : public mlir::RewritePattern {
114114
bool interleave = true; // isSecondDot(dotOp);
115115
unsigned mDim = 16, kDim = 16;
116116
unsigned nDim = oldBType.getShape()[1] < 32 ? 16 : 32;
117-
if (bGobalOrder[0] == 0) {
117+
// if (bGobalOrder[0] == 0) {
118+
if (bGobalOrder[0] == 0 || oldBType.getShape()[1] >= 256) {
118119
mDim = 16, nDim = 16;
119120
}
120121
auto newEnc = ttg::HCUMfmaEncodingAttr::get(

third_party/hcu/lib/TritonHCUTransforms/ReorderInstructions.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,10 @@ static void sinkSecondLoad(triton::FuncOp funcOp) {
328328
triton::DotOp dotOp;
329329
for (Operation &op : forOp) {
330330
if (auto loadOp = dyn_cast<triton::LoadOp>(&op))
331-
loadOps.insert(loadOp);
331+
if (isa<RankedTensorType>(loadOp.getType())) {
332+
loadOps.insert(loadOp);
333+
}
334+
332335
if (auto curOp = dyn_cast<triton::DotOp>(&op))
333336
dotOp = curOp;
334337
}

third_party/hcu/python/triton/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""isort:skip_file"""
2-
__version__ = '3.0.0'
2+
__version__ = '3.1.0'
33

44
# ---------------------------------------
55
# Note: import order is significant here.

0 commit comments

Comments
 (0)