@@ -83,6 +83,183 @@ Fp16_to_Fp8E5M2_RTZ(Location loc, ConversionPatternRewriter &rewriter,
8383 extract_element (i8_ty, a1, i32_val (3 ))};
8484}
8585
86+ // ===----------------===//
87+ // / FP8E4M3
88+ // ===----------------===//
89+ static Value
90+ Fp32_to_Fp8E4M3FN_RTNE_oneValue (Location loc,
91+ ConversionPatternRewriter &rewriter, Value v) {
92+ // == Step 1: Bitcast to i32 and extract components ==
93+ Value vi32 = bitcast (v, i32_ty);
94+ Value sign = trunc (i8_ty, lshr (vi32, i32_val (24 ))); // get sign << 7
95+ sign = and_ (i8_ty, sign, i8_val (0x80 )); // keep sign bit
96+
97+ Value abs = and_ (i32_ty, vi32, i32_val (0x7FFFFFFF )); // strip sign
98+
99+ // == Step 2: NaN check ==
100+ Value isNaN = and_ (
101+ i1_ty,
102+ icmp_eq (and_ (i32_ty, vi32, i32_val (0x7F800000 )),
103+ i32_val (0x7F800000 )), // exp == 0xFF
104+ icmp_ne (and_ (i32_ty, vi32, i32_val (0x007FFFFF )), i32_val (0 )) // frac != 0
105+ );
106+
107+ // == Step 3: Rounding (RTNE) ==
108+ // bias diff = (127 - 7) << 23 = 120 << 23 = 0x3C000000
109+ // mantissa alignment: keep top 3 mantissa bits => shift right by 23 - 3 = 20
110+ constexpr uint32_t baseRoundingBias =
111+ (1 << 19 ) - 1 ; // 0x7FFFF = round-to-even
112+
113+ Value roundBit = lshr (and_ (i32_ty, vi32, i32_val (0x00100000 )),
114+ i32_val (20 )); // bit 20 (for even)
115+ Value roundingBias = add (i32_val (baseRoundingBias), roundBit);
116+ Value vFp8 = add (vi32, roundingBias);
117+
118+ // Keep top 9 bits (sign | exp(4) | mant(3)) ← trunc later
119+ vFp8 = and_ (i32_ty, vFp8, i32_val (0xFFFFFF80 )); // clear bottom bits
120+
121+ // == Step 4: Clamp to min normal (smallest FP8 normal in FP32) ==
122+ // FP32 representation of 2^-6 = 2.0^(-6) = 0x38800000
123+ vFp8 = select (icmp_ult (vFp8, i32_val (0x38800000 )), i32_val (0x38800000 ), vFp8);
124+
125+ // == Step 5: Adjust exponent bias ==
126+ // Subtract (127 - 7) << 23 = 0x3C000000
127+ vFp8 = sub (vFp8, i32_val (0x3C000000 ));
128+
129+ // Shift right to extract FP8 bits: (exp + mant) → shift 20 to fit 3-bit mant
130+ vFp8 = trunc (i8_ty, lshr (vFp8, i32_val (20 )));
131+
132+ // == Step 6: Clamp to FP8 max value (before inf) ==
133+ // 0x7E is largest finite E4M3FN value (0.1111.10x)
134+ Value isOverflowOrInf =
135+ icmp_ugt (abs, i32_val (0x7F7FFFFF )); // > max finite FP32
136+ vFp8 = select (isOverflowOrInf, i8_val (0x7E ), vFp8);
137+
138+ // == Step 7: Handle subnormals via LUT ==
139+ constexpr size_t lutSize = 8 ;
140+ constexpr uint32_t halfwayPointsLUT[lutSize] = {
141+ 0x33800000 , 0x37000000 , 0x38800000 , 0x39400000 ,
142+ 0x3A000000 , 0x3A800000 , 0x3B000000 , 0x3B800000 };
143+
144+ for (int i = lutSize - 1 ; i >= 0 ; --i) {
145+ Value cmp;
146+ if (i % 2 == 0 ) {
147+ cmp = icmp_ule (abs, i32_val (halfwayPointsLUT[i]));
148+ } else {
149+ cmp = icmp_ult (abs, i32_val (halfwayPointsLUT[i]));
150+ }
151+ vFp8 = select (cmp, i8_val (i), vFp8);
152+ }
153+
154+ // == Step 8: Handle NaN ==
155+ vFp8 = select (isNaN, i8_val (0x7F ), vFp8);
156+
157+ // == Step 9: Set sign bit ==
158+ vFp8 = or_ (vFp8, sign);
159+ return vFp8;
160+ }
161+
162+ static SmallVector<Value>
163+ Fp32_to_Fp8E4M3FN_RTNE (Location loc, ConversionPatternRewriter &rewriter,
164+ const SmallVector<Value> &v) {
165+ assert (v.size () == 4 && " Expected 4 values for FP8E4M3FN conversion" );
166+ SmallVector<Value> result (4 );
167+ for (size_t i = 0 ; i < 4 ; ++i) {
168+ result[i] = Fp32_to_Fp8E4M3FN_RTNE_oneValue (loc, rewriter, v[i]);
169+ }
170+ return result;
171+ }
172+
173+ // Cast FP16 to FP8E4M3FN in saturation and round-to-nearest-even mode.
174+ // According to
175+ // https://www.opencompute.org/documents/ocp-8-bit-floating-point-specification-ofp8-revision-1-0-2023-12-01-pdf-1,
176+ // In saturation mode, inf and out-of-range numbers are converted to the largest
177+ // normal number, i.e. ±448. NaNs are converted to NaNs.
178+ static Value
179+ Fp16_to_Fp8E4M3FN_RTNE_oneValue (Location loc,
180+ ConversionPatternRewriter &rewriter, Value v) {
181+ // StringRef funcName = "llvm.is.fpclass";
182+ // Value isNaN = LLVM::createLLVMIntrinsicCallOp(rewriter, loc, funcName,
183+ // i1_ty,
184+ // {v, i32_val(0x3)})
185+ // ->getResult(0);
186+ // Get sign and absolute value
187+ Value vi16 = bitcast (v, i16_ty);
188+ Value isNaN = and_ (
189+ icmp_eq (and_ (vi16, i16_val (0x7C00 )), i16_val (0x7C00 )), // ; exp == 0x7C00
190+ icmp_ne (and_ (vi16, i16_val (0x03FF )), i16_val (0 )) // ; frac != 0
191+ );
192+
193+ Value sign = trunc (i8_ty, lshr (and_ (vi16, i16_val (0x8000 )), i16_val (8 )));
194+ vi16 = and_ (vi16, i16_val (0x7FFF ));
195+
196+ // Rounding to nearest even
197+ constexpr uint16_t baseRoundingBias = 0x003F ; // 1 << (10 - 3 - 1) - 1
198+
199+ // S.EEEEE.MMMMMMMMMM => 0.00000.00M0000000 => 0.00000.000000000M
200+ Value remainingMantissaLSB = lshr (and_ (vi16, i16_val (0x0080 )), i16_val (7 ));
201+ Value roundingBias = add (remainingMantissaLSB, i16_val (baseRoundingBias));
202+ Value vFp8 = add (vi16, roundingBias);
203+
204+ // Reduce mantissa to 3 bits
205+ vFp8 = and_ (vFp8, i16_val (0xFF80 )); // 0xFF80 == 1.11111.1110000000
206+
207+ // 0x2400 is the FP16 representation of 2^{-6}, which is the smallest normal
208+ // number in FP8E4M3FN. We round numbers smaller than that to 0x2400 to make
209+ // it easier to handle subnormals
210+ vFp8 = umax (vFp8, i16_val (0x2400 ));
211+
212+ // Adjust exponent bias
213+ vFp8 = sub (vFp8, i16_val (0x2000 )); // (15 - 7) << 10
214+
215+ // Shift right and truncate
216+ vFp8 = trunc (i8_ty, lshr (vFp8, i16_val (7 ))); // 10 - 3
217+
218+ // 0x5F7F == 0.10111.1101111111 is the largest possible normal
219+ // number(including infinity) after rounding in FP8
220+ //
221+ // In saturation mode, numbers larger than the max normal number(including
222+ // infinity) in FP8 after rounding will be replaced with max_E4M3, i.e. 0x7E
223+ // === 0.1111.110
224+ Value isOverflowOrInf = icmp_ugt (vi16, i16_val (0x5F7F ));
225+ vFp8 = select (isOverflowOrInf, i8_val (0x7E ), vFp8);
226+
227+ // Round subnormals to nearest even. Ref:
228+ // https://github.com/openxla/xla/blob/f20c6fe2/xla/service/elemental_ir_emitter.cc#L272
229+ constexpr size_t lutSize = 8 ;
230+ constexpr float halfwayPointsLUT[lutSize] = {0x1400 , 0x1A00 , 0x1D00 , 0x1F00 ,
231+ 0x2080 , 0x2180 , 0x2280 , 0x2380 };
232+
233+ for (int i = lutSize - 1 ; i >= 0 ; i--) {
234+ Value cmp;
235+ if (i % 2 == 0 ) {
236+ cmp = icmp_ule (vi16, i16_val (halfwayPointsLUT[i]));
237+ } else {
238+ cmp = icmp_ult (vi16, i16_val (halfwayPointsLUT[i]));
239+ }
240+
241+ vFp8 = select (cmp, i8_val (i), vFp8);
242+ }
243+
244+ // NaN remains NaN after conversion
245+ vFp8 = select (isNaN, i8_val (0x7F ), vFp8);
246+
247+ // Set sign bit
248+ vFp8 = or_ (i8_ty, vFp8, sign);
249+
250+ return vFp8;
251+ }
252+
253+ static SmallVector<Value>
254+ Fp16_to_Fp8E4M3FN_RTNE (Location loc, ConversionPatternRewriter &rewriter,
255+ const SmallVector<Value> &v) {
256+ SmallVector<Value> result (v.size ());
257+ for (size_t i = 0 ; i < v.size (); ++i) {
258+ result[i] = Fp16_to_Fp8E4M3FN_RTNE_oneValue (loc, rewriter, v[i]);
259+ }
260+ return result;
261+ }
262+
86263static Value cvtFp16ToFp32 (Location loc, ConversionPatternRewriter &rewriter,
87264 const Value &v) {
88265 GCNBuilder builder;
@@ -271,6 +448,120 @@ ConverterT Fp16_to_Fp8E5M2FNUZ(HCU::ISAFamily isaFamily) {
271448 : Fp16_to_Fp8E5M2FNUZ_SW;
272449}
273450
451+ static Value Fp8E4M3FN_to_Fp32_oneValue (Location loc,
452+ ConversionPatternRewriter &rewriter,
453+ Value v) {
454+ // Bitcast FP8 (i8) value
455+ Value vi8 = bitcast (v, i8_ty);
456+
457+ // Extract sign and absolute value
458+ Value sign = and_ (vi8, i8_val (0x80 )); // 0b1000'0000
459+ Value vAbs = and_ (vi8, i8_val (0x7F )); // 0b0111'1111
460+
461+ // Extract exponent and mantissa
462+ Value exp = and_ (vAbs, i8_val (0x78 )); // 0b0111'1000
463+ Value mant = and_ (vAbs, i8_val (0x07 )); // 0b0000'0111
464+
465+ // Right shift exponent to LSB
466+ exp = lshr (exp, i8_val (3 ));
467+
468+ // Detect special cases
469+ Value isZeroOrDenorm = icmp_ult (vAbs, i8_val (0x08 )); // < 0b0000'1000
470+ Value isNaN = icmp_eq (vAbs, i8_val (0x7F )); // == 0b0111'1111
471+
472+ // Default normal conversion
473+ // Compose 32-bit float:
474+ // sign << 31 | (exp + 120) << 23 | (mant << 20)
475+ Value sign32 = shl (zext (i32_ty, sign), i32_val (24 )); // move sign to bit 31
476+ Value exp32 = shl (zext (i32_ty, exp), i32_val (23 )); // move exp to bit 23
477+ Value expBias = i32_val (120 << 23 ); // bias diff (127 - 7)
478+ exp32 = add (exp32, expBias);
479+
480+ Value mant32 =
481+ shl (zext (i32_ty, mant), i32_val (20 )); // move mant to bits 22-20
482+
483+ Value combined = or_ (or_ (sign32, exp32), mant32);
484+
485+ // Handle NaN: output canonical FP32 NaN 0x7FC00000
486+ Value nanVal = i32_val (0x7FC00000 );
487+ combined = select (isNaN, nanVal, combined);
488+
489+ // Handle denorm/zero via LUT
490+ // For vAbs in [0x00 ~ 0x07] → use LUT
491+ constexpr int lutSize = 8 ;
492+ static constexpr uint32_t denormsAndZeroLut[lutSize] = {
493+ 0x00000000 , 0x38800000 , 0x39800000 , 0x3A000000 , 0x3A800000 ,
494+ 0x3B000000 , 0x3B400000 , 0x3B800000 }; // approximated FP32 tiny values
495+
496+ for (int i = 0 ; i < lutSize; ++i) {
497+ combined = select (icmp_eq (vAbs, i8_val (i)), i32_val (denormsAndZeroLut[i]),
498+ combined);
499+ }
500+
501+ // Bitcast to float
502+ Value result = bitcast (combined, f32_ty);
503+ return result;
504+ }
505+
506+ static SmallVector<Value> Fp8E4M3FN_to_Fp32 (Location loc,
507+ ConversionPatternRewriter &rewriter,
508+ const SmallVector<Value> &values) {
509+ SmallVector<Value> results (values.size ());
510+ for (size_t i = 0 ; i < values.size (); i++)
511+ results[i] = Fp8E4M3FN_to_Fp32_oneValue (loc, rewriter, values[i]);
512+ return results;
513+ }
514+
515+ static Value Fp8E4M3FN_to_Fp16_oneValue (Location loc,
516+ ConversionPatternRewriter &rewriter,
517+ Value v) {
518+ auto fp8x2VecTy = vec_ty (i8_ty, 2 );
519+ Value a = undef (fp8x2VecTy);
520+ a = insert_element (fp8x2VecTy, a, i8_val (0 ), i32_val (0 ));
521+ a = insert_element (fp8x2VecTy, a, v, i32_val (1 ));
522+ a = bitcast (a, i16_ty);
523+
524+ // Get sign and absolute value
525+ Value sign = and_ (a, i16_val (0x8000 ));
526+ a = and_ (a, i16_val (0x7FFF ));
527+
528+ // Right shift 1 bit to adjust the positions of exponent and mantissa
529+ a = lshr (a, i16_val (1 ));
530+
531+ // Adjust exponent, (15 - 7) << 10 === 0x2000
532+ a = add (a, i16_val (0x2000 ));
533+
534+ // Check NaN
535+ Value vAbs = and_ (bitcast (v, i8_ty), i8_val (0x7F ));
536+ a = select (icmp_eq (vAbs, i8_val (0x7F )), i16_val (0x7E00 ), a);
537+
538+ // Check denorms and zero
539+ // Here we use a LUT to map S.0000.000 ~ S.0000.111 to its corresponding fp16
540+ // value
541+ constexpr size_t lutSize = 8 ;
542+ static constexpr int denormsAndZeroLut[lutSize] = {
543+ 0x0000 , 0x1800 , 0x1C00 , 0x1E00 , 0x2000 , 0x2100 , 0x2200 , 0x2300 };
544+
545+ for (int i = 0 ; i < lutSize; i++) {
546+ a = select (icmp_eq (vAbs, i8_val (i)), i16_val (denormsAndZeroLut[i]), a);
547+ }
548+
549+ // Set sign
550+ a = or_ (a, sign);
551+ a = bitcast (a, f16_ty);
552+
553+ return a;
554+ }
555+
556+ static SmallVector<Value> Fp8E4M3FN_to_Fp16 (Location loc,
557+ ConversionPatternRewriter &rewriter,
558+ const SmallVector<Value> &values) {
559+ SmallVector<Value> results (values.size ());
560+ for (size_t i = 0 ; i < values.size (); i++)
561+ results[i] = Fp8E4M3FN_to_Fp16_oneValue (loc, rewriter, values[i]);
562+ return results;
563+ }
564+
274565static SmallVector<Value> Fp8E5M2_to_Fp16 (Location loc,
275566 ConversionPatternRewriter &rewriter,
276567 const SmallVector<Value> &v) {
@@ -867,10 +1158,13 @@ struct FpToFpOpConversion
8671158 // F8 -> F16
8681159 {{F8E4M3FNUZTyID, F16TyID, undefRounding},
8691160 Fp8E4M3FNUZ_to_Fp16 (isaFamily)},
1161+ {{F8E4M3FNTyID, F16TyID, undefRounding}, Fp8E4M3FN_to_Fp16},
8701162 {{F8E5M2FNUZTyID, F16TyID, undefRounding},
8711163 Fp8E5M2FNUZ_to_Fp16 (isaFamily)},
8721164 {{F8E5M2TyID, F16TyID, undefRounding}, Fp8E5M2_to_Fp16},
8731165 // F16 -> F8
1166+ {{F16TyID, F8E4M3FNTyID, RoundingMode::RTNE},
1167+ Fp16_to_Fp8E4M3FN_RTNE},
8741168 {{F16TyID, F8E5M2FNUZTyID, RoundingMode::RTNE},
8751169 Fp16_to_Fp8E5M2FNUZ (isaFamily)},
8761170 {{F16TyID, F8E4M3FNUZTyID, RoundingMode::RTNE},
@@ -893,8 +1187,11 @@ struct FpToFpOpConversion
8931187 Fp32_to_Fp8E4M3FNUZ},
8941188 {{F32TyID, F8E5M2FNUZTyID, RoundingMode::RTNE},
8951189 Fp32_to_Fp8E5M2FNUZ},
1190+ {{F32TyID, F8E4M3FNTyID, RoundingMode::RTNE},
1191+ Fp32_to_Fp8E4M3FN_RTNE},
8961192 {{F8E4M3FNUZTyID, F32TyID, undefRounding}, Fp8E4M3FNUZ_to_Fp32},
8971193 {{F8E5M2FNUZTyID, F32TyID, undefRounding}, Fp8E5M2FNUZ_to_Fp32},
1194+ {{F8E4M3FNTyID, F32TyID, RoundingMode::RTNE}, Fp8E4M3FN_to_Fp32},
8981195 };
8991196 std::tuple<TypeID, TypeID, RoundingMode> key = {
9001197 srcTy.getTypeID (), dstTy.getTypeID (),
0 commit comments