From 3a25ebfde369198407176bcbf263d3d5de03f071 Mon Sep 17 00:00:00 2001 From: "Klochkov, Denis" Date: Wed, 18 Feb 2026 09:28:10 +0100 Subject: [PATCH 01/89] [SYCL] add new fp8 data types and unit tests --- .../oneapi/experimental/float_8bit/types.hpp | 1758 +++++++++++++++++ sycl/unittests/Extensions/CMakeLists.txt | 1 + sycl/unittests/Extensions/fp8/CMakeLists.txt | 9 + sycl/unittests/Extensions/fp8/fp8_e4m3.cpp | 699 +++++++ sycl/unittests/Extensions/fp8/fp8_e5m2.cpp | 512 +++++ sycl/unittests/Extensions/fp8/fp8_e8m0.cpp | 280 +++ 6 files changed, 3259 insertions(+) create mode 100644 sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp create mode 100644 sycl/unittests/Extensions/fp8/CMakeLists.txt create mode 100644 sycl/unittests/Extensions/fp8/fp8_e4m3.cpp create mode 100644 sycl/unittests/Extensions/fp8/fp8_e5m2.cpp create mode 100644 sycl/unittests/Extensions/fp8/fp8_e8m0.cpp diff --git a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp new file mode 100644 index 0000000000000..b030dd1009682 --- /dev/null +++ b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp @@ -0,0 +1,1758 @@ +#pragma once + +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace sycl { +inline namespace _V1 { +namespace ext::oneapi::experimental { + +#ifdef __SYCL_TARGET_INTEL_GPU_CRI__ + +#ifdef __SYCL_DEVICE_ONLY__ + +// New FP8 builtins +extern __DPCPP_SYCL_EXTERNAL sycl::half +__builtin_spirv_ConvertE4M3ToFP16EXT(uint8_t) noexcept; +extern __DPCPP_SYCL_EXTERNAL sycl::half +__builtin_spirv_ConvertE5M2ToFP16EXT(uint8_t) noexcept; +extern __DPCPP_SYCL_EXTERNAL sycl::ext::oneapi::bfloat16 +__builtin_spirv_ConvertE4M3ToBF16EXT(uint8_t) noexcept; +extern __DPCPP_SYCL_EXTERNAL sycl::ext::oneapi::bfloat16 +__builtin_spirv_ConvertE5M2ToBF16EXT(uint8_t) noexcept; +extern __DPCPP_SYCL_EXTERNAL + uint8_t __builtin_spirv_ConvertFP16ToE4M3EXT(sycl::half) noexcept; +extern __DPCPP_SYCL_EXTERNAL + uint8_t __builtin_spirv_ConvertFP16ToE5M2EXT(sycl::half) noexcept; +extern __DPCPP_SYCL_EXTERNAL uint8_t + __builtin_spirv_ConvertBF16ToE4M3EXT(sycl::ext::oneapi::bfloat16) noexcept; +extern __DPCPP_SYCL_EXTERNAL uint8_t + __builtin_spirv_ConvertBF16ToE5M2EXT(sycl::ext::oneapi::bfloat16) noexcept; +extern __DPCPP_SYCL_EXTERNAL + uint8_t __builtin_spirv_ClampConvertFP16ToE4M3INTEL(sycl::half) noexcept; +extern __DPCPP_SYCL_EXTERNAL + uint8_t __builtin_spirv_ClampConvertBF16ToE4M3INTEL( + sycl::ext::oneapi::bfloat16) noexcept; +extern __DPCPP_SYCL_EXTERNAL + uint8_t __builtin_spirv_ClampConvertFP16ToE5M2INTEL(sycl::half) noexcept; +extern __DPCPP_SYCL_EXTERNAL + uint8_t __builtin_spirv_ClampConvertBF16ToE5M2INTEL( + sycl::ext::oneapi::bfloat16) noexcept; +extern __DPCPP_SYCL_EXTERNAL + uint8_t __builtin_spirv_StochasticRoundFP16ToE5M2INTEL(sycl::half) noexcept; +extern __DPCPP_SYCL_EXTERNAL + uint8_t __builtin_spirv_StochasticRoundFP16ToE4M3INTEL(sycl::half) noexcept; +extern __DPCPP_SYCL_EXTERNAL + uint8_t __builtin_spirv_StochasticRoundBF16ToE5M2INTEL( + sycl::ext::oneapi::bfloat16) noexcept; +extern __DPCPP_SYCL_EXTERNAL + uint8_t __builtin_spirv_StochasticRoundBF16ToE4M3INTEL( + sycl::ext::oneapi::bfloat16) noexcept; +extern __DPCPP_SYCL_EXTERNAL uint8_t + __builtin_spirv_ClampStochasticRoundFP16ToE5M2INTEL(sycl::half) noexcept; +extern __DPCPP_SYCL_EXTERNAL uint8_t + __builtin_spirv_ClampStochasticRoundFP16ToE4M3INTEL(sycl::half) noexcept; +extern __DPCPP_SYCL_EXTERNAL + uint8_t __builtin_spirv_ClampStochasticRoundBF16ToE5M2INTEL( + sycl::ext::oneapi::bfloat16) noexcept; +extern __DPCPP_SYCL_EXTERNAL + uint8_t __builtin_spirv_ClampStochasticRoundBF16ToE4M3INTEL( + sycl::ext::oneapi::bfloat16) noexcept; + +#endif // __SYCL_DEVICE_ONLY__ + +enum class saturation { none, finite }; + +enum class rounding { + to_even, + upward, + downward, + toward_zero, + to_away, + stochastic +}; + +struct stochastic_seed { + explicit stochastic_seed(uint32_t *pseed) : pseed(pseed) {} + uint32_t *const pseed; +}; + +static inline uint8_t RneClip(float x, uint8_t max) noexcept { + float f = std::floor(x); + float frac = x - f; + uint8_t i = static_cast(f); + if (frac > 0.5f) + ++i; + else if (frac == 0.5f) + i += (i & 1u); // ties to even + return i > max ? max : i; +} + +static inline uint8_t RoundClip(float x, uint8_t max, rounding R, + uint8_t sign_bit) noexcept { + if (max == 0) { + // No fraction bits (E8M0 path) + if (R == rounding::upward) { + // Any positive residual causes a carry; NaN / non-positive → 0 + if (!std::isnan(x) && x > 0.0f) + return 1u; + return 0u; + } + // Default / to_even + if (std::isnan(x)) + return 0u; + if (x > 0.5f) + return 1u; + if (x == 0.5f) + return 0u; // tie -> even (0) + return 0u; + } + + // Formats with fraction bits (E4M3, E5M2) + if (R == rounding::upward) { + if (sign_bit == 0u) { + // Positive: ceil + uint32_t ci = static_cast(std::ceil(x)); + if (ci > max) + ci = max; + return static_cast(ci); + } else { + // Negative: toward +inf => magnitude decreases -> floor + uint32_t fi = static_cast(std::floor(x)); + if (fi > max) + fi = max; + return static_cast(fi); + } + } + // default: round-to-nearest-even + return RneClip(x, max); +} + +template +static inline ToT ConvertFloatToTarget(float v, rounding R) noexcept { + if constexpr (std::is_same_v || + std::is_same_v) { + ToT cand = static_cast(v); + if (R == rounding::toward_zero) { + // If cast increased magnitude, step the 16-bit encoding toward zero. + float fcand = static_cast(cand); + if (std::fabs(fcand) > std::fabs(v)) { + uint16_t bits = sycl::bit_cast(cand); + // Order-preserving transform: sign-bit mapped to MSB ordering + uint16_t ord = (bits & 0x8000u) ? static_cast(~bits) + : static_cast(bits ^ 0x8000u); + if (v >= 0.0f) { + if (ord != 0u) + --ord; // step toward smaller positive + } else { + if (ord != 0xFFFFu) + ++ord; // step toward smaller magnitude for negative numbers + } + uint16_t newbits = (ord & 0x8000u) + ? static_cast(~ord) + : static_cast(ord ^ 0x8000u); + cand = sycl::bit_cast(newbits); + } + } + return cand; + } else + // For float/double/integral targets just use normal cast + return static_cast(v); +} + +template +static inline ToT ConvertFromFP8_CPU(uint8_t b, + rounding R = rounding::to_even) noexcept { + static_assert((Ebits == 4 && Mbits == 3) || (Ebits == 5 && Mbits == 2) || + (Ebits == 8 && Mbits == 0), + "Unsupported FP8 (Ebits,Mbits) combination"); + + constexpr int Bias = (1 << (Ebits - 1)) - 1; + constexpr int Emin = 1 - Bias; + constexpr uint8_t ExpMaskAll = static_cast((1u << Ebits) - 1u); + constexpr uint32_t FracDen = (Mbits == 0) ? 1u : (1u << Mbits); + constexpr uint8_t MaxFrac = static_cast(FracDen - 1u); + + // Extract fields. + uint8_t sign_bit = (b & 0x80u) ? 1u : 0u; + uint8_t frac = (Mbits == 0) ? 0u : static_cast(b & MaxFrac); + + uint8_t exp = static_cast((b >> Mbits) & ExpMaskAll); + if constexpr (Ebits == 8 && Mbits == 0) { + sign_bit = 0u; + exp = b; + } + + auto make_nan = [&]() -> ToT { + float qn = std::numeric_limits::quiet_NaN(); + return static_cast(qn); + }; + + // Handle exp = all ones (custom finite-only rules). + if (exp == ExpMaskAll) { + if constexpr (Ebits == 4 && Mbits == 3) { + // E4M3: only frac==111 -> NaN, otherwise normal. + if (frac == MaxFrac) + return make_nan(); + // treat as normal finite + } else if constexpr (Ebits == 5 && Mbits == 2) { + // E5M2: NaN when frac in {01,10,11} i.e. frac != 00 + if (frac != 0) + return make_nan(); + // frac==00 -> normal finite + } else // E8M0: exp all ones -> NaN + return make_nan(); + } + + // exp == 0 : zero or subnormal (if Mbits>0) + if (exp == 0) { + if constexpr (Mbits == 0) { + // E8M0: exp==0 is the smallest normal (no subnormals) + int E = -Bias; + float v = std::ldexp(1.0f, E); + return ConvertFloatToTarget(v, R); + } else { + if (frac == 0) { + float zf = std::copysign(0.0f, sign_bit ? -1.0f : 1.0f); + if constexpr (std::is_same_v || + std::is_same_v) + return ConvertFloatToTarget(zf, R); + else + return static_cast(zf); + } + // Subnormal: value = sign * (frac / 2^Mbits) * 2^(Emin) + float m = static_cast(frac) / static_cast(FracDen); + float v = std::ldexp(m, Emin); + return ConvertFloatToTarget((sign_bit ? -v : v), R); + } + } + + // Normal number. + int E = static_cast(exp) - Bias; + float m; + if constexpr (Mbits == 0) + // E8M0: mantissa == 1 always + m = 1.0f; + else + m = 1.0f + static_cast(frac) / static_cast(FracDen); + float v = std::ldexp(m, E); + return ConvertFloatToTarget((sign_bit ? -v : v), R); +} + +/// \brief Converts a given value to fp8 floating point with a rounding +/// mode to_even by default and saturation finite for host code. +/// \param h The input value to be converted. +/// \param R The rounding mode to be used during conversion. +/// \return uint8_t The converted 8-bit floating point value, MSB is sign bit, +/// Ebits bits mantissa, Mbits bits exponent. +template +static inline uint8_t +ConvertToFP8_CPU(T h, rounding R = rounding::to_even) noexcept { + // Specialized implementation for fp8_e8m0 (Ebits=8, Mbits=0) + if constexpr (Ebits == 8 && Mbits == 0) { + // Format characteristics (finite-only, no zero, no infinity): + // - Bias: 127 + // - Exponent field range used for normals: 0 .. 254 (E = ecode - 127 -> + // [-127, +127]) + // - Encoding with exp==255 (0xFF) reserved for NaN (single payload 0xFF) + // - Value encoded when exponent field == 0: +/- 2^{-127} + // - Max normal: +/- 2^{127} (~1.7014118e+38) + // + // Rounding mode: the public API restricts this format to rounding::upward. + // Here we honor upward if passed; any other mode falls back to upward + // behavior. + // + // Note: The format cannot represent zero; inputs with |x| < 2^{-127} map + // to the smallest magnitude normal with the input sign preserved + // (consistent with prior sign-preserving underflow behavior). + // + constexpr int Bias = 127; + constexpr int Emin = -127; + constexpr int Emax = 127; + constexpr uint8_t NaNCode = 0xFF; // 11111111 + constexpr uint8_t MaxExpField = 254; // 255 reserved for NaN + const float min_normal = std::ldexp(1.0f, Emin); // 2^{-127} + const float max_normal = std::ldexp(1.0f, Emax); // 2^{127} + + float x = static_cast(h); + + if (std::isnan(x)) + return NaNCode; + + uint8_t sign = std::signbit(x) ? 0x80 : 0x00; + float ax = std::fabs(x); + + // Handle underflow (|x| < min_normal) and x == 0: encode smallest normal + // with sign. + if (ax == 0.0f || ax < min_normal) + return sign; // exp field = 0 -> E = -127 + + // Handle overflow (|x| >= max_normal * (anything beyond representable)): + if (ax >= max_normal) + return static_cast(sign | (MaxExpField)); // E = +127 + + // Determine exponent E such that 2^E <= ax < 2^{E+1} + int e2; + float m = std::frexp(ax, &e2); // ax = m * 2^{e2}, m in [0.5,1) + int E = e2 - 1; // Now 2^E <= ax < 2^{E+1} + + // Upward rounding semantics: + // - For positive numbers: if not exact power-of-two, round up to next + // power (E+1) if within range. + // - For negative numbers: rounding toward +inf moves value toward zero, so + // keep current E. + // Exact power-of-two: m == 0.5 (since frexp gives m in [0.5,1)) + bool is_exact_power_of_two = (m == 0.5f); + + rounding effR = (R == rounding::upward) ? R : rounding::upward; + + if (effR == rounding::upward) { + if (sign == 0x00) { + if (!is_exact_power_of_two) { + // Round up (increase exponent) if possible. + if (E < Emax) + ++E; + else + E = Emax; + } + } else { + // Negative: leave E as-is (toward +inf reduces magnitude). + } + } + + // Clamp exponent just in case. + if (E < Emin) + E = Emin; + if (E > Emax) + E = Emax; + + uint8_t ecode = static_cast(E + Bias); // 0 .. 254 + // ecode must never be 255 here. + return static_cast(sign | ecode); + } + + constexpr int bias = (1 << (Ebits - 1)) - 1; + // allow the top exponent field (ExpAllOnes) as a normal exponent except when + // frac==MaxFrac (NaN) + int emax = 0; + int emin = 0; + if constexpr (Ebits == 8) + emax = 127; + else { + emax = (1 << Ebits) - 1 - bias; // ExpAllOnes - bias + emin = 1 - bias; + } + constexpr uint8_t ExpAllOnes = static_cast((1 << Ebits) - 1); + constexpr uint8_t MaxFrac = static_cast((1 << Mbits) - 1); + constexpr uint8_t MaxFracForMaxNormal = + (Ebits == 4 && Mbits == 3) ? static_cast(MaxFrac - 1u) : MaxFrac; + constexpr uint8_t MaxExpForMaxNormal = + (Ebits == 5 && Mbits == 2) + ? static_cast(ExpAllOnes - 1u) + : ExpAllOnes; + constexpr uint8_t MaxFracMask = MaxFrac; + + float x = static_cast(h); + uint8_t sign = std::signbit(x) ? 0x80 : 0x00; + if (std::isnan(x)) + return static_cast( + sign | ((ExpAllOnes << Mbits) | MaxFracMask)); // S.1111.111 -> NaN + uint8_t sign_bit = sign ? 1u : 0u; + float ax = std::fabs(x); + + const float max_finite = + (2.0f - std::ldexp(1.0f, 1 - Mbits)) * std::ldexp(1.0f, emax); + const float min_sub = std::ldexp(1.0f, emin - Mbits); + + if (ax > max_finite) { + return static_cast( + sign | ((MaxExpForMaxNormal << Mbits) | MaxFracForMaxNormal)); + } + if (ax >= max_finite) { + return static_cast( + sign | ((MaxExpForMaxNormal << Mbits) | MaxFracForMaxNormal)); + } + + if (ax < min_sub) + return sign; // underflow + + int e2; + float m = std::frexp(ax, &e2); + int E = e2 - 1; + + if (E < emin) { + float scaled = std::ldexp(ax, -emin) * static_cast(1 << Mbits); + uint32_t k = RoundClip(scaled, MaxFrac, R, sign_bit); + if (k == 0) + return sign; + return static_cast(sign | static_cast(k)); + } + + float y = m * 2.0f; + float frac_scaled = (y - 1.0f) * static_cast(1 << Mbits); + uint32_t frac = RoundClip(frac_scaled, MaxFrac, R, sign_bit); + if (frac == (1u << Mbits)) { + frac = 0; + ++E; + } + if (E > emax) { + auto ret = static_cast( + sign | ((MaxExpForMaxNormal << Mbits) | MaxFracForMaxNormal)); + return ret; + } + uint8_t ecode = static_cast(E + bias); + auto ret = static_cast(sign | (ecode << Mbits) | + static_cast(frac)); + return ret; +} + +// Map E4M3 byte to integer +// then "nextUp" in that order, and map back. +// E4M3 finite-only: exp=0xF & frac!=0 => NaN (no Inf). +inline uint8_t nextE4M3(uint8_t b, bool up) { + uint8_t exp = (b >> 3) & 0x0F; + uint8_t frac = b & 0x07; + // NaN -> NaN + if (exp == 0x0F && frac) + return b; + uint8_t ord = + (b & 0x80) ? static_cast(~b) : static_cast(b ^ 0x80); + + if (up) { + if (ord == 0xFF) + return b; + ++ord; + } else { + if (ord == 0x00) + return b; + --ord; + } + return (ord & 0x80) ? static_cast(ord ^ 0x80) + : static_cast(~ord); +} + +template +uint8_t round(rounding r, uint8_t b, sycl::half yi, T vi) { + switch (r) { + case rounding::upward: { + if (yi < vi) + return nextE4M3(b, /*up=*/true); + break; + } + case rounding::downward: { + if (yi > vi) + return nextE4M3(b, /*up=*/false); + break; + } + case rounding::toward_zero: + if (vi > 0.0f && yi > vi) { + return nextE4M3(b, /*up=*/false); + } else if (vi < 0.0f && yi < vi) { + return nextE4M3(b, /*up=*/true); + } + break; + case rounding::to_away: + if (vi > 0.0f && yi < vi) { + return nextE4M3(b, /*up=*/true); + } else if (vi < 0.0f && yi > vi) { + return nextE4M3(b, /*up=*/false); + } + break; + default: + break; + } + return b; +} + +void CheckRoundingConstraints(rounding r) { +#ifdef __SYCL_DEVICE_ONLY__ +#else + if (r != rounding::to_even) + throw std::invalid_argument("Host code supports only rounding to_even"); +#endif +} + +template class fp8_e4m3 { + static constexpr size_t NExpBits = 4; + static constexpr size_t NFracBits = 3; + static constexpr float MaxNormal = 448.0f; + static constexpr float MinSubnormal = 0.001953125f; // 2^-9 + static constexpr uint8_t NaNCode = 0xFF; + static constexpr uint8_t MaxFiniteCode = + 0x7E; // 0.1111.110 (positive max normal) + + template uint8_t ConvertToFP8(T h, rounding r) { + sycl::half hi = static_cast(h); +#ifdef __SYCL_DEVICE_ONLY__ + // TODO: optimize with vectorized builtin calls + const uint8_t sign = std::signbit(hi) ? 0x80u : 0x00u; + const float ax = sycl::fabs(hi); + + if (ax > MaxNormal) + return static_cast(sign | MaxFiniteCode); + + if (ax < MinSubnormal) + return sign; + + uint8_t b = __builtin_spirv_ConvertFP16ToE4M3EXT(h); + if (r == rounding::to_even) + return b; + + const sycl::half yi = __builtin_spirv_ConvertE4M3ToFP16EXT(b); + return round(r, b, yi, hi); + +#else + return ConvertToFP8_CPU<4, 3, sycl::half>(hi, r); +#endif + } + + uint8_t ConvertBF16ToFP8(bfloat16 h, rounding r) { +#ifdef __SYCL_DEVICE_ONLY__ + const uint8_t sign = std::signbit(h) ? 0x80u : 0x00u; + const float ax = sycl::fabs(h); + + if (ax > MaxNormal) + return static_cast(sign | MaxFiniteCode); + + if (ax < MinSubnormal) + return sign; + + uint8_t b = __builtin_spirv_ConvertBF16ToE4M3EXT(h); + if (r == rounding::to_even) + return b; + + const half yi = __builtin_spirv_ConvertBF16ToE4M3EXT(b); + return round(r, b, yi, h); +#else + return ConvertToFP8_CPU<4, 3, bfloat16>(h, r); +#endif + } + + template T ConvertFromFP8(uint8_t v) const { +#ifdef __SYCL_DEVICE_ONLY__ + sycl::half hi = __builtin_spirv_ConvertE4M3ToFP16EXT(v); + return static_cast(hi); +#else + return ConvertFromFP8_CPU<4, 3, T>(v); +#endif + } + + bfloat16 ConvertBF16FromFP8(uint8_t v) const { +#ifdef __SYCL_DEVICE_ONLY__ + return __builtin_spirv_ConvertE4M3ToBF16EXT(v); +#else + return ConvertFromFP8_CPU<4, 3, bfloat16>(v); +#endif + } + +public: + fp8_e4m3() = default; + fp8_e4m3(const fp8_e4m3 &) = default; + + ~fp8_e4m3() = default; + fp8_e4m3 &operator=(const fp8_e4m3 &) = default; + + // Construct from pack of half, float, double. + // Available only when the size of the pack is equal to N. + + template , half> || + std::is_same_v, bfloat16> || + std::is_same_v, float> || + std::is_same_v, double>) && + ...))>> + explicit fp8_e4m3(Types... v) { +#ifdef __SYCL_DEVICE_ONLY__ + static_assert(N == 1 || N == 2, + "fp8_e4m3: Template argument N must be 1 or 2 on device"); +#endif + if constexpr (((std::is_same_v, bfloat16>) && ...)) { + const bfloat16 in[N] = {static_cast(v)...}; + for (size_t i = 0; i < N; ++i) + vals[i] = ConvertBF16ToFP8(in[i], rounding::to_even); + return; + } + const sycl::half in[N] = {v...}; + for (size_t i = 0; i < N; ++i) + vals[i] = ConvertToFP8(in[i], rounding::to_even); + } + + // Construct from an array of half, bfloat16, float, double. + explicit fp8_e4m3(sycl::half const (&v)[N], rounding r = rounding::to_even) { +#ifdef __SYCL_DEVICE_ONLY__ + static_assert(N == 1 || N == 2, + "fp8_e4m3: Template argument N must be 1 or 2 on device"); +#endif + // TODO: optimize with vectorized builtin calls + for (size_t i = 0; i < N; ++i) + vals[i] = ConvertToFP8(v[i], r); + } + + explicit fp8_e4m3(bfloat16 const (&v)[N], rounding r = rounding::to_even) { +#ifdef __SYCL_DEVICE_ONLY__ + static_assert(N == 1 || N == 2, + "fp8_e4m3: Template argument N must be 1 or 2 on device"); +#else + CheckRoundingConstraints(r); +#endif + for (size_t i = 0; i < N; ++i) + vals[i] = ConvertBF16ToFP8(v[i], r); + } + + explicit fp8_e4m3(float const (&v)[N], rounding r = rounding::to_even) { +#ifdef __SYCL_DEVICE_ONLY__ + static_assert(N == 1 || N == 2, + "fp8_e4m3: Template argument N must be 1 or 2 on device"); +#else + CheckRoundingConstraints(r); +#endif + for (size_t i = 0; i < N; ++i) + vals[i] = ConvertToFP8(v[i], r); + } + + explicit fp8_e4m3(double const (&v)[N]) { +#ifdef __SYCL_DEVICE_ONLY__ + static_assert(N == 1 || N == 2, + "fp8_e4m3: Template argument N must be 1 or 2 on device"); +#else + CheckRoundingConstraints(r); +#endif + for (size_t i = 0; i < N; ++i) + vals[i] = ConvertToFP8(v[i], rounding::to_even); + } + + // Construct from an marray of half, bfloat16, float, double. + explicit fp8_e4m3(const sycl::marray &v, + rounding r = rounding::to_even) { +#ifdef __SYCL_DEVICE_ONLY__ + static_assert(N == 1 || N == 2, + "fp8_e4m3: Template argument N must be 1 or 2 on device"); +#else + CheckRoundingConstraints(r); +#endif + for (size_t i = 0; i < N; ++i) + vals[i] = ConvertToFP8(v[i], r); + } + + explicit fp8_e4m3(const sycl::marray &v, + rounding r = rounding::to_even) { +#ifdef __SYCL_DEVICE_ONLY__ + static_assert(N == 1 || N == 2, + "fp8_e4m3: Template argument N must be 1 or 2 on device"); +#else + CheckRoundingConstraints(r); +#endif + for (size_t i = 0; i < N; ++i) + vals[i] = ConvertBF16ToFP8(v[i], r); + } + + explicit fp8_e4m3(const sycl::marray &v, rounding r = rounding::to_even) { +#ifdef __SYCL_DEVICE_ONLY__ + static_assert(N == 1 || N == 2, + "fp8_e4m3: Template argument N must be 1 or 2 on device"); +#else + CheckRoundingConstraints(r); +#endif + for (size_t i = 0; i < N; ++i) + vals[i] = ConvertToFP8(v[i], r); + } + + explicit fp8_e4m3(const sycl::marray &v) { +#ifdef __SYCL_DEVICE_ONLY__ + static_assert(N == 1 || N == 2, + "fp8_e4m3: Template argument N must be 1 or 2 on device"); +#else + CheckRoundingConstraints(r); +#endif + for (size_t i = 0; i < N; ++i) + vals[i] = ConvertToFP8(v[i], rounding::to_even); + } + + // Construct with stochastic rounding with user provided seed from an array of + // half, bfloat16, float. + // Should be removed once docs updated + explicit fp8_e4m3(half const (&vals)[N], const stochastic_seed &seed, + saturation s = saturation::finite); + explicit fp8_e4m3(bfloat16 const (&vals)[N], const stochastic_seed &seed, + saturation s = saturation::finite); + explicit fp8_e4m3(float const (&vals)[N], const stochastic_seed &seed, + saturation s = saturation::finite); + + // Construct with stochastic rounding with user provided seed from an marray + // of half, bfloat16, float. + + // Should be removed once docs updated + explicit fp8_e4m3(const sycl::marray &vals, const stochastic_seed &seed, + saturation s = saturation::finite); + explicit fp8_e4m3(const sycl::marray &vals, + const stochastic_seed &seed, + saturation s = saturation::finite); + explicit fp8_e4m3(const sycl::marray &vals, const stochastic_seed &seed, + saturation s = saturation::finite); + + // Construct from integer types. + // Available only when N==1. + + explicit fp8_e4m3(short val) { + assert(N == 1 && "fp8_e4m3: N must be 1 for short constructor"); + vals[0] = ConvertToFP8(val, rounding::to_even); + } + + explicit fp8_e4m3(int val) { + assert(N == 1 && "fp8_e4m3: N must be 1 for int constructor"); + vals[0] = ConvertToFP8(val, rounding::to_even); + } + + explicit fp8_e4m3(long val) { + assert(N == 1 && "fp8_e4m3: N must be 1 for long constructor"); + vals[0] = ConvertToFP8(val, rounding::to_even); + } + + explicit fp8_e4m3(long long val) { + assert(N == 1 && "fp8_e4m3: N must be 1 for long long constructor"); + vals[0] = ConvertToFP8(val, rounding::to_even); + } + + explicit fp8_e4m3(unsigned short val) { + assert(N == 1 && "fp8_e4m3: N must be 1 for unsigned short constructor"); + vals[0] = ConvertToFP8(val, rounding::to_even); + } + + explicit fp8_e4m3(unsigned int val) { + assert(N == 1 && "fp8_e4m3: N must be 1 for unsigned int constructor"); + vals[0] = ConvertToFP8(val, rounding::to_even); + } + + explicit fp8_e4m3(unsigned long val) { + assert(N == 1 && "fp8_e4m3: N must be 1 for unsigned long constructor"); + vals[0] = ConvertToFP8(val, rounding::to_even); + } + + explicit fp8_e4m3(unsigned long long val) { + assert(N == 1 && + "fp8_e4m3: N must be 1 for unsigned long long constructor"); + vals[0] = ConvertToFP8(val, rounding::to_even); + } + + // Assign (operator) from half, bfloat16, float, double, and integer types. + // Available only when N==1. + + fp8_e4m3 &operator=(sycl::half val) { + assert(N == 1 && "fp8_e4m3: N must be 1 for half assignment operator"); + vals[0] = ConvertToFP8(val, rounding::to_even); + return *this; + } + + fp8_e4m3 &operator=(bfloat16 val) { + assert(N == 1 && "fp8_e4m3: N must be 1 for bfloat16 assignment operator"); + vals[0] = ConvertToFP8(val, rounding::to_even); + return *this; + } + + fp8_e4m3 &operator=(float val) { + assert(N == 1 && "fp8_e4m3: N must be 1 for float assignment operator"); + vals[0] = ConvertToFP8(val, rounding::to_even); + return *this; + } + + fp8_e4m3 &operator=(double val) { + assert(N == 1 && "fp8_e4m3: N must be 1 for double assignment operator"); + vals[0] = ConvertBF16ToFP8(val, rounding::to_even); + return *this; + } + + fp8_e4m3 &operator=(short val) { + assert(N == 1 && "fp8_e4m3: N must be 1 for short assignment operator"); + vals[0] = ConvertToFP8(val, rounding::to_even); + return *this; + } + + fp8_e4m3 &operator=(int val) { + assert(N == 1 && "fp8_e4m3: N must be 1 for int assignment operator"); + vals[0] = ConvertToFP8(val, rounding::to_even); + return *this; + } + + fp8_e4m3 &operator=(long val) { + assert(N == 1 && "fp8_e4m3: N must be 1 for long assignment operator"); + vals[0] = ConvertToFP8(val, rounding::to_even); + return *this; + } + + fp8_e4m3 &operator=(long long val) { + assert(N == 1 && "fp8_e4m3: N must be 1 for long long assignment operator"); + vals[0] = ConvertToFP8(val, rounding::to_even); + return *this; + } + + fp8_e4m3 &operator=(unsigned short val) { + assert(N == 1 && + "fp8_e4m3: N must be 1 for unsigned short assignment operator"); + vals[0] = ConvertToFP8(val, rounding::to_even); + return *this; + } + + fp8_e4m3 &operator=(unsigned int val) { + assert(N == 1 && + "fp8_e4m3: N must be 1 for unsigned int assignment operator"); + vals[0] = ConvertToFP8(val, rounding::to_even); + return *this; + } + + fp8_e4m3 &operator=(unsigned long val) { + assert(N == 1 && + "fp8_e4m3: N must be 1 for unsigned long assignment operator"); + vals[0] = ConvertToFP8(val, rounding::to_even); + return *this; + } + + fp8_e4m3 &operator=(unsigned long long val) { + assert(N == 1 && + "fp8_e4m3: N must be 1 for unsigned long long assignment operator"); + vals[0] = ConvertToFP8(val, rounding::to_even); + return *this; + } + + // Convert to half, bfloat16, float, double. + // Available only when N==1. + + explicit operator half() const { + assert(N == 1 && "fp8_e4m3: N must be 1 for half conversion operator"); + return ConvertFromFP8(vals[0]); + } + + explicit operator bfloat16() const { + assert(N == 1 && "fp8_e4m3: N must be 1 for bfloat16 conversion operator"); + return ConvertBF16FromFP8(vals[0]); + } + explicit operator float() const { + assert(N == 1 && "fp8_e4m3: N must be 1 for float conversion operator"); + return ConvertFromFP8(vals[0]); + } + explicit operator double() const { + assert(N == 1 && "fp8_e4m3: N must be 1 for double conversion operator"); + return ConvertFromFP8(vals[0]); + } + + // Convert to integer types. + // Available only when N==1. + + explicit operator char() const { + assert(N == 1 && "fp8_e4m3: N must be 1 for char conversion operator"); + return ConvertFromFP8(vals[0]); + } + + explicit operator signed char() const { + assert(N == 1 && + "fp8_e4m3: N must be 1 for signed char conversion operator"); + return ConvertFromFP8(vals[0]); + } + + explicit operator short() const { + assert(N == 1 && "fp8_e4m3: N must be 1 for short conversion operator"); + return ConvertFromFP8(vals[0]); + } + + explicit operator int() const { + assert(N == 1 && "fp8_e4m3: N must be 1 for int conversion operator"); + return ConvertFromFP8(vals[0]); + } + + explicit operator long() const { + assert(N == 1 && "fp8_e4m3: N must be 1 for long conversion operator"); + return ConvertFromFP8(vals[0]); + } + + explicit operator long long() const { + assert(N == 1 && "fp8_e4m3: N must be 1 for long long conversion operator"); + return ConvertFromFP8(vals[0]); + } + explicit operator unsigned char() const { + assert(N == 1 && + "fp8_e4m3: N must be 1 for unsigned char conversion operator"); + return ConvertFromFP8(vals[0]); + } + + explicit operator unsigned short() const { + assert(N == 1 && + "fp8_e4m3: N must be 1 for unsigned short conversion operator"); + return ConvertFromFP8(vals[0]); + } + + explicit operator unsigned int() const { + assert(N == 1 && + "fp8_e4m3: N must be 1 for unsigned int conversion operator"); + return ConvertFromFP8(vals[0]); + } + + explicit operator unsigned long() const { + assert(N == 1 && + "fp8_e4m3: N must be 1 for unsigned long conversion operator"); + return ConvertFromFP8(vals[0]); + } + + explicit operator unsigned long long() const { + assert(N == 1 && + "fp8_e4m3: N must be 1 for unsigned long long conversion operator"); + return ConvertFromFP8(vals[0]); + } + + // Convert to bool + // Available only when N==1. + + explicit operator bool() const { + static_assert(N == 1, "fp8_e4m3: operator() requires size N=1"); +#ifdef __SYCL_DEVICE_ONLY__ + // detect +0 / -0 + sycl::half h = __builtin_spirv_ConvertE4M3ToFP16EXT(vals[0]); + return h != 0; +#else + // no need to convert, just check sign bit amd 0s + return vals[0] != 0 && vals[0] != 0x80; +#endif + } + + // Convert to marray of half, bfloat16, float + + explicit operator sycl::marray() const { + sycl::marray ret; + for (size_t i = 0; i < N; ++i) + ret[i] = ConvertFromFP8(vals[i]); + return ret; + } + + explicit operator sycl::marray() const { + sycl::marray ret; + for (size_t i = 0; i < N; ++i) + ret[i] = ConvertBF16FromFP8(vals[i]); + return ret; + } + + explicit operator sycl::marray() const { + sycl::marray ret; + for (size_t i = 0; i < N; ++i) + ret[i] = ConvertFromFP8(vals[i]); + return ret; + } + + // Intentionally public to allow access to the raw values. + uint8_t vals[N]; +}; + +template +class fp8_e5m2 { + + uint8_t ConvertToFP8(sycl::half h, rounding r) { +#ifdef __SYCL_DEVICE_ONLY__ + // TODO: optimize with vectorized builtin calls + const uint8_t sign = std::signbit(h) ? 0x80u : 0x00u; + const float ax = sycl::fabs(h); + + if (ax > MaxNormal) + return static_cast(sign | MaxFiniteCode); + + if (ax < MinSubnormal) + return sign; + + uint8_t b = __builtin_spirv_ConvertFP16ToE5M2EXT(h); + if (r == rounding::to_even) + return b; + + const sycl::half yi = __builtin_spirv_ConvertFP16ToE5M2EXT(b); + return round(r, b, yi, h); + +#else + return ConvertToFP8_CPU<5, 2, sycl::half>(h, r); +#endif + } + + uint8_t ConvertBF16ToFP8(bfloat16 h, rounding r) { +#ifdef __SYCL_DEVICE_ONLY__ + const uint8_t sign = std::signbit(h) ? 0x80u : 0x00u; + const float ax = sycl::fabs(h); + + if (ax > MaxNormal) + return static_cast(sign | MaxFiniteCode); + + if (ax < MinSubnormal) + return sign; + + uint8_t b = __builtin_spirv_ConvertBF16ToE5M2EXT(h); + if (r == rounding::to_even) + return b; + + const half yi = __builtin_spirv_ConvertBF16ToE5M2EXT(b); + return round(r, b, yi, h); +#else + return ConvertToFP8_CPU<5, 2, bfloat16>(h, r); +#endif + } + + template T ConvertFromFP8(uint8_t v) const { +#ifdef __SYCL_DEVICE_ONLY__ + sycl::half hi = __builtin_spirv_ConvertE5M2ToFP16EXT(v); + return static_cast(hi); +#else + return ConvertFromFP8_CPU<5, 2, T>(v); +#endif + } + + bfloat16 ConvertFP16FromFP8(uint8_t v) const { +#ifdef __SYCL_DEVICE_ONLY__ + return __builtin_spirv_ConvertE5M2ToBF16EXT(v); +#else + return ConvertFromFP8_CPU<5, 2, bfloat16>(v); +#endif + } + +public: + fp8_e5m2() = default; + fp8_e5m2(const fp8_e5m2 &) = default; + ~fp8_e5m2() = default; + fp8_e5m2 &operator=(const fp8_e5m2 &) = default; + + // Construct from pack of half, bfloat16, float, double. + // Available only when the size of the pack is equal to N. + + // Available only when each type in the pack is half. + + template , half> || + std::is_same_v, bfloat16> || + std::is_same_v, float> || + std::is_same_v, double>) && + ...))>> + explicit fp8_e5m2(Types... v) { +#ifdef __SYCL_DEVICE_ONLY__ + static_assert(N == 1 || N == 2, + "fp8_e5m2: Template argument N must be 1 or 2 on device"); +#endif + if constexpr (((std::is_same_v, bfloat16>) && ...)) { + const bfloat16 in[N] = {static_cast(v)...}; + for (size_t i = 0; i < N; ++i) + vals[i] = ConvertBF16ToFP8(in[i], rounding::to_even); + return; + } + const sycl::half in[N] = {v...}; + for (size_t i = 0; i < N; ++i) + vals[i] = ConvertToFP8(in[i], rounding::to_even); + } + + // Construct from an array of half, bfloat16, float, double. + + explicit fp8_e5m2(half const (&v)[N], rounding r = rounding::to_even) { +#ifdef __SYCL_DEVICE_ONLY__ + static_assert(N == 1 || N == 2, + "fp8_e5m2: Template argument N must be 1 or 2 on device"); +#else + CheckRoundingConstraints(r); +#endif + // TODO: optimize with vectorized builtin calls + for (size_t i = 0; i < N; ++i) + vals[i] = ConvertToFP8(v[i], r); + } + + explicit fp8_e5m2(bfloat16 const (&v)[N], rounding r = rounding::to_even) { +#ifdef __SYCL_DEVICE_ONLY__ + static_assert(N == 1 || N == 2, + "fp8_e5m2: Template argument N must be 1 or 2 on device"); +#else + CheckRoundingConstraints(r); +#endif + // TODO: optimize with vectorized builtin calls + for (size_t i = 0; i < N; ++i) + vals[i] = ConvertBF16ToFP8(v[i], r); + } + + explicit fp8_e5m2(float const (&v)[N], rounding r = rounding::to_even) { +#ifdef __SYCL_DEVICE_ONLY__ + static_assert(N == 1 || N == 2, + "fp8_e5m2: Template argument N must be 1 or 2 on device"); +#else + CheckRoundingConstraints(r); +#endif + for (size_t i = 0; i < N; ++i) + vals[i] = ConvertToFP8(v[i], r); + } + + explicit fp8_e5m2(double const (&v)[N]) { +#ifdef __SYCL_DEVICE_ONLY__ + static_assert(N == 1 || N == 2, + "fp8_e5m2: Template argument N must be 1 or 2 on device"); +#else + CheckRoundingConstraints(r); +#endif + for (size_t i = 0; i < N; ++i) + vals[i] = ConvertToFP8(v[i], rounding::to_even); + } + + // Construct from an marray of half, bfloat16, float, double. + + explicit fp8_e5m2(const sycl::marray &v, + rounding r = rounding::to_even) { +#ifdef __SYCL_DEVICE_ONLY__ + static_assert(N == 1 || N == 2, + "fp8_e5m2: Template argument N must be 1 or 2 on device"); +#else + CheckRoundingConstraints(r); +#endif + for (size_t i = 0; i < N; ++i) + vals[i] = ConvertToFP8(v[i], r); + } + + explicit fp8_e5m2(const sycl::marray &v, + rounding r = rounding::to_even) { +#ifdef __SYCL_DEVICE_ONLY__ + static_assert(N == 1 || N == 2, + "fp8_e5m2: Template argument N must be 1 or 2 on device"); +#else + CheckRoundingConstraints(r); +#endif + for (size_t i = 0; i < N; ++i) + vals[i] = ConvertBF16ToFP8(v[i], r); + } + + explicit fp8_e5m2(const sycl::marray &v, rounding r = rounding::to_even) { +#ifdef __SYCL_DEVICE_ONLY__ + static_assert(N == 1 || N == 2, + "fp8_e5m2: Template argument N must be 1 or 2 on device"); +#else + CheckRoundingConstraints(r); +#endif + for (size_t i = 0; i < N; ++i) + vals[i] = ConvertToFP8(v[i], r); + } + + explicit fp8_e5m2(const sycl::marray &v) { +#ifdef __SYCL_DEVICE_ONLY__ + static_assert(N == 1 || N == 2, + "fp8_e5m2: Template argument N must be 1 or 2 on device"); +#else + CheckRoundingConstraints(r); +#endif + for (size_t i = 0; i < N; ++i) + vals[i] = ConvertToFP8(v[i], rounding::to_even); + } + + // Construct with stochastic rounding with user provided seed from an array of + // half, bfloat16, float. + + // should be removed once docs updated + explicit fp8_e5m2(half const (&vals)[N], const stochastic_seed &seed, + saturation s = saturation::finite); + explicit fp8_e5m2(bfloat16 const (&vals)[N], const stochastic_seed &seed, + saturation s = saturation::finite); + explicit fp8_e5m2(double const (&vals)[N], const stochastic_seed &seed, + saturation s = saturation::finite); + + // Construct with stochastic rounding with user provided seed from an marray + // of half, bfloat16, float. + + // should be removed once docs updated + explicit fp8_e5m2(const sycl::marray &vals, const stochastic_seed &seed, + saturation s = saturation::finite); + explicit fp8_e5m2(const sycl::marray &vals, + const stochastic_seed &seed, + saturation s = saturation::finite); + explicit fp8_e5m2(const sycl::marray &vals, const stochastic_seed &seed, + saturation s = saturation::finite); + + // Construct from integer types. + // Available only when N==1. + + explicit fp8_e5m2(short val) { + assert(N == 1 && "fp8_e5m2: N must be 1 for short constructor"); + vals[0] = ConvertToFP8(val, rounding::to_even); + } + + explicit fp8_e5m2(int val) { + assert(N == 1 && "fp8_e5m2: N must be 1 for int constructor"); + vals[0] = ConvertToFP8(val, rounding::to_even); + } + + explicit fp8_e5m2(long val) { + assert(N == 1 && "fp8_e5m2: N must be 1 for long constructor"); + vals[0] = ConvertToFP8(val, rounding::to_even); + } + + explicit fp8_e5m2(long long val) { + assert(N == 1 && "fp8_e5m2: N must be 1 for long long constructor"); + vals[0] = ConvertToFP8(val, rounding::to_even); + } + + explicit fp8_e5m2(unsigned short val) { + assert(N == 1 && "fp8_e5m2: N must be 1 for unsigned short constructor"); + vals[0] = ConvertToFP8(val, rounding::to_even); + } + + explicit fp8_e5m2(unsigned int val) { + assert(N == 1 && "fp8_e5m2: N must be 1 for unsigned int constructor"); + vals[0] = ConvertToFP8(val, rounding::to_even); + } + + explicit fp8_e5m2(unsigned long val) { + assert(N == 1 && "fp8_e5m2: N must be 1 for unsigned long constructor"); + vals[0] = ConvertToFP8(val, rounding::to_even); + } + + explicit fp8_e5m2(unsigned long long val) { + assert(N == 1 && + "fp8_e5m2: N must be 1 for unsigned long long constructor"); + vals[0] = ConvertToFP8(val, rounding::to_even); + } + + // Assign (operator) from half, bfloat16, float, double, and integer types. + // Available only when N==1. + + fp8_e5m2 &operator=(sycl::half val) { + assert(N == 1 && "fp8_e5m2: N must be 1 for half assignment operator"); + vals[0] = ConvertToFP8(val, rounding::to_even); + return *this; + } + + fp8_e5m2 &operator=(bfloat16 val) { + assert(N == 1 && "fp8_e5m2: N must be 1 for half bfloat16 operator"); + vals[0] = ConvertBF16ToFP8(val, rounding::to_even); + return *this; + } + + fp8_e5m2 &operator=(float val) { + assert(N == 1 && "fp8_e5m2: N must be 1 for float assignment operator"); + vals[0] = ConvertToFP8(val, rounding::to_even); + return *this; + } + + fp8_e5m2 &operator=(double val) { + assert(N == 1 && "fp8_e5m2: N must be 1 for double assignment operator"); + vals[0] = ConvertToFP8(val, rounding::to_even); + return *this; + } + + fp8_e5m2 &operator=(short val) { + assert(N == 1 && "fp8_e5m2: N must be 1 for short assignment operator"); + vals[0] = ConvertToFP8(val, rounding::to_even); + return *this; + } + + fp8_e5m2 &operator=(int val) { + assert(N == 1 && "fp8_e5m2: N must be 1 for int assignment operator"); + vals[0] = ConvertToFP8(val, rounding::to_even); + return *this; + } + + fp8_e5m2 &operator=(long val) { + assert(N == 1 && "fp8_e5m2: N must be 1 for long assignment operator"); + vals[0] = ConvertToFP8(val, rounding::to_even); + return *this; + } + + fp8_e5m2 &operator=(long long val) { + assert(N == 1 && "fp8_e5m2: N must be 1 for long long assignment operator"); + vals[0] = ConvertToFP8(val, rounding::to_even); + return *this; + } + + fp8_e5m2 &operator=(unsigned short val) { + assert(N == 1 && + "fp8_e5m2: N must be 1 for unsigned short assignment operator"); + vals[0] = ConvertToFP8(val, rounding::to_even); + return *this; + } + + fp8_e5m2 &operator=(unsigned int val) { + assert(N == 1 && + "fp8_e5m2: N must be 1 for unsigned int assignment operator"); + vals[0] = ConvertToFP8(val, rounding::to_even); + return *this; + } + + fp8_e5m2 &operator=(unsigned long val) { + assert(N == 1 && + "fp8_e5m2: N must be 1 for unsigned long assignment operator"); + vals[0] = ConvertToFP8(val, rounding::to_even); + return *this; + } + + fp8_e5m2 &operator=(unsigned long long val) { + assert(N == 1 && + "fp8_e5m2: N must be 1 for unsigned long long assignment operator"); + vals[0] = ConvertToFP8(val, rounding::to_even); + return *this; + } + + // Convert to half, bfloat16, float, double. + // Available only when N==1. + + explicit operator half() const { + assert(N == 1 && "fp8_e5m2: N must be 1 for half conversion operator"); + return ConvertFromFP8(vals[0]); + } + + explicit operator bfloat16() const { + assert(N == 1 && "fp8_e5m2: N must be 1 for bfloat16 conversion operator"); + return ConvertFP16FromFP8(vals[0]); + } + + explicit operator float() const { + assert(N == 1 && "fp8_e5m2: N must be 1 for float conversion operator"); + return ConvertFromFP8(vals[0]); + } + + explicit operator double() const { + assert(N == 1 && "fp8_e5m2: N must be 1 for double conversion operator"); + return ConvertFromFP8(vals[0]); + } + + // Convert to integer types. + // Available only when N==1. + + explicit operator char() const { + assert(N == 1 && "fp8_e5m2: N must be 1 for char conversion operator"); + return ConvertFromFP8(vals[0]); + } + + explicit operator signed char() const { + assert(N == 1 && + "fp8_e5m2: N must be 1 for signed char conversion operator"); + return ConvertFromFP8(vals[0]); + } + + explicit operator short() const { + assert(N == 1 && "fp8_e5m2: N must be 1 for short conversion operator"); + return ConvertFromFP8(vals[0]); + } + + explicit operator int() const { + assert(N == 1 && "fp8_e5m2: N must be 1 for int conversion operator"); + return ConvertFromFP8(vals[0]); + } + + explicit operator long() const { + assert(N == 1 && "fp8_e5m2: N must be 1 for long conversion operator"); + return ConvertFromFP8(vals[0]); + } + + explicit operator long long() const { + assert(N == 1 && "fp8_e5m2: N must be 1 for long long conversion operator"); + return ConvertFromFP8(vals[0]); + } + + explicit operator unsigned char() const { + assert(N == 1 && + "fp8_e5m2: N must be 1 for unsigned char conversion operator"); + return ConvertFromFP8(vals[0]); + } + + explicit operator unsigned short() const { + assert(N == 1 && + "fp8_e5m2: N must be 1 for unsigned short conversion operator"); + return ConvertFromFP8(vals[0]); + } + + explicit operator unsigned int() const { + assert(N == 1 && + "fp8_e5m2: N must be 1 for unsigned int conversion operator"); + return ConvertFromFP8(vals[0]); + } + + explicit operator unsigned long() const { + assert(N == 1 && + "fp8_e5m2: N must be 1 for unsigned long conversion operator"); + return ConvertFromFP8(vals[0]); + } + + explicit operator unsigned long long() const { + assert(N == 1 && + "fp8_e5m2: N must be 1 for unsigned long long conversion operator"); + return ConvertFromFP8(vals[0]); + } + + // Convert to bool + // Available only when N==1. + + explicit operator bool() const { + static_assert(N == 1, "fp8_e5m2: operator() requires size N=1"); + // false iff +0 or -0; otherwise true. + return vals[0] != 0x00 && vals[0] != 0x80; + } + + explicit operator sycl::marray() const { + sycl::marray out; + for (size_t i = 0; i < N; ++i) + out[i] = ConvertFromFP8(vals[i]); + return out; + } + explicit operator sycl::marray() const { + sycl::marray out; + for (size_t i = 0; i < N; ++i) + out[i] = ConvertFP16FromFP8(vals[i]); + return out; + } + explicit operator sycl::marray() const { + sycl::marray out; + for (size_t i = 0; i < N; ++i) + out[i] = ConvertFromFP8(vals[i]); + return out; + } + + // Intentionally public to allow access to the raw values. + + uint8_t vals[N]; +}; + +static inline uint8_t ConvertToE8M0_CPU(float x, rounding R, + saturation S) noexcept { + // E8M0: unsigned 8-bit exponent code, bias 127. + // Code 0xFF reserved for NaN. No Inf, no subnormals, no signed zero. + constexpr int Bias = 127; + constexpr int Emin = -127; + constexpr int Emax = 127; + constexpr uint8_t NaNCode = 0xFF; + constexpr uint8_t MaxFiniteCode = 0xFE; + + if (std::isnan(x)) + return NaNCode; + + // No sign bit: negative inputs are treated as their magnitude. + float ax = std::fabs(x); + + // Infinity handling: depends on saturation. + if (std::isinf(ax)) + return (S == saturation::finite) ? MaxFiniteCode : NaNCode; + + // Zero and underflow: map to min normal (code 0). + // Min normal = 2^-127. + const float min_normal = std::ldexp(1.0f, Emin); + if (ax == 0.0f || ax < min_normal) + return 0x00; + + // Overflow and "too large": clamp or NaN depending on saturation. + const float max_normal = std::ldexp(1.0f, Emax); // 2^127 + if (ax >= max_normal) + return (S == saturation::finite) ? MaxFiniteCode : NaNCode; + + // Determine E such that 2^E <= ax < 2^(E+1). + int e2 = 0; + float m = std::frexp(ax, &e2); // ax = m * 2^e2, m in [0.5, 1) + int E = e2 - 1; + + // With no mantissa, representables are exact powers of two. + // Choose between 2^E and 2^(E+1) based on rounding mode. + const bool is_exact_power_of_two = (m == 0.5f); + + switch (R) { + case rounding::upward: + // toward +inf; with no sign, this is "ceil in magnitude". + if (!is_exact_power_of_two && E < Emax) + ++E; + break; + case rounding::downward: + case rounding::toward_zero: + // toward -inf / toward 0: both pick the lower power for non-exact. + break; + case rounding::to_even: + default: { + if (!is_exact_power_of_two) { + // Nearest of {2^E, 2^(E+1)} w/ ties-to-even (even exponent on tie). + float lo = std::ldexp(1.0f, E); + float hi = std::ldexp(1.0f, E + 1); + float dlo = ax - lo; + float dhi = hi - ax; + if (dhi < dlo) { + if (E < Emax) + ++E; + } else if (dhi == dlo) { + // tie -> even exponent + if ((E & 1) != 0 && E < Emax) + ++E; + } + } + break; + } + } + + if (E < Emin) + E = Emin; + if (E > Emax) + E = Emax; + + uint8_t code = static_cast(E + Bias); // 0..254 + return code; +} + +template +static inline ToT ConvertFromE8M0_CPU(uint8_t code) noexcept { + constexpr int Bias = 127; + if (code == 0xFF) { + float qn = std::numeric_limits::quiet_NaN(); + return static_cast(qn); + } + int E = static_cast(code) - Bias; // includes code==0 -> -127 + float v = std::ldexp(1.0f, E); + return ConvertFloatToTarget(v, rounding::to_even); +} + +template +class fp8_e8m0 { +public: + fp8_e8m0() = default; + fp8_e8m0(const fp8_e8m0 &) = default; + ~fp8_e8m0() = default; + fp8_e8m0 &operator=(const fp8_e8m0 &) = default; + + template , half> || + std::is_same_v, bfloat16> || + std::is_same_v, float> || + std::is_same_v, double>) && + ...))>> + explicit fp8_e8m0(Types... v) { +#ifdef __SYCL_DEVICE_ONLY__ + static_assert(N == 1 || N == 2, + "fp8_e8m0: Template argument N must be 1 or 2 on device"); +#endif + using InT = std::common_type_t...>; + const InT in[N] = {v...}; + for (size_t i = 0; i < N; ++i) + vals[i] = ConvertToE8M0_CPU(static_cast(in[i]), rounding::upward, + saturation::finite); + } + + explicit fp8_e8m0(half const (&in)[N], rounding r = rounding::upward, + saturation s = saturation::finite) { +#ifdef __SYCL_DEVICE_ONLY__ + static_assert(N == 1 || N == 2, + "fp8_e8m0: Template argument N must be 1 or 2 on device"); +#endif + for (size_t i = 0; i < N; ++i) + vals[i] = ConvertToE8M0_CPU(static_cast(in[i]), r, s); + } + explicit fp8_e8m0(bfloat16 const (&in)[N], rounding r = rounding::upward, + saturation s = saturation::finite) { +#ifdef __SYCL_DEVICE_ONLY__ + static_assert(N == 1 || N == 2, + "fp8_e8m0: Template argument N must be 1 or 2 on device"); +#endif + for (size_t i = 0; i < N; ++i) + vals[i] = ConvertToE8M0_CPU(static_cast(in[i]), r, s); + } + explicit fp8_e8m0(float const (&in)[N], rounding r = rounding::upward, + saturation s = saturation::finite) { +#ifdef __SYCL_DEVICE_ONLY__ + static_assert(N == 1 || N == 2, + "fp8_e8m0: Template argument N must be 1 or 2 on device"); +#endif + for (size_t i = 0; i < N; ++i) + vals[i] = ConvertToE8M0_CPU(in[i], r, s); + } + explicit fp8_e8m0(double const (&in)[N]) { +#ifdef __SYCL_DEVICE_ONLY__ + static_assert(N == 1 || N == 2, + "fp8_e8m0: Template argument N must be 1 or 2 on device"); +#endif + for (size_t i = 0; i < N; ++i) + vals[i] = ConvertToE8M0_CPU(static_cast(in[i]), + rounding::upward, saturation::finite); + } + + explicit fp8_e8m0(const marray &vals, rounding r = rounding::upward, + saturation s = saturation::finite) { +#ifdef __SYCL_DEVICE_ONLY__ + static_assert(N == 1 || N == 2, + "fp8_e8m0: Template argument N must be 1 or 2 on device"); + assert((r == rounding::upward && s == saturation::finite) && + "fp8_e8m0: device supports rounding::upward and saturation::finite " + "only"); +#endif + for (size_t i = 0; i < N; ++i) + this->vals[i] = ConvertToE8M0_CPU(static_cast(vals[i]), r, s); + } + + explicit fp8_e8m0(const marray &vals, + rounding r = rounding::upward, + saturation s = saturation::finite) { +#ifdef __SYCL_DEVICE_ONLY__ + static_assert(N == 1 || N == 2, + "fp8_e8m0: Template argument N must be 1 or 2 on device"); + assert((r == rounding::upward && s == saturation::finite) && + "fp8_e8m0: device supports rounding::upward and saturation::finite " + "only"); +#endif + for (size_t i = 0; i < N; ++i) + this->vals[i] = ConvertToE8M0_CPU(static_cast(vals[i]), r, s); + } + + explicit fp8_e8m0(const marray &vals, rounding r = rounding::upward, + saturation s = saturation::finite) { +#ifdef __SYCL_DEVICE_ONLY__ + static_assert(N == 1 || N == 2, + "fp8_e8m0: Template argument N must be 1 or 2 on device"); + assert((r == rounding::upward && s == saturation::finite) && + "fp8_e8m0: device supports rounding::upward and saturation::finite " + "only"); +#endif + for (size_t i = 0; i < N; ++i) + this->vals[i] = ConvertToE8M0_CPU(vals[i], r, s); + } + + explicit fp8_e8m0(const marray &vals) { +#ifdef __SYCL_DEVICE_ONLY__ + static_assert(N == 1 || N == 2, + "fp8_e8m0: Template argument N must be 1 or 2 on device"); +#endif + for (size_t i = 0; i < N; ++i) + this->vals[i] = ConvertToE8M0_CPU(static_cast(vals[i]), + rounding::upward, saturation::finite); + } + + // Construct with stochastic rounding with user provided seed from an array of + // half, bfloat16, float. + + // should be removed once docs updated + explicit fp8_e8m0(half const (&vals)[N], const stochastic_seed &seed, + saturation s = saturation::finite); + explicit fp8_e8m0(bfloat16 const (&vals)[N], const stochastic_seed &seed, + saturation s = saturation::finite); + explicit fp8_e8m0(double const (&vals)[N], const stochastic_seed &seed, + saturation s = saturation::finite); + + // Construct with stochastic rounding with user provided seed from an marray + // of half, bfloat16, float. + + // should be removed once docs updated + explicit fp8_e8m0(const sycl::marray &vals, const stochastic_seed &seed, + saturation s = saturation::finite); + explicit fp8_e8m0(const sycl::marray &vals, + const stochastic_seed &seed, + saturation s = saturation::finite); + explicit fp8_e8m0(const sycl::marray &vals, const stochastic_seed &seed, + saturation s = saturation::finite); + + // Construct from integer types. + // Available only when N==1. + + explicit fp8_e8m0(short val) { + assert(N == 1 && "fp8_e8m0: N must be 1 for short constructor"); + vals[0] = ConvertToE8M0_CPU(static_cast(val), rounding::upward, + saturation::finite); + } + explicit fp8_e8m0(int val) : fp8_e8m0(static_cast(val)) {} + explicit fp8_e8m0(long val) : fp8_e8m0(static_cast(val)) {} + explicit fp8_e8m0(long long val) : fp8_e8m0(static_cast(val)) {} + explicit fp8_e8m0(unsigned short val) : fp8_e8m0(static_cast(val)) {} + explicit fp8_e8m0(unsigned int val) : fp8_e8m0(static_cast(val)) {} + explicit fp8_e8m0(unsigned long val) : fp8_e8m0(static_cast(val)) {} + explicit fp8_e8m0(unsigned long long val) : fp8_e8m0(static_cast(val)) {} + + fp8_e8m0 &operator=(half val) { + static_assert(N == 1, "fp8_e8m0: N must be 1 for scalar assignment"); + vals[0] = ConvertToE8M0_CPU(static_cast(val), rounding::upward, + saturation::finite); + return *this; + } + fp8_e8m0 &operator=(bfloat16 val) { + static_assert(N == 1, "fp8_e8m0: N must be 1 for scalar assignment"); + vals[0] = ConvertToE8M0_CPU(static_cast(val), rounding::upward, + saturation::finite); + return *this; + } + fp8_e8m0 &operator=(float val) { + static_assert(N == 1, "fp8_e8m0: N must be 1 for scalar assignment"); + vals[0] = ConvertToE8M0_CPU(val, rounding::upward, saturation::finite); + return *this; + } + fp8_e8m0 &operator=(double val) { return (*this = static_cast(val)); } + fp8_e8m0 &operator=(short val) { return (*this = static_cast(val)); } + fp8_e8m0 &operator=(int val) { return (*this = static_cast(val)); } + fp8_e8m0 &operator=(long val) { return (*this = static_cast(val)); } + fp8_e8m0 &operator=(long long val) { return (*this = static_cast(val)); } + fp8_e8m0 &operator=(unsigned short val) { + return (*this = static_cast(val)); + } + fp8_e8m0 &operator=(unsigned int val) { return (*this = static_cast(val)); } + fp8_e8m0 &operator=(unsigned long val) { + return (*this = static_cast(val)); + } + fp8_e8m0 &operator=(unsigned long long val) { + return (*this = static_cast(val)); + } + + explicit operator half() const { + static_assert(N == 1, "fp8_e8m0: N must be 1 for scalar conversion"); + return ConvertFromE8M0_CPU(vals[0]); + } + explicit operator bfloat16() const { + static_assert(N == 1, "fp8_e8m0: N must be 1 for scalar conversion"); + return ConvertFromE8M0_CPU(vals[0]); + } + explicit operator float() const { + static_assert(N == 1, "fp8_e8m0: N must be 1 for scalar conversion"); + return ConvertFromE8M0_CPU(vals[0]); + } + explicit operator double() const { + static_assert(N == 1, "fp8_e8m0: N must be 1 for scalar conversion"); + return ConvertFromE8M0_CPU(vals[0]); + } + + explicit operator char() const { + static_assert(N == 1, "fp8_e8m0: N must be 1 for scalar conversion"); + return static_cast(static_cast(*this)); + } + explicit operator signed char() const { return static_cast(static_cast(*this)); } + explicit operator short() const { return static_cast(static_cast(*this)); } + explicit operator int() const { return static_cast(static_cast(*this)); } + explicit operator long() const { return static_cast(static_cast(*this)); } + explicit operator long long() const { return static_cast(static_cast(*this)); } + explicit operator unsigned char() const { return static_cast(static_cast(*this)); } + explicit operator unsigned short() const { return static_cast(static_cast(*this)); } + explicit operator unsigned int() const { return static_cast(static_cast(*this)); } + explicit operator unsigned long() const { return static_cast(static_cast(*this)); } + explicit operator unsigned long long() const { return static_cast(static_cast(*this)); } + + explicit operator bool() const { + static_assert(N == 1, "fp8_e8m0: operator bool requires size N=1"); + return true; + } + + explicit operator sycl::marray() const { + sycl::marray out; + for (size_t i = 0; i < N; ++i) + out[i] = ConvertFromE8M0_CPU(vals[i]); + return out; + } + explicit operator sycl::marray() const { + sycl::marray out; + for (size_t i = 0; i < N; ++i) + out[i] = ConvertFromE8M0_CPU(vals[i]); + return out; + } + explicit operator sycl::marray() const { + sycl::marray out; + for (size_t i = 0; i < N; ++i) + out[i] = ConvertFromE8M0_CPU(vals[i]); + return out; + } + + // Intentionally public to allow access to the raw values. + + uint8_t vals[N]; +}; + +#endif // __SYCL_TARGET_INTEL_GPU_CRI__ + +} // namespace ext::oneapi::experimental +} // namespace _V1 +} // namespace sycl \ No newline at end of file diff --git a/sycl/unittests/Extensions/CMakeLists.txt b/sycl/unittests/Extensions/CMakeLists.txt index 63f527b245f48..63730d56ed088 100644 --- a/sycl/unittests/Extensions/CMakeLists.txt +++ b/sycl/unittests/Extensions/CMakeLists.txt @@ -36,3 +36,4 @@ add_subdirectory(FreeFunctionCommands) add_subdirectory(KernelQueries) add_subdirectory(InterProcessCommunication) add_subdirectory(DeviceIndex) +add_subdirectory(fp8) diff --git a/sycl/unittests/Extensions/fp8/CMakeLists.txt b/sycl/unittests/Extensions/fp8/CMakeLists.txt new file mode 100644 index 0000000000000..2d0c53daf4268 --- /dev/null +++ b/sycl/unittests/Extensions/fp8/CMakeLists.txt @@ -0,0 +1,9 @@ +add_sycl_unittest(FP8TypesTests OBJECT + fp8_e4m3.cpp + fp8_e5m2.cpp + fp8_e8m0.cpp +) + +target_compile_options(FP8TypesTests_Preview_Tests PUBLIC -D__SYCL_TARGET_INTEL_GPU_CRI__) +target_compile_options(FP8TypesTests_Non_Preview_Tests PUBLIC -D__SYCL_TARGET_INTEL_GPU_CRI__) + diff --git a/sycl/unittests/Extensions/fp8/fp8_e4m3.cpp b/sycl/unittests/Extensions/fp8/fp8_e4m3.cpp new file mode 100644 index 0000000000000..41cc44e881de3 --- /dev/null +++ b/sycl/unittests/Extensions/fp8/fp8_e4m3.cpp @@ -0,0 +1,699 @@ +#include +#include + +#include +#include +#include + +using namespace sycl::ext::oneapi::experimental; + +TEST(FP8E4M3Test, VariadicConstructorHalf) { + fp8_e4m3<2> a(sycl::half(1.0f), sycl::half(2.0f)); + + EXPECT_EQ(sizeof(a.vals), 2u); + EXPECT_EQ(a.vals[0], 0x38); // 1.0 -> 0b0_0111_000 + EXPECT_EQ(a.vals[1], 0x40); // 2.0 -> 0b0_1000_000 + + fp8_e4m3<1> b(sycl::half(1.1f)); + EXPECT_EQ(sizeof(b.vals), 1u); + EXPECT_EQ(b.vals[0], 0x39); // 1.1 rounds to 1.125 -> frac=1 +} + +TEST(FP8E4M3Test, VariadicConstructorBFloat16) { + fp8_e4m3<2> a(sycl::ext::oneapi::bfloat16(1.0f), sycl::ext::oneapi::bfloat16(2.0f)); + + EXPECT_EQ(sizeof(a.vals), 2u); + EXPECT_EQ(a.vals[0], 0x38); + EXPECT_EQ(a.vals[1], 0x40); + + fp8_e4m3<1> b(sycl::ext::oneapi::bfloat16(1.1f)); + EXPECT_EQ(sizeof(b.vals), 1u); + EXPECT_EQ(b.vals[0], 0x39); +} + +TEST(FP8E4M3Test, VariadicConstructorFloat) { + fp8_e4m3<2> a(1.0f, 2.0f); + + EXPECT_EQ(sizeof(a.vals), 2u); + EXPECT_EQ(a.vals[0], 0x38); + EXPECT_EQ(a.vals[1], 0x40); + + fp8_e4m3<1> b(1.1f); + EXPECT_EQ(sizeof(b.vals), 1u); + EXPECT_EQ(b.vals[0], 0x39); +} + +TEST(FP8E4M3Test, VariadicBoundaryEncodingsFloat) { + // CPU host path: variadic constructors use rounding::to_even and saturation::finite. + fp8_e4m3<6> a( + 448.0f, // max normal -> S.1111.110 + 0.015625f, // min normal -> S.0001.000 (2^-6) + 0.013671875f, // max subnorm -> S.0000.111 (0.875 * 2^-6) + 0.001953125f, // min subnorm -> S.0000.001 (2^-9) + 0.0f, // +0 + -0.0f // -0 + ); + + EXPECT_EQ(sizeof(a.vals), 6u); + + EXPECT_EQ(a.vals[0], 0x7E); // +448.0 -> 0b0_1111_110 + EXPECT_EQ(a.vals[1], 0x08); // +2^-6 -> 0b0_0001_000 + EXPECT_EQ(a.vals[2], 0x07); // +max subnorm -> 0b0_0000_111 + EXPECT_EQ(a.vals[3], 0x01); // +min subnorm -> 0b0_0000_001 + EXPECT_EQ(a.vals[4], 0x00); // +0 -> 0b0_0000_000 + EXPECT_EQ(a.vals[5], 0x80); // -0 -> 0b1_0000_000 +} + +TEST(FP8E4M3Test, VariadicNaNEncodingFloat) { + // NaN is encoded as S.1111.111; sign is permitted. + fp8_e4m3<2> a(std::numeric_limits::quiet_NaN(), + -std::numeric_limits::quiet_NaN()); + + EXPECT_EQ(sizeof(a.vals), 2u); + EXPECT_EQ(a.vals[0], 0x7F); // +NaN -> 0b0_1111_111 + EXPECT_EQ(a.vals[1], 0xFF); // -NaN -> 0b1_1111_111 +} + +TEST(FP8E4M3Test, IntegerToEvenFiniteAndSize) { + // Integer constructors: to_even + finite saturation (CPU). + fp8_e4m3<1> a0(0); + fp8_e4m3<1> a1(1); + fp8_e4m3<1> a2(2); + fp8_e4m3<1> an1(-1); + + EXPECT_EQ(sizeof(a0.vals), 1u); + EXPECT_EQ(sizeof(a1.vals), 1u); + EXPECT_EQ(sizeof(a2.vals), 1u); + EXPECT_EQ(sizeof(an1.vals), 1u); + + EXPECT_EQ(a0.vals[0], 0x00); // +0 + EXPECT_EQ(a1.vals[0], 0x38); // +1.0 -> 0b0_0111_000 + EXPECT_EQ(a2.vals[0], 0x40); // +2.0 -> 0b0_1000_000 + EXPECT_EQ(an1.vals[0], 0xB8); // -1.0 -> sign set: 0b1_0111_000 +} + +TEST(FP8E4M3Test, AssignmentOperatorToEvenFiniteAndSize) { + // operator= from scalar: to_even + finite saturation (CPU). + fp8_e4m3<1> a(0.0f); + EXPECT_EQ(sizeof(a.vals), 1u); + EXPECT_EQ(a.vals[0], 0x00); + + a = 1.0f; + EXPECT_EQ(a.vals[0], 0x38); + + a = -2.0f; + EXPECT_EQ(a.vals[0], 0xC0); // -2.0 -> 0b1_1000_000 + + a = 0.015625f; // min normal + EXPECT_EQ(a.vals[0], 0x08); +} + +TEST(FP8E4M3Test, FloatingPointConversionOperators) { + // Floating-point operators: convert stored fp8 to the respective type. + fp8_e4m3<1> one(1.0f); + fp8_e4m3<1> zero_pos(0.0f); + fp8_e4m3<1> zero_neg(-0.0f); + fp8_e4m3<1> min_norm(0.015625f); + + EXPECT_EQ(sizeof(one.vals), 1u); + EXPECT_EQ(one.vals[0], 0x38); + + float f1 = static_cast(one); + float fz = static_cast(zero_pos); + float fnz = static_cast(zero_neg); + float fmn = static_cast(min_norm); + + EXPECT_EQ(f1, 1.0f); + EXPECT_EQ(fz, 0.0f); + // -0.0 compares equal to +0.0; check signbit to validate negative zero survives. + EXPECT_EQ(fnz, 0.0f); + EXPECT_TRUE(std::signbit(fnz)); + + EXPECT_EQ(fmn, 0.015625f); +} + +TEST(FP8E4M3Test, IntegerConversionOperatorsTowardZero) { + // Integer operators: convert using rounding::toward_zero. + fp8_e4m3<1> p(1.5f); // 1.5 exactly representable: 0b0_0111_100 (0x3C) + fp8_e4m3<1> n(-1.5f); // 0xBC + + EXPECT_EQ(sizeof(p.vals), 1u); + EXPECT_EQ(sizeof(n.vals), 1u); + EXPECT_EQ(p.vals[0], 0x3C); + EXPECT_EQ(n.vals[0], 0xBC); + + int ip = static_cast(p); + int in = static_cast(n); + + EXPECT_EQ(ip, 1); // toward zero + EXPECT_EQ(in, -1); // toward zero +} + +TEST(FP8E4M3Test, BoolOperatorZeroRules) { + // bool operator: false iff +0 or -0; otherwise true. + fp8_e4m3<1> zp(0.0f); + fp8_e4m3<1> zn(-0.0f); + fp8_e4m3<1> one(1.0f); + fp8_e4m3<1> sub(0.001953125f); // min subnormal + + EXPECT_EQ(sizeof(zp.vals), 1u); + EXPECT_EQ(sizeof(zn.vals), 1u); + EXPECT_EQ(sizeof(one.vals), 1u); + EXPECT_EQ(sizeof(sub.vals), 1u); + + EXPECT_FALSE(static_cast(zp)); + EXPECT_FALSE(static_cast(zn)); + EXPECT_TRUE(static_cast(one)); + EXPECT_TRUE(static_cast(sub)); +} + +TEST(FP8E4M3Test, VariadicSaturatesFinite) { + // Variadic constructors: to_even + finite saturation (CPU). + fp8_e4m3<4> a( + 1.0f, + 1000.0f, // above max normal: clamp to +448 + -1000.0f, // clamp to -448 + -0.0f); + + EXPECT_EQ(sizeof(a.vals), 4u); + EXPECT_EQ(a.vals[0], 0x38); + EXPECT_EQ(a.vals[1], 0x7E); // +max normal + EXPECT_EQ(a.vals[2], 0xFE); // -max normal + EXPECT_EQ(a.vals[3], 0x80); // -0 +} + +TEST(FP8E4M3Test, VariadicToEvenTie) { + // Tie case: between 1.0 (0x38) and 1.125 (0x39) is 1.0625 exactly. + // to_even => choose 1.0 because its LSB (fraction) is even (0). + fp8_e4m3<2> a(1.0625f, -1.0625f); + EXPECT_EQ(sizeof(a.vals), 2u); + EXPECT_EQ(a.vals[0], 0x38); + EXPECT_EQ(a.vals[1], 0xB8); +} + +TEST(FP8E4M3Test, CArrayFloatHostToEvenFinite) { + // Host code supports only rounding::to_even and saturation::finite. + const float in[5] = {1.0f, 1.1f, 1.0625f, 1000.0f, -0.0f}; + fp8_e4m3<5> a(in); + + EXPECT_EQ(sizeof(a.vals), 5u); + EXPECT_EQ(a.vals[0], 0x38); // 1.0 + EXPECT_EQ(a.vals[1], 0x39); // 1.1 -> 1.125 + EXPECT_EQ(a.vals[2], 0x38); // tie -> to_even => 1.0 + EXPECT_EQ(a.vals[3], 0x7E); // finite saturation => +448 + EXPECT_EQ(a.vals[4], 0x80); // -0 +} + +TEST(FP8E4M3Test, CArrayDoubleToEvenFinite) { + // Double c-array: to_even + finite saturation. + const double in[6] = {448.0, 449.0, 0.015625, 0.013671875, 0.001953125, std::numeric_limits::quiet_NaN()}; + fp8_e4m3<6> a(in); + + EXPECT_EQ(sizeof(a.vals), 6u); + EXPECT_EQ(a.vals[0], 0x7E); // +448 + EXPECT_EQ(a.vals[1], 0x7E); // 449 -> clamp to +448 + EXPECT_EQ(a.vals[2], 0x08); // min normal + EXPECT_EQ(a.vals[3], 0x07); // max subnormal + EXPECT_EQ(a.vals[4], 0x01); // min subnormal + EXPECT_EQ(a.vals[5], 0x7F); // NaN +} + +TEST(FP8E4M3Test, CArrayHalfHostToEvenFinite) { + // Host code supports only rounding::to_even and saturation::finite. + const sycl::half in[6] = {sycl::half(448.0f), sycl::half(449.0f), + sycl::half(0.015625f), sycl::half(0.013671875f), + sycl::half(0.001953125f), sycl::half(-0.0f)}; + fp8_e4m3<6> a(in); + + EXPECT_EQ(sizeof(a.vals), 6u); + EXPECT_EQ(a.vals[0], 0x7E); // +448 + EXPECT_EQ(a.vals[1], 0x7E); // 449 -> clamp to +448 + EXPECT_EQ(a.vals[2], 0x08); // min normal + EXPECT_EQ(a.vals[3], 0x07); // max subnormal + EXPECT_EQ(a.vals[4], 0x01); // min subnormal + EXPECT_EQ(a.vals[5], 0x80); // -0 +} + +TEST(FP8E4M3Test, CArrayBFloat16HostToEvenFinite) { + // Host code supports only rounding::to_even and saturation::finite. + const sycl::ext::oneapi::bfloat16 in[6] = { + sycl::ext::oneapi::bfloat16(448.0f), + sycl::ext::oneapi::bfloat16(449.0f), + sycl::ext::oneapi::bfloat16(0.015625f), + sycl::ext::oneapi::bfloat16(0.013671875f), + sycl::ext::oneapi::bfloat16(0.001953125f), + sycl::ext::oneapi::bfloat16(-0.0f)}; + fp8_e4m3<6> a(in); + + EXPECT_EQ(sizeof(a.vals), 6u); + EXPECT_EQ(a.vals[0], 0x7E); // +448 + EXPECT_EQ(a.vals[1], 0x7E); // 449 -> clamp to +448 + EXPECT_EQ(a.vals[2], 0x08); // min normal + EXPECT_EQ(a.vals[3], 0x07); // max subnormal + EXPECT_EQ(a.vals[4], 0x01); // min subnormal + EXPECT_EQ(a.vals[5], 0x80); // -0 +} + +TEST(FP8E4M3Test, MarrayAndOperatorsHostAllN) { + // marray constructors/operators: host supports all N. + sycl::marray in = {1.0f, 2.0f, 0.0f, -0.0f, 448.0f, 1000.0f, 0.001953125f, -1.5f}; + fp8_e4m3<8> a(in); + + EXPECT_EQ(sizeof(a.vals), 8u); + EXPECT_EQ(a.vals[0], 0x38); + EXPECT_EQ(a.vals[1], 0x40); + EXPECT_EQ(a.vals[2], 0x00); + EXPECT_EQ(a.vals[3], 0x80); + EXPECT_EQ(a.vals[4], 0x7E); + EXPECT_EQ(a.vals[5], 0x7E); // finite saturation + EXPECT_EQ(a.vals[6], 0x01); + EXPECT_EQ(a.vals[7], 0xBC); // -1.5 + + // marray operator: convert fp8 vector back to marray. + sycl::marray out = static_cast>(a); + EXPECT_EQ(out[0], 1.0f); + EXPECT_EQ(out[1], 2.0f); + EXPECT_EQ(out[2], 0.0f); + EXPECT_EQ(out[3], 0.0f); + EXPECT_TRUE(std::signbit(out[3])); // preserve -0 + EXPECT_EQ(out[4], 448.0f); + EXPECT_EQ(out[5], 448.0f); + EXPECT_EQ(out[6], 0.001953125f); + EXPECT_EQ(out[7], -1.5f); +} + +TEST(FP8E4M3Test, FloatingPointConversionOperatorsMoreTypes) { + fp8_e4m3<1> a(1.0f); + fp8_e4m3<1> b(0.015625f); + fp8_e4m3<1> nanv(std::numeric_limits::quiet_NaN()); + + EXPECT_EQ(sizeof(a.vals), 1u); + EXPECT_EQ(sizeof(b.vals), 1u); + EXPECT_EQ(sizeof(nanv.vals), 1u); + + double da = static_cast(a); + sycl::half ha = static_cast(a); + sycl::ext::oneapi::bfloat16 ba = static_cast(a); + + EXPECT_EQ(da, 1.0); + EXPECT_EQ(static_cast(ha), 1.0f); + EXPECT_EQ(static_cast(ba), 1.0f); + + EXPECT_EQ(static_cast(b), 0.015625f); + + float fn = static_cast(nanv); + EXPECT_TRUE(std::isnan(fn)); +} + +TEST(FP8E4M3Test, IntegerConversionOperatorsMultipleWidthsTowardZero) { + fp8_e4m3<1> p(1.5f); + fp8_e4m3<1> n(-1.5f); + + std::int32_t i32p = static_cast(p); + std::int32_t i32n = static_cast(n); + std::int64_t i64p = static_cast(p); + std::int64_t i64n = static_cast(n); + + EXPECT_EQ(i32p, 1); + EXPECT_EQ(i32n, -1); + EXPECT_EQ(i64p, 1); + EXPECT_EQ(i64n, -1); +} + +TEST(FP8E4M3Test, VariadicHalfBoundaryEncodings) { + fp8_e4m3<4> a(sycl::half(448.0f), sycl::half(0.015625f), sycl::half(0.001953125f), + sycl::half(-0.0f)); + + EXPECT_EQ(sizeof(a.vals), 4u); + EXPECT_EQ(a.vals[0], 0x7E); // +max normal + EXPECT_EQ(a.vals[1], 0x08); // min normal + EXPECT_EQ(a.vals[2], 0x01); // min subnormal + EXPECT_EQ(a.vals[3], 0x80); // -0 +} + +TEST(FP8E4M3Test, VariadicBFloat16BoundaryEncodings) { + fp8_e4m3<4> a(sycl::ext::oneapi::bfloat16(1.0f), + sycl::ext::oneapi::bfloat16(2.0f), + sycl::ext::oneapi::bfloat16(0.001953125f), + sycl::ext::oneapi::bfloat16(-0.0f)); + + EXPECT_EQ(sizeof(a.vals), 4u); + EXPECT_EQ(a.vals[0], 0x38); + EXPECT_EQ(a.vals[1], 0x40); + EXPECT_EQ(a.vals[2], 0x01); + EXPECT_EQ(a.vals[3], 0x80); +} + +TEST(FP8E4M3Test, VariadicDoubleBoundaryEncodingsAndSaturation) { + fp8_e4m3<5> a(448.0, 449.0, 0.013671875, 0.001953125, -1000.0); + + EXPECT_EQ(sizeof(a.vals), 5u); + EXPECT_EQ(a.vals[0], 0x7E); // +448 + EXPECT_EQ(a.vals[1], 0x7E); // clamp to +448 (finite saturation) + EXPECT_EQ(a.vals[2], 0x07); // max subnormal + EXPECT_EQ(a.vals[3], 0x01); // min subnormal + EXPECT_EQ(a.vals[4], 0xFE); // clamp to -448 +} + +TEST(FP8E4M3Test, BoolOperatorWithNaN) { + float pz = 0.0f; + fp8_e4m3<1> zp(pz); + float zv = -0.0f; + fp8_e4m3<1> zn(zv); + float nv = {std::numeric_limits::quiet_NaN()}; + fp8_e4m3<1> nanv(nv); + + EXPECT_EQ(sizeof(zp.vals), 1u); + EXPECT_EQ(sizeof(zn.vals), 1u); + EXPECT_EQ(sizeof(nanv.vals), 1u); + + EXPECT_FALSE(static_cast(zp)); + EXPECT_FALSE(static_cast(zn)); + EXPECT_TRUE(static_cast(nanv)); // not +0 or -0 + EXPECT_EQ(nanv.vals[0], 0x7F); // NaN encoding remains S.1111.111 +} + +TEST(FP8E4M3Test, CArrayFloatRoundingToEven) { + const float in[3] = {0.012f, 1.0625f, 1000.0f}; + fp8_e4m3<3> a(in, rounding::to_even); + + EXPECT_EQ(a.vals[0], 0x06); + EXPECT_EQ(a.vals[1], 0x38); + EXPECT_EQ(a.vals[2], 0x7E); +} + +TEST(FP8E4M3Test, CArrayFloatRoundingUpward) { + const float in[3] = {0.012f, 1.0625f, 1000.0f}; + fp8_e4m3<3> a(in, rounding::upward); + + EXPECT_EQ(a.vals[0], 0x07); + EXPECT_EQ(a.vals[1], 0x39); + EXPECT_EQ(a.vals[2], 0x7E); +} + +TEST(FP8E4M3Test, CArrayFloatRoundingDownward) { + const float in[3] = {0.012f, 1.0625f, 1000.0f}; + fp8_e4m3<3> a(in, rounding::downward); + + EXPECT_EQ(a.vals[0], 0x06); + EXPECT_EQ(a.vals[1], 0x38); + EXPECT_EQ(a.vals[2], 0x7E); +} + +TEST(FP8E4M3Test, CArrayFloatRoundingTowardZero) { + const float in[3] = {0.012f, 1.0625f, 1000.0f}; + fp8_e4m3<3> a(in, rounding::toward_zero); + + EXPECT_EQ(a.vals[0], 0x06); + EXPECT_EQ(a.vals[1], 0x38); + EXPECT_EQ(a.vals[2], 0x7E); +} + +TEST(FP8E4M3Test, CArrayFloatRoundingToAway) { + const float in[3] = {0.012f, 1.0625f, 1000.0f}; + fp8_e4m3<3> a(in, rounding::to_away); + + EXPECT_EQ(a.vals[0], 0x06); + EXPECT_EQ(a.vals[1], 0x38); + EXPECT_EQ(a.vals[2], 0x7E); +} + +TEST(FP8E4M3Test, CArrayHalfRoundingToEven) { + const sycl::half in[3] = {sycl::half(0.012f), sycl::half(1.0625f), + sycl::half(1000.0f)}; + fp8_e4m3<3> a(in, rounding::to_even); + + EXPECT_EQ(a.vals[0], 0x06); + EXPECT_EQ(a.vals[1], 0x38); + EXPECT_EQ(a.vals[2], 0x7E); +} + +TEST(FP8E4M3Test, CArrayHalfRoundingUpward) { + const sycl::half in[3] = {sycl::half(0.012f), sycl::half(1.0625f), + sycl::half(1000.0f)}; + fp8_e4m3<3> a(in, rounding::upward); + + EXPECT_EQ(a.vals[0], 0x07); + EXPECT_EQ(a.vals[1], 0x39); + EXPECT_EQ(a.vals[2], 0x7E); +} + +TEST(FP8E4M3Test, CArrayHalfRoundingDownward) { + const sycl::half in[3] = {sycl::half(0.012f), sycl::half(1.0625f), + sycl::half(1000.0f)}; + fp8_e4m3<3> a(in, rounding::downward); + + EXPECT_EQ(a.vals[0], 0x06); + EXPECT_EQ(a.vals[1], 0x38); + EXPECT_EQ(a.vals[2], 0x7E); +} + +TEST(FP8E4M3Test, CArrayHalfRoundingTowardZero) { + const sycl::half in[3] = {sycl::half(0.012f), sycl::half(1.0625f), + sycl::half(1000.0f)}; + fp8_e4m3<3> a(in, rounding::toward_zero); + + EXPECT_EQ(a.vals[0], 0x06); + EXPECT_EQ(a.vals[1], 0x38); + EXPECT_EQ(a.vals[2], 0x7E); +} + +TEST(FP8E4M3Test, CArrayHalfRoundingToAway) { + const sycl::half in[3] = {sycl::half(0.012f), sycl::half(1.0625f), + sycl::half(1000.0f)}; + fp8_e4m3<3> a(in, rounding::to_away); + + EXPECT_EQ(a.vals[0], 0x06); + EXPECT_EQ(a.vals[1], 0x38); + EXPECT_EQ(a.vals[2], 0x7E); +} + +TEST(FP8E4M3Test, CArrayBFloat16RoundingToEven) { + const sycl::ext::oneapi::bfloat16 in[3] = { + sycl::ext::oneapi::bfloat16(0.012f), + sycl::ext::oneapi::bfloat16(1.0625f), + sycl::ext::oneapi::bfloat16(1000.0f)}; + fp8_e4m3<3> a(in, rounding::to_even); + + EXPECT_EQ(a.vals[0], 0x06); + EXPECT_EQ(a.vals[1], 0x38); + EXPECT_EQ(a.vals[2], 0x7E); +} + +TEST(FP8E4M3Test, CArrayBFloat16RoundingUpward) { + const sycl::ext::oneapi::bfloat16 in[3] = { + sycl::ext::oneapi::bfloat16(0.012f), + sycl::ext::oneapi::bfloat16(1.0625f), + sycl::ext::oneapi::bfloat16(1000.0f)}; + fp8_e4m3<3> a(in, rounding::upward); + + EXPECT_EQ(a.vals[0], 0x07); + EXPECT_EQ(a.vals[1], 0x39); + EXPECT_EQ(a.vals[2], 0x7E); +} + +TEST(FP8E4M3Test, CArrayBFloat16Downward) { + const sycl::ext::oneapi::bfloat16 in[3] = { + sycl::ext::oneapi::bfloat16(0.012f), + sycl::ext::oneapi::bfloat16(1.0625f), + sycl::ext::oneapi::bfloat16(1000.0f)}; + fp8_e4m3<3> a(in, rounding::downward); + + EXPECT_EQ(a.vals[0], 0x06); + EXPECT_EQ(a.vals[1], 0x38); + EXPECT_EQ(a.vals[2], 0x7E); +} + +TEST(FP8E4M3Test, CArrayBFloat16TowardZero) { + const sycl::ext::oneapi::bfloat16 in[3] = { + sycl::ext::oneapi::bfloat16(0.012f), + sycl::ext::oneapi::bfloat16(1.0625f), + sycl::ext::oneapi::bfloat16(1000.0f)}; + fp8_e4m3<3> a(in, rounding::toward_zero); + + EXPECT_EQ(a.vals[0], 0x06); + EXPECT_EQ(a.vals[1], 0x38); + EXPECT_EQ(a.vals[2], 0x7E); +} + +TEST(FP8E4M3Test, CArrayBFloat16ToAway) { + const sycl::ext::oneapi::bfloat16 in[3] = { + sycl::ext::oneapi::bfloat16(0.012f), + sycl::ext::oneapi::bfloat16(1.0625f), + sycl::ext::oneapi::bfloat16(1000.0f)}; + fp8_e4m3<3> a(in, rounding::to_away); + + EXPECT_EQ(a.vals[0], 0x06); + EXPECT_EQ(a.vals[1], 0x38); + EXPECT_EQ(a.vals[2], 0x7E); +} + +TEST(FP8E4M3Test, MarrayHalfRoundingToEven) { + const sycl::marray in = {sycl::half(0.012f), + sycl::half(1.0625f), + sycl::half(1000.0f)}; + fp8_e4m3<3> a(in, rounding::to_even); + + EXPECT_EQ(a.vals[0], 0x06); + EXPECT_EQ(a.vals[1], 0x38); + EXPECT_EQ(a.vals[2], 0x7E); +} + +TEST(FP8E4M3Test, MarrayHalfRoundingUpward) { + const sycl::marray in = {sycl::half(0.012f), + sycl::half(1.0625f), + sycl::half(1000.0f)}; + fp8_e4m3<3> a(in, rounding::upward); + + EXPECT_EQ(a.vals[0], 0x07); + EXPECT_EQ(a.vals[1], 0x39); + EXPECT_EQ(a.vals[2], 0x7E); +} + +TEST(FP8E4M3Test, MarrayHalfRoundingDownward) { + const sycl::marray in = {sycl::half(0.012f), + sycl::half(1.0625f), + sycl::half(1000.0f)}; + fp8_e4m3<3> a(in, rounding::downward); + + EXPECT_EQ(a.vals[0], 0x06); + EXPECT_EQ(a.vals[1], 0x38); + EXPECT_EQ(a.vals[2], 0x7E); +} + +TEST(FP8E4M3Test, MarrayHalfRoundingTowardZero) { + const sycl::marray in = {sycl::half(0.012f), + sycl::half(1.0625f), + sycl::half(1000.0f)}; + fp8_e4m3<3> a(in, rounding::toward_zero); + + EXPECT_EQ(a.vals[0], 0x06); + EXPECT_EQ(a.vals[1], 0x38); + EXPECT_EQ(a.vals[2], 0x7E); +} + +TEST(FP8E4M3Test, MarrayHalfRoundingToAway) { + const sycl::marray in = {sycl::half(0.012f), + sycl::half(1.0625f), + sycl::half(1000.0f)}; + fp8_e4m3<3> a(in, rounding::to_away); + + EXPECT_EQ(a.vals[0], 0x06); + EXPECT_EQ(a.vals[1], 0x38); + EXPECT_EQ(a.vals[2], 0x7E); +} + +TEST(FP8E4M3Test, MarrayBFloat16RoundingToEven) { + const sycl::marray in = { + sycl::ext::oneapi::bfloat16(0.012f), + sycl::ext::oneapi::bfloat16(1.0625f), + sycl::ext::oneapi::bfloat16(1000.0f)}; + fp8_e4m3<3> a(in, rounding::to_even); + + EXPECT_EQ(a.vals[0], 0x06); + EXPECT_EQ(a.vals[1], 0x38); + EXPECT_EQ(a.vals[2], 0x7E); +} + +TEST(FP8E4M3Test, MarrayBFloat16RoundingUpward) { + const sycl::marray in = { + sycl::ext::oneapi::bfloat16(0.012f), + sycl::ext::oneapi::bfloat16(1.0625f), + sycl::ext::oneapi::bfloat16(1000.0f)}; + fp8_e4m3<3> a(in, rounding::upward); + + EXPECT_EQ(a.vals[0], 0x07); + EXPECT_EQ(a.vals[1], 0x39); + EXPECT_EQ(a.vals[2], 0x7E); +} + +TEST(FP8E4M3Test, MarrayBFloat16RoundingDownward) { + const sycl::marray in = { + sycl::ext::oneapi::bfloat16(0.012f), + sycl::ext::oneapi::bfloat16(1.0625f), + sycl::ext::oneapi::bfloat16(1000.0f)}; + fp8_e4m3<3> a(in, rounding::downward); + + EXPECT_EQ(a.vals[0], 0x06); + EXPECT_EQ(a.vals[1], 0x38); + EXPECT_EQ(a.vals[2], 0x7E); +} + +TEST(FP8E4M3Test, MarrayBFloat16RoundingTowardZero) { + const sycl::marray in = { + sycl::ext::oneapi::bfloat16(0.012f), + sycl::ext::oneapi::bfloat16(1.0625f), + sycl::ext::oneapi::bfloat16(1000.0f)}; + fp8_e4m3<3> a(in, rounding::toward_zero); + + EXPECT_EQ(a.vals[0], 0x06); + EXPECT_EQ(a.vals[1], 0x38); + EXPECT_EQ(a.vals[2], 0x7E); +} + +TEST(FP8E4M3Test, MarrayBFloat16RoundingToAway) { + const sycl::marray in = { + sycl::ext::oneapi::bfloat16(0.012f), + sycl::ext::oneapi::bfloat16(1.0625f), + sycl::ext::oneapi::bfloat16(1000.0f)}; + fp8_e4m3<3> a(in, rounding::to_away); + + EXPECT_EQ(a.vals[0], 0x06); + EXPECT_EQ(a.vals[1], 0x38); + EXPECT_EQ(a.vals[2], 0x7E); +} + +TEST(FP8E4M3Test, MarrayFloatRoundingToEven) { + const sycl::marray in = {0.012f, 1.0625f, 1000.0f}; + fp8_e4m3<3> a(in, rounding::to_even); + + EXPECT_EQ(a.vals[0], 0x06); + EXPECT_EQ(a.vals[1], 0x38); + EXPECT_EQ(a.vals[2], 0x7E); +} + +TEST(FP8E4M3Test, MarrayFloatRoundingUpward) { + const sycl::marray in = {0.012f, 1.0625f, 1000.0f}; + fp8_e4m3<3> a(in, rounding::upward); + + EXPECT_EQ(a.vals[0], 0x07); + EXPECT_EQ(a.vals[1], 0x39); + EXPECT_EQ(a.vals[2], 0x7E); +} + +TEST(FP8E4M3Test, MarrayFloatRoundingDownward) { + const sycl::marray in = {0.012f, 1.0625f, 1000.0f}; + fp8_e4m3<3> a(in, rounding::downward); + + EXPECT_EQ(a.vals[0], 0x06); + EXPECT_EQ(a.vals[1], 0x38); + EXPECT_EQ(a.vals[2], 0x7E); +} + +TEST(FP8E4M3Test, MarrayFloatRoundingTowardZero) { + const sycl::marray in = {0.012f, 1.0625f, 1000.0f}; + fp8_e4m3<3> a(in, rounding::toward_zero); + + EXPECT_EQ(a.vals[0], 0x06); + EXPECT_EQ(a.vals[1], 0x38); + EXPECT_EQ(a.vals[2], 0x7E); +} + +TEST(FP8E4M3Test, MarrayFloatRoundingToAway) { + const sycl::marray in = {0.012f, 1.0625f, 1000.0f}; + fp8_e4m3<3> a(in, rounding::to_away); + + EXPECT_EQ(a.vals[0], 0x06); + EXPECT_EQ(a.vals[1], 0x38); + EXPECT_EQ(a.vals[2], 0x7E); +} + + +TEST(FP8E4M3Test, MarrayDoubleToEven) { + const sycl::marray in = {0.012, 1.0625, 1000.0}; + fp8_e4m3<3> a(in); + + EXPECT_EQ(a.vals[0], 0x06); + EXPECT_EQ(a.vals[1], 0x38); + EXPECT_EQ(a.vals[2], 0x7E); +} diff --git a/sycl/unittests/Extensions/fp8/fp8_e5m2.cpp b/sycl/unittests/Extensions/fp8/fp8_e5m2.cpp new file mode 100644 index 0000000000000..8455ba8f93752 --- /dev/null +++ b/sycl/unittests/Extensions/fp8/fp8_e5m2.cpp @@ -0,0 +1,512 @@ +#include +#include + +#include +#include +#include + +using namespace sycl::ext::oneapi::experimental; + +TEST(FP8E5M2Test, VariadicConstructorHalf) { + fp8_e5m2<2> a(sycl::half(1.0f), sycl::half(2.0f)); + + EXPECT_EQ(sizeof(a.vals), 2u); + EXPECT_EQ(a.vals[0], 0x3C); // 1.0 -> 0b0_01111_00 + EXPECT_EQ(a.vals[1], 0x40); // 2.0 -> 0b0_10000_00 + + fp8_e5m2<1> b(sycl::half(1.1f)); + EXPECT_EQ(sizeof(b.vals), 1u); + EXPECT_EQ(b.vals[0], 0x3C); // 1.1 rounds to 1.0 +} + +TEST(FP8E5M2Test, VariadicConstructorBFloat16) { + fp8_e5m2<2> a(sycl::ext::oneapi::bfloat16(1.0f), + sycl::ext::oneapi::bfloat16(2.0f)); + + EXPECT_EQ(sizeof(a.vals), 2u); + EXPECT_EQ(a.vals[0], 0x3C); + EXPECT_EQ(a.vals[1], 0x40); + + fp8_e5m2<1> b(sycl::ext::oneapi::bfloat16(1.1f)); + EXPECT_EQ(sizeof(b.vals), 1u); + EXPECT_EQ(b.vals[0], 0x3C); +} + +TEST(FP8E5M2Test, VariadicConstructorFloat) { + fp8_e5m2<2> a(1.0f, 2.0f); + + EXPECT_EQ(sizeof(a.vals), 2u); + EXPECT_EQ(a.vals[0], 0x3C); + EXPECT_EQ(a.vals[1], 0x40); + + fp8_e5m2<1> b(1.1f); + EXPECT_EQ(sizeof(b.vals), 1u); + EXPECT_EQ(b.vals[0], 0x3C); +} + +TEST(FP8E5M2Test, VariadicConstructorBoundaryEncodingsFloat) { + fp8_e5m2<6> a( + 57344.0f, // max normal -> S.11110.11 + 0.00006103515625f, // min normal -> S.00001.00 (2^-14) + 0.0000457763671875f, // max subnorm -> S.00000.11 (0.75 * 2^-14) + 0.0000152587890625f, // min subnorm -> S.00000.01 (2^-16) + 0.0f, // +0 + -0.0f // -0 + ); + + EXPECT_EQ(sizeof(a.vals), 6u); + + EXPECT_EQ(a.vals[0], 0x7B); // +57344.0 -> 0b0_11110_11 + EXPECT_EQ(a.vals[1], 0x04); // +2^-14 -> 0b0_00001_00 + EXPECT_EQ(a.vals[2], 0x03); // +max subnorm -> 0b0_00000_11 + EXPECT_EQ(a.vals[3], 0x01); // +min subnorm -> 0b0_00000_01 + EXPECT_EQ(a.vals[4], 0x00); // +0 -> 0b0_00000_00 + EXPECT_EQ(a.vals[5], 0x80); // -0 -> 0b1_00000_00 +} + +TEST(FP8E5M2Test, VariadicConstructorNaNEncodingFloat) { + fp8_e5m2<2> a(std::numeric_limits::quiet_NaN(), + -std::numeric_limits::quiet_NaN()); + + EXPECT_EQ(sizeof(a.vals), 2u); + EXPECT_EQ(a.vals[0], 0x7F); // +NaN -> 0b0_11111_11 + EXPECT_EQ(a.vals[1], 0xFF); // -NaN -> 0b1_11111_11 +} + +TEST(FP8E5M2Test, IntegerConstructorToEvenFiniteAndSize) { + fp8_e5m2<1> a0(0); + fp8_e5m2<1> a1(1); + fp8_e5m2<1> a2(2); + fp8_e5m2<1> an1(-1); + + EXPECT_EQ(sizeof(a0.vals), 1u); + EXPECT_EQ(sizeof(a1.vals), 1u); + EXPECT_EQ(sizeof(a2.vals), 1u); + EXPECT_EQ(sizeof(an1.vals), 1u); + + EXPECT_EQ(a0.vals[0], 0x00); // +0 + EXPECT_EQ(a1.vals[0], 0x3C); // +1.0 -> 0b0_01111_00 + EXPECT_EQ(a2.vals[0], 0x40); // +2.0 -> 0b0_10000_00 + EXPECT_EQ(an1.vals[0], 0xBC); // -1.0 -> 0b1_01111_00 +} + +TEST(FP8E5M2Test, AssignmentOperatorToEvenFiniteAndSize) { + fp8_e5m2<1> a(0.0f); + EXPECT_EQ(sizeof(a.vals), 1u); + EXPECT_EQ(a.vals[0], 0x00); + + a = 1.0f; + EXPECT_EQ(a.vals[0], 0x3C); + + a = -2.0f; + EXPECT_EQ(a.vals[0], 0xC0); // -2.0 -> 0b1_10000_00 + + a = 0.00006103515625f; // min normal + EXPECT_EQ(a.vals[0], 0x04); +} + +TEST(FP8E5M2Test, FloatingPointConversionOperators) { + // Floating-point operators: convert stored fp8 to the respective type. + fp8_e5m2<1> one(1.0f); + fp8_e5m2<1> zero_pos(0.0f); + fp8_e5m2<1> zero_neg(-0.0f); + fp8_e5m2<1> min_norm(0.00006103515625f); + + EXPECT_EQ(sizeof(one.vals), 1u); + EXPECT_EQ(one.vals[0], 0x3C); + + float f1 = static_cast(one); + float fz = static_cast(zero_pos); + float fnz = static_cast(zero_neg); + float fmn = static_cast(min_norm); + + EXPECT_EQ(f1, 1.0f); + EXPECT_EQ(fz, 0.0f); + EXPECT_EQ(fnz, 0.0f); + EXPECT_TRUE(std::signbit(fnz)); + + EXPECT_EQ(fmn, 0.00006103515625f); +} + +TEST(FP8E5M2Test, IntegerConversionOperatorsTowardZero) { + // Integer operators: convert using rounding::toward_zero. + fp8_e5m2<1> p(1.5f); // 1.5 exactly representable: 0b0_01111_10 (0x3E) + fp8_e5m2<1> n(-1.5f); // 0xBE + + EXPECT_EQ(sizeof(p.vals), 1u); + EXPECT_EQ(sizeof(n.vals), 1u); + EXPECT_EQ(p.vals[0], 0x3E); + EXPECT_EQ(n.vals[0], 0xBE); + + int ip = static_cast(p); + int in = static_cast(n); + + EXPECT_EQ(ip, 1); // toward zero + EXPECT_EQ(in, -1); // toward zero +} + +TEST(FP8E5M2Test, BoolOperatorZeroRules) { + // bool operator: false iff +0 or -0; otherwise true. + fp8_e5m2<1> zp(0.0f); + fp8_e5m2<1> zn(-0.0f); + fp8_e5m2<1> one(1.0f); + fp8_e5m2<1> sub(0.0000152587890625f); // min subnormal + + EXPECT_EQ(sizeof(zp.vals), 1u); + EXPECT_EQ(sizeof(zn.vals), 1u); + EXPECT_EQ(sizeof(one.vals), 1u); + EXPECT_EQ(sizeof(sub.vals), 1u); + + EXPECT_FALSE(static_cast(zp)); + EXPECT_FALSE(static_cast(zn)); + EXPECT_TRUE(static_cast(one)); + EXPECT_TRUE(static_cast(sub)); +} + +TEST(FP8E5M2Test, VariadicConstructorSaturatesFinite) { + // Variadic constructors: to_even + finite saturation (CPU). + fp8_e5m2<4> a(1.0f, + 100000.0f, // above max normal: clamp to +57344 + -100000.0f, // clamp to -57344 + -0.0f); + + EXPECT_EQ(sizeof(a.vals), 4u); + EXPECT_EQ(a.vals[0], 0x3C); + EXPECT_EQ(a.vals[1], 0x7B); // +max normal + EXPECT_EQ(a.vals[2], 0xFB); // -max normal + EXPECT_EQ(a.vals[3], 0x80); // -0 +} + +TEST(FP8E5M2Test, VariadicConstructorToEvenTie) { + // Tie case: between 1.0 (0x3C) and 1.25 (0x3D) is 1.125 exactly. + // to_even => choose 1.0 because its LSB (fraction) is even (0). + // Tie between 1.25 (0x3D) and 1.5 (0x3E) is 1.375 exactly => choose 1.5. + fp8_e5m2<2> a(1.125f, -1.375f); + EXPECT_EQ(sizeof(a.vals), 2u); + EXPECT_EQ(a.vals[0], 0x3C); + EXPECT_EQ(a.vals[1], 0xBE); +} + +TEST(FP8E5M2Test, CArrayConstructorFloatHostToEvenFinite) { + // Host code supports only rounding::to_even and saturation::finite. + const float in[5] = {1.0f, 1.1f, 1.125f, 100000.0f, -0.0f}; + fp8_e5m2<5> a(in); + + EXPECT_EQ(sizeof(a.vals), 5u); + EXPECT_EQ(a.vals[0], 0x3C); // 1.0 + EXPECT_EQ(a.vals[1], 0x3C); // 1.1 -> 1.0 + EXPECT_EQ(a.vals[2], 0x3C); // tie -> to_even => 1.0 + EXPECT_EQ(a.vals[3], 0x7B); // finite saturation => +57344 + EXPECT_EQ(a.vals[4], 0x80); // -0 +} + +TEST(FP8E5M2Test, CArrayConstructorDoubleToEvenFinite) { + // Double c-array: to_even + finite saturation. + const double in[6] = {57344.0, + 60000.0, + 0.00006103515625, + 0.0000457763671875, + 0.0000152587890625, + std::numeric_limits::quiet_NaN()}; + fp8_e5m2<6> a(in); + + EXPECT_EQ(sizeof(a.vals), 6u); + EXPECT_EQ(a.vals[0], 0x7B); // +57344 + EXPECT_EQ(a.vals[1], 0x7B); // 60000 -> clamp to +57344 + EXPECT_EQ(a.vals[2], 0x04); // min normal + EXPECT_EQ(a.vals[3], 0x03); // max subnormal + EXPECT_EQ(a.vals[4], 0x01); // min subnormal + EXPECT_EQ(a.vals[5], 0x7F); // NaN +} + +TEST(FP8E5M2Test, CArrayConstructorHalfHostToEvenFinite) { + const sycl::half in[4] = {sycl::half(1.0f), sycl::half(2.0f), + sycl::half(1.125f), sycl::half(-0.0f)}; + fp8_e5m2<4> a(in); + + EXPECT_EQ(sizeof(a.vals), 4u); + EXPECT_EQ(a.vals[0], 0x3C); + EXPECT_EQ(a.vals[1], 0x40); + EXPECT_EQ(a.vals[2], 0x3C); // tie -> to_even => 1.0 + EXPECT_EQ(a.vals[3], 0x80); +} + +TEST(FP8E5M2Test, CArrayConstructorBFloat16HostToEvenFinite) { + const sycl::ext::oneapi::bfloat16 in[4] = { + sycl::ext::oneapi::bfloat16(1.0f), sycl::ext::oneapi::bfloat16(2.0f), + sycl::ext::oneapi::bfloat16(1.125f), sycl::ext::oneapi::bfloat16(-0.0f)}; + fp8_e5m2<4> a(in); + + EXPECT_EQ(sizeof(a.vals), 4u); + EXPECT_EQ(a.vals[0], 0x3C); + EXPECT_EQ(a.vals[1], 0x40); + EXPECT_EQ(a.vals[2], 0x3C); // tie -> to_even => 1.0 + EXPECT_EQ(a.vals[3], 0x80); +} + +TEST(FP8E5M2Test, MarrayConstructorAndOperatorsHostAllN) { + // marray constructors/operators: host supports all N. + sycl::marray in = { + 1.0f, 2.0f, 0.0f, -0.0f, 57344.0f, 100000.0f, 0.0000152587890625f, -1.5f}; + fp8_e5m2<8> a(in); + + EXPECT_EQ(sizeof(a.vals), 8u); + EXPECT_EQ(a.vals[0], 0x3C); + EXPECT_EQ(a.vals[1], 0x40); + EXPECT_EQ(a.vals[2], 0x00); + EXPECT_EQ(a.vals[3], 0x80); + EXPECT_EQ(a.vals[4], 0x7B); + EXPECT_EQ(a.vals[5], 0x7B); // finite saturation + EXPECT_EQ(a.vals[6], 0x01); + EXPECT_EQ(a.vals[7], 0xBE); // -1.5 + + sycl::marray out = static_cast>(a); + EXPECT_EQ(out[0], 1.0f); + EXPECT_EQ(out[1], 2.0f); + EXPECT_EQ(out[2], 0.0f); + EXPECT_EQ(out[3], 0.0f); + EXPECT_TRUE(std::signbit(out[3])); + EXPECT_EQ(out[4], 57344.0f); + EXPECT_EQ(out[5], 57344.0f); + EXPECT_EQ(out[6], 0.0000152587890625f); + EXPECT_EQ(out[7], -1.5f); +} + +TEST(FP8E5M2Test, MarrayConstructorHalfBFloat16Double) { + sycl::marray hvals = {sycl::half(1.0f), sycl::half(2.0f), + sycl::half(57344.0f), sycl::half(-0.0f)}; + sycl::marray bvals = { + sycl::ext::oneapi::bfloat16(1.0f), sycl::ext::oneapi::bfloat16(2.0f), + sycl::ext::oneapi::bfloat16(0.0000152587890625f), + sycl::ext::oneapi::bfloat16(-0.0f)}; + sycl::marray dvals = {1.0, 2.0, 57344.0, -0.0}; + + fp8_e5m2<4> ah(hvals); + fp8_e5m2<4> ab(bvals); + fp8_e5m2<4> ad(dvals); + + EXPECT_EQ(sizeof(ah.vals), 4u); + EXPECT_EQ(sizeof(ab.vals), 4u); + EXPECT_EQ(sizeof(ad.vals), 4u); + + EXPECT_EQ(ah.vals[0], 0x3C); + EXPECT_EQ(ah.vals[1], 0x40); + EXPECT_EQ(ah.vals[2], 0x7B); + EXPECT_EQ(ah.vals[3], 0x80); + + EXPECT_EQ(ab.vals[0], 0x3C); + EXPECT_EQ(ab.vals[1], 0x40); + EXPECT_EQ(ab.vals[2], 0x01); + EXPECT_EQ(ab.vals[3], 0x80); + + EXPECT_EQ(ad.vals[0], 0x3C); + EXPECT_EQ(ad.vals[1], 0x40); + EXPECT_EQ(ad.vals[2], 0x7B); + EXPECT_EQ(ad.vals[3], 0x80); +} + +TEST(FP8E5M2Test, FloatingPointConversionOperatorsMoreTypes) { + fp8_e5m2<1> a(1.0f); + fp8_e5m2<1> b(0.00006103515625f); + fp8_e5m2<1> nanv(std::numeric_limits::quiet_NaN()); + + EXPECT_EQ(sizeof(a.vals), 1u); + EXPECT_EQ(sizeof(b.vals), 1u); + EXPECT_EQ(sizeof(nanv.vals), 1u); + + double da = static_cast(a); + sycl::half ha = static_cast(a); + sycl::ext::oneapi::bfloat16 ba = static_cast(a); + + EXPECT_EQ(da, 1.0); + EXPECT_EQ(static_cast(ha), 1.0f); + EXPECT_EQ(static_cast(ba), 1.0f); + + EXPECT_EQ(static_cast(b), 0.00006103515625f); + + float fn = static_cast(nanv); + EXPECT_TRUE(std::isnan(fn)); +} + +TEST(FP8E5M2Test, MarrayConversionOperatorsHalfBFloat16) { + fp8_e5m2<2> a(1.0f, -0.0f); + + EXPECT_EQ(sizeof(a.vals), 2u); + EXPECT_EQ(a.vals[0], 0x3C); + EXPECT_EQ(a.vals[1], 0x80); + + sycl::marray ho = static_cast>(a); + sycl::marray bo = + static_cast>(a); + + EXPECT_EQ(static_cast(ho[0]), 1.0f); + EXPECT_EQ(static_cast(ho[1]), 0.0f); + EXPECT_TRUE(std::signbit(static_cast(ho[1]))); + + EXPECT_EQ(static_cast(bo[0]), 1.0f); + EXPECT_EQ(static_cast(bo[1]), 0.0f); + EXPECT_TRUE(std::signbit(static_cast(bo[1]))); +} + +TEST(FP8E5M2Test, IntegerConversionOperatorsMultipleWidthsTowardZero) { + fp8_e5m2<1> p(1.5f); + fp8_e5m2<1> n(-1.5f); + + std::int32_t i32p = static_cast(p); + std::int32_t i32n = static_cast(n); + std::int64_t i64p = static_cast(p); + std::int64_t i64n = static_cast(n); + + EXPECT_EQ(i32p, 1); + EXPECT_EQ(i32n, -1); + EXPECT_EQ(i64p, 1); + EXPECT_EQ(i64n, -1); +} + +TEST(FP8E5M2Test, IntegerConversionOperatorsAllTypesTowardZero) { + fp8_e5m2<1> p(1.5f); + fp8_e5m2<1> n(-1.5f); + + EXPECT_EQ(sizeof(p.vals), 1u); + EXPECT_EQ(sizeof(n.vals), 1u); + EXPECT_EQ(p.vals[0], 0x3E); + EXPECT_EQ(n.vals[0], 0xBE); + + EXPECT_EQ(static_cast(p), 1); + EXPECT_EQ(static_cast(n), -1); + EXPECT_EQ(static_cast(n), -1); + EXPECT_EQ(static_cast(n), -1); + EXPECT_EQ(static_cast(n), -1); + EXPECT_EQ(static_cast(n), -1); + EXPECT_EQ(static_cast(p), 1u); + EXPECT_EQ(static_cast(p), 1u); + EXPECT_EQ(static_cast(p), 1u); + EXPECT_EQ(static_cast(p), 1u); + EXPECT_EQ(static_cast(p), 1u); +} + +TEST(FP8E5M2Test, VariadicConstructorHalfBoundaryEncodings) { + fp8_e5m2<4> a(sycl::half(57344.0f), sycl::half(0.00006103515625f), + sycl::half(0.0000152587890625f), sycl::half(-0.0f)); + + EXPECT_EQ(sizeof(a.vals), 4u); + EXPECT_EQ(a.vals[0], 0x7B); // +max normal + EXPECT_EQ(a.vals[1], 0x04); // min normal + EXPECT_EQ(a.vals[2], 0x01); // min subnormal + EXPECT_EQ(a.vals[3], 0x80); // -0 +} + +TEST(FP8E5M2Test, VariadicConstructorBFloat16BoundaryEncodings) { + fp8_e5m2<4> a(sycl::ext::oneapi::bfloat16(1.0f), + sycl::ext::oneapi::bfloat16(2.0f), + sycl::ext::oneapi::bfloat16(0.0000152587890625f), + sycl::ext::oneapi::bfloat16(-0.0f)); + + EXPECT_EQ(sizeof(a.vals), 4u); + EXPECT_EQ(a.vals[0], 0x3C); + EXPECT_EQ(a.vals[1], 0x40); + EXPECT_EQ(a.vals[2], 0x01); + EXPECT_EQ(a.vals[3], 0x80); +} + +TEST(FP8E5M2Test, VariadicConstructorDoubleBoundaryEncodingsAndSaturation) { + fp8_e5m2<5> a(57344.0, 60000.0, 0.0000457763671875, 0.0000152587890625, + -100000.0); + + EXPECT_EQ(sizeof(a.vals), 5u); + EXPECT_EQ(a.vals[0], 0x7B); // +57344 + EXPECT_EQ(a.vals[1], 0x7B); // clamp to +57344 (finite saturation) + EXPECT_EQ(a.vals[2], 0x03); // max subnormal + EXPECT_EQ(a.vals[3], 0x01); // min subnormal + EXPECT_EQ(a.vals[4], 0xFB); // clamp to -57344 +} + +TEST(FP8E5M2Test, IntegerConstructorsAllTypes) { + fp8_e5m2<1> s(static_cast(1)); + fp8_e5m2<1> i(static_cast(2)); + fp8_e5m2<1> l(static_cast(3)); + fp8_e5m2<1> ll(static_cast(-1)); + fp8_e5m2<1> us(static_cast(1)); + fp8_e5m2<1> ui(static_cast(2)); + fp8_e5m2<1> ul(static_cast(3)); + fp8_e5m2<1> ull(static_cast(4)); + + EXPECT_EQ(sizeof(s.vals), 1u); + EXPECT_EQ(sizeof(i.vals), 1u); + EXPECT_EQ(sizeof(l.vals), 1u); + EXPECT_EQ(sizeof(ll.vals), 1u); + EXPECT_EQ(sizeof(us.vals), 1u); + EXPECT_EQ(sizeof(ui.vals), 1u); + EXPECT_EQ(sizeof(ul.vals), 1u); + EXPECT_EQ(sizeof(ull.vals), 1u); + + EXPECT_EQ(s.vals[0], 0x3C); + EXPECT_EQ(i.vals[0], 0x40); + EXPECT_EQ(l.vals[0], 0x42); // 3.0 -> 0b0_10000_10 + EXPECT_EQ(ll.vals[0], 0xBC); // -1.0 + EXPECT_EQ(us.vals[0], 0x3C); + EXPECT_EQ(ui.vals[0], 0x40); + EXPECT_EQ(ul.vals[0], 0x42); // 3.0 + EXPECT_EQ(ull.vals[0], 0x44); // 4.0 -> 0b0_10001_00 +} + +TEST(FP8E5M2Test, AssignmentOperatorsAllTypes) { + fp8_e5m2<1> a(0.0f); + + EXPECT_EQ(sizeof(a.vals), 1u); + EXPECT_EQ(a.vals[0], 0x00); + + a = sycl::half(1.0f); + EXPECT_EQ(a.vals[0], 0x3C); + + a = sycl::ext::oneapi::bfloat16(2.0f); + EXPECT_EQ(a.vals[0], 0x40); + + a = 3.0f; + EXPECT_EQ(a.vals[0], 0x42); // 3.0 + + a = 4.0; + EXPECT_EQ(a.vals[0], 0x44); // 4.0 + + a = static_cast(-1); + EXPECT_EQ(a.vals[0], 0xBC); + + a = static_cast(2); + EXPECT_EQ(a.vals[0], 0x40); + + a = static_cast(1); + EXPECT_EQ(a.vals[0], 0x3C); + + a = static_cast(-2); + EXPECT_EQ(a.vals[0], 0xC0); + + a = static_cast(1); + EXPECT_EQ(a.vals[0], 0x3C); + + a = static_cast(2); + EXPECT_EQ(a.vals[0], 0x40); + + a = static_cast(3); + EXPECT_EQ(a.vals[0], 0x42); + + a = static_cast(4); + EXPECT_EQ(a.vals[0], 0x44); +} + +TEST(FP8E5M2Test, BoolOperatorWithNaN) { + float pz = 0.0f; + fp8_e5m2<1> zp(pz); + float zv = -0.0f; + fp8_e5m2<1> zn(zv); + float nv = {std::numeric_limits::quiet_NaN()}; + fp8_e5m2<1> nanv(nv); + + EXPECT_EQ(sizeof(zp.vals), 1u); + EXPECT_EQ(sizeof(zn.vals), 1u); + EXPECT_EQ(sizeof(nanv.vals), 1u); + + EXPECT_FALSE(static_cast(zp)); + EXPECT_FALSE(static_cast(zn)); + EXPECT_TRUE(static_cast(nanv)); // not +0 or -0 + EXPECT_EQ(nanv.vals[0], 0x7F); // NaN encoding remains S.11111.11 +} diff --git a/sycl/unittests/Extensions/fp8/fp8_e8m0.cpp b/sycl/unittests/Extensions/fp8/fp8_e8m0.cpp new file mode 100644 index 0000000000000..7c7de054559b6 --- /dev/null +++ b/sycl/unittests/Extensions/fp8/fp8_e8m0.cpp @@ -0,0 +1,280 @@ +#include +#include + +#include +#include +#include + +using namespace sycl::ext::oneapi::experimental; + +TEST(FP8E8M0Test, VariadicConstructorFloat) { + fp8_e8m0<4> a(1.0f, 2.0f, 1.1f, 0.0f); + + EXPECT_EQ(sizeof(a.vals), 4u); + EXPECT_EQ(a.vals[0], 0x7F); // 1.0 -> exp=127 + EXPECT_EQ(a.vals[1], 0x80); // 2.0 -> exp=128 + EXPECT_EQ(a.vals[2], 0x80); // 1.1 -> upward to 2.0 + EXPECT_EQ(a.vals[3], 0x00); // 0.0 -> min normal +} + +TEST(FP8E8M0Test, VariadicConstructorHalf) { + fp8_e8m0<2> a(sycl::half(1.0f), sycl::half(3.0f)); + + EXPECT_EQ(sizeof(a.vals), 2u); + EXPECT_EQ(a.vals[0], 0x7F); + EXPECT_EQ(a.vals[1], 0x81); // 3.0 -> upward to 4.0 +} + +TEST(FP8E8M0Test, VariadicConstructorBFloat16) { + fp8_e8m0<2> a(sycl::ext::oneapi::bfloat16(1.0f), + sycl::ext::oneapi::bfloat16(2.0f)); + + EXPECT_EQ(sizeof(a.vals), 2u); + EXPECT_EQ(a.vals[0], 0x7F); + EXPECT_EQ(a.vals[1], 0x80); +} + +TEST(FP8E8M0Test, VariadicConstructorDouble) { + fp8_e8m0<2> a(1.0, 3.0); + + EXPECT_EQ(sizeof(a.vals), 2u); + EXPECT_EQ(a.vals[0], 0x7F); + EXPECT_EQ(a.vals[1], 0x81); +} + +TEST(FP8E8M0Test, VariadicConstructorBoundaryEncodings) { + fp8_e8m0<3> a(std::ldexp(1.0f, 127), std::ldexp(1.0f, -127), + std::numeric_limits::quiet_NaN()); + + EXPECT_EQ(sizeof(a.vals), 3u); + EXPECT_EQ(a.vals[0], 0xFE); // max normal + EXPECT_EQ(a.vals[1], 0x00); // min normal + EXPECT_EQ(a.vals[2], 0xFF); // NaN +} + +TEST(FP8E8M0Test, CArrayConstructorFloatHostUpwardFinite) { + const float in[5] = {1.0f, 1.1f, 3.0f, 0.0f, 1000.0f}; + fp8_e8m0<5> a(in, rounding::upward, saturation::finite); + + EXPECT_EQ(sizeof(a.vals), 5u); + EXPECT_EQ(a.vals[0], 0x7F); + EXPECT_EQ(a.vals[1], 0x80); // upward to 2.0 + EXPECT_EQ(a.vals[2], 0x81); // upward to 4.0 + EXPECT_EQ(a.vals[3], 0x00); // min normal + EXPECT_EQ(a.vals[4], 0x89); // upward to 2^10 = 1024 +} + +TEST(FP8E8M0Test, CArrayConstructorHalfHostUpwardFinite) { + const sycl::half in[4] = {sycl::half(1.0f), sycl::half(1.1f), + sycl::half(3.0f), sycl::half(0.0f)}; + fp8_e8m0<4> a(in, rounding::upward, saturation::finite); + + EXPECT_EQ(sizeof(a.vals), 4u); + EXPECT_EQ(a.vals[0], 0x7F); + EXPECT_EQ(a.vals[1], 0x80); + EXPECT_EQ(a.vals[2], 0x81); + EXPECT_EQ(a.vals[3], 0x00); +} + +TEST(FP8E8M0Test, CArrayConstructorBFloat16HostUpwardFinite) { + const sycl::ext::oneapi::bfloat16 in[3] = { + sycl::ext::oneapi::bfloat16(1.0f), + sycl::ext::oneapi::bfloat16(2.0f), + sycl::ext::oneapi::bfloat16(0.0f)}; + fp8_e8m0<3> a(in, rounding::upward, saturation::finite); + + EXPECT_EQ(sizeof(a.vals), 3u); + EXPECT_EQ(a.vals[0], 0x7F); + EXPECT_EQ(a.vals[1], 0x80); + EXPECT_EQ(a.vals[2], 0x00); +} + +TEST(FP8E8M0Test, CArrayConstructorDoubleDefaultUpwardFinite) { + const double in[3] = {1.0, 3.0, 0.0}; + fp8_e8m0<3> a(in); + + EXPECT_EQ(sizeof(a.vals), 3u); + EXPECT_EQ(a.vals[0], 0x7F); + EXPECT_EQ(a.vals[1], 0x81); + EXPECT_EQ(a.vals[2], 0x00); +} + +TEST(FP8E8M0Test, MarrayConstructorAndOperatorsFloat) { + sycl::marray in = {1.0f, 2.0f, 3.0f, 0.0f}; + fp8_e8m0<4> a(in, rounding::upward, saturation::finite); + + EXPECT_EQ(sizeof(a.vals), 4u); + EXPECT_EQ(a.vals[0], 0x7F); + EXPECT_EQ(a.vals[1], 0x80); + EXPECT_EQ(a.vals[2], 0x81); + EXPECT_EQ(a.vals[3], 0x00); + + sycl::marray out = static_cast>(a); + EXPECT_EQ(out[0], 1.0f); + EXPECT_EQ(out[1], 2.0f); + EXPECT_EQ(out[2], 4.0f); + EXPECT_EQ(out[3], std::ldexp(1.0f, -127)); +} + +TEST(FP8E8M0Test, MarrayConstructorHalfBFloat16Double) { + sycl::marray hvals = {sycl::half(1.0f), sycl::half(3.0f)}; + sycl::marray bvals = { + sycl::ext::oneapi::bfloat16(1.0f), + sycl::ext::oneapi::bfloat16(2.0f)}; + sycl::marray dvals = {1.0, 3.0}; + + fp8_e8m0<2> ah(hvals, rounding::upward, saturation::finite); + fp8_e8m0<2> ab(bvals, rounding::upward, saturation::finite); + fp8_e8m0<2> ad(dvals); + + EXPECT_EQ(sizeof(ah.vals), 2u); + EXPECT_EQ(sizeof(ab.vals), 2u); + EXPECT_EQ(sizeof(ad.vals), 2u); + + EXPECT_EQ(ah.vals[0], 0x7F); + EXPECT_EQ(ah.vals[1], 0x81); + EXPECT_EQ(ab.vals[0], 0x7F); + EXPECT_EQ(ab.vals[1], 0x80); + EXPECT_EQ(ad.vals[0], 0x7F); + EXPECT_EQ(ad.vals[1], 0x81); +} + +TEST(FP8E8M0Test, IntegerConstructorsAllTypes) { + fp8_e8m0<1> s(static_cast(1)); + fp8_e8m0<1> i(static_cast(2)); + fp8_e8m0<1> l(static_cast(3)); + fp8_e8m0<1> ll(static_cast(4)); + fp8_e8m0<1> us(static_cast(1)); + fp8_e8m0<1> ui(static_cast(2)); + fp8_e8m0<1> ul(static_cast(3)); + fp8_e8m0<1> ull(static_cast(4)); + + EXPECT_EQ(sizeof(s.vals), 1u); + EXPECT_EQ(sizeof(i.vals), 1u); + EXPECT_EQ(sizeof(l.vals), 1u); + EXPECT_EQ(sizeof(ll.vals), 1u); + EXPECT_EQ(sizeof(us.vals), 1u); + EXPECT_EQ(sizeof(ui.vals), 1u); + EXPECT_EQ(sizeof(ul.vals), 1u); + EXPECT_EQ(sizeof(ull.vals), 1u); + + EXPECT_EQ(s.vals[0], 0x7F); // 1.0 + EXPECT_EQ(i.vals[0], 0x80); // 2.0 + EXPECT_EQ(l.vals[0], 0x81); // 3.0 -> upward to 4.0 + EXPECT_EQ(ll.vals[0], 0x81); // 4.0 + EXPECT_EQ(us.vals[0], 0x7F); + EXPECT_EQ(ui.vals[0], 0x80); + EXPECT_EQ(ul.vals[0], 0x81); + EXPECT_EQ(ull.vals[0], 0x81); +} + +TEST(FP8E8M0Test, AssignmentOperatorsAllTypes) { + fp8_e8m0<1> a(1.0f); + EXPECT_EQ(sizeof(a.vals), 1u); + + a = sycl::half(1.0f); + EXPECT_EQ(a.vals[0], 0x7F); + + a = sycl::ext::oneapi::bfloat16(2.0f); + EXPECT_EQ(a.vals[0], 0x80); + + a = 3.0f; + EXPECT_EQ(a.vals[0], 0x81); + + a = 4.0; + EXPECT_EQ(a.vals[0], 0x81); + + a = static_cast(1); + EXPECT_EQ(a.vals[0], 0x7F); + + a = static_cast(2); + EXPECT_EQ(a.vals[0], 0x80); + + a = static_cast(3); + EXPECT_EQ(a.vals[0], 0x81); + + a = static_cast(4); + EXPECT_EQ(a.vals[0], 0x81); + + a = static_cast(1); + EXPECT_EQ(a.vals[0], 0x7F); + + a = static_cast(2); + EXPECT_EQ(a.vals[0], 0x80); + + a = static_cast(3); + EXPECT_EQ(a.vals[0], 0x81); + + a = static_cast(4); + EXPECT_EQ(a.vals[0], 0x81); +} + +TEST(FP8E8M0Test, FloatingPointConversionOperators) { + fp8_e8m0<1> one(1.0f); + fp8_e8m0<1> max(std::ldexp(1.0f, 127)); + fp8_e8m0<1> min(std::ldexp(1.0f, -127)); + + EXPECT_EQ(sizeof(one.vals), 1u); + EXPECT_EQ(one.vals[0], 0x7F); + EXPECT_EQ(max.vals[0], 0xFE); + EXPECT_EQ(min.vals[0], 0x00); + + float fo = static_cast(one); + double doo = static_cast(one); + sycl::half ho = static_cast(one); + sycl::ext::oneapi::bfloat16 bo = static_cast(one); + + EXPECT_EQ(fo, 1.0f); + EXPECT_EQ(doo, 1.0); + EXPECT_EQ(static_cast(ho), 1.0f); + EXPECT_EQ(static_cast(bo), 1.0f); + + sycl::half hmax = static_cast(max); + EXPECT_TRUE(std::isinf(static_cast(hmax))); + EXPECT_FALSE(std::signbit(static_cast(hmax))); + + EXPECT_EQ(static_cast(min), std::ldexp(1.0f, -127)); +} + +TEST(FP8E8M0Test, UnsignedConversionOperatorsTowardZero) { + fp8_e8m0<1> a(3.0f); // upward to 4.0 + + EXPECT_EQ(sizeof(a.vals), 1u); + EXPECT_EQ(a.vals[0], 0x81); + + EXPECT_EQ(static_cast(a), 4u); + EXPECT_EQ(static_cast(a), 4u); + EXPECT_EQ(static_cast(a), 4u); + EXPECT_EQ(static_cast(a), 4u); + EXPECT_EQ(static_cast(a), 4u); +} + +TEST(FP8E8M0Test, BoolOperatorAlwaysTrue) { + fp8_e8m0<1> min(std::ldexp(1.0f, -127)); + fp8_e8m0<1> nanv(std::numeric_limits::quiet_NaN()); + + EXPECT_TRUE(static_cast(min)); + EXPECT_TRUE(static_cast(nanv)); +} + +TEST(FP8E8M0Test, MarrayConversionOperators) { + fp8_e8m0<3> a(1.0f, 3.0f, std::ldexp(1.0f, 127)); + + sycl::marray ho = static_cast>(a); + sycl::marray bo = + static_cast>(a); + sycl::marray fo = static_cast>(a); + + EXPECT_EQ(static_cast(ho[0]), 1.0f); + EXPECT_EQ(static_cast(ho[1]), 4.0f); + EXPECT_TRUE(std::isinf(static_cast(ho[2]))); + + EXPECT_EQ(static_cast(bo[0]), 1.0f); + EXPECT_EQ(static_cast(bo[1]), 4.0f); + EXPECT_EQ(static_cast(bo[2]), std::ldexp(1.0f, 127)); + + EXPECT_EQ(fo[0], 1.0f); + EXPECT_EQ(fo[1], 4.0f); + EXPECT_EQ(fo[2], std::ldexp(1.0f, 127)); +} + From 94904984f3b464a3a2f9218ccb184b44eb40980d Mon Sep 17 00:00:00 2001 From: "Klochkov, Denis" Date: Wed, 18 Feb 2026 10:28:48 +0100 Subject: [PATCH 02/89] [SYCL] update fp8 to check constraints --- .../oneapi/experimental/float_8bit/types.hpp | 193 +++++++++++------- sycl/unittests/Extensions/fp8/fp8_e8m0.cpp | 12 +- 2 files changed, 124 insertions(+), 81 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp index b030dd1009682..2150faf78264c 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp @@ -353,10 +353,9 @@ ConvertToFP8_CPU(T h, rounding R = rounding::to_even) noexcept { constexpr uint8_t MaxFrac = static_cast((1 << Mbits) - 1); constexpr uint8_t MaxFracForMaxNormal = (Ebits == 4 && Mbits == 3) ? static_cast(MaxFrac - 1u) : MaxFrac; - constexpr uint8_t MaxExpForMaxNormal = - (Ebits == 5 && Mbits == 2) - ? static_cast(ExpAllOnes - 1u) - : ExpAllOnes; + constexpr uint8_t MaxExpForMaxNormal = + (Ebits == 5 && Mbits == 2) ? static_cast(ExpAllOnes - 1u) + : ExpAllOnes; constexpr uint8_t MaxFracMask = MaxFrac; float x = static_cast(h); @@ -560,7 +559,7 @@ template class fp8_e4m3 { fp8_e4m3 &operator=(const fp8_e4m3 &) = default; // Construct from pack of half, float, double. - // Available only when the size of the pack is equal to N. + // Available only when the size of the pack is equal to N. template class fp8_e4m3 { vals[i] = ConvertBF16ToFP8(v[i], r); } - explicit fp8_e4m3(const sycl::marray &v, rounding r = rounding::to_even) { + explicit fp8_e4m3(const sycl::marray &v, + rounding r = rounding::to_even) { #ifdef __SYCL_DEVICE_ONLY__ static_assert(N == 1 || N == 2, "fp8_e4m3: Template argument N must be 1 or 2 on device"); @@ -691,12 +691,14 @@ template class fp8_e4m3 { // of half, bfloat16, float. // Should be removed once docs updated - explicit fp8_e4m3(const sycl::marray &vals, const stochastic_seed &seed, + explicit fp8_e4m3(const sycl::marray &vals, + const stochastic_seed &seed, saturation s = saturation::finite); explicit fp8_e4m3(const sycl::marray &vals, const stochastic_seed &seed, saturation s = saturation::finite); - explicit fp8_e4m3(const sycl::marray &vals, const stochastic_seed &seed, + explicit fp8_e4m3(const sycl::marray &vals, + const stochastic_seed &seed, saturation s = saturation::finite); // Construct from integer types. @@ -948,8 +950,7 @@ template class fp8_e4m3 { uint8_t vals[N]; }; -template -class fp8_e5m2 { +template class fp8_e5m2 { uint8_t ConvertToFP8(sycl::half h, rounding r) { #ifdef __SYCL_DEVICE_ONLY__ @@ -1123,7 +1124,8 @@ class fp8_e5m2 { vals[i] = ConvertBF16ToFP8(v[i], r); } - explicit fp8_e5m2(const sycl::marray &v, rounding r = rounding::to_even) { + explicit fp8_e5m2(const sycl::marray &v, + rounding r = rounding::to_even) { #ifdef __SYCL_DEVICE_ONLY__ static_assert(N == 1 || N == 2, "fp8_e5m2: Template argument N must be 1 or 2 on device"); @@ -1160,12 +1162,14 @@ class fp8_e5m2 { // of half, bfloat16, float. // should be removed once docs updated - explicit fp8_e5m2(const sycl::marray &vals, const stochastic_seed &seed, + explicit fp8_e5m2(const sycl::marray &vals, + const stochastic_seed &seed, saturation s = saturation::finite); explicit fp8_e5m2(const sycl::marray &vals, const stochastic_seed &seed, saturation s = saturation::finite); - explicit fp8_e5m2(const sycl::marray &vals, const stochastic_seed &seed, + explicit fp8_e5m2(const sycl::marray &vals, + const stochastic_seed &seed, saturation s = saturation::finite); // Construct from integer types. @@ -1412,7 +1416,7 @@ class fp8_e5m2 { }; static inline uint8_t ConvertToE8M0_CPU(float x, rounding R, - saturation S) noexcept { + saturation S) noexcept { // E8M0: unsigned 8-bit exponent code, bias 127. // Code 0xFF reserved for NaN. No Inf, no subnormals, no signed zero. constexpr int Bias = 127; @@ -1503,8 +1507,7 @@ static inline ToT ConvertFromE8M0_CPU(uint8_t code) noexcept { return ConvertFloatToTarget(v, rounding::to_even); } -template -class fp8_e8m0 { +template class fp8_e8m0 { public: fp8_e8m0() = default; fp8_e8m0(const fp8_e8m0 &) = default; @@ -1531,115 +1534,130 @@ class fp8_e8m0 { saturation::finite); } - explicit fp8_e8m0(half const (&in)[N], rounding r = rounding::upward, - saturation s = saturation::finite) { + explicit fp8_e8m0(half const (&in)[N], rounding r = rounding::upward) { + if (r != rounding::upward && r != rounding::toward_zero) + throw std::invalid_argument( + "fp8_e8m0 supports only rounding upward and toward_zero"); #ifdef __SYCL_DEVICE_ONLY__ static_assert(N == 1 || N == 2, "fp8_e8m0: Template argument N must be 1 or 2 on device"); #endif for (size_t i = 0; i < N; ++i) - vals[i] = ConvertToE8M0_CPU(static_cast(in[i]), r, s); + vals[i] = + ConvertToE8M0_CPU(static_cast(in[i]), r, saturation::finite); } - explicit fp8_e8m0(bfloat16 const (&in)[N], rounding r = rounding::upward, - saturation s = saturation::finite) { + + explicit fp8_e8m0(bfloat16 const (&in)[N], rounding r = rounding::upward) { + if (r != rounding::upward && r != rounding::toward_zero) + throw std::invalid_argument( + "fp8_e8m0 supports only rounding upward and toward_zero"); #ifdef __SYCL_DEVICE_ONLY__ static_assert(N == 1 || N == 2, "fp8_e8m0: Template argument N must be 1 or 2 on device"); #endif for (size_t i = 0; i < N; ++i) - vals[i] = ConvertToE8M0_CPU(static_cast(in[i]), r, s); + vals[i] = + ConvertToE8M0_CPU(static_cast(in[i]), r, saturation::finite); } - explicit fp8_e8m0(float const (&in)[N], rounding r = rounding::upward, - saturation s = saturation::finite) { + + explicit fp8_e8m0(float const (&in)[N], rounding r = rounding::upward) { + if (r != rounding::upward && r != rounding::toward_zero) + throw std::invalid_argument( + "fp8_e8m0 supports only rounding upward and toward_zero"); #ifdef __SYCL_DEVICE_ONLY__ static_assert(N == 1 || N == 2, "fp8_e8m0: Template argument N must be 1 or 2 on device"); #endif for (size_t i = 0; i < N; ++i) - vals[i] = ConvertToE8M0_CPU(in[i], r, s); + vals[i] = ConvertToE8M0_CPU(in[i], r, saturation::finite); } + explicit fp8_e8m0(double const (&in)[N]) { #ifdef __SYCL_DEVICE_ONLY__ static_assert(N == 1 || N == 2, "fp8_e8m0: Template argument N must be 1 or 2 on device"); #endif for (size_t i = 0; i < N; ++i) - vals[i] = ConvertToE8M0_CPU(static_cast(in[i]), - rounding::upward, saturation::finite); + vals[i] = ConvertToE8M0_CPU(static_cast(in[i]), rounding::upward, + saturation::finite); } - explicit fp8_e8m0(const marray &vals, rounding r = rounding::upward, - saturation s = saturation::finite) { + explicit fp8_e8m0(const marray &vals, + rounding r = rounding::upward) { + if (r != rounding::upward && r != rounding::toward_zero) + throw std::invalid_argument( + "fp8_e8m0 supports only rounding upward and toward_zero"); #ifdef __SYCL_DEVICE_ONLY__ static_assert(N == 1 || N == 2, "fp8_e8m0: Template argument N must be 1 or 2 on device"); - assert((r == rounding::upward && s == saturation::finite) && - "fp8_e8m0: device supports rounding::upward and saturation::finite " - "only"); + assert((r == rounding::upward) && + "fp8_e8m0: device supports rounding::upward only"); #endif for (size_t i = 0; i < N; ++i) - this->vals[i] = ConvertToE8M0_CPU(static_cast(vals[i]), r, s); + vals[i] = + ConvertToE8M0_CPU(static_cast(vals[i]), r, saturation::finite); } explicit fp8_e8m0(const marray &vals, - rounding r = rounding::upward, - saturation s = saturation::finite) { + rounding r = rounding::upward) { + if (r != rounding::upward && r != rounding::toward_zero) + throw std::invalid_argument( + "fp8_e8m0 supports only rounding upward and toward_zero"); #ifdef __SYCL_DEVICE_ONLY__ static_assert(N == 1 || N == 2, "fp8_e8m0: Template argument N must be 1 or 2 on device"); - assert((r == rounding::upward && s == saturation::finite) && - "fp8_e8m0: device supports rounding::upward and saturation::finite " - "only"); #endif for (size_t i = 0; i < N; ++i) - this->vals[i] = ConvertToE8M0_CPU(static_cast(vals[i]), r, s); + vals[i] = + ConvertToE8M0_CPU(static_cast(vals[i]), r, saturation::finite); } - explicit fp8_e8m0(const marray &vals, rounding r = rounding::upward, - saturation s = saturation::finite) { + explicit fp8_e8m0(const marray &vals, + rounding r = rounding::upward) { + if (r != rounding::upward && r != rounding::toward_zero) + throw std::invalid_argument( + "fp8_e8m0 supports only rounding upward and toward_zero"); #ifdef __SYCL_DEVICE_ONLY__ static_assert(N == 1 || N == 2, "fp8_e8m0: Template argument N must be 1 or 2 on device"); - assert((r == rounding::upward && s == saturation::finite) && - "fp8_e8m0: device supports rounding::upward and saturation::finite " - "only"); + assert((r == rounding::upward) && + "fp8_e8m0: device supports rounding::upward only"); #endif for (size_t i = 0; i < N; ++i) - this->vals[i] = ConvertToE8M0_CPU(vals[i], r, s); + vals[i] = ConvertToE8M0_CPU(vals[i], r, saturation::finite); } explicit fp8_e8m0(const marray &vals) { + if (r != rounding::upward && r != rounding::toward_zero) + throw std::invalid_argument( + "fp8_e8m0 supports only rounding upward and toward_zero"); #ifdef __SYCL_DEVICE_ONLY__ static_assert(N == 1 || N == 2, "fp8_e8m0: Template argument N must be 1 or 2 on device"); #endif for (size_t i = 0; i < N; ++i) - this->vals[i] = ConvertToE8M0_CPU(static_cast(vals[i]), - rounding::upward, saturation::finite); + vals[i] = ConvertToE8M0_CPU(static_cast(vals[i]), rounding::upward, + saturation::finite); } // Construct with stochastic rounding with user provided seed from an array of // half, bfloat16, float. // should be removed once docs updated - explicit fp8_e8m0(half const (&vals)[N], const stochastic_seed &seed, - saturation s = saturation::finite); - explicit fp8_e8m0(bfloat16 const (&vals)[N], const stochastic_seed &seed, - saturation s = saturation::finite); - explicit fp8_e8m0(double const (&vals)[N], const stochastic_seed &seed, - saturation s = saturation::finite); + explicit fp8_e8m0(half const (&vals)[N], const stochastic_seed &seed); + explicit fp8_e8m0(bfloat16 const (&vals)[N], const stochastic_seed &seed); + explicit fp8_e8m0(double const (&vals)[N], const stochastic_seed &seed); // Construct with stochastic rounding with user provided seed from an marray // of half, bfloat16, float. // should be removed once docs updated - explicit fp8_e8m0(const sycl::marray &vals, const stochastic_seed &seed, - saturation s = saturation::finite); + explicit fp8_e8m0(const sycl::marray &vals, + const stochastic_seed &seed); explicit fp8_e8m0(const sycl::marray &vals, - const stochastic_seed &seed, - saturation s = saturation::finite); - explicit fp8_e8m0(const sycl::marray &vals, const stochastic_seed &seed, - saturation s = saturation::finite); + const stochastic_seed &seed); + explicit fp8_e8m0(const sycl::marray &vals, + const stochastic_seed &seed); // Construct from integer types. // Available only when N==1. @@ -1647,7 +1665,7 @@ class fp8_e8m0 { explicit fp8_e8m0(short val) { assert(N == 1 && "fp8_e8m0: N must be 1 for short constructor"); vals[0] = ConvertToE8M0_CPU(static_cast(val), rounding::upward, - saturation::finite); + saturation::finite); } explicit fp8_e8m0(int val) : fp8_e8m0(static_cast(val)) {} explicit fp8_e8m0(long val) : fp8_e8m0(static_cast(val)) {} @@ -1655,18 +1673,19 @@ class fp8_e8m0 { explicit fp8_e8m0(unsigned short val) : fp8_e8m0(static_cast(val)) {} explicit fp8_e8m0(unsigned int val) : fp8_e8m0(static_cast(val)) {} explicit fp8_e8m0(unsigned long val) : fp8_e8m0(static_cast(val)) {} - explicit fp8_e8m0(unsigned long long val) : fp8_e8m0(static_cast(val)) {} + explicit fp8_e8m0(unsigned long long val) + : fp8_e8m0(static_cast(val)) {} fp8_e8m0 &operator=(half val) { static_assert(N == 1, "fp8_e8m0: N must be 1 for scalar assignment"); vals[0] = ConvertToE8M0_CPU(static_cast(val), rounding::upward, - saturation::finite); + saturation::finite); return *this; } fp8_e8m0 &operator=(bfloat16 val) { static_assert(N == 1, "fp8_e8m0: N must be 1 for scalar assignment"); vals[0] = ConvertToE8M0_CPU(static_cast(val), rounding::upward, - saturation::finite); + saturation::finite); return *this; } fp8_e8m0 &operator=(float val) { @@ -1678,11 +1697,15 @@ class fp8_e8m0 { fp8_e8m0 &operator=(short val) { return (*this = static_cast(val)); } fp8_e8m0 &operator=(int val) { return (*this = static_cast(val)); } fp8_e8m0 &operator=(long val) { return (*this = static_cast(val)); } - fp8_e8m0 &operator=(long long val) { return (*this = static_cast(val)); } + fp8_e8m0 &operator=(long long val) { + return (*this = static_cast(val)); + } fp8_e8m0 &operator=(unsigned short val) { return (*this = static_cast(val)); } - fp8_e8m0 &operator=(unsigned int val) { return (*this = static_cast(val)); } + fp8_e8m0 &operator=(unsigned int val) { + return (*this = static_cast(val)); + } fp8_e8m0 &operator=(unsigned long val) { return (*this = static_cast(val)); } @@ -1711,16 +1734,36 @@ class fp8_e8m0 { static_assert(N == 1, "fp8_e8m0: N must be 1 for scalar conversion"); return static_cast(static_cast(*this)); } - explicit operator signed char() const { return static_cast(static_cast(*this)); } - explicit operator short() const { return static_cast(static_cast(*this)); } - explicit operator int() const { return static_cast(static_cast(*this)); } - explicit operator long() const { return static_cast(static_cast(*this)); } - explicit operator long long() const { return static_cast(static_cast(*this)); } - explicit operator unsigned char() const { return static_cast(static_cast(*this)); } - explicit operator unsigned short() const { return static_cast(static_cast(*this)); } - explicit operator unsigned int() const { return static_cast(static_cast(*this)); } - explicit operator unsigned long() const { return static_cast(static_cast(*this)); } - explicit operator unsigned long long() const { return static_cast(static_cast(*this)); } + explicit operator signed char() const { + return static_cast(static_cast(*this)); + } + explicit operator short() const { + return static_cast(static_cast(*this)); + } + explicit operator int() const { + return static_cast(static_cast(*this)); + } + explicit operator long() const { + return static_cast(static_cast(*this)); + } + explicit operator long long() const { + return static_cast(static_cast(*this)); + } + explicit operator unsigned char() const { + return static_cast(static_cast(*this)); + } + explicit operator unsigned short() const { + return static_cast(static_cast(*this)); + } + explicit operator unsigned int() const { + return static_cast(static_cast(*this)); + } + explicit operator unsigned long() const { + return static_cast(static_cast(*this)); + } + explicit operator unsigned long long() const { + return static_cast(static_cast(*this)); + } explicit operator bool() const { static_assert(N == 1, "fp8_e8m0: operator bool requires size N=1"); diff --git a/sycl/unittests/Extensions/fp8/fp8_e8m0.cpp b/sycl/unittests/Extensions/fp8/fp8_e8m0.cpp index 7c7de054559b6..30313da4b9264 100644 --- a/sycl/unittests/Extensions/fp8/fp8_e8m0.cpp +++ b/sycl/unittests/Extensions/fp8/fp8_e8m0.cpp @@ -54,7 +54,7 @@ TEST(FP8E8M0Test, VariadicConstructorBoundaryEncodings) { TEST(FP8E8M0Test, CArrayConstructorFloatHostUpwardFinite) { const float in[5] = {1.0f, 1.1f, 3.0f, 0.0f, 1000.0f}; - fp8_e8m0<5> a(in, rounding::upward, saturation::finite); + fp8_e8m0<5> a(in, rounding::upward); EXPECT_EQ(sizeof(a.vals), 5u); EXPECT_EQ(a.vals[0], 0x7F); @@ -67,7 +67,7 @@ TEST(FP8E8M0Test, CArrayConstructorFloatHostUpwardFinite) { TEST(FP8E8M0Test, CArrayConstructorHalfHostUpwardFinite) { const sycl::half in[4] = {sycl::half(1.0f), sycl::half(1.1f), sycl::half(3.0f), sycl::half(0.0f)}; - fp8_e8m0<4> a(in, rounding::upward, saturation::finite); + fp8_e8m0<4> a(in, rounding::upward); EXPECT_EQ(sizeof(a.vals), 4u); EXPECT_EQ(a.vals[0], 0x7F); @@ -81,7 +81,7 @@ TEST(FP8E8M0Test, CArrayConstructorBFloat16HostUpwardFinite) { sycl::ext::oneapi::bfloat16(1.0f), sycl::ext::oneapi::bfloat16(2.0f), sycl::ext::oneapi::bfloat16(0.0f)}; - fp8_e8m0<3> a(in, rounding::upward, saturation::finite); + fp8_e8m0<3> a(in, rounding::upward); EXPECT_EQ(sizeof(a.vals), 3u); EXPECT_EQ(a.vals[0], 0x7F); @@ -101,7 +101,7 @@ TEST(FP8E8M0Test, CArrayConstructorDoubleDefaultUpwardFinite) { TEST(FP8E8M0Test, MarrayConstructorAndOperatorsFloat) { sycl::marray in = {1.0f, 2.0f, 3.0f, 0.0f}; - fp8_e8m0<4> a(in, rounding::upward, saturation::finite); + fp8_e8m0<4> a(in, rounding::upward); EXPECT_EQ(sizeof(a.vals), 4u); EXPECT_EQ(a.vals[0], 0x7F); @@ -123,8 +123,8 @@ TEST(FP8E8M0Test, MarrayConstructorHalfBFloat16Double) { sycl::ext::oneapi::bfloat16(2.0f)}; sycl::marray dvals = {1.0, 3.0}; - fp8_e8m0<2> ah(hvals, rounding::upward, saturation::finite); - fp8_e8m0<2> ab(bvals, rounding::upward, saturation::finite); + fp8_e8m0<2> ah(hvals, rounding::upward); + fp8_e8m0<2> ab(bvals, rounding::upward); fp8_e8m0<2> ad(dvals); EXPECT_EQ(sizeof(ah.vals), 2u); From a24ac3bc89f7f02591cc167d3ec036768e08a546 Mon Sep 17 00:00:00 2001 From: "Klochkov, Denis" Date: Thu, 19 Feb 2026 14:07:31 +0100 Subject: [PATCH 03/89] [SYCL] apply new updates from docs and e5m3 data type --- .../oneapi/experimental/float_8bit/types.hpp | 1075 ++++++++++------- sycl/unittests/Extensions/fp8/CMakeLists.txt | 1 + sycl/unittests/Extensions/fp8/fp8_e4m3.cpp | 648 +++------- sycl/unittests/Extensions/fp8/fp8_e5m2.cpp | 353 +++--- sycl/unittests/Extensions/fp8/fp8_e5m3.cpp | 495 ++++++++ sycl/unittests/Extensions/fp8/fp8_e8m0.cpp | 394 +++--- 6 files changed, 1672 insertions(+), 1294 deletions(-) create mode 100644 sycl/unittests/Extensions/fp8/fp8_e5m3.cpp diff --git a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp index 2150faf78264c..b5c22ee582cdb 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp @@ -171,7 +171,7 @@ template static inline ToT ConvertFromFP8_CPU(uint8_t b, rounding R = rounding::to_even) noexcept { static_assert((Ebits == 4 && Mbits == 3) || (Ebits == 5 && Mbits == 2) || - (Ebits == 8 && Mbits == 0), + (Ebits == 5 && Mbits == 3) || (Ebits == 8 && Mbits == 0), "Unsupported FP8 (Ebits,Mbits) combination"); constexpr int Bias = (1 << (Ebits - 1)) - 1; @@ -188,6 +188,9 @@ static inline ToT ConvertFromFP8_CPU(uint8_t b, if constexpr (Ebits == 8 && Mbits == 0) { sign_bit = 0u; exp = b; + } else if constexpr (Ebits == 5 && Mbits == 3) { + // E5M3 is unsigned: MSB belongs to exponent, no sign bit. + sign_bit = 0u; } auto make_nan = [&]() -> ToT { @@ -207,6 +210,11 @@ static inline ToT ConvertFromFP8_CPU(uint8_t b, if (frac != 0) return make_nan(); // frac==00 -> normal finite + } else if constexpr (Ebits == 5 && Mbits == 3) { + // E5M3: only frac==111 -> NaN, otherwise normal. + if (frac == MaxFrac) + return make_nan(); + // treat as normal finite } else // E8M0: exp all ones -> NaN return make_nan(); } @@ -255,7 +263,7 @@ static inline ToT ConvertFromFP8_CPU(uint8_t b, template static inline uint8_t ConvertToFP8_CPU(T h, rounding R = rounding::to_even) noexcept { - // Specialized implementation for fp8_e8m0 (Ebits=8, Mbits=0) + // Specialized implementation for fp8_e8m0_x (Ebits=8, Mbits=0) if constexpr (Ebits == 8 && Mbits == 0) { // Format characteristics (finite-only, no zero, no infinity): // - Bias: 127 @@ -352,7 +360,9 @@ ConvertToFP8_CPU(T h, rounding R = rounding::to_even) noexcept { constexpr uint8_t ExpAllOnes = static_cast((1 << Ebits) - 1); constexpr uint8_t MaxFrac = static_cast((1 << Mbits) - 1); constexpr uint8_t MaxFracForMaxNormal = - (Ebits == 4 && Mbits == 3) ? static_cast(MaxFrac - 1u) : MaxFrac; + (Ebits == 4 && Mbits == 3) || (Ebits == 5 && Mbits == 3) + ? static_cast(MaxFrac - 1u) + : MaxFrac; constexpr uint8_t MaxExpForMaxNormal = (Ebits == 5 && Mbits == 2) ? static_cast(ExpAllOnes - 1u) : ExpAllOnes; @@ -365,6 +375,11 @@ ConvertToFP8_CPU(T h, rounding R = rounding::to_even) noexcept { sign | ((ExpAllOnes << Mbits) | MaxFracMask)); // S.1111.111 -> NaN uint8_t sign_bit = sign ? 1u : 0u; float ax = std::fabs(x); + if constexpr (Ebits == 5 && Mbits == 3) { + // E5M3 is unsigned: ignore sign and treat input as magnitude. + sign = 0x00; + sign_bit = 0u; + } const float max_finite = (2.0f - std::ldexp(1.0f, 1 - Mbits)) * std::ldexp(1.0f, emax); @@ -470,15 +485,7 @@ uint8_t round(rounding r, uint8_t b, sycl::half yi, T vi) { return b; } -void CheckRoundingConstraints(rounding r) { -#ifdef __SYCL_DEVICE_ONLY__ -#else - if (r != rounding::to_even) - throw std::invalid_argument("Host code supports only rounding to_even"); -#endif -} - -template class fp8_e4m3 { +template class fp8_e4m3_x { static constexpr size_t NExpBits = 4; static constexpr size_t NFracBits = 3; static constexpr float MaxNormal = 448.0f; @@ -551,12 +558,20 @@ template class fp8_e4m3 { #endif } + void CheckConstraints(rounding r) const { + static_assert(N == 1 || N == 2, + "fp8_e4m3_x: Template argument N must be 1 or 2"); + if (r != rounding::to_even) + throw std::invalid_argument( + "fp8_e4m3_x: only rounding::to_even is supported"); + } + public: - fp8_e4m3() = default; - fp8_e4m3(const fp8_e4m3 &) = default; + fp8_e4m3_x() = default; + fp8_e4m3_x(const fp8_e4m3_x &) = default; - ~fp8_e4m3() = default; - fp8_e4m3 &operator=(const fp8_e4m3 &) = default; + ~fp8_e4m3_x() = default; + fp8_e4m3_x &operator=(const fp8_e4m3_x &) = default; // Construct from pack of half, float, double. // Available only when the size of the pack is equal to N. @@ -569,11 +584,9 @@ template class fp8_e4m3 { std::is_same_v, float> || std::is_same_v, double>) && ...))>> - explicit fp8_e4m3(Types... v) { -#ifdef __SYCL_DEVICE_ONLY__ + explicit fp8_e4m3_x(Types... v) { static_assert(N == 1 || N == 2, - "fp8_e4m3: Template argument N must be 1 or 2 on device"); -#endif + "fp8_e4m3_x: Template argument N must be 1 or 2"); if constexpr (((std::is_same_v, bfloat16>) && ...)) { const bfloat16 in[N] = {static_cast(v)...}; for (size_t i = 0; i < N; ++i) @@ -586,93 +599,56 @@ template class fp8_e4m3 { } // Construct from an array of half, bfloat16, float, double. - explicit fp8_e4m3(sycl::half const (&v)[N], rounding r = rounding::to_even) { -#ifdef __SYCL_DEVICE_ONLY__ - static_assert(N == 1 || N == 2, - "fp8_e4m3: Template argument N must be 1 or 2 on device"); -#endif + explicit fp8_e4m3_x(sycl::half const (&v)[N], + rounding r = rounding::to_even) { + CheckConstraints(r); // TODO: optimize with vectorized builtin calls for (size_t i = 0; i < N; ++i) vals[i] = ConvertToFP8(v[i], r); } - explicit fp8_e4m3(bfloat16 const (&v)[N], rounding r = rounding::to_even) { -#ifdef __SYCL_DEVICE_ONLY__ - static_assert(N == 1 || N == 2, - "fp8_e4m3: Template argument N must be 1 or 2 on device"); -#else - CheckRoundingConstraints(r); -#endif + explicit fp8_e4m3_x(bfloat16 const (&v)[N], rounding r = rounding::to_even) { + CheckConstraints(r); for (size_t i = 0; i < N; ++i) vals[i] = ConvertBF16ToFP8(v[i], r); } - explicit fp8_e4m3(float const (&v)[N], rounding r = rounding::to_even) { -#ifdef __SYCL_DEVICE_ONLY__ - static_assert(N == 1 || N == 2, - "fp8_e4m3: Template argument N must be 1 or 2 on device"); -#else - CheckRoundingConstraints(r); -#endif + explicit fp8_e4m3_x(float const (&v)[N], rounding r = rounding::to_even) { + CheckConstraints(r); for (size_t i = 0; i < N; ++i) vals[i] = ConvertToFP8(v[i], r); } - explicit fp8_e4m3(double const (&v)[N]) { -#ifdef __SYCL_DEVICE_ONLY__ + explicit fp8_e4m3_x(double const (&v)[N]) { static_assert(N == 1 || N == 2, - "fp8_e4m3: Template argument N must be 1 or 2 on device"); -#else - CheckRoundingConstraints(r); -#endif + "fp8_e4m3_x: Template argument N must be 1 or 2"); for (size_t i = 0; i < N; ++i) vals[i] = ConvertToFP8(v[i], rounding::to_even); } // Construct from an marray of half, bfloat16, float, double. - explicit fp8_e4m3(const sycl::marray &v, - rounding r = rounding::to_even) { -#ifdef __SYCL_DEVICE_ONLY__ - static_assert(N == 1 || N == 2, - "fp8_e4m3: Template argument N must be 1 or 2 on device"); -#else - CheckRoundingConstraints(r); -#endif + explicit fp8_e4m3_x(const sycl::marray &v, + rounding r = rounding::to_even) { + CheckConstraints(r); for (size_t i = 0; i < N; ++i) vals[i] = ConvertToFP8(v[i], r); } - explicit fp8_e4m3(const sycl::marray &v, - rounding r = rounding::to_even) { -#ifdef __SYCL_DEVICE_ONLY__ - static_assert(N == 1 || N == 2, - "fp8_e4m3: Template argument N must be 1 or 2 on device"); -#else - CheckRoundingConstraints(r); -#endif + explicit fp8_e4m3_x(const sycl::marray &v, + rounding r = rounding::to_even) { + CheckConstraints(r); for (size_t i = 0; i < N; ++i) vals[i] = ConvertBF16ToFP8(v[i], r); } - explicit fp8_e4m3(const sycl::marray &v, - rounding r = rounding::to_even) { -#ifdef __SYCL_DEVICE_ONLY__ - static_assert(N == 1 || N == 2, - "fp8_e4m3: Template argument N must be 1 or 2 on device"); -#else - CheckRoundingConstraints(r); -#endif + explicit fp8_e4m3_x(const sycl::marray &v, + rounding r = rounding::to_even) { + CheckConstraints(r); for (size_t i = 0; i < N; ++i) vals[i] = ConvertToFP8(v[i], r); } - explicit fp8_e4m3(const sycl::marray &v) { -#ifdef __SYCL_DEVICE_ONLY__ - static_assert(N == 1 || N == 2, - "fp8_e4m3: Template argument N must be 1 or 2 on device"); -#else - CheckRoundingConstraints(r); -#endif + explicit fp8_e4m3_x(const sycl::marray &v) { for (size_t i = 0; i < N; ++i) vals[i] = ConvertToFP8(v[i], rounding::to_even); } @@ -680,146 +656,149 @@ template class fp8_e4m3 { // Construct with stochastic rounding with user provided seed from an array of // half, bfloat16, float. // Should be removed once docs updated - explicit fp8_e4m3(half const (&vals)[N], const stochastic_seed &seed, - saturation s = saturation::finite); - explicit fp8_e4m3(bfloat16 const (&vals)[N], const stochastic_seed &seed, - saturation s = saturation::finite); - explicit fp8_e4m3(float const (&vals)[N], const stochastic_seed &seed, - saturation s = saturation::finite); + explicit fp8_e4m3_x(half const (&vals)[N], const stochastic_seed &seed, + saturation s = saturation::finite); + explicit fp8_e4m3_x(bfloat16 const (&vals)[N], const stochastic_seed &seed, + saturation s = saturation::finite); + explicit fp8_e4m3_x(float const (&vals)[N], const stochastic_seed &seed, + saturation s = saturation::finite); // Construct with stochastic rounding with user provided seed from an marray // of half, bfloat16, float. // Should be removed once docs updated - explicit fp8_e4m3(const sycl::marray &vals, - const stochastic_seed &seed, - saturation s = saturation::finite); - explicit fp8_e4m3(const sycl::marray &vals, - const stochastic_seed &seed, - saturation s = saturation::finite); - explicit fp8_e4m3(const sycl::marray &vals, - const stochastic_seed &seed, - saturation s = saturation::finite); + explicit fp8_e4m3_x(const sycl::marray &vals, + const stochastic_seed &seed, + saturation s = saturation::finite); + explicit fp8_e4m3_x(const sycl::marray &vals, + const stochastic_seed &seed, + saturation s = saturation::finite); + explicit fp8_e4m3_x(const sycl::marray &vals, + const stochastic_seed &seed, + saturation s = saturation::finite); // Construct from integer types. // Available only when N==1. - explicit fp8_e4m3(short val) { - assert(N == 1 && "fp8_e4m3: N must be 1 for short constructor"); + explicit fp8_e4m3_x(short val) { + assert(N == 1 && "fp8_e4m3_x: N must be 1 for short constructor"); vals[0] = ConvertToFP8(val, rounding::to_even); } - explicit fp8_e4m3(int val) { - assert(N == 1 && "fp8_e4m3: N must be 1 for int constructor"); + explicit fp8_e4m3_x(int val) { + assert(N == 1 && "fp8_e4m3_x: N must be 1 for int constructor"); vals[0] = ConvertToFP8(val, rounding::to_even); } - explicit fp8_e4m3(long val) { - assert(N == 1 && "fp8_e4m3: N must be 1 for long constructor"); + explicit fp8_e4m3_x(long val) { + assert(N == 1 && "fp8_e4m3_x: N must be 1 for long constructor"); vals[0] = ConvertToFP8(val, rounding::to_even); } - explicit fp8_e4m3(long long val) { - assert(N == 1 && "fp8_e4m3: N must be 1 for long long constructor"); + explicit fp8_e4m3_x(long long val) { + assert(N == 1 && "fp8_e4m3_x: N must be 1 for long long constructor"); vals[0] = ConvertToFP8(val, rounding::to_even); } - explicit fp8_e4m3(unsigned short val) { - assert(N == 1 && "fp8_e4m3: N must be 1 for unsigned short constructor"); + explicit fp8_e4m3_x(unsigned short val) { + assert(N == 1 && "fp8_e4m3_x: N must be 1 for unsigned short constructor"); vals[0] = ConvertToFP8(val, rounding::to_even); } - explicit fp8_e4m3(unsigned int val) { - assert(N == 1 && "fp8_e4m3: N must be 1 for unsigned int constructor"); + explicit fp8_e4m3_x(unsigned int val) { + assert(N == 1 && "fp8_e4m3_x: N must be 1 for unsigned int constructor"); vals[0] = ConvertToFP8(val, rounding::to_even); } - explicit fp8_e4m3(unsigned long val) { - assert(N == 1 && "fp8_e4m3: N must be 1 for unsigned long constructor"); + explicit fp8_e4m3_x(unsigned long val) { + assert(N == 1 && "fp8_e4m3_x: N must be 1 for unsigned long constructor"); vals[0] = ConvertToFP8(val, rounding::to_even); } - explicit fp8_e4m3(unsigned long long val) { + explicit fp8_e4m3_x(unsigned long long val) { assert(N == 1 && - "fp8_e4m3: N must be 1 for unsigned long long constructor"); + "fp8_e4m3_x: N must be 1 for unsigned long long constructor"); vals[0] = ConvertToFP8(val, rounding::to_even); } // Assign (operator) from half, bfloat16, float, double, and integer types. // Available only when N==1. - fp8_e4m3 &operator=(sycl::half val) { - assert(N == 1 && "fp8_e4m3: N must be 1 for half assignment operator"); + fp8_e4m3_x &operator=(sycl::half val) { + assert(N == 1 && "fp8_e4m3_x: N must be 1 for half assignment operator"); vals[0] = ConvertToFP8(val, rounding::to_even); return *this; } - fp8_e4m3 &operator=(bfloat16 val) { - assert(N == 1 && "fp8_e4m3: N must be 1 for bfloat16 assignment operator"); + fp8_e4m3_x &operator=(bfloat16 val) { + assert(N == 1 && + "fp8_e4m3_x: N must be 1 for bfloat16 assignment operator"); vals[0] = ConvertToFP8(val, rounding::to_even); return *this; } - fp8_e4m3 &operator=(float val) { - assert(N == 1 && "fp8_e4m3: N must be 1 for float assignment operator"); + fp8_e4m3_x &operator=(float val) { + assert(N == 1 && "fp8_e4m3_x: N must be 1 for float assignment operator"); vals[0] = ConvertToFP8(val, rounding::to_even); return *this; } - fp8_e4m3 &operator=(double val) { - assert(N == 1 && "fp8_e4m3: N must be 1 for double assignment operator"); + fp8_e4m3_x &operator=(double val) { + assert(N == 1 && "fp8_e4m3_x: N must be 1 for double assignment operator"); vals[0] = ConvertBF16ToFP8(val, rounding::to_even); return *this; } - fp8_e4m3 &operator=(short val) { - assert(N == 1 && "fp8_e4m3: N must be 1 for short assignment operator"); + fp8_e4m3_x &operator=(short val) { + assert(N == 1 && "fp8_e4m3_x: N must be 1 for short assignment operator"); vals[0] = ConvertToFP8(val, rounding::to_even); return *this; } - fp8_e4m3 &operator=(int val) { - assert(N == 1 && "fp8_e4m3: N must be 1 for int assignment operator"); + fp8_e4m3_x &operator=(int val) { + assert(N == 1 && "fp8_e4m3_x: N must be 1 for int assignment operator"); vals[0] = ConvertToFP8(val, rounding::to_even); return *this; } - fp8_e4m3 &operator=(long val) { - assert(N == 1 && "fp8_e4m3: N must be 1 for long assignment operator"); + fp8_e4m3_x &operator=(long val) { + assert(N == 1 && "fp8_e4m3_x: N must be 1 for long assignment operator"); vals[0] = ConvertToFP8(val, rounding::to_even); return *this; } - fp8_e4m3 &operator=(long long val) { - assert(N == 1 && "fp8_e4m3: N must be 1 for long long assignment operator"); + fp8_e4m3_x &operator=(long long val) { + assert(N == 1 && + "fp8_e4m3_x: N must be 1 for long long assignment operator"); vals[0] = ConvertToFP8(val, rounding::to_even); return *this; } - fp8_e4m3 &operator=(unsigned short val) { + fp8_e4m3_x &operator=(unsigned short val) { assert(N == 1 && - "fp8_e4m3: N must be 1 for unsigned short assignment operator"); + "fp8_e4m3_x: N must be 1 for unsigned short assignment operator"); vals[0] = ConvertToFP8(val, rounding::to_even); return *this; } - fp8_e4m3 &operator=(unsigned int val) { + fp8_e4m3_x &operator=(unsigned int val) { assert(N == 1 && - "fp8_e4m3: N must be 1 for unsigned int assignment operator"); + "fp8_e4m3_x: N must be 1 for unsigned int assignment operator"); vals[0] = ConvertToFP8(val, rounding::to_even); return *this; } - fp8_e4m3 &operator=(unsigned long val) { + fp8_e4m3_x &operator=(unsigned long val) { assert(N == 1 && - "fp8_e4m3: N must be 1 for unsigned long assignment operator"); + "fp8_e4m3_x: N must be 1 for unsigned long assignment operator"); vals[0] = ConvertToFP8(val, rounding::to_even); return *this; } - fp8_e4m3 &operator=(unsigned long long val) { - assert(N == 1 && - "fp8_e4m3: N must be 1 for unsigned long long assignment operator"); + fp8_e4m3_x &operator=(unsigned long long val) { + assert( + N == 1 && + "fp8_e4m3_x: N must be 1 for unsigned long long assignment operator"); vals[0] = ConvertToFP8(val, rounding::to_even); return *this; } @@ -828,20 +807,21 @@ template class fp8_e4m3 { // Available only when N==1. explicit operator half() const { - assert(N == 1 && "fp8_e4m3: N must be 1 for half conversion operator"); + assert(N == 1 && "fp8_e4m3_x: N must be 1 for half conversion operator"); return ConvertFromFP8(vals[0]); } explicit operator bfloat16() const { - assert(N == 1 && "fp8_e4m3: N must be 1 for bfloat16 conversion operator"); + assert(N == 1 && + "fp8_e4m3_x: N must be 1 for bfloat16 conversion operator"); return ConvertBF16FromFP8(vals[0]); } explicit operator float() const { - assert(N == 1 && "fp8_e4m3: N must be 1 for float conversion operator"); + assert(N == 1 && "fp8_e4m3_x: N must be 1 for float conversion operator"); return ConvertFromFP8(vals[0]); } explicit operator double() const { - assert(N == 1 && "fp8_e4m3: N must be 1 for double conversion operator"); + assert(N == 1 && "fp8_e4m3_x: N must be 1 for double conversion operator"); return ConvertFromFP8(vals[0]); } @@ -849,62 +829,64 @@ template class fp8_e4m3 { // Available only when N==1. explicit operator char() const { - assert(N == 1 && "fp8_e4m3: N must be 1 for char conversion operator"); + assert(N == 1 && "fp8_e4m3_x: N must be 1 for char conversion operator"); return ConvertFromFP8(vals[0]); } explicit operator signed char() const { assert(N == 1 && - "fp8_e4m3: N must be 1 for signed char conversion operator"); + "fp8_e4m3_x: N must be 1 for signed char conversion operator"); return ConvertFromFP8(vals[0]); } explicit operator short() const { - assert(N == 1 && "fp8_e4m3: N must be 1 for short conversion operator"); + assert(N == 1 && "fp8_e4m3_x: N must be 1 for short conversion operator"); return ConvertFromFP8(vals[0]); } explicit operator int() const { - assert(N == 1 && "fp8_e4m3: N must be 1 for int conversion operator"); + assert(N == 1 && "fp8_e4m3_x: N must be 1 for int conversion operator"); return ConvertFromFP8(vals[0]); } explicit operator long() const { - assert(N == 1 && "fp8_e4m3: N must be 1 for long conversion operator"); + assert(N == 1 && "fp8_e4m3_x: N must be 1 for long conversion operator"); return ConvertFromFP8(vals[0]); } explicit operator long long() const { - assert(N == 1 && "fp8_e4m3: N must be 1 for long long conversion operator"); + assert(N == 1 && + "fp8_e4m3_x: N must be 1 for long long conversion operator"); return ConvertFromFP8(vals[0]); } explicit operator unsigned char() const { assert(N == 1 && - "fp8_e4m3: N must be 1 for unsigned char conversion operator"); + "fp8_e4m3_x: N must be 1 for unsigned char conversion operator"); return ConvertFromFP8(vals[0]); } explicit operator unsigned short() const { assert(N == 1 && - "fp8_e4m3: N must be 1 for unsigned short conversion operator"); + "fp8_e4m3_x: N must be 1 for unsigned short conversion operator"); return ConvertFromFP8(vals[0]); } explicit operator unsigned int() const { assert(N == 1 && - "fp8_e4m3: N must be 1 for unsigned int conversion operator"); + "fp8_e4m3_x: N must be 1 for unsigned int conversion operator"); return ConvertFromFP8(vals[0]); } explicit operator unsigned long() const { assert(N == 1 && - "fp8_e4m3: N must be 1 for unsigned long conversion operator"); + "fp8_e4m3_x: N must be 1 for unsigned long conversion operator"); return ConvertFromFP8(vals[0]); } explicit operator unsigned long long() const { - assert(N == 1 && - "fp8_e4m3: N must be 1 for unsigned long long conversion operator"); + assert( + N == 1 && + "fp8_e4m3_x: N must be 1 for unsigned long long conversion operator"); return ConvertFromFP8(vals[0]); } @@ -912,7 +894,7 @@ template class fp8_e4m3 { // Available only when N==1. explicit operator bool() const { - static_assert(N == 1, "fp8_e4m3: operator() requires size N=1"); + static_assert(N == 1, "fp8_e4m3_x: operator() requires size N=1"); #ifdef __SYCL_DEVICE_ONLY__ // detect +0 / -0 sycl::half h = __builtin_spirv_ConvertE4M3ToFP16EXT(vals[0]); @@ -950,7 +932,7 @@ template class fp8_e4m3 { uint8_t vals[N]; }; -template class fp8_e5m2 { +template class fp8_e5m2_x { uint8_t ConvertToFP8(sycl::half h, rounding r) { #ifdef __SYCL_DEVICE_ONLY__ @@ -1015,11 +997,22 @@ template class fp8_e5m2 { #endif } + void CheckConstraints(rounding r, saturation s) const { + static_assert(N == 1 || N == 2, + "fp8_e5m2_x: Template argument N must be 1 or 2"); + if (r != rounding::to_even) + throw std::invalid_argument( + "fp8_e5m2_x: only rounding::to_even is supported"); + if (s != saturation::finite) + throw std::invalid_argument( + "fp8_e5m2_x: only saturation::finite is supported"); + } + public: - fp8_e5m2() = default; - fp8_e5m2(const fp8_e5m2 &) = default; - ~fp8_e5m2() = default; - fp8_e5m2 &operator=(const fp8_e5m2 &) = default; + fp8_e5m2_x() = default; + fp8_e5m2_x(const fp8_e5m2_x &) = default; + ~fp8_e5m2_x() = default; + fp8_e5m2_x &operator=(const fp8_e5m2_x &) = default; // Construct from pack of half, bfloat16, float, double. // Available only when the size of the pack is equal to N. @@ -1034,11 +1027,9 @@ template class fp8_e5m2 { std::is_same_v, float> || std::is_same_v, double>) && ...))>> - explicit fp8_e5m2(Types... v) { -#ifdef __SYCL_DEVICE_ONLY__ + explicit fp8_e5m2_x(Types... v) { static_assert(N == 1 || N == 2, - "fp8_e5m2: Template argument N must be 1 or 2 on device"); -#endif + "fp8_e5m2_x: Template argument N must be 1 or 2 on device"); if constexpr (((std::is_same_v, bfloat16>) && ...)) { const bfloat16 in[N] = {static_cast(v)...}; for (size_t i = 0; i < N; ++i) @@ -1052,97 +1043,61 @@ template class fp8_e5m2 { // Construct from an array of half, bfloat16, float, double. - explicit fp8_e5m2(half const (&v)[N], rounding r = rounding::to_even) { -#ifdef __SYCL_DEVICE_ONLY__ - static_assert(N == 1 || N == 2, - "fp8_e5m2: Template argument N must be 1 or 2 on device"); -#else - CheckRoundingConstraints(r); -#endif + explicit fp8_e5m2_x(half const (&v)[N], rounding r = rounding::to_even, + saturation s = saturation::finite) { + CheckConstraints(r, s); // TODO: optimize with vectorized builtin calls for (size_t i = 0; i < N; ++i) vals[i] = ConvertToFP8(v[i], r); } - explicit fp8_e5m2(bfloat16 const (&v)[N], rounding r = rounding::to_even) { -#ifdef __SYCL_DEVICE_ONLY__ - static_assert(N == 1 || N == 2, - "fp8_e5m2: Template argument N must be 1 or 2 on device"); -#else - CheckRoundingConstraints(r); -#endif + explicit fp8_e5m2_x(bfloat16 const (&v)[N], rounding r = rounding::to_even, + saturation s = saturation::finite) { + CheckConstraints(r, s); // TODO: optimize with vectorized builtin calls for (size_t i = 0; i < N; ++i) vals[i] = ConvertBF16ToFP8(v[i], r); } - explicit fp8_e5m2(float const (&v)[N], rounding r = rounding::to_even) { -#ifdef __SYCL_DEVICE_ONLY__ - static_assert(N == 1 || N == 2, - "fp8_e5m2: Template argument N must be 1 or 2 on device"); -#else - CheckRoundingConstraints(r); -#endif + explicit fp8_e5m2_x(float const (&v)[N], rounding r = rounding::to_even, + saturation s = saturation::finite) { + CheckConstraints(r, s); for (size_t i = 0; i < N; ++i) vals[i] = ConvertToFP8(v[i], r); } - explicit fp8_e5m2(double const (&v)[N]) { -#ifdef __SYCL_DEVICE_ONLY__ - static_assert(N == 1 || N == 2, - "fp8_e5m2: Template argument N must be 1 or 2 on device"); -#else - CheckRoundingConstraints(r); -#endif + explicit fp8_e5m2_x(double const (&v)[N]) { for (size_t i = 0; i < N; ++i) vals[i] = ConvertToFP8(v[i], rounding::to_even); } // Construct from an marray of half, bfloat16, float, double. - explicit fp8_e5m2(const sycl::marray &v, - rounding r = rounding::to_even) { -#ifdef __SYCL_DEVICE_ONLY__ - static_assert(N == 1 || N == 2, - "fp8_e5m2: Template argument N must be 1 or 2 on device"); -#else - CheckRoundingConstraints(r); -#endif + explicit fp8_e5m2_x(const sycl::marray &v, + rounding r = rounding::to_even, + saturation s = saturation::finite) { + CheckConstraints(r, s); for (size_t i = 0; i < N; ++i) vals[i] = ConvertToFP8(v[i], r); } - explicit fp8_e5m2(const sycl::marray &v, - rounding r = rounding::to_even) { -#ifdef __SYCL_DEVICE_ONLY__ - static_assert(N == 1 || N == 2, - "fp8_e5m2: Template argument N must be 1 or 2 on device"); -#else - CheckRoundingConstraints(r); -#endif + explicit fp8_e5m2_x(const sycl::marray &v, + rounding r = rounding::to_even, + saturation s = saturation::finite) { + CheckConstraints(r, s); for (size_t i = 0; i < N; ++i) vals[i] = ConvertBF16ToFP8(v[i], r); } - explicit fp8_e5m2(const sycl::marray &v, - rounding r = rounding::to_even) { -#ifdef __SYCL_DEVICE_ONLY__ - static_assert(N == 1 || N == 2, - "fp8_e5m2: Template argument N must be 1 or 2 on device"); -#else - CheckRoundingConstraints(r); -#endif + explicit fp8_e5m2_x(const sycl::marray &v, + rounding r = rounding::to_even, + saturation s = saturation::finite) { + CheckConstraints(r, s); for (size_t i = 0; i < N; ++i) vals[i] = ConvertToFP8(v[i], r); } - explicit fp8_e5m2(const sycl::marray &v) { -#ifdef __SYCL_DEVICE_ONLY__ - static_assert(N == 1 || N == 2, - "fp8_e5m2: Template argument N must be 1 or 2 on device"); -#else - CheckRoundingConstraints(r); -#endif + explicit fp8_e5m2_x(const sycl::marray &v) { for (size_t i = 0; i < N; ++i) vals[i] = ConvertToFP8(v[i], rounding::to_even); } @@ -1151,146 +1106,148 @@ template class fp8_e5m2 { // half, bfloat16, float. // should be removed once docs updated - explicit fp8_e5m2(half const (&vals)[N], const stochastic_seed &seed, - saturation s = saturation::finite); - explicit fp8_e5m2(bfloat16 const (&vals)[N], const stochastic_seed &seed, - saturation s = saturation::finite); - explicit fp8_e5m2(double const (&vals)[N], const stochastic_seed &seed, - saturation s = saturation::finite); + explicit fp8_e5m2_x(half const (&vals)[N], const stochastic_seed &seed, + saturation s = saturation::finite); + explicit fp8_e5m2_x(bfloat16 const (&vals)[N], const stochastic_seed &seed, + saturation s = saturation::finite); + explicit fp8_e5m2_x(double const (&vals)[N], const stochastic_seed &seed, + saturation s = saturation::finite); // Construct with stochastic rounding with user provided seed from an marray // of half, bfloat16, float. // should be removed once docs updated - explicit fp8_e5m2(const sycl::marray &vals, - const stochastic_seed &seed, - saturation s = saturation::finite); - explicit fp8_e5m2(const sycl::marray &vals, - const stochastic_seed &seed, - saturation s = saturation::finite); - explicit fp8_e5m2(const sycl::marray &vals, - const stochastic_seed &seed, - saturation s = saturation::finite); + explicit fp8_e5m2_x(const sycl::marray &vals, + const stochastic_seed &seed, + saturation s = saturation::finite); + explicit fp8_e5m2_x(const sycl::marray &vals, + const stochastic_seed &seed, + saturation s = saturation::finite); + explicit fp8_e5m2_x(const sycl::marray &vals, + const stochastic_seed &seed, + saturation s = saturation::finite); // Construct from integer types. // Available only when N==1. - explicit fp8_e5m2(short val) { - assert(N == 1 && "fp8_e5m2: N must be 1 for short constructor"); + explicit fp8_e5m2_x(short val) { + assert(N == 1 && "fp8_e5m2_x: N must be 1 for short constructor"); vals[0] = ConvertToFP8(val, rounding::to_even); } - explicit fp8_e5m2(int val) { - assert(N == 1 && "fp8_e5m2: N must be 1 for int constructor"); + explicit fp8_e5m2_x(int val) { + assert(N == 1 && "fp8_e5m2_x: N must be 1 for int constructor"); vals[0] = ConvertToFP8(val, rounding::to_even); } - explicit fp8_e5m2(long val) { - assert(N == 1 && "fp8_e5m2: N must be 1 for long constructor"); + explicit fp8_e5m2_x(long val) { + assert(N == 1 && "fp8_e5m2_x: N must be 1 for long constructor"); vals[0] = ConvertToFP8(val, rounding::to_even); } - explicit fp8_e5m2(long long val) { - assert(N == 1 && "fp8_e5m2: N must be 1 for long long constructor"); + explicit fp8_e5m2_x(long long val) { + assert(N == 1 && "fp8_e5m2_x: N must be 1 for long long constructor"); vals[0] = ConvertToFP8(val, rounding::to_even); } - explicit fp8_e5m2(unsigned short val) { - assert(N == 1 && "fp8_e5m2: N must be 1 for unsigned short constructor"); + explicit fp8_e5m2_x(unsigned short val) { + assert(N == 1 && "fp8_e5m2_x: N must be 1 for unsigned short constructor"); vals[0] = ConvertToFP8(val, rounding::to_even); } - explicit fp8_e5m2(unsigned int val) { - assert(N == 1 && "fp8_e5m2: N must be 1 for unsigned int constructor"); + explicit fp8_e5m2_x(unsigned int val) { + assert(N == 1 && "fp8_e5m2_x: N must be 1 for unsigned int constructor"); vals[0] = ConvertToFP8(val, rounding::to_even); } - explicit fp8_e5m2(unsigned long val) { - assert(N == 1 && "fp8_e5m2: N must be 1 for unsigned long constructor"); + explicit fp8_e5m2_x(unsigned long val) { + assert(N == 1 && "fp8_e5m2_x: N must be 1 for unsigned long constructor"); vals[0] = ConvertToFP8(val, rounding::to_even); } - explicit fp8_e5m2(unsigned long long val) { + explicit fp8_e5m2_x(unsigned long long val) { assert(N == 1 && - "fp8_e5m2: N must be 1 for unsigned long long constructor"); + "fp8_e5m2_x: N must be 1 for unsigned long long constructor"); vals[0] = ConvertToFP8(val, rounding::to_even); } // Assign (operator) from half, bfloat16, float, double, and integer types. // Available only when N==1. - fp8_e5m2 &operator=(sycl::half val) { - assert(N == 1 && "fp8_e5m2: N must be 1 for half assignment operator"); + fp8_e5m2_x &operator=(sycl::half val) { + assert(N == 1 && "fp8_e5m2_x: N must be 1 for half assignment operator"); vals[0] = ConvertToFP8(val, rounding::to_even); return *this; } - fp8_e5m2 &operator=(bfloat16 val) { - assert(N == 1 && "fp8_e5m2: N must be 1 for half bfloat16 operator"); + fp8_e5m2_x &operator=(bfloat16 val) { + assert(N == 1 && "fp8_e5m2_x: N must be 1 for half bfloat16 operator"); vals[0] = ConvertBF16ToFP8(val, rounding::to_even); return *this; } - fp8_e5m2 &operator=(float val) { - assert(N == 1 && "fp8_e5m2: N must be 1 for float assignment operator"); + fp8_e5m2_x &operator=(float val) { + assert(N == 1 && "fp8_e5m2_x: N must be 1 for float assignment operator"); vals[0] = ConvertToFP8(val, rounding::to_even); return *this; } - fp8_e5m2 &operator=(double val) { - assert(N == 1 && "fp8_e5m2: N must be 1 for double assignment operator"); + fp8_e5m2_x &operator=(double val) { + assert(N == 1 && "fp8_e5m2_x: N must be 1 for double assignment operator"); vals[0] = ConvertToFP8(val, rounding::to_even); return *this; } - fp8_e5m2 &operator=(short val) { - assert(N == 1 && "fp8_e5m2: N must be 1 for short assignment operator"); + fp8_e5m2_x &operator=(short val) { + assert(N == 1 && "fp8_e5m2_x: N must be 1 for short assignment operator"); vals[0] = ConvertToFP8(val, rounding::to_even); return *this; } - fp8_e5m2 &operator=(int val) { - assert(N == 1 && "fp8_e5m2: N must be 1 for int assignment operator"); + fp8_e5m2_x &operator=(int val) { + assert(N == 1 && "fp8_e5m2_x: N must be 1 for int assignment operator"); vals[0] = ConvertToFP8(val, rounding::to_even); return *this; } - fp8_e5m2 &operator=(long val) { - assert(N == 1 && "fp8_e5m2: N must be 1 for long assignment operator"); + fp8_e5m2_x &operator=(long val) { + assert(N == 1 && "fp8_e5m2_x: N must be 1 for long assignment operator"); vals[0] = ConvertToFP8(val, rounding::to_even); return *this; } - fp8_e5m2 &operator=(long long val) { - assert(N == 1 && "fp8_e5m2: N must be 1 for long long assignment operator"); + fp8_e5m2_x &operator=(long long val) { + assert(N == 1 && + "fp8_e5m2_x: N must be 1 for long long assignment operator"); vals[0] = ConvertToFP8(val, rounding::to_even); return *this; } - fp8_e5m2 &operator=(unsigned short val) { + fp8_e5m2_x &operator=(unsigned short val) { assert(N == 1 && - "fp8_e5m2: N must be 1 for unsigned short assignment operator"); + "fp8_e5m2_x: N must be 1 for unsigned short assignment operator"); vals[0] = ConvertToFP8(val, rounding::to_even); return *this; } - fp8_e5m2 &operator=(unsigned int val) { + fp8_e5m2_x &operator=(unsigned int val) { assert(N == 1 && - "fp8_e5m2: N must be 1 for unsigned int assignment operator"); + "fp8_e5m2_x: N must be 1 for unsigned int assignment operator"); vals[0] = ConvertToFP8(val, rounding::to_even); return *this; } - fp8_e5m2 &operator=(unsigned long val) { + fp8_e5m2_x &operator=(unsigned long val) { assert(N == 1 && - "fp8_e5m2: N must be 1 for unsigned long assignment operator"); + "fp8_e5m2_x: N must be 1 for unsigned long assignment operator"); vals[0] = ConvertToFP8(val, rounding::to_even); return *this; } - fp8_e5m2 &operator=(unsigned long long val) { - assert(N == 1 && - "fp8_e5m2: N must be 1 for unsigned long long assignment operator"); + fp8_e5m2_x &operator=(unsigned long long val) { + assert( + N == 1 && + "fp8_e5m2_x: N must be 1 for unsigned long long assignment operator"); vals[0] = ConvertToFP8(val, rounding::to_even); return *this; } @@ -1299,22 +1256,23 @@ template class fp8_e5m2 { // Available only when N==1. explicit operator half() const { - assert(N == 1 && "fp8_e5m2: N must be 1 for half conversion operator"); + assert(N == 1 && "fp8_e5m2_x: N must be 1 for half conversion operator"); return ConvertFromFP8(vals[0]); } explicit operator bfloat16() const { - assert(N == 1 && "fp8_e5m2: N must be 1 for bfloat16 conversion operator"); + assert(N == 1 && + "fp8_e5m2_x: N must be 1 for bfloat16 conversion operator"); return ConvertFP16FromFP8(vals[0]); } explicit operator float() const { - assert(N == 1 && "fp8_e5m2: N must be 1 for float conversion operator"); + assert(N == 1 && "fp8_e5m2_x: N must be 1 for float conversion operator"); return ConvertFromFP8(vals[0]); } explicit operator double() const { - assert(N == 1 && "fp8_e5m2: N must be 1 for double conversion operator"); + assert(N == 1 && "fp8_e5m2_x: N must be 1 for double conversion operator"); return ConvertFromFP8(vals[0]); } @@ -1322,63 +1280,65 @@ template class fp8_e5m2 { // Available only when N==1. explicit operator char() const { - assert(N == 1 && "fp8_e5m2: N must be 1 for char conversion operator"); + assert(N == 1 && "fp8_e5m2_x: N must be 1 for char conversion operator"); return ConvertFromFP8(vals[0]); } explicit operator signed char() const { assert(N == 1 && - "fp8_e5m2: N must be 1 for signed char conversion operator"); + "fp8_e5m2_x: N must be 1 for signed char conversion operator"); return ConvertFromFP8(vals[0]); } explicit operator short() const { - assert(N == 1 && "fp8_e5m2: N must be 1 for short conversion operator"); + assert(N == 1 && "fp8_e5m2_x: N must be 1 for short conversion operator"); return ConvertFromFP8(vals[0]); } explicit operator int() const { - assert(N == 1 && "fp8_e5m2: N must be 1 for int conversion operator"); + assert(N == 1 && "fp8_e5m2_x: N must be 1 for int conversion operator"); return ConvertFromFP8(vals[0]); } explicit operator long() const { - assert(N == 1 && "fp8_e5m2: N must be 1 for long conversion operator"); + assert(N == 1 && "fp8_e5m2_x: N must be 1 for long conversion operator"); return ConvertFromFP8(vals[0]); } explicit operator long long() const { - assert(N == 1 && "fp8_e5m2: N must be 1 for long long conversion operator"); + assert(N == 1 && + "fp8_e5m2_x: N must be 1 for long long conversion operator"); return ConvertFromFP8(vals[0]); } explicit operator unsigned char() const { assert(N == 1 && - "fp8_e5m2: N must be 1 for unsigned char conversion operator"); + "fp8_e5m2_x: N must be 1 for unsigned char conversion operator"); return ConvertFromFP8(vals[0]); } explicit operator unsigned short() const { assert(N == 1 && - "fp8_e5m2: N must be 1 for unsigned short conversion operator"); + "fp8_e5m2_x: N must be 1 for unsigned short conversion operator"); return ConvertFromFP8(vals[0]); } explicit operator unsigned int() const { assert(N == 1 && - "fp8_e5m2: N must be 1 for unsigned int conversion operator"); + "fp8_e5m2_x: N must be 1 for unsigned int conversion operator"); return ConvertFromFP8(vals[0]); } explicit operator unsigned long() const { assert(N == 1 && - "fp8_e5m2: N must be 1 for unsigned long conversion operator"); + "fp8_e5m2_x: N must be 1 for unsigned long conversion operator"); return ConvertFromFP8(vals[0]); } explicit operator unsigned long long() const { - assert(N == 1 && - "fp8_e5m2: N must be 1 for unsigned long long conversion operator"); + assert( + N == 1 && + "fp8_e5m2_x: N must be 1 for unsigned long long conversion operator"); return ConvertFromFP8(vals[0]); } @@ -1386,7 +1346,7 @@ template class fp8_e5m2 { // Available only when N==1. explicit operator bool() const { - static_assert(N == 1, "fp8_e5m2: operator() requires size N=1"); + static_assert(N == 1, "fp8_e5m2_x: operator() requires size N=1"); // false iff +0 or -0; otherwise true. return vals[0] != 0x00 && vals[0] != 0x80; } @@ -1507,12 +1467,21 @@ static inline ToT ConvertFromE8M0_CPU(uint8_t code) noexcept { return ConvertFloatToTarget(v, rounding::to_even); } -template class fp8_e8m0 { +template class fp8_e8m0_x { + + void CheckConstraints(rounding r) const { + static_assert(N == 1 || N == 2, + "fp8_e8m0_x: Template argument N must be 1 or 2"); + if (r != rounding::upward && r != rounding::toward_zero) + throw std::invalid_argument("fp8_e8m0_x: only rounding::upward and " + "rounding::toward_zero are supported"); + } + public: - fp8_e8m0() = default; - fp8_e8m0(const fp8_e8m0 &) = default; - ~fp8_e8m0() = default; - fp8_e8m0 &operator=(const fp8_e8m0 &) = default; + fp8_e8m0_x() = default; + fp8_e8m0_x(const fp8_e8m0_x &) = default; + ~fp8_e8m0_x() = default; + fp8_e8m0_x &operator=(const fp8_e8m0_x &) = default; template class fp8_e8m0 { std::is_same_v, float> || std::is_same_v, double>) && ...))>> - explicit fp8_e8m0(Types... v) { + explicit fp8_e8m0_x(Types... v) { #ifdef __SYCL_DEVICE_ONLY__ static_assert(N == 1 || N == 2, - "fp8_e8m0: Template argument N must be 1 or 2 on device"); + "fp8_e8m0_x: Template argument N must be 1 or 2 on device"); #endif using InT = std::common_type_t...>; const InT in[N] = {v...}; @@ -1534,109 +1503,62 @@ template class fp8_e8m0 { saturation::finite); } - explicit fp8_e8m0(half const (&in)[N], rounding r = rounding::upward) { - if (r != rounding::upward && r != rounding::toward_zero) - throw std::invalid_argument( - "fp8_e8m0 supports only rounding upward and toward_zero"); -#ifdef __SYCL_DEVICE_ONLY__ - static_assert(N == 1 || N == 2, - "fp8_e8m0: Template argument N must be 1 or 2 on device"); -#endif + explicit fp8_e8m0_x(half const (&in)[N], rounding r = rounding::upward) { + CheckConstraints(r); for (size_t i = 0; i < N; ++i) vals[i] = ConvertToE8M0_CPU(static_cast(in[i]), r, saturation::finite); } - explicit fp8_e8m0(bfloat16 const (&in)[N], rounding r = rounding::upward) { - if (r != rounding::upward && r != rounding::toward_zero) - throw std::invalid_argument( - "fp8_e8m0 supports only rounding upward and toward_zero"); -#ifdef __SYCL_DEVICE_ONLY__ - static_assert(N == 1 || N == 2, - "fp8_e8m0: Template argument N must be 1 or 2 on device"); -#endif + explicit fp8_e8m0_x(bfloat16 const (&in)[N], rounding r = rounding::upward) { + CheckConstraints(r); for (size_t i = 0; i < N; ++i) vals[i] = ConvertToE8M0_CPU(static_cast(in[i]), r, saturation::finite); } - explicit fp8_e8m0(float const (&in)[N], rounding r = rounding::upward) { - if (r != rounding::upward && r != rounding::toward_zero) - throw std::invalid_argument( - "fp8_e8m0 supports only rounding upward and toward_zero"); -#ifdef __SYCL_DEVICE_ONLY__ - static_assert(N == 1 || N == 2, - "fp8_e8m0: Template argument N must be 1 or 2 on device"); -#endif + explicit fp8_e8m0_x(float const (&in)[N], rounding r = rounding::upward) { + CheckConstraints(r); for (size_t i = 0; i < N; ++i) vals[i] = ConvertToE8M0_CPU(in[i], r, saturation::finite); } - explicit fp8_e8m0(double const (&in)[N]) { -#ifdef __SYCL_DEVICE_ONLY__ + explicit fp8_e8m0_x(double const (&in)[N]) { static_assert(N == 1 || N == 2, - "fp8_e8m0: Template argument N must be 1 or 2 on device"); -#endif + "fp8_e8m0_x: Template argument N must be 1 or 2 on device"); for (size_t i = 0; i < N; ++i) vals[i] = ConvertToE8M0_CPU(static_cast(in[i]), rounding::upward, saturation::finite); } - explicit fp8_e8m0(const marray &vals, - rounding r = rounding::upward) { - if (r != rounding::upward && r != rounding::toward_zero) - throw std::invalid_argument( - "fp8_e8m0 supports only rounding upward and toward_zero"); -#ifdef __SYCL_DEVICE_ONLY__ - static_assert(N == 1 || N == 2, - "fp8_e8m0: Template argument N must be 1 or 2 on device"); - assert((r == rounding::upward) && - "fp8_e8m0: device supports rounding::upward only"); -#endif + explicit fp8_e8m0_x(const marray &in, + rounding r = rounding::upward) { + CheckConstraints(r); for (size_t i = 0; i < N; ++i) vals[i] = - ConvertToE8M0_CPU(static_cast(vals[i]), r, saturation::finite); + ConvertToE8M0_CPU(static_cast(in[i]), r, saturation::finite); } - explicit fp8_e8m0(const marray &vals, - rounding r = rounding::upward) { - if (r != rounding::upward && r != rounding::toward_zero) - throw std::invalid_argument( - "fp8_e8m0 supports only rounding upward and toward_zero"); -#ifdef __SYCL_DEVICE_ONLY__ - static_assert(N == 1 || N == 2, - "fp8_e8m0: Template argument N must be 1 or 2 on device"); -#endif + explicit fp8_e8m0_x(const marray &in, + rounding r = rounding::upward) { + CheckConstraints(r); for (size_t i = 0; i < N; ++i) vals[i] = - ConvertToE8M0_CPU(static_cast(vals[i]), r, saturation::finite); + ConvertToE8M0_CPU(static_cast(in[i]), r, saturation::finite); } - explicit fp8_e8m0(const marray &vals, - rounding r = rounding::upward) { - if (r != rounding::upward && r != rounding::toward_zero) - throw std::invalid_argument( - "fp8_e8m0 supports only rounding upward and toward_zero"); -#ifdef __SYCL_DEVICE_ONLY__ - static_assert(N == 1 || N == 2, - "fp8_e8m0: Template argument N must be 1 or 2 on device"); - assert((r == rounding::upward) && - "fp8_e8m0: device supports rounding::upward only"); -#endif + explicit fp8_e8m0_x(const marray &in, + rounding r = rounding::upward) { + CheckConstraints(r); for (size_t i = 0; i < N; ++i) - vals[i] = ConvertToE8M0_CPU(vals[i], r, saturation::finite); + vals[i] = ConvertToE8M0_CPU(in[i], r, saturation::finite); } - explicit fp8_e8m0(const marray &vals) { - if (r != rounding::upward && r != rounding::toward_zero) - throw std::invalid_argument( - "fp8_e8m0 supports only rounding upward and toward_zero"); -#ifdef __SYCL_DEVICE_ONLY__ + explicit fp8_e8m0_x(const marray &in) { static_assert(N == 1 || N == 2, - "fp8_e8m0: Template argument N must be 1 or 2 on device"); -#endif + "fp8_e8m0_x: Template argument N must be 1 or 2 on device"); for (size_t i = 0; i < N; ++i) - vals[i] = ConvertToE8M0_CPU(static_cast(vals[i]), rounding::upward, + vals[i] = ConvertToE8M0_CPU(static_cast(in[i]), rounding::upward, saturation::finite); } @@ -1644,94 +1566,98 @@ template class fp8_e8m0 { // half, bfloat16, float. // should be removed once docs updated - explicit fp8_e8m0(half const (&vals)[N], const stochastic_seed &seed); - explicit fp8_e8m0(bfloat16 const (&vals)[N], const stochastic_seed &seed); - explicit fp8_e8m0(double const (&vals)[N], const stochastic_seed &seed); + explicit fp8_e8m0_x(half const (&vals)[N], const stochastic_seed &seed); + explicit fp8_e8m0_x(bfloat16 const (&vals)[N], const stochastic_seed &seed); + explicit fp8_e8m0_x(double const (&vals)[N], const stochastic_seed &seed); // Construct with stochastic rounding with user provided seed from an marray // of half, bfloat16, float. // should be removed once docs updated - explicit fp8_e8m0(const sycl::marray &vals, - const stochastic_seed &seed); - explicit fp8_e8m0(const sycl::marray &vals, - const stochastic_seed &seed); - explicit fp8_e8m0(const sycl::marray &vals, - const stochastic_seed &seed); + explicit fp8_e8m0_x(const sycl::marray &vals, + const stochastic_seed &seed); + explicit fp8_e8m0_x(const sycl::marray &vals, + const stochastic_seed &seed); + explicit fp8_e8m0_x(const sycl::marray &vals, + const stochastic_seed &seed); // Construct from integer types. // Available only when N==1. - explicit fp8_e8m0(short val) { - assert(N == 1 && "fp8_e8m0: N must be 1 for short constructor"); + explicit fp8_e8m0_x(short val) { + assert(N == 1 && "fp8_e8m0_x: N must be 1 for short constructor"); vals[0] = ConvertToE8M0_CPU(static_cast(val), rounding::upward, saturation::finite); } - explicit fp8_e8m0(int val) : fp8_e8m0(static_cast(val)) {} - explicit fp8_e8m0(long val) : fp8_e8m0(static_cast(val)) {} - explicit fp8_e8m0(long long val) : fp8_e8m0(static_cast(val)) {} - explicit fp8_e8m0(unsigned short val) : fp8_e8m0(static_cast(val)) {} - explicit fp8_e8m0(unsigned int val) : fp8_e8m0(static_cast(val)) {} - explicit fp8_e8m0(unsigned long val) : fp8_e8m0(static_cast(val)) {} - explicit fp8_e8m0(unsigned long long val) - : fp8_e8m0(static_cast(val)) {} - - fp8_e8m0 &operator=(half val) { - static_assert(N == 1, "fp8_e8m0: N must be 1 for scalar assignment"); + explicit fp8_e8m0_x(int val) : fp8_e8m0_x(static_cast(val)) {} + explicit fp8_e8m0_x(long val) : fp8_e8m0_x(static_cast(val)) {} + explicit fp8_e8m0_x(long long val) : fp8_e8m0_x(static_cast(val)) {} + explicit fp8_e8m0_x(unsigned short val) + : fp8_e8m0_x(static_cast(val)) {} + explicit fp8_e8m0_x(unsigned int val) : fp8_e8m0_x(static_cast(val)) {} + explicit fp8_e8m0_x(unsigned long val) + : fp8_e8m0_x(static_cast(val)) {} + explicit fp8_e8m0_x(unsigned long long val) + : fp8_e8m0_x(static_cast(val)) {} + + fp8_e8m0_x &operator=(half val) { + static_assert(N == 1, "fp8_e8m0_x: N must be 1 for scalar assignment"); vals[0] = ConvertToE8M0_CPU(static_cast(val), rounding::upward, saturation::finite); return *this; } - fp8_e8m0 &operator=(bfloat16 val) { - static_assert(N == 1, "fp8_e8m0: N must be 1 for scalar assignment"); + fp8_e8m0_x &operator=(bfloat16 val) { + static_assert(N == 1, "fp8_e8m0_x: N must be 1 for scalar assignment"); vals[0] = ConvertToE8M0_CPU(static_cast(val), rounding::upward, saturation::finite); return *this; } - fp8_e8m0 &operator=(float val) { - static_assert(N == 1, "fp8_e8m0: N must be 1 for scalar assignment"); + fp8_e8m0_x &operator=(float val) { + static_assert(N == 1, "fp8_e8m0_x: N must be 1 for scalar assignment"); vals[0] = ConvertToE8M0_CPU(val, rounding::upward, saturation::finite); return *this; } - fp8_e8m0 &operator=(double val) { return (*this = static_cast(val)); } - fp8_e8m0 &operator=(short val) { return (*this = static_cast(val)); } - fp8_e8m0 &operator=(int val) { return (*this = static_cast(val)); } - fp8_e8m0 &operator=(long val) { return (*this = static_cast(val)); } - fp8_e8m0 &operator=(long long val) { + fp8_e8m0_x &operator=(double val) { + return (*this = static_cast(val)); + } + fp8_e8m0_x &operator=(short val) { return (*this = static_cast(val)); } + fp8_e8m0_x &operator=(int val) { return (*this = static_cast(val)); } + fp8_e8m0_x &operator=(long val) { return (*this = static_cast(val)); } + fp8_e8m0_x &operator=(long long val) { return (*this = static_cast(val)); } - fp8_e8m0 &operator=(unsigned short val) { + fp8_e8m0_x &operator=(unsigned short val) { return (*this = static_cast(val)); } - fp8_e8m0 &operator=(unsigned int val) { + fp8_e8m0_x &operator=(unsigned int val) { return (*this = static_cast(val)); } - fp8_e8m0 &operator=(unsigned long val) { + fp8_e8m0_x &operator=(unsigned long val) { return (*this = static_cast(val)); } - fp8_e8m0 &operator=(unsigned long long val) { + fp8_e8m0_x &operator=(unsigned long long val) { return (*this = static_cast(val)); } explicit operator half() const { - static_assert(N == 1, "fp8_e8m0: N must be 1 for scalar conversion"); + static_assert(N == 1, "fp8_e8m0_x: N must be 1 for scalar conversion"); return ConvertFromE8M0_CPU(vals[0]); } explicit operator bfloat16() const { - static_assert(N == 1, "fp8_e8m0: N must be 1 for scalar conversion"); + static_assert(N == 1, "fp8_e8m0_x: N must be 1 for scalar conversion"); return ConvertFromE8M0_CPU(vals[0]); } explicit operator float() const { - static_assert(N == 1, "fp8_e8m0: N must be 1 for scalar conversion"); + static_assert(N == 1, "fp8_e8m0_x: N must be 1 for scalar conversion"); return ConvertFromE8M0_CPU(vals[0]); } explicit operator double() const { - static_assert(N == 1, "fp8_e8m0: N must be 1 for scalar conversion"); + static_assert(N == 1, "fp8_e8m0_x: N must be 1 for scalar conversion"); return ConvertFromE8M0_CPU(vals[0]); } explicit operator char() const { - static_assert(N == 1, "fp8_e8m0: N must be 1 for scalar conversion"); + static_assert(N == 1, "fp8_e8m0_x: N must be 1 for scalar conversion"); return static_cast(static_cast(*this)); } explicit operator signed char() const { @@ -1766,7 +1692,7 @@ template class fp8_e8m0 { } explicit operator bool() const { - static_assert(N == 1, "fp8_e8m0: operator bool requires size N=1"); + static_assert(N == 1, "fp8_e8m0_x: operator bool requires size N=1"); return true; } @@ -1794,6 +1720,339 @@ template class fp8_e8m0 { uint8_t vals[N]; }; +template class fp8_e5m3_x { +private: + template uint8_t ConvertToFP8(T h, rounding r) { + if constexpr (std::is_integral_v) { + sycl::half hi = static_cast(h); + return ConvertToFP8_CPU<5, 3, sycl::half>(hi, r); + } + return ConvertToFP8_CPU<5, 3, T>(h, r); + } + + template + T ConvertFromFP8(uint8_t v, rounding r = rounding::to_even) const { + return ConvertFromFP8_CPU<5, 3, T>(v, r); + } + + bfloat16 ConvertBF16FromFP8(uint8_t v) const { + return ConvertFromFP8_CPU<5, 3, bfloat16>(v); + } + + void CheckConstraints(rounding r) const { + static_assert(N == 1 || N == 2, + "fp8_e5m3_x: Template argument N must be 1 or 2"); + if (r != rounding::to_even) + throw std::invalid_argument( + "fp8_e5m3_x: only rounding::to_even is supported"); + } + +public: + fp8_e5m3_x() = default; + fp8_e5m3_x(const fp8_e5m3_x &) = default; + ~fp8_e5m3_x() = default; + fp8_e5m3_x &operator=(const fp8_e5m3_x &) = default; + + // Construct from pack of half, bfloat16, float, double. + // Available only when the size of the pack is equal to N. + + template , half> || + std::is_same_v, bfloat16> || + std::is_same_v, float> || + std::is_same_v, double>) && + ...))>> + explicit fp8_e5m3_x(Types... v) { + static_assert(N == 1 || N == 2, + "fp8_e5m3_x: Template argument N must be 1 or 2"); + /*if constexpr (((std::is_same_v, bfloat16>) && ...)) { + const bfloat16 in[N] = {static_cast(v)...}; + for (size_t i = 0; i < N; ++i) + vals[i] = ConvertBF16ToFP8(in[i], rounding::to_even); + return; + }*/ + const sycl::half in[N] = {v...}; + for (size_t i = 0; i < N; ++i) + vals[i] = ConvertToFP8(in[i], rounding::to_even); + } + + // Construct from an array of half, bfloat16, float, double. + + explicit fp8_e5m3_x(half const (&in)[N], rounding r = rounding::to_even) { + CheckConstraints(r); + for (size_t i = 0; i < N; ++i) + vals[i] = ConvertToFP8(in[i], r); + } + explicit fp8_e5m3_x(bfloat16 const (&in)[N], rounding r = rounding::to_even) { + CheckConstraints(r); + for (size_t i = 0; i < N; ++i) + vals[i] = ConvertToFP8(in[i], r); + } + explicit fp8_e5m3_x(float const (&in)[N], rounding r = rounding::to_even) { + CheckConstraints(r); + for (size_t i = 0; i < N; ++i) + vals[i] = ConvertToFP8(in[i], r); + } + explicit fp8_e5m3_x(double const (&in)[N]) { + for (size_t i = 0; i < N; ++i) + vals[i] = ConvertToFP8(in[i], rounding::to_even); + } + + // Construct from an marray of half, bfloat16, float, double. + + explicit fp8_e5m3_x(const marray &in, + rounding r = rounding::to_even) { + CheckConstraints(r); + for (size_t i = 0; i < N; ++i) + vals[i] = ConvertToFP8(in[i], r); + } + explicit fp8_e5m3_x(const marray &in, + rounding r = rounding::to_even) { + CheckConstraints(r); + for (size_t i = 0; i < N; ++i) + vals[i] = ConvertToFP8(in[i], r); + } + explicit fp8_e5m3_x(const marray &in, + rounding r = rounding::to_even) { + CheckConstraints(r); + for (size_t i = 0; i < N; ++i) + vals[i] = ConvertToFP8(in[i], r); + } + explicit fp8_e5m3_x(const marray &in) { + for (size_t i = 0; i < N; ++i) + vals[i] = ConvertToFP8(in[i], rounding::to_even); + } + + // Construct from integer types. + // Available only when N==1. + + explicit fp8_e5m3_x(short val) { + assert(N == 1 && "fp8_e5m3_x: N must be 1 for short constructor"); + vals[0] = ConvertToFP8(val, rounding::to_even); + } + explicit fp8_e5m3_x(int val) { + assert(N == 1 && "fp8_e5m3_x: N must be 1 for int constructor"); + vals[0] = ConvertToFP8(val, rounding::to_even); + } + explicit fp8_e5m3_x(long val) { + assert(N == 1 && "fp8_e5m3_x: N must be 1 for long constructor"); + vals[0] = ConvertToFP8(val, rounding::to_even); + } + explicit fp8_e5m3_x(long long val) { + assert(N == 1 && "fp8_e5m3_x: N must be 1 for long long constructor"); + vals[0] = ConvertToFP8(val, rounding::to_even); + } + explicit fp8_e5m3_x(unsigned short val) { + assert(N == 1 && "fp8_e5m3_x: N must be 1 for unsigned short constructor"); + vals[0] = ConvertToFP8(val, rounding::to_even); + } + explicit fp8_e5m3_x(unsigned int val) { + assert(N == 1 && "fp8_e5m3_x: N must be 1 for unsigned int constructor"); + vals[0] = ConvertToFP8(val, rounding::to_even); + } + explicit fp8_e5m3_x(unsigned long val) { + assert(N == 1 && "fp8_e5m3_x: N must be 1 for unsigned long constructor"); + vals[0] = ConvertToFP8(val, rounding::to_even); + } + explicit fp8_e5m3_x(unsigned long long val) { + assert(N == 1 && + "fp8_e5m3_x: N must be 1 for unsigned long long constructor"); + vals[0] = ConvertToFP8(val, rounding::to_even); + } + + // Assign (operator) from half, bfloat16, float, double, and integer types. + // Available only when N==1. + + fp8_e5m3_x &operator=(half val) { + assert(N == 1 && "fp8_e5m3_x: N must be 1 for half assignment operator"); + vals[0] = ConvertToFP8(val, rounding::to_even); + return *this; + } + fp8_e5m3_x &operator=(bfloat16 val) { + assert(N == 1 && + "fp8_e5m3_x: N must be 1 for bfloat16 assignment operator"); + vals[0] = ConvertToFP8(val, rounding::to_even); + return *this; + } + fp8_e5m3_x &operator=(float val) { + assert(N == 1 && "fp8_e5m3_x: N must be 1 for float assignment operator"); + vals[0] = ConvertToFP8(val, rounding::to_even); + return *this; + } + fp8_e5m3_x &operator=(double val) { + assert(N == 1 && "fp8_e5m3_x: N must be 1 for double assignment operator"); + vals[0] = ConvertToFP8(val, rounding::to_even); + return *this; + } + fp8_e5m3_x &operator=(short val) { + assert(N == 1 && "fp8_e5m3_x: N must be 1 for short assignment operator"); + vals[0] = ConvertToFP8(val, rounding::to_even); + return *this; + } + fp8_e5m3_x &operator=(int val) { + assert(N == 1 && "fp8_e5m3_x: N must be 1 for int assignment operator"); + vals[0] = ConvertToFP8(val, rounding::to_even); + return *this; + } + fp8_e5m3_x &operator=(long val) { + assert(N == 1 && "fp8_e5m3_x: N must be 1 for long assignment operator"); + vals[0] = ConvertToFP8(val, rounding::to_even); + return *this; + } + fp8_e5m3_x &operator=(long long val) { + assert(N == 1 && + "fp8_e5m3_x: N must be 1 for long long assignment operator"); + vals[0] = ConvertToFP8(val, rounding::to_even); + return *this; + } + fp8_e5m3_x &operator=(unsigned short val) { + assert(N == 1 && + "fp8_e5m3_x: N must be 1 for unsigned short assignment operator"); + vals[0] = ConvertToFP8(val, rounding::to_even); + return *this; + } + fp8_e5m3_x &operator=(unsigned int val) { + assert(N == 1 && + "fp8_e5m3_x: N must be 1 for unsigned int assignment operator"); + vals[0] = ConvertToFP8(val, rounding::to_even); + return *this; + } + fp8_e5m3_x &operator=(unsigned long val) { + assert(N == 1 && + "fp8_e5m3_x: N must be 1 for unsigned long assignment operator"); + vals[0] = ConvertToFP8(val, rounding::to_even); + return *this; + } + fp8_e5m3_x &operator=(unsigned long long val) { + assert( + N == 1 && + "fp8_e5m3_x: N must be 1 for unsigned long long assignment operator"); + vals[0] = ConvertToFP8(val, rounding::to_even); + return *this; + } + + // Convert to half, bfloat16, float, double. + // Available only when N==1. + + explicit operator half() const { + assert(N == 1 && "fp8_e5m3_x: N must be 1 for half conversion operator"); + return ConvertFromFP8(vals[0]); + } + explicit operator bfloat16() const { + assert(N == 1 && + "fp8_e5m3_x: N must be 1 for bfloat16 conversion operator"); + return ConvertBF16FromFP8(vals[0]); + } + explicit operator float() const { + assert(N == 1 && "fp8_e5m3_x: N must be 1 for float conversion operator"); + return ConvertFromFP8(vals[0]); + } + explicit operator double() const { + assert(N == 1 && "fp8_e5m3_x: N must be 1 for double conversion operator"); + return ConvertFromFP8(vals[0]); + } + + // Convert to integer types. + // Available only when N==1. + + explicit operator char() const { + assert(N == 1 && "fp8_e5m3_x: N must be 1 for char conversion operator"); + return ConvertFromFP8(vals[0], rounding::toward_zero); + } + explicit operator signed char() const { + assert(N == 1 && + "fp8_e5m3_x: N must be 1 for signed char conversion operator"); + return ConvertFromFP8(vals[0], rounding::toward_zero); + } + explicit operator short() const { + assert(N == 1 && "fp8_e5m3_x: N must be 1 for short conversion operator"); + return ConvertFromFP8(vals[0], rounding::toward_zero); + } + explicit operator int() const { + assert(N == 1 && "fp8_e5m3_x: N must be 1 for int conversion operator"); + return ConvertFromFP8(vals[0], rounding::toward_zero); + } + explicit operator long() const { + assert(N == 1 && "fp8_e5m3_x: N must be 1 for long conversion operator"); + return ConvertFromFP8(vals[0], rounding::toward_zero); + } + explicit operator long long() const { + assert(N == 1 && + "fp8_e5m3_x: N must be 1 for long long conversion operator"); + return ConvertFromFP8(vals[0], rounding::toward_zero); + } + explicit operator unsigned char() const { + assert(N == 1 && + "fp8_e5m3_x: N must be 1 for unsigned char conversion operator"); + return ConvertFromFP8(vals[0], rounding::toward_zero); + } + explicit operator unsigned short() const { + assert(N == 1 && + "fp8_e5m3_x: N must be 1 for unsigned short conversion operator"); + return ConvertFromFP8(vals[0], rounding::toward_zero); + } + explicit operator unsigned int() const { + assert(N == 1 && + "fp8_e5m3_x: N must be 1 for unsigned int conversion operator"); + return ConvertFromFP8(vals[0], rounding::toward_zero); + } + explicit operator unsigned long() const { + assert(N == 1 && + "fp8_e5m3_x: N must be 1 for unsigned long conversion operator"); + return ConvertFromFP8(vals[0], rounding::toward_zero); + } + explicit operator unsigned long long() const { + assert( + N == 1 && + "fp8_e5m3_x: N must be 1 for unsigned long long conversion operator"); + return ConvertFromFP8(vals[0], rounding::toward_zero); + } + + // Convert to bool + // Available only when N==1. + + explicit operator bool() const { + static_assert(N == 1, "fp8_e5m3_x: operator() requires size N=1"); + return vals[0] != 0x00 && vals[0] != 0x80; + } + + // Convert to marray of half, bfloat16, float + + explicit operator marray() const { + marray out; + for (size_t i = 0; i < N; ++i) + out[i] = ConvertFromFP8(vals[i]); + return out; + } + explicit operator marray() const { + marray out; + for (size_t i = 0; i < N; ++i) + out[i] = ConvertBF16FromFP8(vals[i]); + return out; + } + explicit operator marray() const { + marray out; + for (size_t i = 0; i < N; ++i) + out[i] = ConvertFromFP8(vals[i]); + return out; + } + + // Intentionally public to allow access to the raw values. + + uint8_t vals[N]; +}; + +using fp8_e4m3 = fp8_e4m3_x<1>; +using fp8_e4m3_x2 = fp8_e4m3_x<2>; +using fp8_e5m2 = fp8_e5m2_x<1>; +using fp8_e5m2_x2 = fp8_e5m2_x<2>; +using fp8_e8m0 = fp8_e8m0_x<1>; +using fp8_e8m0_x2 = fp8_e8m0_x<2>; +using fp8_e5m3 = fp8_e5m3_x<1>; +using fp8_e5m3_x2 = fp8_e5m3_x<2>; + #endif // __SYCL_TARGET_INTEL_GPU_CRI__ } // namespace ext::oneapi::experimental diff --git a/sycl/unittests/Extensions/fp8/CMakeLists.txt b/sycl/unittests/Extensions/fp8/CMakeLists.txt index 2d0c53daf4268..45778104df248 100644 --- a/sycl/unittests/Extensions/fp8/CMakeLists.txt +++ b/sycl/unittests/Extensions/fp8/CMakeLists.txt @@ -2,6 +2,7 @@ add_sycl_unittest(FP8TypesTests OBJECT fp8_e4m3.cpp fp8_e5m2.cpp fp8_e8m0.cpp + fp8_e5m3.cpp ) target_compile_options(FP8TypesTests_Preview_Tests PUBLIC -D__SYCL_TARGET_INTEL_GPU_CRI__) diff --git a/sycl/unittests/Extensions/fp8/fp8_e4m3.cpp b/sycl/unittests/Extensions/fp8/fp8_e4m3.cpp index 41cc44e881de3..809efd179c5e7 100644 --- a/sycl/unittests/Extensions/fp8/fp8_e4m3.cpp +++ b/sycl/unittests/Extensions/fp8/fp8_e4m3.cpp @@ -8,78 +8,82 @@ using namespace sycl::ext::oneapi::experimental; TEST(FP8E4M3Test, VariadicConstructorHalf) { - fp8_e4m3<2> a(sycl::half(1.0f), sycl::half(2.0f)); + fp8_e4m3_x2 a(sycl::half(1.0f), sycl::half(2.0f)); EXPECT_EQ(sizeof(a.vals), 2u); EXPECT_EQ(a.vals[0], 0x38); // 1.0 -> 0b0_0111_000 EXPECT_EQ(a.vals[1], 0x40); // 2.0 -> 0b0_1000_000 - fp8_e4m3<1> b(sycl::half(1.1f)); + fp8_e4m3 b(sycl::half(1.1f)); EXPECT_EQ(sizeof(b.vals), 1u); EXPECT_EQ(b.vals[0], 0x39); // 1.1 rounds to 1.125 -> frac=1 } TEST(FP8E4M3Test, VariadicConstructorBFloat16) { - fp8_e4m3<2> a(sycl::ext::oneapi::bfloat16(1.0f), sycl::ext::oneapi::bfloat16(2.0f)); + fp8_e4m3_x2 a(sycl::ext::oneapi::bfloat16(1.0f), sycl::ext::oneapi::bfloat16(2.0f)); EXPECT_EQ(sizeof(a.vals), 2u); EXPECT_EQ(a.vals[0], 0x38); EXPECT_EQ(a.vals[1], 0x40); - fp8_e4m3<1> b(sycl::ext::oneapi::bfloat16(1.1f)); + fp8_e4m3 b(sycl::ext::oneapi::bfloat16(1.1f)); EXPECT_EQ(sizeof(b.vals), 1u); EXPECT_EQ(b.vals[0], 0x39); } TEST(FP8E4M3Test, VariadicConstructorFloat) { - fp8_e4m3<2> a(1.0f, 2.0f); + fp8_e4m3_x2 a(1.0f, 2.0f); EXPECT_EQ(sizeof(a.vals), 2u); EXPECT_EQ(a.vals[0], 0x38); EXPECT_EQ(a.vals[1], 0x40); - fp8_e4m3<1> b(1.1f); + fp8_e4m3 b(1.1f); EXPECT_EQ(sizeof(b.vals), 1u); EXPECT_EQ(b.vals[0], 0x39); } TEST(FP8E4M3Test, VariadicBoundaryEncodingsFloat) { - // CPU host path: variadic constructors use rounding::to_even and saturation::finite. - fp8_e4m3<6> a( - 448.0f, // max normal -> S.1111.110 - 0.015625f, // min normal -> S.0001.000 (2^-6) - 0.013671875f, // max subnorm -> S.0000.111 (0.875 * 2^-6) - 0.001953125f, // min subnorm -> S.0000.001 (2^-9) - 0.0f, // +0 - -0.0f // -0 + // CPU host path: variadic constructors use rounding::to_even and + // saturation::finite. + fp8_e4m3_x2 a(448.0f, // max normal -> S.1111.110 + 0.015625f // min normal -> S.0001.000 (2^-6) ); - EXPECT_EQ(sizeof(a.vals), 6u); + fp8_e4m3_x2 b(0.013671875f, // max subnorm -> S.0000.111 (0.875 * 2^-6) + 0.001953125f // min subnorm -> S.0000.001 (2^-9) + ); + + fp8_e4m3_x2 c(0.0f, // +0 + -0.0f // -0 + ); + EXPECT_EQ(sizeof(a.vals), 2u); + EXPECT_EQ(sizeof(b.vals), 2u); + EXPECT_EQ(sizeof(c.vals), 2u); EXPECT_EQ(a.vals[0], 0x7E); // +448.0 -> 0b0_1111_110 EXPECT_EQ(a.vals[1], 0x08); // +2^-6 -> 0b0_0001_000 - EXPECT_EQ(a.vals[2], 0x07); // +max subnorm -> 0b0_0000_111 - EXPECT_EQ(a.vals[3], 0x01); // +min subnorm -> 0b0_0000_001 - EXPECT_EQ(a.vals[4], 0x00); // +0 -> 0b0_0000_000 - EXPECT_EQ(a.vals[5], 0x80); // -0 -> 0b1_0000_000 + EXPECT_EQ(b.vals[0], 0x07); // +max subnorm -> 0b0_0000_111 + EXPECT_EQ(b.vals[1], 0x01); // +min subnorm -> 0b0_0000_001 + EXPECT_EQ(c.vals[0], 0x00); // +0 -> 0b0_0000_000 + EXPECT_EQ(c.vals[1], 0x80); // -0 -> 0b1_0000_000 } TEST(FP8E4M3Test, VariadicNaNEncodingFloat) { // NaN is encoded as S.1111.111; sign is permitted. - fp8_e4m3<2> a(std::numeric_limits::quiet_NaN(), + fp8_e4m3_x2 a(std::numeric_limits::quiet_NaN(), -std::numeric_limits::quiet_NaN()); - EXPECT_EQ(sizeof(a.vals), 2u); EXPECT_EQ(a.vals[0], 0x7F); // +NaN -> 0b0_1111_111 EXPECT_EQ(a.vals[1], 0xFF); // -NaN -> 0b1_1111_111 } TEST(FP8E4M3Test, IntegerToEvenFiniteAndSize) { // Integer constructors: to_even + finite saturation (CPU). - fp8_e4m3<1> a0(0); - fp8_e4m3<1> a1(1); - fp8_e4m3<1> a2(2); - fp8_e4m3<1> an1(-1); + fp8_e4m3 a0(0); + fp8_e4m3 a1(1); + fp8_e4m3 a2(2); + fp8_e4m3 an1(-1); EXPECT_EQ(sizeof(a0.vals), 1u); EXPECT_EQ(sizeof(a1.vals), 1u); @@ -94,7 +98,7 @@ TEST(FP8E4M3Test, IntegerToEvenFiniteAndSize) { TEST(FP8E4M3Test, AssignmentOperatorToEvenFiniteAndSize) { // operator= from scalar: to_even + finite saturation (CPU). - fp8_e4m3<1> a(0.0f); + fp8_e4m3 a(0.0f); EXPECT_EQ(sizeof(a.vals), 1u); EXPECT_EQ(a.vals[0], 0x00); @@ -110,10 +114,10 @@ TEST(FP8E4M3Test, AssignmentOperatorToEvenFiniteAndSize) { TEST(FP8E4M3Test, FloatingPointConversionOperators) { // Floating-point operators: convert stored fp8 to the respective type. - fp8_e4m3<1> one(1.0f); - fp8_e4m3<1> zero_pos(0.0f); - fp8_e4m3<1> zero_neg(-0.0f); - fp8_e4m3<1> min_norm(0.015625f); + fp8_e4m3 one(1.0f); + fp8_e4m3 zero_pos(0.0f); + fp8_e4m3 zero_neg(-0.0f); + fp8_e4m3 min_norm(0.015625f); EXPECT_EQ(sizeof(one.vals), 1u); EXPECT_EQ(one.vals[0], 0x38); @@ -125,7 +129,8 @@ TEST(FP8E4M3Test, FloatingPointConversionOperators) { EXPECT_EQ(f1, 1.0f); EXPECT_EQ(fz, 0.0f); - // -0.0 compares equal to +0.0; check signbit to validate negative zero survives. + // -0.0 compares equal to +0.0; check signbit to validate negative zero + // survives. EXPECT_EQ(fnz, 0.0f); EXPECT_TRUE(std::signbit(fnz)); @@ -134,8 +139,8 @@ TEST(FP8E4M3Test, FloatingPointConversionOperators) { TEST(FP8E4M3Test, IntegerConversionOperatorsTowardZero) { // Integer operators: convert using rounding::toward_zero. - fp8_e4m3<1> p(1.5f); // 1.5 exactly representable: 0b0_0111_100 (0x3C) - fp8_e4m3<1> n(-1.5f); // 0xBC + fp8_e4m3 p(1.5f); // 1.5 exactly representable: 0b0_0111_100 (0x3C) + fp8_e4m3 n(-1.5f); // 0xBC EXPECT_EQ(sizeof(p.vals), 1u); EXPECT_EQ(sizeof(n.vals), 1u); @@ -145,16 +150,16 @@ TEST(FP8E4M3Test, IntegerConversionOperatorsTowardZero) { int ip = static_cast(p); int in = static_cast(n); - EXPECT_EQ(ip, 1); // toward zero - EXPECT_EQ(in, -1); // toward zero + EXPECT_EQ(ip, 1); // toward zero + EXPECT_EQ(in, -1); // toward zero } TEST(FP8E4M3Test, BoolOperatorZeroRules) { // bool operator: false iff +0 or -0; otherwise true. - fp8_e4m3<1> zp(0.0f); - fp8_e4m3<1> zn(-0.0f); - fp8_e4m3<1> one(1.0f); - fp8_e4m3<1> sub(0.001953125f); // min subnormal + fp8_e4m3 zp(0.0f); + fp8_e4m3 zn(-0.0f); + fp8_e4m3 one(1.0f); + fp8_e4m3 sub(0.001953125f); // min subnormal EXPECT_EQ(sizeof(zp.vals), 1u); EXPECT_EQ(sizeof(zn.vals), 1u); @@ -167,125 +172,139 @@ TEST(FP8E4M3Test, BoolOperatorZeroRules) { EXPECT_TRUE(static_cast(sub)); } -TEST(FP8E4M3Test, VariadicSaturatesFinite) { - // Variadic constructors: to_even + finite saturation (CPU). - fp8_e4m3<4> a( - 1.0f, - 1000.0f, // above max normal: clamp to +448 - -1000.0f, // clamp to -448 - -0.0f); - - EXPECT_EQ(sizeof(a.vals), 4u); - EXPECT_EQ(a.vals[0], 0x38); - EXPECT_EQ(a.vals[1], 0x7E); // +max normal - EXPECT_EQ(a.vals[2], 0xFE); // -max normal - EXPECT_EQ(a.vals[3], 0x80); // -0 -} - -TEST(FP8E4M3Test, VariadicToEvenTie) { - // Tie case: between 1.0 (0x38) and 1.125 (0x39) is 1.0625 exactly. - // to_even => choose 1.0 because its LSB (fraction) is even (0). - fp8_e4m3<2> a(1.0625f, -1.0625f); - EXPECT_EQ(sizeof(a.vals), 2u); - EXPECT_EQ(a.vals[0], 0x38); - EXPECT_EQ(a.vals[1], 0xB8); -} - TEST(FP8E4M3Test, CArrayFloatHostToEvenFinite) { // Host code supports only rounding::to_even and saturation::finite. - const float in[5] = {1.0f, 1.1f, 1.0625f, 1000.0f, -0.0f}; - fp8_e4m3<5> a(in); - - EXPECT_EQ(sizeof(a.vals), 5u); - EXPECT_EQ(a.vals[0], 0x38); // 1.0 - EXPECT_EQ(a.vals[1], 0x39); // 1.1 -> 1.125 - EXPECT_EQ(a.vals[2], 0x38); // tie -> to_even => 1.0 - EXPECT_EQ(a.vals[3], 0x7E); // finite saturation => +448 - EXPECT_EQ(a.vals[4], 0x80); // -0 + const float in[2] = {1.0f, 1.1f}; + const float in1[2] = {1.0625f, 1000.0f}; + const float in2[2] = {-0.0f, 0.0f}; + fp8_e4m3_x2 a(in); + fp8_e4m3_x2 a1(in1); + fp8_e4m3_x2 a2(in2); + + EXPECT_EQ(sizeof(a.vals), 2u); + EXPECT_EQ(sizeof(a1.vals), 2u); + EXPECT_EQ(sizeof(a2.vals), 2u); + EXPECT_EQ(a.vals[0], 0x38); // 1.0 + EXPECT_EQ(a.vals[1], 0x39); // 1.1 -> 1.125 + EXPECT_EQ(a1.vals[0], 0x38); // tie -> to_even => 1.0 + EXPECT_EQ(a1.vals[1], 0x7E); // finite saturation => +448 + EXPECT_EQ(a2.vals[0], 0x80); // -0 + EXPECT_EQ(a2.vals[1], 0x00); // 0 } TEST(FP8E4M3Test, CArrayDoubleToEvenFinite) { // Double c-array: to_even + finite saturation. - const double in[6] = {448.0, 449.0, 0.015625, 0.013671875, 0.001953125, std::numeric_limits::quiet_NaN()}; - fp8_e4m3<6> a(in); - - EXPECT_EQ(sizeof(a.vals), 6u); - EXPECT_EQ(a.vals[0], 0x7E); // +448 - EXPECT_EQ(a.vals[1], 0x7E); // 449 -> clamp to +448 - EXPECT_EQ(a.vals[2], 0x08); // min normal - EXPECT_EQ(a.vals[3], 0x07); // max subnormal - EXPECT_EQ(a.vals[4], 0x01); // min subnormal - EXPECT_EQ(a.vals[5], 0x7F); // NaN + const double in[2] = {448.0, 449.0}; + const double in1[2] = {0.015625, 0.013671875}; + const double in2[2] = {0.001953125, std::numeric_limits::quiet_NaN()}; + fp8_e4m3_x2 a(in); + fp8_e4m3_x2 a1(in1); + fp8_e4m3_x2 a2(in2); + + EXPECT_EQ(sizeof(a.vals), 2u); + EXPECT_EQ(sizeof(a1.vals), 2u); + EXPECT_EQ(sizeof(a2.vals), 2u); + EXPECT_EQ(a.vals[0], 0x7E); // +448 + EXPECT_EQ(a.vals[1], 0x7E); // 449 -> clamp to +448 + EXPECT_EQ(a1.vals[0], 0x08); // min normal + EXPECT_EQ(a1.vals[1], 0x07); // max subnormal + EXPECT_EQ(a2.vals[0], 0x01); // min subnormal + EXPECT_EQ(a2.vals[1], 0x7F); // NaN } TEST(FP8E4M3Test, CArrayHalfHostToEvenFinite) { // Host code supports only rounding::to_even and saturation::finite. - const sycl::half in[6] = {sycl::half(448.0f), sycl::half(449.0f), - sycl::half(0.015625f), sycl::half(0.013671875f), - sycl::half(0.001953125f), sycl::half(-0.0f)}; - fp8_e4m3<6> a(in); - - EXPECT_EQ(sizeof(a.vals), 6u); - EXPECT_EQ(a.vals[0], 0x7E); // +448 - EXPECT_EQ(a.vals[1], 0x7E); // 449 -> clamp to +448 - EXPECT_EQ(a.vals[2], 0x08); // min normal - EXPECT_EQ(a.vals[3], 0x07); // max subnormal - EXPECT_EQ(a.vals[4], 0x01); // min subnormal - EXPECT_EQ(a.vals[5], 0x80); // -0 + const sycl::half in[2] = {sycl::half(448.0f), sycl::half(449.0f)}; + const sycl::half in1[2] = {sycl::half(0.015625f), sycl::half(0.013671875f)}; + const sycl::half in2[2] = {sycl::half(0.001953125f), sycl::half(-0.0f)}; + + fp8_e4m3_x2 a(in); + fp8_e4m3_x2 a1(in1); + fp8_e4m3_x2 a2(in2); + + EXPECT_EQ(sizeof(a.vals), 2u); + EXPECT_EQ(sizeof(a1.vals), 2u); + EXPECT_EQ(sizeof(a2.vals), 2u); + EXPECT_EQ(a.vals[0], 0x7E); // +448 + EXPECT_EQ(a.vals[1], 0x7E); // 449 -> clamp to +448 + EXPECT_EQ(a1.vals[0], 0x08); // min normal + EXPECT_EQ(a1.vals[1], 0x07); // max subnormal + EXPECT_EQ(a2.vals[0], 0x01); // min subnormal + EXPECT_EQ(a2.vals[1], 0x80); // -0 } TEST(FP8E4M3Test, CArrayBFloat16HostToEvenFinite) { // Host code supports only rounding::to_even and saturation::finite. - const sycl::ext::oneapi::bfloat16 in[6] = { - sycl::ext::oneapi::bfloat16(448.0f), - sycl::ext::oneapi::bfloat16(449.0f), + const sycl::ext::oneapi::bfloat16 in[2] = { + sycl::ext::oneapi::bfloat16(448.0f), sycl::ext::oneapi::bfloat16(449.0f)}; + const sycl::ext::oneapi::bfloat16 in1[2] = { sycl::ext::oneapi::bfloat16(0.015625f), - sycl::ext::oneapi::bfloat16(0.013671875f), + sycl::ext::oneapi::bfloat16(0.013671875f)}; + const sycl::ext::oneapi::bfloat16 in2[2] = { sycl::ext::oneapi::bfloat16(0.001953125f), sycl::ext::oneapi::bfloat16(-0.0f)}; - fp8_e4m3<6> a(in); - - EXPECT_EQ(sizeof(a.vals), 6u); - EXPECT_EQ(a.vals[0], 0x7E); // +448 - EXPECT_EQ(a.vals[1], 0x7E); // 449 -> clamp to +448 - EXPECT_EQ(a.vals[2], 0x08); // min normal - EXPECT_EQ(a.vals[3], 0x07); // max subnormal - EXPECT_EQ(a.vals[4], 0x01); // min subnormal - EXPECT_EQ(a.vals[5], 0x80); // -0 + + fp8_e4m3_x2 a(in); + fp8_e4m3_x2 a1(in1); + fp8_e4m3_x2 a2(in2); + + EXPECT_EQ(sizeof(a.vals), 2u); + EXPECT_EQ(sizeof(a1.vals), 2u); + EXPECT_EQ(sizeof(a2.vals), 2u); + EXPECT_EQ(a.vals[0], 0x7E); // +448 + EXPECT_EQ(a.vals[1], 0x7E); // 449 -> clamp to +448 + EXPECT_EQ(a1.vals[0], 0x08); // min normal + EXPECT_EQ(a1.vals[1], 0x07); // max subnormal + EXPECT_EQ(a2.vals[0], 0x01); // min subnormal + EXPECT_EQ(a2.vals[1], 0x80); // -0 } TEST(FP8E4M3Test, MarrayAndOperatorsHostAllN) { // marray constructors/operators: host supports all N. - sycl::marray in = {1.0f, 2.0f, 0.0f, -0.0f, 448.0f, 1000.0f, 0.001953125f, -1.5f}; - fp8_e4m3<8> a(in); + sycl::marray in = {1.0f, 2.0f}; + sycl::marray in1 = {0.0f, -0.0f}; + sycl::marray in2 = {448.0f, 1000.0f}; + sycl::marray in3 = {0.001953125f, -1.5f}; + + fp8_e4m3_x2 a(in); + fp8_e4m3_x2 a1(in1); + fp8_e4m3_x2 a2(in2); + fp8_e4m3_x2 a3(in3); + + EXPECT_EQ(sizeof(a.vals), 2u); + EXPECT_EQ(sizeof(a1.vals), 2u); + EXPECT_EQ(sizeof(a2.vals), 2u); + EXPECT_EQ(sizeof(a3.vals), 2u); - EXPECT_EQ(sizeof(a.vals), 8u); EXPECT_EQ(a.vals[0], 0x38); EXPECT_EQ(a.vals[1], 0x40); - EXPECT_EQ(a.vals[2], 0x00); - EXPECT_EQ(a.vals[3], 0x80); - EXPECT_EQ(a.vals[4], 0x7E); - EXPECT_EQ(a.vals[5], 0x7E); // finite saturation - EXPECT_EQ(a.vals[6], 0x01); - EXPECT_EQ(a.vals[7], 0xBC); // -1.5 + EXPECT_EQ(a1.vals[0], 0x00); + EXPECT_EQ(a1.vals[1], 0x80); + EXPECT_EQ(a2.vals[0], 0x7E); + EXPECT_EQ(a2.vals[1], 0x7E); // finite saturation + EXPECT_EQ(a3.vals[0], 0x01); + EXPECT_EQ(a3.vals[1], 0xBC); // -1.5 // marray operator: convert fp8 vector back to marray. - sycl::marray out = static_cast>(a); + sycl::marray out = static_cast>(a); + sycl::marray out1 = static_cast>(a1); + sycl::marray out2 = static_cast>(a2); + sycl::marray out3 = static_cast>(a3); EXPECT_EQ(out[0], 1.0f); EXPECT_EQ(out[1], 2.0f); - EXPECT_EQ(out[2], 0.0f); - EXPECT_EQ(out[3], 0.0f); - EXPECT_TRUE(std::signbit(out[3])); // preserve -0 - EXPECT_EQ(out[4], 448.0f); - EXPECT_EQ(out[5], 448.0f); - EXPECT_EQ(out[6], 0.001953125f); - EXPECT_EQ(out[7], -1.5f); + EXPECT_EQ(out1[0], 0.0f); + EXPECT_EQ(out1[1], 0.0f); + EXPECT_TRUE(std::signbit(out1[1])); // preserve -0 + EXPECT_EQ(out2[0], 448.0f); + EXPECT_EQ(out2[1], 448.0f); + EXPECT_EQ(out3[0], 0.001953125f); + EXPECT_EQ(out3[1], -1.5f); } TEST(FP8E4M3Test, FloatingPointConversionOperatorsMoreTypes) { - fp8_e4m3<1> a(1.0f); - fp8_e4m3<1> b(0.015625f); - fp8_e4m3<1> nanv(std::numeric_limits::quiet_NaN()); + fp8_e4m3 a(1.0f); + fp8_e4m3 b(0.015625f); + fp8_e4m3 nanv(std::numeric_limits::quiet_NaN()); EXPECT_EQ(sizeof(a.vals), 1u); EXPECT_EQ(sizeof(b.vals), 1u); @@ -306,394 +325,77 @@ TEST(FP8E4M3Test, FloatingPointConversionOperatorsMoreTypes) { } TEST(FP8E4M3Test, IntegerConversionOperatorsMultipleWidthsTowardZero) { - fp8_e4m3<1> p(1.5f); - fp8_e4m3<1> n(-1.5f); - - std::int32_t i32p = static_cast(p); - std::int32_t i32n = static_cast(n); - std::int64_t i64p = static_cast(p); - std::int64_t i64n = static_cast(n); - - EXPECT_EQ(i32p, 1); - EXPECT_EQ(i32n, -1); - EXPECT_EQ(i64p, 1); - EXPECT_EQ(i64n, -1); -} - -TEST(FP8E4M3Test, VariadicHalfBoundaryEncodings) { - fp8_e4m3<4> a(sycl::half(448.0f), sycl::half(0.015625f), sycl::half(0.001953125f), - sycl::half(-0.0f)); - - EXPECT_EQ(sizeof(a.vals), 4u); - EXPECT_EQ(a.vals[0], 0x7E); // +max normal - EXPECT_EQ(a.vals[1], 0x08); // min normal - EXPECT_EQ(a.vals[2], 0x01); // min subnormal - EXPECT_EQ(a.vals[3], 0x80); // -0 -} - -TEST(FP8E4M3Test, VariadicBFloat16BoundaryEncodings) { - fp8_e4m3<4> a(sycl::ext::oneapi::bfloat16(1.0f), - sycl::ext::oneapi::bfloat16(2.0f), - sycl::ext::oneapi::bfloat16(0.001953125f), - sycl::ext::oneapi::bfloat16(-0.0f)); - - EXPECT_EQ(sizeof(a.vals), 4u); - EXPECT_EQ(a.vals[0], 0x38); - EXPECT_EQ(a.vals[1], 0x40); - EXPECT_EQ(a.vals[2], 0x01); - EXPECT_EQ(a.vals[3], 0x80); -} - -TEST(FP8E4M3Test, VariadicDoubleBoundaryEncodingsAndSaturation) { - fp8_e4m3<5> a(448.0, 449.0, 0.013671875, 0.001953125, -1000.0); + fp8_e4m3 p(1.5f); + fp8_e4m3 n(-1.5f); - EXPECT_EQ(sizeof(a.vals), 5u); - EXPECT_EQ(a.vals[0], 0x7E); // +448 - EXPECT_EQ(a.vals[1], 0x7E); // clamp to +448 (finite saturation) - EXPECT_EQ(a.vals[2], 0x07); // max subnormal - EXPECT_EQ(a.vals[3], 0x01); // min subnormal - EXPECT_EQ(a.vals[4], 0xFE); // clamp to -448 -} - -TEST(FP8E4M3Test, BoolOperatorWithNaN) { - float pz = 0.0f; - fp8_e4m3<1> zp(pz); - float zv = -0.0f; - fp8_e4m3<1> zn(zv); - float nv = {std::numeric_limits::quiet_NaN()}; - fp8_e4m3<1> nanv(nv); + int i = static_cast(p); + short s = static_cast(n); + long l = static_cast(p); + long long ll = static_cast(n); - EXPECT_EQ(sizeof(zp.vals), 1u); - EXPECT_EQ(sizeof(zn.vals), 1u); - EXPECT_EQ(sizeof(nanv.vals), 1u); - - EXPECT_FALSE(static_cast(zp)); - EXPECT_FALSE(static_cast(zn)); - EXPECT_TRUE(static_cast(nanv)); // not +0 or -0 - EXPECT_EQ(nanv.vals[0], 0x7F); // NaN encoding remains S.1111.111 + EXPECT_EQ(i, 1); + EXPECT_EQ(s, -1); + EXPECT_EQ(l, 1); + EXPECT_EQ(ll, -1); } TEST(FP8E4M3Test, CArrayFloatRoundingToEven) { - const float in[3] = {0.012f, 1.0625f, 1000.0f}; - fp8_e4m3<3> a(in, rounding::to_even); + const float in[2] = {0.012f, 1000.0f}; + fp8_e4m3_x2 a(in, rounding::to_even); EXPECT_EQ(a.vals[0], 0x06); - EXPECT_EQ(a.vals[1], 0x38); - EXPECT_EQ(a.vals[2], 0x7E); -} - -TEST(FP8E4M3Test, CArrayFloatRoundingUpward) { - const float in[3] = {0.012f, 1.0625f, 1000.0f}; - fp8_e4m3<3> a(in, rounding::upward); - - EXPECT_EQ(a.vals[0], 0x07); - EXPECT_EQ(a.vals[1], 0x39); - EXPECT_EQ(a.vals[2], 0x7E); -} - -TEST(FP8E4M3Test, CArrayFloatRoundingDownward) { - const float in[3] = {0.012f, 1.0625f, 1000.0f}; - fp8_e4m3<3> a(in, rounding::downward); - - EXPECT_EQ(a.vals[0], 0x06); - EXPECT_EQ(a.vals[1], 0x38); - EXPECT_EQ(a.vals[2], 0x7E); -} - -TEST(FP8E4M3Test, CArrayFloatRoundingTowardZero) { - const float in[3] = {0.012f, 1.0625f, 1000.0f}; - fp8_e4m3<3> a(in, rounding::toward_zero); - - EXPECT_EQ(a.vals[0], 0x06); - EXPECT_EQ(a.vals[1], 0x38); - EXPECT_EQ(a.vals[2], 0x7E); -} - -TEST(FP8E4M3Test, CArrayFloatRoundingToAway) { - const float in[3] = {0.012f, 1.0625f, 1000.0f}; - fp8_e4m3<3> a(in, rounding::to_away); - - EXPECT_EQ(a.vals[0], 0x06); - EXPECT_EQ(a.vals[1], 0x38); - EXPECT_EQ(a.vals[2], 0x7E); + EXPECT_EQ(a.vals[1], 0x7E); } TEST(FP8E4M3Test, CArrayHalfRoundingToEven) { - const sycl::half in[3] = {sycl::half(0.012f), sycl::half(1.0625f), - sycl::half(1000.0f)}; - fp8_e4m3<3> a(in, rounding::to_even); + const sycl::half in[2] = {sycl::half(0.012f), sycl::half(1000.0f)}; + fp8_e4m3_x2 a(in, rounding::to_even); EXPECT_EQ(a.vals[0], 0x06); - EXPECT_EQ(a.vals[1], 0x38); - EXPECT_EQ(a.vals[2], 0x7E); -} - -TEST(FP8E4M3Test, CArrayHalfRoundingUpward) { - const sycl::half in[3] = {sycl::half(0.012f), sycl::half(1.0625f), - sycl::half(1000.0f)}; - fp8_e4m3<3> a(in, rounding::upward); - - EXPECT_EQ(a.vals[0], 0x07); - EXPECT_EQ(a.vals[1], 0x39); - EXPECT_EQ(a.vals[2], 0x7E); -} - -TEST(FP8E4M3Test, CArrayHalfRoundingDownward) { - const sycl::half in[3] = {sycl::half(0.012f), sycl::half(1.0625f), - sycl::half(1000.0f)}; - fp8_e4m3<3> a(in, rounding::downward); - - EXPECT_EQ(a.vals[0], 0x06); - EXPECT_EQ(a.vals[1], 0x38); - EXPECT_EQ(a.vals[2], 0x7E); -} - -TEST(FP8E4M3Test, CArrayHalfRoundingTowardZero) { - const sycl::half in[3] = {sycl::half(0.012f), sycl::half(1.0625f), - sycl::half(1000.0f)}; - fp8_e4m3<3> a(in, rounding::toward_zero); - - EXPECT_EQ(a.vals[0], 0x06); - EXPECT_EQ(a.vals[1], 0x38); - EXPECT_EQ(a.vals[2], 0x7E); -} - -TEST(FP8E4M3Test, CArrayHalfRoundingToAway) { - const sycl::half in[3] = {sycl::half(0.012f), sycl::half(1.0625f), - sycl::half(1000.0f)}; - fp8_e4m3<3> a(in, rounding::to_away); - - EXPECT_EQ(a.vals[0], 0x06); - EXPECT_EQ(a.vals[1], 0x38); - EXPECT_EQ(a.vals[2], 0x7E); + EXPECT_EQ(a.vals[1], 0x7E); } TEST(FP8E4M3Test, CArrayBFloat16RoundingToEven) { - const sycl::ext::oneapi::bfloat16 in[3] = { + const sycl::ext::oneapi::bfloat16 in[2] = { sycl::ext::oneapi::bfloat16(0.012f), - sycl::ext::oneapi::bfloat16(1.0625f), sycl::ext::oneapi::bfloat16(1000.0f)}; - fp8_e4m3<3> a(in, rounding::to_even); + fp8_e4m3_x2 a(in, rounding::to_even); EXPECT_EQ(a.vals[0], 0x06); - EXPECT_EQ(a.vals[1], 0x38); - EXPECT_EQ(a.vals[2], 0x7E); -} - -TEST(FP8E4M3Test, CArrayBFloat16RoundingUpward) { - const sycl::ext::oneapi::bfloat16 in[3] = { - sycl::ext::oneapi::bfloat16(0.012f), - sycl::ext::oneapi::bfloat16(1.0625f), - sycl::ext::oneapi::bfloat16(1000.0f)}; - fp8_e4m3<3> a(in, rounding::upward); - - EXPECT_EQ(a.vals[0], 0x07); - EXPECT_EQ(a.vals[1], 0x39); - EXPECT_EQ(a.vals[2], 0x7E); -} - -TEST(FP8E4M3Test, CArrayBFloat16Downward) { - const sycl::ext::oneapi::bfloat16 in[3] = { - sycl::ext::oneapi::bfloat16(0.012f), - sycl::ext::oneapi::bfloat16(1.0625f), - sycl::ext::oneapi::bfloat16(1000.0f)}; - fp8_e4m3<3> a(in, rounding::downward); - - EXPECT_EQ(a.vals[0], 0x06); - EXPECT_EQ(a.vals[1], 0x38); - EXPECT_EQ(a.vals[2], 0x7E); -} - -TEST(FP8E4M3Test, CArrayBFloat16TowardZero) { - const sycl::ext::oneapi::bfloat16 in[3] = { - sycl::ext::oneapi::bfloat16(0.012f), - sycl::ext::oneapi::bfloat16(1.0625f), - sycl::ext::oneapi::bfloat16(1000.0f)}; - fp8_e4m3<3> a(in, rounding::toward_zero); - - EXPECT_EQ(a.vals[0], 0x06); - EXPECT_EQ(a.vals[1], 0x38); - EXPECT_EQ(a.vals[2], 0x7E); -} - -TEST(FP8E4M3Test, CArrayBFloat16ToAway) { - const sycl::ext::oneapi::bfloat16 in[3] = { - sycl::ext::oneapi::bfloat16(0.012f), - sycl::ext::oneapi::bfloat16(1.0625f), - sycl::ext::oneapi::bfloat16(1000.0f)}; - fp8_e4m3<3> a(in, rounding::to_away); - - EXPECT_EQ(a.vals[0], 0x06); - EXPECT_EQ(a.vals[1], 0x38); - EXPECT_EQ(a.vals[2], 0x7E); + EXPECT_EQ(a.vals[1], 0x7E); } TEST(FP8E4M3Test, MarrayHalfRoundingToEven) { - const sycl::marray in = {sycl::half(0.012f), - sycl::half(1.0625f), - sycl::half(1000.0f)}; - fp8_e4m3<3> a(in, rounding::to_even); + const sycl::marray in = {sycl::half(0.012f), + sycl::half(1.0625f)}; + fp8_e4m3_x2 a(in, rounding::to_even); EXPECT_EQ(a.vals[0], 0x06); EXPECT_EQ(a.vals[1], 0x38); - EXPECT_EQ(a.vals[2], 0x7E); -} - -TEST(FP8E4M3Test, MarrayHalfRoundingUpward) { - const sycl::marray in = {sycl::half(0.012f), - sycl::half(1.0625f), - sycl::half(1000.0f)}; - fp8_e4m3<3> a(in, rounding::upward); - - EXPECT_EQ(a.vals[0], 0x07); - EXPECT_EQ(a.vals[1], 0x39); - EXPECT_EQ(a.vals[2], 0x7E); -} - -TEST(FP8E4M3Test, MarrayHalfRoundingDownward) { - const sycl::marray in = {sycl::half(0.012f), - sycl::half(1.0625f), - sycl::half(1000.0f)}; - fp8_e4m3<3> a(in, rounding::downward); - - EXPECT_EQ(a.vals[0], 0x06); - EXPECT_EQ(a.vals[1], 0x38); - EXPECT_EQ(a.vals[2], 0x7E); -} - -TEST(FP8E4M3Test, MarrayHalfRoundingTowardZero) { - const sycl::marray in = {sycl::half(0.012f), - sycl::half(1.0625f), - sycl::half(1000.0f)}; - fp8_e4m3<3> a(in, rounding::toward_zero); - - EXPECT_EQ(a.vals[0], 0x06); - EXPECT_EQ(a.vals[1], 0x38); - EXPECT_EQ(a.vals[2], 0x7E); -} - -TEST(FP8E4M3Test, MarrayHalfRoundingToAway) { - const sycl::marray in = {sycl::half(0.012f), - sycl::half(1.0625f), - sycl::half(1000.0f)}; - fp8_e4m3<3> a(in, rounding::to_away); - - EXPECT_EQ(a.vals[0], 0x06); - EXPECT_EQ(a.vals[1], 0x38); - EXPECT_EQ(a.vals[2], 0x7E); } TEST(FP8E4M3Test, MarrayBFloat16RoundingToEven) { - const sycl::marray in = { - sycl::ext::oneapi::bfloat16(0.012f), - sycl::ext::oneapi::bfloat16(1.0625f), - sycl::ext::oneapi::bfloat16(1000.0f)}; - fp8_e4m3<3> a(in, rounding::to_even); - - EXPECT_EQ(a.vals[0], 0x06); - EXPECT_EQ(a.vals[1], 0x38); - EXPECT_EQ(a.vals[2], 0x7E); -} - -TEST(FP8E4M3Test, MarrayBFloat16RoundingUpward) { - const sycl::marray in = { - sycl::ext::oneapi::bfloat16(0.012f), - sycl::ext::oneapi::bfloat16(1.0625f), - sycl::ext::oneapi::bfloat16(1000.0f)}; - fp8_e4m3<3> a(in, rounding::upward); - - EXPECT_EQ(a.vals[0], 0x07); - EXPECT_EQ(a.vals[1], 0x39); - EXPECT_EQ(a.vals[2], 0x7E); -} - -TEST(FP8E4M3Test, MarrayBFloat16RoundingDownward) { - const sycl::marray in = { - sycl::ext::oneapi::bfloat16(0.012f), - sycl::ext::oneapi::bfloat16(1.0625f), - sycl::ext::oneapi::bfloat16(1000.0f)}; - fp8_e4m3<3> a(in, rounding::downward); - - EXPECT_EQ(a.vals[0], 0x06); - EXPECT_EQ(a.vals[1], 0x38); - EXPECT_EQ(a.vals[2], 0x7E); -} - -TEST(FP8E4M3Test, MarrayBFloat16RoundingTowardZero) { - const sycl::marray in = { + const sycl::marray in = { sycl::ext::oneapi::bfloat16(0.012f), - sycl::ext::oneapi::bfloat16(1.0625f), - sycl::ext::oneapi::bfloat16(1000.0f)}; - fp8_e4m3<3> a(in, rounding::toward_zero); + sycl::ext::oneapi::bfloat16(1.0625f)}; + fp8_e4m3_x2 a(in, rounding::to_even); EXPECT_EQ(a.vals[0], 0x06); EXPECT_EQ(a.vals[1], 0x38); - EXPECT_EQ(a.vals[2], 0x7E); -} - -TEST(FP8E4M3Test, MarrayBFloat16RoundingToAway) { - const sycl::marray in = { - sycl::ext::oneapi::bfloat16(0.012f), - sycl::ext::oneapi::bfloat16(1.0625f), - sycl::ext::oneapi::bfloat16(1000.0f)}; - fp8_e4m3<3> a(in, rounding::to_away); - - EXPECT_EQ(a.vals[0], 0x06); - EXPECT_EQ(a.vals[1], 0x38); - EXPECT_EQ(a.vals[2], 0x7E); } TEST(FP8E4M3Test, MarrayFloatRoundingToEven) { - const sycl::marray in = {0.012f, 1.0625f, 1000.0f}; - fp8_e4m3<3> a(in, rounding::to_even); + const sycl::marray in = {0.012f, 1.0625f}; + fp8_e4m3_x2 a(in, rounding::to_even); EXPECT_EQ(a.vals[0], 0x06); EXPECT_EQ(a.vals[1], 0x38); - EXPECT_EQ(a.vals[2], 0x7E); -} - -TEST(FP8E4M3Test, MarrayFloatRoundingUpward) { - const sycl::marray in = {0.012f, 1.0625f, 1000.0f}; - fp8_e4m3<3> a(in, rounding::upward); - - EXPECT_EQ(a.vals[0], 0x07); - EXPECT_EQ(a.vals[1], 0x39); - EXPECT_EQ(a.vals[2], 0x7E); } -TEST(FP8E4M3Test, MarrayFloatRoundingDownward) { - const sycl::marray in = {0.012f, 1.0625f, 1000.0f}; - fp8_e4m3<3> a(in, rounding::downward); - - EXPECT_EQ(a.vals[0], 0x06); - EXPECT_EQ(a.vals[1], 0x38); - EXPECT_EQ(a.vals[2], 0x7E); -} - -TEST(FP8E4M3Test, MarrayFloatRoundingTowardZero) { - const sycl::marray in = {0.012f, 1.0625f, 1000.0f}; - fp8_e4m3<3> a(in, rounding::toward_zero); - - EXPECT_EQ(a.vals[0], 0x06); - EXPECT_EQ(a.vals[1], 0x38); - EXPECT_EQ(a.vals[2], 0x7E); -} - -TEST(FP8E4M3Test, MarrayFloatRoundingToAway) { - const sycl::marray in = {0.012f, 1.0625f, 1000.0f}; - fp8_e4m3<3> a(in, rounding::to_away); - - EXPECT_EQ(a.vals[0], 0x06); - EXPECT_EQ(a.vals[1], 0x38); - EXPECT_EQ(a.vals[2], 0x7E); -} - - TEST(FP8E4M3Test, MarrayDoubleToEven) { - const sycl::marray in = {0.012, 1.0625, 1000.0}; - fp8_e4m3<3> a(in); + const sycl::marray in = {0.012, 1.0625}; + fp8_e4m3_x2 a(in); EXPECT_EQ(a.vals[0], 0x06); EXPECT_EQ(a.vals[1], 0x38); - EXPECT_EQ(a.vals[2], 0x7E); } diff --git a/sycl/unittests/Extensions/fp8/fp8_e5m2.cpp b/sycl/unittests/Extensions/fp8/fp8_e5m2.cpp index 8455ba8f93752..c1ea19ea3fa47 100644 --- a/sycl/unittests/Extensions/fp8/fp8_e5m2.cpp +++ b/sycl/unittests/Extensions/fp8/fp8_e5m2.cpp @@ -8,64 +8,70 @@ using namespace sycl::ext::oneapi::experimental; TEST(FP8E5M2Test, VariadicConstructorHalf) { - fp8_e5m2<2> a(sycl::half(1.0f), sycl::half(2.0f)); + fp8_e5m2_x2 a(sycl::half(1.0f), sycl::half(2.0f)); EXPECT_EQ(sizeof(a.vals), 2u); EXPECT_EQ(a.vals[0], 0x3C); // 1.0 -> 0b0_01111_00 EXPECT_EQ(a.vals[1], 0x40); // 2.0 -> 0b0_10000_00 - fp8_e5m2<1> b(sycl::half(1.1f)); + fp8_e5m2 b(sycl::half(1.1f)); EXPECT_EQ(sizeof(b.vals), 1u); EXPECT_EQ(b.vals[0], 0x3C); // 1.1 rounds to 1.0 } TEST(FP8E5M2Test, VariadicConstructorBFloat16) { - fp8_e5m2<2> a(sycl::ext::oneapi::bfloat16(1.0f), + fp8_e5m2_x2 a(sycl::ext::oneapi::bfloat16(1.0f), sycl::ext::oneapi::bfloat16(2.0f)); EXPECT_EQ(sizeof(a.vals), 2u); EXPECT_EQ(a.vals[0], 0x3C); EXPECT_EQ(a.vals[1], 0x40); - fp8_e5m2<1> b(sycl::ext::oneapi::bfloat16(1.1f)); + fp8_e5m2 b(sycl::ext::oneapi::bfloat16(1.1f)); EXPECT_EQ(sizeof(b.vals), 1u); EXPECT_EQ(b.vals[0], 0x3C); } TEST(FP8E5M2Test, VariadicConstructorFloat) { - fp8_e5m2<2> a(1.0f, 2.0f); + fp8_e5m2_x2 a(1.0f, 2.0f); EXPECT_EQ(sizeof(a.vals), 2u); EXPECT_EQ(a.vals[0], 0x3C); EXPECT_EQ(a.vals[1], 0x40); - fp8_e5m2<1> b(1.1f); + fp8_e5m2 b(1.1f); EXPECT_EQ(sizeof(b.vals), 1u); EXPECT_EQ(b.vals[0], 0x3C); } TEST(FP8E5M2Test, VariadicConstructorBoundaryEncodingsFloat) { - fp8_e5m2<6> a( - 57344.0f, // max normal -> S.11110.11 - 0.00006103515625f, // min normal -> S.00001.00 (2^-14) + fp8_e5m2_x2 a(57344.0f, // max normal -> S.11110.11 + 0.00006103515625f // min normal -> S.00001.00 (2^-14) + ); + + fp8_e5m2_x2 a1( 0.0000457763671875f, // max subnorm -> S.00000.11 (0.75 * 2^-14) - 0.0000152587890625f, // min subnorm -> S.00000.01 (2^-16) - 0.0f, // +0 - -0.0f // -0 + 0.0000152587890625f // min subnorm -> S.00000.01 (2^-16) ); - EXPECT_EQ(sizeof(a.vals), 6u); + fp8_e5m2_x2 a2(0.0f, // +0 + -0.0f // -0 + ); - EXPECT_EQ(a.vals[0], 0x7B); // +57344.0 -> 0b0_11110_11 - EXPECT_EQ(a.vals[1], 0x04); // +2^-14 -> 0b0_00001_00 - EXPECT_EQ(a.vals[2], 0x03); // +max subnorm -> 0b0_00000_11 - EXPECT_EQ(a.vals[3], 0x01); // +min subnorm -> 0b0_00000_01 - EXPECT_EQ(a.vals[4], 0x00); // +0 -> 0b0_00000_00 - EXPECT_EQ(a.vals[5], 0x80); // -0 -> 0b1_00000_00 + EXPECT_EQ(sizeof(a.vals), 2u); + EXPECT_EQ(sizeof(a1.vals), 2u); + EXPECT_EQ(sizeof(a2.vals), 2u); + + EXPECT_EQ(a.vals[0], 0x7B); // +57344.0 -> 0b0_11110_11 + EXPECT_EQ(a.vals[1], 0x04); // +2^-14 -> 0b0_00001_00 + EXPECT_EQ(a1.vals[0], 0x03); // +max subnorm -> 0b0_00000_11 + EXPECT_EQ(a1.vals[1], 0x01); // +min subnorm -> 0b0_00000_01 + EXPECT_EQ(a2.vals[0], 0x00); // +0 -> 0b0_00000_00 + EXPECT_EQ(a2.vals[1], 0x80); // -0 -> 0b1_00000_00 } TEST(FP8E5M2Test, VariadicConstructorNaNEncodingFloat) { - fp8_e5m2<2> a(std::numeric_limits::quiet_NaN(), + fp8_e5m2_x2 a(std::numeric_limits::quiet_NaN(), -std::numeric_limits::quiet_NaN()); EXPECT_EQ(sizeof(a.vals), 2u); @@ -74,10 +80,10 @@ TEST(FP8E5M2Test, VariadicConstructorNaNEncodingFloat) { } TEST(FP8E5M2Test, IntegerConstructorToEvenFiniteAndSize) { - fp8_e5m2<1> a0(0); - fp8_e5m2<1> a1(1); - fp8_e5m2<1> a2(2); - fp8_e5m2<1> an1(-1); + fp8_e5m2 a0(0); + fp8_e5m2 a1(1); + fp8_e5m2 a2(2); + fp8_e5m2 an1(-1); EXPECT_EQ(sizeof(a0.vals), 1u); EXPECT_EQ(sizeof(a1.vals), 1u); @@ -91,7 +97,7 @@ TEST(FP8E5M2Test, IntegerConstructorToEvenFiniteAndSize) { } TEST(FP8E5M2Test, AssignmentOperatorToEvenFiniteAndSize) { - fp8_e5m2<1> a(0.0f); + fp8_e5m2 a(0.0f); EXPECT_EQ(sizeof(a.vals), 1u); EXPECT_EQ(a.vals[0], 0x00); @@ -107,10 +113,10 @@ TEST(FP8E5M2Test, AssignmentOperatorToEvenFiniteAndSize) { TEST(FP8E5M2Test, FloatingPointConversionOperators) { // Floating-point operators: convert stored fp8 to the respective type. - fp8_e5m2<1> one(1.0f); - fp8_e5m2<1> zero_pos(0.0f); - fp8_e5m2<1> zero_neg(-0.0f); - fp8_e5m2<1> min_norm(0.00006103515625f); + fp8_e5m2 one(1.0f); + fp8_e5m2 zero_pos(0.0f); + fp8_e5m2 zero_neg(-0.0f); + fp8_e5m2 min_norm(0.00006103515625f); EXPECT_EQ(sizeof(one.vals), 1u); EXPECT_EQ(one.vals[0], 0x3C); @@ -130,8 +136,8 @@ TEST(FP8E5M2Test, FloatingPointConversionOperators) { TEST(FP8E5M2Test, IntegerConversionOperatorsTowardZero) { // Integer operators: convert using rounding::toward_zero. - fp8_e5m2<1> p(1.5f); // 1.5 exactly representable: 0b0_01111_10 (0x3E) - fp8_e5m2<1> n(-1.5f); // 0xBE + fp8_e5m2 p(1.5f); // 1.5 exactly representable: 0b0_01111_10 (0x3E) + fp8_e5m2 n(-1.5f); // 0xBE EXPECT_EQ(sizeof(p.vals), 1u); EXPECT_EQ(sizeof(n.vals), 1u); @@ -147,10 +153,10 @@ TEST(FP8E5M2Test, IntegerConversionOperatorsTowardZero) { TEST(FP8E5M2Test, BoolOperatorZeroRules) { // bool operator: false iff +0 or -0; otherwise true. - fp8_e5m2<1> zp(0.0f); - fp8_e5m2<1> zn(-0.0f); - fp8_e5m2<1> one(1.0f); - fp8_e5m2<1> sub(0.0000152587890625f); // min subnormal + fp8_e5m2 zp(0.0f); + fp8_e5m2 zn(-0.0f); + fp8_e5m2 one(1.0f); + fp8_e5m2 sub(0.0000152587890625f); // min subnormal EXPECT_EQ(sizeof(zp.vals), 1u); EXPECT_EQ(sizeof(zn.vals), 1u); @@ -165,23 +171,26 @@ TEST(FP8E5M2Test, BoolOperatorZeroRules) { TEST(FP8E5M2Test, VariadicConstructorSaturatesFinite) { // Variadic constructors: to_even + finite saturation (CPU). - fp8_e5m2<4> a(1.0f, - 100000.0f, // above max normal: clamp to +57344 - -100000.0f, // clamp to -57344 + fp8_e5m2_x2 a(1.0f, + 100000.0f // above max normal: clamp to +57344 + ); + + fp8_e5m2_x2 a1(-100000.0f, // clamp to -57344 -0.0f); - EXPECT_EQ(sizeof(a.vals), 4u); + EXPECT_EQ(sizeof(a.vals), 2u); + EXPECT_EQ(sizeof(a1.vals), 2u); EXPECT_EQ(a.vals[0], 0x3C); EXPECT_EQ(a.vals[1], 0x7B); // +max normal - EXPECT_EQ(a.vals[2], 0xFB); // -max normal - EXPECT_EQ(a.vals[3], 0x80); // -0 + EXPECT_EQ(a1.vals[0], 0xFB); // -max normal + EXPECT_EQ(a1.vals[1], 0x80); // -0 } TEST(FP8E5M2Test, VariadicConstructorToEvenTie) { // Tie case: between 1.0 (0x3C) and 1.25 (0x3D) is 1.125 exactly. // to_even => choose 1.0 because its LSB (fraction) is even (0). // Tie between 1.25 (0x3D) and 1.5 (0x3E) is 1.375 exactly => choose 1.5. - fp8_e5m2<2> a(1.125f, -1.375f); + fp8_e5m2_x2 a(1.125f, -1.375f); EXPECT_EQ(sizeof(a.vals), 2u); EXPECT_EQ(a.vals[0], 0x3C); EXPECT_EQ(a.vals[1], 0xBE); @@ -189,126 +198,128 @@ TEST(FP8E5M2Test, VariadicConstructorToEvenTie) { TEST(FP8E5M2Test, CArrayConstructorFloatHostToEvenFinite) { // Host code supports only rounding::to_even and saturation::finite. - const float in[5] = {1.0f, 1.1f, 1.125f, 100000.0f, -0.0f}; - fp8_e5m2<5> a(in); + const float in[2] = {1.0f, 1.1f}; + const float in1[2] = {1.125f, 100000.0f}; + fp8_e5m2_x2 a(in); + fp8_e5m2_x2 a1(in1); - EXPECT_EQ(sizeof(a.vals), 5u); + EXPECT_EQ(sizeof(a.vals), 2u); + EXPECT_EQ(sizeof(a1.vals), 2u); EXPECT_EQ(a.vals[0], 0x3C); // 1.0 EXPECT_EQ(a.vals[1], 0x3C); // 1.1 -> 1.0 - EXPECT_EQ(a.vals[2], 0x3C); // tie -> to_even => 1.0 - EXPECT_EQ(a.vals[3], 0x7B); // finite saturation => +57344 - EXPECT_EQ(a.vals[4], 0x80); // -0 + EXPECT_EQ(a1.vals[0], 0x3C); // tie -> to_even => 1.0 + EXPECT_EQ(a1.vals[1], 0x7B); // finite saturation => +57344 } TEST(FP8E5M2Test, CArrayConstructorDoubleToEvenFinite) { // Double c-array: to_even + finite saturation. - const double in[6] = {57344.0, - 60000.0, - 0.00006103515625, - 0.0000457763671875, - 0.0000152587890625, - std::numeric_limits::quiet_NaN()}; - fp8_e5m2<6> a(in); - - EXPECT_EQ(sizeof(a.vals), 6u); + const double in[2] = {57344.0, 60000.0}; + const double in1[2] = {0.00006103515625, 0.0000457763671875}; + const double in2[2] = {0.0000152587890625, + std::numeric_limits::quiet_NaN()}; + fp8_e5m2_x2 a(in); + fp8_e5m2_x2 a1(in1); + fp8_e5m2_x2 a2(in2); + + EXPECT_EQ(sizeof(a.vals), 2u); + EXPECT_EQ(sizeof(a1.vals), 2u); + EXPECT_EQ(sizeof(a2.vals), 2u); EXPECT_EQ(a.vals[0], 0x7B); // +57344 EXPECT_EQ(a.vals[1], 0x7B); // 60000 -> clamp to +57344 - EXPECT_EQ(a.vals[2], 0x04); // min normal - EXPECT_EQ(a.vals[3], 0x03); // max subnormal - EXPECT_EQ(a.vals[4], 0x01); // min subnormal - EXPECT_EQ(a.vals[5], 0x7F); // NaN + EXPECT_EQ(a1.vals[0], 0x04); // min normal + EXPECT_EQ(a1.vals[1], 0x03); // max subnormal + EXPECT_EQ(a2.vals[0], 0x01); // min subnormal + EXPECT_EQ(a2.vals[1], 0x7F); // NaN } TEST(FP8E5M2Test, CArrayConstructorHalfHostToEvenFinite) { - const sycl::half in[4] = {sycl::half(1.0f), sycl::half(2.0f), - sycl::half(1.125f), sycl::half(-0.0f)}; - fp8_e5m2<4> a(in); + const sycl::half in[2] = {sycl::half(1.0f), sycl::half(2.0f)}; + const sycl::half in1[2] = {sycl::half(1.125f), sycl::half(-0.0f)}; + fp8_e5m2_x2 a(in); + fp8_e5m2_x2 a1(in1); - EXPECT_EQ(sizeof(a.vals), 4u); + EXPECT_EQ(sizeof(a.vals), 2u); EXPECT_EQ(a.vals[0], 0x3C); EXPECT_EQ(a.vals[1], 0x40); - EXPECT_EQ(a.vals[2], 0x3C); // tie -> to_even => 1.0 - EXPECT_EQ(a.vals[3], 0x80); + EXPECT_EQ(a1.vals[0], 0x3C); // tie -> to_even => 1.0 + EXPECT_EQ(a1.vals[1], 0x80); } TEST(FP8E5M2Test, CArrayConstructorBFloat16HostToEvenFinite) { - const sycl::ext::oneapi::bfloat16 in[4] = { - sycl::ext::oneapi::bfloat16(1.0f), sycl::ext::oneapi::bfloat16(2.0f), + const sycl::ext::oneapi::bfloat16 in[2] = {sycl::ext::oneapi::bfloat16(1.0f), + sycl::ext::oneapi::bfloat16(2.0f)}; + const sycl::ext::oneapi::bfloat16 in1[2] = { sycl::ext::oneapi::bfloat16(1.125f), sycl::ext::oneapi::bfloat16(-0.0f)}; - fp8_e5m2<4> a(in); + fp8_e5m2_x2 a(in); + fp8_e5m2_x2 a1(in1); - EXPECT_EQ(sizeof(a.vals), 4u); + EXPECT_EQ(sizeof(a.vals), 2u); + EXPECT_EQ(sizeof(a1.vals), 2u); EXPECT_EQ(a.vals[0], 0x3C); EXPECT_EQ(a.vals[1], 0x40); - EXPECT_EQ(a.vals[2], 0x3C); // tie -> to_even => 1.0 - EXPECT_EQ(a.vals[3], 0x80); + EXPECT_EQ(a1.vals[0], 0x3C); // tie -> to_even => 1.0 + EXPECT_EQ(a1.vals[1], 0x80); } -TEST(FP8E5M2Test, MarrayConstructorAndOperatorsHostAllN) { - // marray constructors/operators: host supports all N. - sycl::marray in = { - 1.0f, 2.0f, 0.0f, -0.0f, 57344.0f, 100000.0f, 0.0000152587890625f, -1.5f}; - fp8_e5m2<8> a(in); +TEST(FP8E5M2Test, MarrayConstructorAndOperators) { + sycl::marray in = {1.0f, 2.0f}; + sycl::marray in1 = {0.0f, -0.0f}; + sycl::marray in2 = {57344.0f, 100000.0f}; + sycl::marray in3 = {0.0000152587890625f, -1.5f}; + fp8_e5m2_x2 a(in); + fp8_e5m2_x2 a1(in1); + fp8_e5m2_x2 a2(in2); + fp8_e5m2_x2 a3(in3); - EXPECT_EQ(sizeof(a.vals), 8u); + EXPECT_EQ(sizeof(a.vals), 2u); + EXPECT_EQ(sizeof(a1.vals), 2u); + EXPECT_EQ(sizeof(a2.vals), 2u); + EXPECT_EQ(sizeof(a3.vals), 2u); EXPECT_EQ(a.vals[0], 0x3C); EXPECT_EQ(a.vals[1], 0x40); - EXPECT_EQ(a.vals[2], 0x00); - EXPECT_EQ(a.vals[3], 0x80); - EXPECT_EQ(a.vals[4], 0x7B); - EXPECT_EQ(a.vals[5], 0x7B); // finite saturation - EXPECT_EQ(a.vals[6], 0x01); - EXPECT_EQ(a.vals[7], 0xBE); // -1.5 - - sycl::marray out = static_cast>(a); + EXPECT_EQ(a1.vals[0], 0x00); + EXPECT_EQ(a1.vals[1], 0x80); + EXPECT_EQ(a2.vals[0], 0x7B); + EXPECT_EQ(a2.vals[1], 0x7B); // finite saturation + EXPECT_EQ(a3.vals[0], 0x01); + EXPECT_EQ(a3.vals[1], 0xBE); // -1.5 + + sycl::marray out = static_cast>(a); + sycl::marray out1 = static_cast>(a1); + sycl::marray out2 = static_cast>(a2); + sycl::marray out3 = static_cast>(a3); + EXPECT_EQ(out[0], 1.0f); EXPECT_EQ(out[1], 2.0f); - EXPECT_EQ(out[2], 0.0f); - EXPECT_EQ(out[3], 0.0f); - EXPECT_TRUE(std::signbit(out[3])); - EXPECT_EQ(out[4], 57344.0f); - EXPECT_EQ(out[5], 57344.0f); - EXPECT_EQ(out[6], 0.0000152587890625f); - EXPECT_EQ(out[7], -1.5f); + EXPECT_EQ(out1[0], 0.0f); + EXPECT_EQ(out1[1], 0.0f); + EXPECT_TRUE(std::signbit(out1[1])); + EXPECT_EQ(out2[0], 57344.0f); + EXPECT_EQ(out2[1], 57344.0f); + EXPECT_EQ(out3[0], 0.0000152587890625f); + EXPECT_EQ(out3[1], -1.5f); } -TEST(FP8E5M2Test, MarrayConstructorHalfBFloat16Double) { - sycl::marray hvals = {sycl::half(1.0f), sycl::half(2.0f), - sycl::half(57344.0f), sycl::half(-0.0f)}; - sycl::marray bvals = { - sycl::ext::oneapi::bfloat16(1.0f), sycl::ext::oneapi::bfloat16(2.0f), - sycl::ext::oneapi::bfloat16(0.0000152587890625f), - sycl::ext::oneapi::bfloat16(-0.0f)}; - sycl::marray dvals = {1.0, 2.0, 57344.0, -0.0}; +TEST(FP8E5M2Test, MarrayConstructorDouble) { + sycl::marray dvals = {1.0, 2.0}; + sycl::marray dvals1 = {57344.0, -0.0}; - fp8_e5m2<4> ah(hvals); - fp8_e5m2<4> ab(bvals); - fp8_e5m2<4> ad(dvals); + fp8_e5m2_x2 ah(dvals); + fp8_e5m2_x2 ah1(dvals1); - EXPECT_EQ(sizeof(ah.vals), 4u); - EXPECT_EQ(sizeof(ab.vals), 4u); - EXPECT_EQ(sizeof(ad.vals), 4u); + EXPECT_EQ(sizeof(ah.vals), 2u); + EXPECT_EQ(sizeof(ah1.vals), 2u); EXPECT_EQ(ah.vals[0], 0x3C); EXPECT_EQ(ah.vals[1], 0x40); - EXPECT_EQ(ah.vals[2], 0x7B); - EXPECT_EQ(ah.vals[3], 0x80); - - EXPECT_EQ(ab.vals[0], 0x3C); - EXPECT_EQ(ab.vals[1], 0x40); - EXPECT_EQ(ab.vals[2], 0x01); - EXPECT_EQ(ab.vals[3], 0x80); - - EXPECT_EQ(ad.vals[0], 0x3C); - EXPECT_EQ(ad.vals[1], 0x40); - EXPECT_EQ(ad.vals[2], 0x7B); - EXPECT_EQ(ad.vals[3], 0x80); + EXPECT_EQ(ah1.vals[0], 0x7B); + EXPECT_EQ(ah1.vals[1], 0x80); } TEST(FP8E5M2Test, FloatingPointConversionOperatorsMoreTypes) { - fp8_e5m2<1> a(1.0f); - fp8_e5m2<1> b(0.00006103515625f); - fp8_e5m2<1> nanv(std::numeric_limits::quiet_NaN()); + fp8_e5m2 a(1.0f); + fp8_e5m2 b(0.00006103515625f); + fp8_e5m2 nanv(std::numeric_limits::quiet_NaN()); EXPECT_EQ(sizeof(a.vals), 1u); EXPECT_EQ(sizeof(b.vals), 1u); @@ -329,7 +340,7 @@ TEST(FP8E5M2Test, FloatingPointConversionOperatorsMoreTypes) { } TEST(FP8E5M2Test, MarrayConversionOperatorsHalfBFloat16) { - fp8_e5m2<2> a(1.0f, -0.0f); + fp8_e5m2_x2 a(1.0f, -0.0f); EXPECT_EQ(sizeof(a.vals), 2u); EXPECT_EQ(a.vals[0], 0x3C); @@ -348,24 +359,9 @@ TEST(FP8E5M2Test, MarrayConversionOperatorsHalfBFloat16) { EXPECT_TRUE(std::signbit(static_cast(bo[1]))); } -TEST(FP8E5M2Test, IntegerConversionOperatorsMultipleWidthsTowardZero) { - fp8_e5m2<1> p(1.5f); - fp8_e5m2<1> n(-1.5f); - - std::int32_t i32p = static_cast(p); - std::int32_t i32n = static_cast(n); - std::int64_t i64p = static_cast(p); - std::int64_t i64n = static_cast(n); - - EXPECT_EQ(i32p, 1); - EXPECT_EQ(i32n, -1); - EXPECT_EQ(i64p, 1); - EXPECT_EQ(i64n, -1); -} - -TEST(FP8E5M2Test, IntegerConversionOperatorsAllTypesTowardZero) { - fp8_e5m2<1> p(1.5f); - fp8_e5m2<1> n(-1.5f); +TEST(FP8E5M2Test, IntegerConversionOperators) { + fp8_e5m2 p(1.5f); + fp8_e5m2 n(-1.5f); EXPECT_EQ(sizeof(p.vals), 1u); EXPECT_EQ(sizeof(n.vals), 1u); @@ -385,73 +381,8 @@ TEST(FP8E5M2Test, IntegerConversionOperatorsAllTypesTowardZero) { EXPECT_EQ(static_cast(p), 1u); } -TEST(FP8E5M2Test, VariadicConstructorHalfBoundaryEncodings) { - fp8_e5m2<4> a(sycl::half(57344.0f), sycl::half(0.00006103515625f), - sycl::half(0.0000152587890625f), sycl::half(-0.0f)); - - EXPECT_EQ(sizeof(a.vals), 4u); - EXPECT_EQ(a.vals[0], 0x7B); // +max normal - EXPECT_EQ(a.vals[1], 0x04); // min normal - EXPECT_EQ(a.vals[2], 0x01); // min subnormal - EXPECT_EQ(a.vals[3], 0x80); // -0 -} - -TEST(FP8E5M2Test, VariadicConstructorBFloat16BoundaryEncodings) { - fp8_e5m2<4> a(sycl::ext::oneapi::bfloat16(1.0f), - sycl::ext::oneapi::bfloat16(2.0f), - sycl::ext::oneapi::bfloat16(0.0000152587890625f), - sycl::ext::oneapi::bfloat16(-0.0f)); - - EXPECT_EQ(sizeof(a.vals), 4u); - EXPECT_EQ(a.vals[0], 0x3C); - EXPECT_EQ(a.vals[1], 0x40); - EXPECT_EQ(a.vals[2], 0x01); - EXPECT_EQ(a.vals[3], 0x80); -} - -TEST(FP8E5M2Test, VariadicConstructorDoubleBoundaryEncodingsAndSaturation) { - fp8_e5m2<5> a(57344.0, 60000.0, 0.0000457763671875, 0.0000152587890625, - -100000.0); - - EXPECT_EQ(sizeof(a.vals), 5u); - EXPECT_EQ(a.vals[0], 0x7B); // +57344 - EXPECT_EQ(a.vals[1], 0x7B); // clamp to +57344 (finite saturation) - EXPECT_EQ(a.vals[2], 0x03); // max subnormal - EXPECT_EQ(a.vals[3], 0x01); // min subnormal - EXPECT_EQ(a.vals[4], 0xFB); // clamp to -57344 -} - -TEST(FP8E5M2Test, IntegerConstructorsAllTypes) { - fp8_e5m2<1> s(static_cast(1)); - fp8_e5m2<1> i(static_cast(2)); - fp8_e5m2<1> l(static_cast(3)); - fp8_e5m2<1> ll(static_cast(-1)); - fp8_e5m2<1> us(static_cast(1)); - fp8_e5m2<1> ui(static_cast(2)); - fp8_e5m2<1> ul(static_cast(3)); - fp8_e5m2<1> ull(static_cast(4)); - - EXPECT_EQ(sizeof(s.vals), 1u); - EXPECT_EQ(sizeof(i.vals), 1u); - EXPECT_EQ(sizeof(l.vals), 1u); - EXPECT_EQ(sizeof(ll.vals), 1u); - EXPECT_EQ(sizeof(us.vals), 1u); - EXPECT_EQ(sizeof(ui.vals), 1u); - EXPECT_EQ(sizeof(ul.vals), 1u); - EXPECT_EQ(sizeof(ull.vals), 1u); - - EXPECT_EQ(s.vals[0], 0x3C); - EXPECT_EQ(i.vals[0], 0x40); - EXPECT_EQ(l.vals[0], 0x42); // 3.0 -> 0b0_10000_10 - EXPECT_EQ(ll.vals[0], 0xBC); // -1.0 - EXPECT_EQ(us.vals[0], 0x3C); - EXPECT_EQ(ui.vals[0], 0x40); - EXPECT_EQ(ul.vals[0], 0x42); // 3.0 - EXPECT_EQ(ull.vals[0], 0x44); // 4.0 -> 0b0_10001_00 -} - TEST(FP8E5M2Test, AssignmentOperatorsAllTypes) { - fp8_e5m2<1> a(0.0f); + fp8_e5m2 a(0.0f); EXPECT_EQ(sizeof(a.vals), 1u); EXPECT_EQ(a.vals[0], 0x00); @@ -495,11 +426,11 @@ TEST(FP8E5M2Test, AssignmentOperatorsAllTypes) { TEST(FP8E5M2Test, BoolOperatorWithNaN) { float pz = 0.0f; - fp8_e5m2<1> zp(pz); + fp8_e5m2 zp(pz); float zv = -0.0f; - fp8_e5m2<1> zn(zv); + fp8_e5m2 zn(zv); float nv = {std::numeric_limits::quiet_NaN()}; - fp8_e5m2<1> nanv(nv); + fp8_e5m2 nanv(nv); EXPECT_EQ(sizeof(zp.vals), 1u); EXPECT_EQ(sizeof(zn.vals), 1u); diff --git a/sycl/unittests/Extensions/fp8/fp8_e5m3.cpp b/sycl/unittests/Extensions/fp8/fp8_e5m3.cpp new file mode 100644 index 0000000000000..f2a69c446d6a5 --- /dev/null +++ b/sycl/unittests/Extensions/fp8/fp8_e5m3.cpp @@ -0,0 +1,495 @@ +#include +#include + +#include +#include +#include + +using namespace sycl::ext::oneapi::experimental; +using sycl::ext::oneapi::bfloat16; + +TEST(FP8E5M3ArrayCtor, HalfDefaultRounding) { + const sycl::half vals[2] = {sycl::half(1.0f), + sycl::half(std::ldexp(1.0f, -14))}; + fp8_e5m3_x2 v(vals); + + EXPECT_EQ(sizeof(v.vals), 2u); + EXPECT_EQ(v.vals[0], 0x78); + EXPECT_EQ(v.vals[1], 0x08); +} + +TEST(FP8E5M3ArrayCtor, HalfExplicitRounding) { + const sycl::half vals[2] = {sycl::half(1.5f), sycl::half(0.75f)}; + fp8_e5m3_x2 v(vals, rounding::to_even); + + EXPECT_EQ(sizeof(v.vals), 2u); + EXPECT_EQ(v.vals[0], 0x7C); + EXPECT_EQ(v.vals[1], 0x74); +} + +TEST(FP8E5M3ArrayCtor, Bfloat16DefaultRounding) { + const bfloat16 vals[2] = {bfloat16(2.0f), bfloat16(std::ldexp(1.0f, -13))}; + fp8_e5m3_x2 v(vals); + + EXPECT_EQ(sizeof(v.vals), 2u); + EXPECT_EQ(v.vals[0], 0x80); + EXPECT_EQ(v.vals[1], 0x10); +} + +TEST(FP8E5M3ArrayCtor, Bfloat16ExplicitRounding) { + const bfloat16 vals[2] = {bfloat16(3.0f), bfloat16(0.5f)}; + fp8_e5m3_x2 v(vals, rounding::to_even); + + EXPECT_EQ(sizeof(v.vals), 2u); + EXPECT_EQ(v.vals[0], 0x84); + EXPECT_EQ(v.vals[1], 0x70); +} + +TEST(FP8E5M3ArrayCtor, FloatDefaultRounding) { + const float vals[2] = {114688.0f, std::ldexp(1.0f, -17)}; + fp8_e5m3_x2 v(vals); + + EXPECT_EQ(sizeof(v.vals), 2u); + EXPECT_EQ(v.vals[0], 0xFE); + EXPECT_EQ(v.vals[1], 0x01); +} + +TEST(FP8E5M3ArrayCtor, FloatExplicitRounding) { + const float vals[2] = {1.25f, 0.875f * std::ldexp(1.0f, -14)}; + fp8_e5m3_x2 v(vals, rounding::to_even); + + EXPECT_EQ(sizeof(v.vals), 2u); + EXPECT_EQ(v.vals[0], 0x7A); + EXPECT_EQ(v.vals[1], 0x07); +} + +TEST(FP8E5M3ArrayCtor, DoubleDefaultRounding) { + const double vals[2] = {8.0, 0.125}; + fp8_e5m3_x2 v(vals); + + EXPECT_EQ(sizeof(v.vals), 2u); + EXPECT_EQ(v.vals[0], 0x90); + EXPECT_EQ(v.vals[1], 0x60); +} + +TEST(FP8E5M3MarrayCtor, HalfDefaultRounding) { + const sycl::marray vals{sycl::half(1.75f), sycl::half(0.5f)}; + fp8_e5m3_x2 v(vals); + + EXPECT_EQ(sizeof(v.vals), 2u); + EXPECT_EQ(v.vals[0], 0x7E); + EXPECT_EQ(v.vals[1], 0x70); +} + +TEST(FP8E5M3MarrayCtor, HalfExplicitRounding) { + const sycl::marray vals{sycl::half(2.0f), sycl::half(0.125f)}; + fp8_e5m3_x2 v(vals, rounding::to_even); + + EXPECT_EQ(sizeof(v.vals), 2u); + EXPECT_EQ(v.vals[0], 0x80); + EXPECT_EQ(v.vals[1], 0x60); +} + +TEST(FP8E5M3MarrayCtor, Bfloat16DefaultRounding) { + const sycl::marray vals{bfloat16(3.0f), bfloat16(0.25f)}; + fp8_e5m3_x2 v(vals); + + EXPECT_EQ(sizeof(v.vals), 2u); + EXPECT_EQ(v.vals[0], 0x84); + EXPECT_EQ(v.vals[1], 0x68); +} + +TEST(FP8E5M3MarrayCtor, Bfloat16ExplicitRounding) { + const sycl::marray vals{bfloat16(6.0f), bfloat16(0.75f)}; + fp8_e5m3_x2 v(vals, rounding::to_even); + + EXPECT_EQ(sizeof(v.vals), 2u); + EXPECT_EQ(v.vals[0], 0x8C); + EXPECT_EQ(v.vals[1], 0x74); +} + +TEST(FP8E5M3MarrayCtor, FloatDefaultRounding) { + const sycl::marray vals{12.0f, 0.03125f}; + fp8_e5m3_x2 v(vals); + + EXPECT_EQ(sizeof(v.vals), 2u); + EXPECT_EQ(v.vals[0], 0x94); + EXPECT_EQ(v.vals[1], 0x50); +} + +TEST(FP8E5M3MarrayCtor, FloatExplicitRounding) { + const sycl::marray vals{1.25f, std::ldexp(0.375f, -14)}; + fp8_e5m3_x2 v(vals, rounding::to_even); + + EXPECT_EQ(sizeof(v.vals), 2u); + EXPECT_EQ(v.vals[0], 0x7A); + EXPECT_EQ(v.vals[1], 0x03); +} + +TEST(FP8E5M3MarrayCtor, DoubleDefaultRounding) { + const sycl::marray vals{16.0, 0.0625}; + fp8_e5m3_x2 v(vals); + + EXPECT_EQ(sizeof(v.vals), 2u); + EXPECT_EQ(v.vals[0], 0x98); + EXPECT_EQ(v.vals[1], 0x58); +} + +TEST(FP8E5M3ScalarIntCtor, ShortValue) { + const short val = 5; + fp8_e5m3 v(val); + + EXPECT_EQ(sizeof(v.vals), 1u); + EXPECT_EQ(v.vals[0], 0x8A); +} + +TEST(FP8E5M3ScalarIntCtor, IntValue) { + const int val = 7; + fp8_e5m3 v(val); + + EXPECT_EQ(sizeof(v.vals), 1u); + EXPECT_EQ(v.vals[0], 0x8E); +} + +TEST(FP8E5M3ScalarIntCtor, LongValue) { + const long val = 9; + fp8_e5m3 v(val); + + EXPECT_EQ(sizeof(v.vals), 1u); + EXPECT_EQ(v.vals[0], 0x91); +} + +TEST(FP8E5M3ScalarIntCtor, LongLongValue) { + const long long val = 10; + fp8_e5m3 v(val); + + EXPECT_EQ(sizeof(v.vals), 1u); + EXPECT_EQ(v.vals[0], 0x92); +} + +TEST(FP8E5M3ScalarIntCtor, UnsignedShortValue) { + const unsigned short val = 14; + fp8_e5m3 v(val); + + EXPECT_EQ(sizeof(v.vals), 1u); + EXPECT_EQ(v.vals[0], 0x96); +} + +TEST(FP8E5M3ScalarIntCtor, UnsignedIntValue) { + const unsigned int val = 15; + fp8_e5m3 v(val); + + EXPECT_EQ(sizeof(v.vals), 1u); + EXPECT_EQ(v.vals[0], 0x97); +} + +TEST(FP8E5M3ScalarIntCtor, UnsignedLongValue) { + const unsigned long val = 18; + fp8_e5m3 v(val); + + EXPECT_EQ(sizeof(v.vals), 1u); + EXPECT_EQ(v.vals[0], 0x99); +} + +TEST(FP8E5M3ScalarIntCtor, UnsignedLongLongValue) { + const unsigned long long val = 20; + fp8_e5m3 v(val); + + EXPECT_EQ(sizeof(v.vals), 1u); + EXPECT_EQ(v.vals[0], 0x9A); +} + +TEST(FP8E5M3ScalarIntCtor, UnsignedLimitsSaturate) { + const unsigned short usmax = std::numeric_limits::max(); + const unsigned int uimax = std::numeric_limits::max(); + const unsigned long ulmax = std::numeric_limits::max(); + const unsigned long long ullmax = + std::numeric_limits::max(); + + fp8_e5m3 vus(usmax); + fp8_e5m3 vui(uimax); + fp8_e5m3 vul(ulmax); + fp8_e5m3 vull(ullmax); + + EXPECT_EQ(sizeof(vus.vals), 1u); + EXPECT_EQ(vus.vals[0], 0xFE); + EXPECT_EQ(vui.vals[0], 0xFE); + EXPECT_EQ(vul.vals[0], 0xFE); + EXPECT_EQ(vull.vals[0], 0xFE); +} + +TEST(FP8E5M3AssignOp, HalfValue) { + fp8_e5m3 v(sycl::half(1.0f)); + v = sycl::half(1.125f); + + EXPECT_EQ(sizeof(v.vals), 1u); + EXPECT_EQ(v.vals[0], 0x79); +} + +TEST(FP8E5M3AssignOp, Bfloat16Value) { + fp8_e5m3 v(bfloat16(1.0f)); + v = bfloat16(1.875f); + + EXPECT_EQ(sizeof(v.vals), 1u); + EXPECT_EQ(v.vals[0], 0x7F); +} + +TEST(FP8E5M3AssignOp, FloatValue) { + fp8_e5m3 v(1.0f); + v = 11.0f; + + EXPECT_EQ(sizeof(v.vals), 1u); + EXPECT_EQ(v.vals[0], 0x93); +} + +TEST(FP8E5M3AssignOp, DoubleValue) { + fp8_e5m3 v(1.0); + v = 5.5; + + EXPECT_EQ(sizeof(v.vals), 1u); + EXPECT_EQ(v.vals[0], 0x8B); +} + +TEST(FP8E5M3AssignOp, ShortValue) { + fp8_e5m3 v(1); + v = static_cast(6); + + EXPECT_EQ(sizeof(v.vals), 1u); + EXPECT_EQ(v.vals[0], 0x8C); +} + +TEST(FP8E5M3AssignOp, IntValue) { + fp8_e5m3 v(1); + v = 12; + + EXPECT_EQ(sizeof(v.vals), 1u); + EXPECT_EQ(v.vals[0], 0x94); +} + +TEST(FP8E5M3AssignOp, LongValue) { + fp8_e5m3 v(1); + v = static_cast(13); + + EXPECT_EQ(sizeof(v.vals), 1u); + EXPECT_EQ(v.vals[0], 0x95); +} + +TEST(FP8E5M3AssignOp, LongLongValue) { + fp8_e5m3 v(1); + v = static_cast(17); + + EXPECT_EQ(sizeof(v.vals), 1u); + EXPECT_EQ(v.vals[0], 0x98); +} + +TEST(FP8E5M3AssignOp, UnsignedShortValue) { + fp8_e5m3 v(1); + v = static_cast(21); + + EXPECT_EQ(sizeof(v.vals), 1u); + EXPECT_EQ(v.vals[0], 0x9A); +} + +TEST(FP8E5M3AssignOp, UnsignedIntValue) { + fp8_e5m3 v(1); + v = static_cast(22); + + EXPECT_EQ(sizeof(v.vals), 1u); + EXPECT_EQ(v.vals[0], 0x9B); +} + +TEST(FP8E5M3AssignOp, UnsignedLongValue) { + fp8_e5m3 v(1); + v = static_cast(24); + + EXPECT_EQ(sizeof(v.vals), 1u); + EXPECT_EQ(v.vals[0], 0x9C); +} + +TEST(FP8E5M3AssignOp, UnsignedLongLongValue) { + fp8_e5m3 v(1); + v = static_cast(26); + + EXPECT_EQ(sizeof(v.vals), 1u); + EXPECT_EQ(v.vals[0], 0x9D); +} + +TEST(FP8E5M3ConvertOp, HalfValue) { + fp8_e5m3 v(2.5f); + auto out = static_cast(v); + + EXPECT_EQ(sizeof(v.vals), 1u); + EXPECT_EQ(v.vals[0], 0x82); + EXPECT_EQ(static_cast(out), 2.5f); +} + +TEST(FP8E5M3ConvertOp, Bfloat16Value) { + fp8_e5m3 v(0.375f); + auto out = static_cast(v); + + EXPECT_EQ(sizeof(v.vals), 1u); + EXPECT_EQ(v.vals[0], 0x6C); + EXPECT_EQ(static_cast(out), 0.375f); +} + +TEST(FP8E5M3ConvertOp, FloatValue) { + fp8_e5m3 v(2.25f); + float out = static_cast(v); + + EXPECT_EQ(sizeof(v.vals), 1u); + EXPECT_EQ(v.vals[0], 0x81); + EXPECT_EQ(out, 2.25f); +} + +TEST(FP8E5M3ConvertOp, DoubleValue) { + fp8_e5m3 v(4.0f); + double out = static_cast(v); + + EXPECT_EQ(sizeof(v.vals), 1u); + EXPECT_EQ(v.vals[0], 0x88); + EXPECT_EQ(out, 4.0); +} + +TEST(FP8E5M3ConvertIntOp, CharValue) { + fp8_e5m3 v(3.5f); + char out = static_cast(v); + + EXPECT_EQ(sizeof(v.vals), 1u); + EXPECT_EQ(v.vals[0], 0x86); + EXPECT_EQ(out, static_cast(3)); +} + +TEST(FP8E5M3ConvertIntOp, SignedCharValue) { + fp8_e5m3 v(6.5f); + signed char out = static_cast(v); + + EXPECT_EQ(sizeof(v.vals), 1u); + EXPECT_EQ(v.vals[0], 0x8D); + EXPECT_EQ(out, static_cast(6)); +} + +TEST(FP8E5M3ConvertIntOp, ShortValue) { + fp8_e5m3 v(7.5f); + short out = static_cast(v); + + EXPECT_EQ(sizeof(v.vals), 1u); + EXPECT_EQ(v.vals[0], 0x8F); + EXPECT_EQ(out, static_cast(7)); +} + +TEST(FP8E5M3ConvertIntOp, IntValue) { + fp8_e5m3 v(8.0f); + int out = static_cast(v); + + EXPECT_EQ(sizeof(v.vals), 1u); + EXPECT_EQ(v.vals[0], 0x90); + EXPECT_EQ(out, 8); +} + +TEST(FP8E5M3ConvertIntOp, LongValue) { + fp8_e5m3 v(9.0f); + long out = static_cast(v); + + EXPECT_EQ(sizeof(v.vals), 1u); + EXPECT_EQ(v.vals[0], 0x91); + EXPECT_EQ(out, 9L); +} + +TEST(FP8E5M3ConvertIntOp, LongLongValue) { + fp8_e5m3 v(10.0f); + long long out = static_cast(v); + + EXPECT_EQ(sizeof(v.vals), 1u); + EXPECT_EQ(v.vals[0], 0x92); + EXPECT_EQ(out, 10LL); +} + +TEST(FP8E5M3ConvertIntOp, UnsignedCharValue) { + fp8_e5m3 v(11.0f); + unsigned char out = static_cast(v); + + EXPECT_EQ(sizeof(v.vals), 1u); + EXPECT_EQ(v.vals[0], 0x93); + EXPECT_EQ(out, static_cast(11)); +} + +TEST(FP8E5M3ConvertIntOp, UnsignedShortValue) { + fp8_e5m3 v(12.0f); + unsigned short out = static_cast(v); + + EXPECT_EQ(sizeof(v.vals), 1u); + EXPECT_EQ(v.vals[0], 0x94); + EXPECT_EQ(out, static_cast(12)); +} + +TEST(FP8E5M3ConvertIntOp, UnsignedIntValue) { + fp8_e5m3 v(13.0f); + unsigned int out = static_cast(v); + + EXPECT_EQ(sizeof(v.vals), 1u); + EXPECT_EQ(v.vals[0], 0x95); + EXPECT_EQ(out, 13u); +} + +TEST(FP8E5M3ConvertIntOp, UnsignedLongValue) { + fp8_e5m3 v(14.0f); + unsigned long out = static_cast(v); + + EXPECT_EQ(sizeof(v.vals), 1u); + EXPECT_EQ(v.vals[0], 0x96); + EXPECT_EQ(out, 14UL); +} + +TEST(FP8E5M3ConvertIntOp, UnsignedLongLongValue) { + fp8_e5m3 v(15.0f); + unsigned long long out = static_cast(v); + + EXPECT_EQ(sizeof(v.vals), 1u); + EXPECT_EQ(v.vals[0], 0x97); + EXPECT_EQ(out, 15ULL); +} + +TEST(FP8E5M3ConvertOp, BoolFalse) { + fp8_e5m3 v(0.0f); + bool out = static_cast(v); + + EXPECT_EQ(sizeof(v.vals), 1u); + EXPECT_EQ(v.vals[0], 0x00); + EXPECT_FALSE(out); +} + +TEST(FP8E5M3ConvertOp, BoolTrue) { + fp8_e5m3 v(0.25f); + bool out = static_cast(v); + + EXPECT_EQ(sizeof(v.vals), 1u); + EXPECT_EQ(v.vals[0], 0x68); + EXPECT_TRUE(out); +} + +TEST(FP8E5M3ConvertOp, MarrayHalf) { + fp8_e5m3 v(1.625f); + auto out = static_cast>(v); + + EXPECT_EQ(sizeof(v.vals), 1u); + EXPECT_EQ(v.vals[0], 0x7D); + EXPECT_EQ(static_cast(out[0]), 1.625f); +} + +TEST(FP8E5M3ConvertOp, MarrayBfloat16) { + fp8_e5m3 v(2.75f); + auto out = static_cast>(v); + + EXPECT_EQ(sizeof(v.vals), 1u); + EXPECT_EQ(v.vals[0], 0x83); + EXPECT_EQ(static_cast(out[0]), 2.75f); +} + +TEST(FP8E5M3ConvertOp, MarrayFloat) { + fp8_e5m3 v(5.0f); + auto out = static_cast>(v); + + EXPECT_EQ(sizeof(v.vals), 1u); + EXPECT_EQ(v.vals[0], 0x8A); + EXPECT_EQ(out[0], 5.0f); +} diff --git a/sycl/unittests/Extensions/fp8/fp8_e8m0.cpp b/sycl/unittests/Extensions/fp8/fp8_e8m0.cpp index 30313da4b9264..e8478bf9447db 100644 --- a/sycl/unittests/Extensions/fp8/fp8_e8m0.cpp +++ b/sycl/unittests/Extensions/fp8/fp8_e8m0.cpp @@ -8,273 +8,263 @@ using namespace sycl::ext::oneapi::experimental; TEST(FP8E8M0Test, VariadicConstructorFloat) { - fp8_e8m0<4> a(1.0f, 2.0f, 1.1f, 0.0f); - - EXPECT_EQ(sizeof(a.vals), 4u); - EXPECT_EQ(a.vals[0], 0x7F); // 1.0 -> exp=127 - EXPECT_EQ(a.vals[1], 0x80); // 2.0 -> exp=128 - EXPECT_EQ(a.vals[2], 0x80); // 1.1 -> upward to 2.0 - EXPECT_EQ(a.vals[3], 0x00); // 0.0 -> min normal + fp8_e8m0_x2 a(1.0f, 2.0f); + fp8_e8m0_x2 a1(1.1f, 0.0f); + + EXPECT_EQ(sizeof(a.vals), 2u); + EXPECT_EQ(sizeof(a1.vals), 2u); + EXPECT_EQ(a.vals[0], 0x7F); // 1.0 -> exp=127 + EXPECT_EQ(a.vals[1], 0x80); // 2.0 -> exp=128 + EXPECT_EQ(a1.vals[0], 0x80); // 1.1 -> upward to 2.0 + EXPECT_EQ(a1.vals[1], 0x00); // 0.0 -> min normal } TEST(FP8E8M0Test, VariadicConstructorHalf) { - fp8_e8m0<2> a(sycl::half(1.0f), sycl::half(3.0f)); + fp8_e8m0_x2 a(sycl::half(1.0f), sycl::half(3.0f)); - EXPECT_EQ(sizeof(a.vals), 2u); - EXPECT_EQ(a.vals[0], 0x7F); - EXPECT_EQ(a.vals[1], 0x81); // 3.0 -> upward to 4.0 + EXPECT_EQ(sizeof(a.vals), 2u); + EXPECT_EQ(a.vals[0], 0x7F); + EXPECT_EQ(a.vals[1], 0x81); // 3.0 -> upward to 4.0 } TEST(FP8E8M0Test, VariadicConstructorBFloat16) { - fp8_e8m0<2> a(sycl::ext::oneapi::bfloat16(1.0f), - sycl::ext::oneapi::bfloat16(2.0f)); + fp8_e8m0_x2 a(sycl::ext::oneapi::bfloat16(1.0f), + sycl::ext::oneapi::bfloat16(2.0f)); - EXPECT_EQ(sizeof(a.vals), 2u); - EXPECT_EQ(a.vals[0], 0x7F); - EXPECT_EQ(a.vals[1], 0x80); + EXPECT_EQ(sizeof(a.vals), 2u); + EXPECT_EQ(a.vals[0], 0x7F); + EXPECT_EQ(a.vals[1], 0x80); } TEST(FP8E8M0Test, VariadicConstructorDouble) { - fp8_e8m0<2> a(1.0, 3.0); + fp8_e8m0_x2 a(1.0, 3.0); - EXPECT_EQ(sizeof(a.vals), 2u); - EXPECT_EQ(a.vals[0], 0x7F); - EXPECT_EQ(a.vals[1], 0x81); + EXPECT_EQ(sizeof(a.vals), 2u); + EXPECT_EQ(a.vals[0], 0x7F); + EXPECT_EQ(a.vals[1], 0x81); } TEST(FP8E8M0Test, VariadicConstructorBoundaryEncodings) { - fp8_e8m0<3> a(std::ldexp(1.0f, 127), std::ldexp(1.0f, -127), - std::numeric_limits::quiet_NaN()); + fp8_e8m0_x2 a(std::ldexp(1.0f, -127), + std::numeric_limits::quiet_NaN()); - EXPECT_EQ(sizeof(a.vals), 3u); - EXPECT_EQ(a.vals[0], 0xFE); // max normal - EXPECT_EQ(a.vals[1], 0x00); // min normal - EXPECT_EQ(a.vals[2], 0xFF); // NaN + EXPECT_EQ(sizeof(a.vals), 2u); + EXPECT_EQ(a.vals[0], 0x00); // min normal + EXPECT_EQ(a.vals[1], 0xFF); // NaN } TEST(FP8E8M0Test, CArrayConstructorFloatHostUpwardFinite) { - const float in[5] = {1.0f, 1.1f, 3.0f, 0.0f, 1000.0f}; - fp8_e8m0<5> a(in, rounding::upward); - - EXPECT_EQ(sizeof(a.vals), 5u); - EXPECT_EQ(a.vals[0], 0x7F); - EXPECT_EQ(a.vals[1], 0x80); // upward to 2.0 - EXPECT_EQ(a.vals[2], 0x81); // upward to 4.0 - EXPECT_EQ(a.vals[3], 0x00); // min normal - EXPECT_EQ(a.vals[4], 0x89); // upward to 2^10 = 1024 + const float in[2] = {1.0f, 1.1f}; + const float in1[2] = {3.0f, 1000.0f}; + fp8_e8m0_x2 a(in, rounding::upward); + fp8_e8m0_x2 a1(in1, rounding::upward); + + EXPECT_EQ(sizeof(a.vals), 2u); + EXPECT_EQ(sizeof(a1.vals), 2u); + EXPECT_EQ(a.vals[0], 0x7F); + EXPECT_EQ(a.vals[1], 0x80); // upward to 2.0 + EXPECT_EQ(a1.vals[0], 0x81); // upward to 4.0 + EXPECT_EQ(a1.vals[1], 0x89); // upward to 2^10 = 1024 } TEST(FP8E8M0Test, CArrayConstructorHalfHostUpwardFinite) { - const sycl::half in[4] = {sycl::half(1.0f), sycl::half(1.1f), - sycl::half(3.0f), sycl::half(0.0f)}; - fp8_e8m0<4> a(in, rounding::upward); - - EXPECT_EQ(sizeof(a.vals), 4u); - EXPECT_EQ(a.vals[0], 0x7F); - EXPECT_EQ(a.vals[1], 0x80); - EXPECT_EQ(a.vals[2], 0x81); - EXPECT_EQ(a.vals[3], 0x00); + const sycl::half in[2] = {sycl::half(1.0f), sycl::half(1.1f)}; + const sycl::half in1[2] = {sycl::half(3.0f), sycl::half(0.0f)}; + + fp8_e8m0_x2 a(in, rounding::upward); + fp8_e8m0_x2 a1(in1, rounding::upward); + + EXPECT_EQ(sizeof(a.vals), 2u); + EXPECT_EQ(sizeof(a1.vals), 2u); + EXPECT_EQ(a.vals[0], 0x7F); + EXPECT_EQ(a.vals[1], 0x80); + EXPECT_EQ(a1.vals[0], 0x81); + EXPECT_EQ(a1.vals[1], 0x00); } TEST(FP8E8M0Test, CArrayConstructorBFloat16HostUpwardFinite) { - const sycl::ext::oneapi::bfloat16 in[3] = { - sycl::ext::oneapi::bfloat16(1.0f), - sycl::ext::oneapi::bfloat16(2.0f), - sycl::ext::oneapi::bfloat16(0.0f)}; - fp8_e8m0<3> a(in, rounding::upward); - - EXPECT_EQ(sizeof(a.vals), 3u); - EXPECT_EQ(a.vals[0], 0x7F); - EXPECT_EQ(a.vals[1], 0x80); - EXPECT_EQ(a.vals[2], 0x00); + const sycl::ext::oneapi::bfloat16 in[2] = {sycl::ext::oneapi::bfloat16(1.0f), + sycl::ext::oneapi::bfloat16(2.0f)}; + fp8_e8m0_x2 a(in, rounding::upward); + + EXPECT_EQ(sizeof(a.vals), 2u); + EXPECT_EQ(a.vals[0], 0x7F); + EXPECT_EQ(a.vals[1], 0x80); } TEST(FP8E8M0Test, CArrayConstructorDoubleDefaultUpwardFinite) { - const double in[3] = {1.0, 3.0, 0.0}; - fp8_e8m0<3> a(in); + const double in[2] = {1.0, 3.0}; + fp8_e8m0_x2 a(in); - EXPECT_EQ(sizeof(a.vals), 3u); - EXPECT_EQ(a.vals[0], 0x7F); - EXPECT_EQ(a.vals[1], 0x81); - EXPECT_EQ(a.vals[2], 0x00); + EXPECT_EQ(sizeof(a.vals), 2u); + EXPECT_EQ(a.vals[0], 0x7F); + EXPECT_EQ(a.vals[1], 0x81); } TEST(FP8E8M0Test, MarrayConstructorAndOperatorsFloat) { - sycl::marray in = {1.0f, 2.0f, 3.0f, 0.0f}; - fp8_e8m0<4> a(in, rounding::upward); - - EXPECT_EQ(sizeof(a.vals), 4u); - EXPECT_EQ(a.vals[0], 0x7F); - EXPECT_EQ(a.vals[1], 0x80); - EXPECT_EQ(a.vals[2], 0x81); - EXPECT_EQ(a.vals[3], 0x00); - - sycl::marray out = static_cast>(a); - EXPECT_EQ(out[0], 1.0f); - EXPECT_EQ(out[1], 2.0f); - EXPECT_EQ(out[2], 4.0f); - EXPECT_EQ(out[3], std::ldexp(1.0f, -127)); + sycl::marray in = {1.0f, 2.0f}; + sycl::marray in1 = {3.0f, 0.0f}; + + fp8_e8m0_x2 a(in, rounding::upward); + fp8_e8m0_x2 a1(in1, rounding::upward); + + EXPECT_EQ(sizeof(a.vals), 2u); + EXPECT_EQ(sizeof(a1.vals), 2u); + EXPECT_EQ(a.vals[0], 0x7F); + EXPECT_EQ(a.vals[1], 0x80); + EXPECT_EQ(a1.vals[0], 0x81); + EXPECT_EQ(a1.vals[1], 0x00); + + sycl::marray out = static_cast>(a); + sycl::marray out1 = static_cast>(a1); + EXPECT_EQ(out[0], 1.0f); + EXPECT_EQ(out[1], 2.0f); + EXPECT_EQ(out1[0], 4.0f); + EXPECT_EQ(out1[1], std::ldexp(1.0f, -127)); } TEST(FP8E8M0Test, MarrayConstructorHalfBFloat16Double) { - sycl::marray hvals = {sycl::half(1.0f), sycl::half(3.0f)}; - sycl::marray bvals = { - sycl::ext::oneapi::bfloat16(1.0f), - sycl::ext::oneapi::bfloat16(2.0f)}; - sycl::marray dvals = {1.0, 3.0}; - - fp8_e8m0<2> ah(hvals, rounding::upward); - fp8_e8m0<2> ab(bvals, rounding::upward); - fp8_e8m0<2> ad(dvals); - - EXPECT_EQ(sizeof(ah.vals), 2u); - EXPECT_EQ(sizeof(ab.vals), 2u); - EXPECT_EQ(sizeof(ad.vals), 2u); - - EXPECT_EQ(ah.vals[0], 0x7F); - EXPECT_EQ(ah.vals[1], 0x81); - EXPECT_EQ(ab.vals[0], 0x7F); - EXPECT_EQ(ab.vals[1], 0x80); - EXPECT_EQ(ad.vals[0], 0x7F); - EXPECT_EQ(ad.vals[1], 0x81); + sycl::marray hvals = {sycl::half(1.0f), sycl::half(3.0f)}; + sycl::marray bvals = { + sycl::ext::oneapi::bfloat16(1.0f), sycl::ext::oneapi::bfloat16(2.0f)}; + sycl::marray dvals = {1.0, 3.0}; + + fp8_e8m0_x2 ah(hvals, rounding::upward); + fp8_e8m0_x2 ab(bvals, rounding::upward); + fp8_e8m0_x2 ad(dvals); + + EXPECT_EQ(sizeof(ah.vals), 2u); + EXPECT_EQ(sizeof(ab.vals), 2u); + EXPECT_EQ(sizeof(ad.vals), 2u); + + EXPECT_EQ(ah.vals[0], 0x7F); + EXPECT_EQ(ah.vals[1], 0x81); + EXPECT_EQ(ab.vals[0], 0x7F); + EXPECT_EQ(ab.vals[1], 0x80); + EXPECT_EQ(ad.vals[0], 0x7F); + EXPECT_EQ(ad.vals[1], 0x81); } TEST(FP8E8M0Test, IntegerConstructorsAllTypes) { - fp8_e8m0<1> s(static_cast(1)); - fp8_e8m0<1> i(static_cast(2)); - fp8_e8m0<1> l(static_cast(3)); - fp8_e8m0<1> ll(static_cast(4)); - fp8_e8m0<1> us(static_cast(1)); - fp8_e8m0<1> ui(static_cast(2)); - fp8_e8m0<1> ul(static_cast(3)); - fp8_e8m0<1> ull(static_cast(4)); - - EXPECT_EQ(sizeof(s.vals), 1u); - EXPECT_EQ(sizeof(i.vals), 1u); - EXPECT_EQ(sizeof(l.vals), 1u); - EXPECT_EQ(sizeof(ll.vals), 1u); - EXPECT_EQ(sizeof(us.vals), 1u); - EXPECT_EQ(sizeof(ui.vals), 1u); - EXPECT_EQ(sizeof(ul.vals), 1u); - EXPECT_EQ(sizeof(ull.vals), 1u); - - EXPECT_EQ(s.vals[0], 0x7F); // 1.0 - EXPECT_EQ(i.vals[0], 0x80); // 2.0 - EXPECT_EQ(l.vals[0], 0x81); // 3.0 -> upward to 4.0 - EXPECT_EQ(ll.vals[0], 0x81); // 4.0 - EXPECT_EQ(us.vals[0], 0x7F); - EXPECT_EQ(ui.vals[0], 0x80); - EXPECT_EQ(ul.vals[0], 0x81); - EXPECT_EQ(ull.vals[0], 0x81); + fp8_e8m0 s(static_cast(1)); + fp8_e8m0 i(static_cast(2)); + fp8_e8m0 l(static_cast(3)); + fp8_e8m0 ll(static_cast(4)); + fp8_e8m0 us(static_cast(1)); + fp8_e8m0 ui(static_cast(2)); + fp8_e8m0 ul(static_cast(3)); + fp8_e8m0 ull(static_cast(4)); + + EXPECT_EQ(sizeof(s.vals), 1u); + EXPECT_EQ(sizeof(i.vals), 1u); + EXPECT_EQ(sizeof(l.vals), 1u); + EXPECT_EQ(sizeof(ll.vals), 1u); + EXPECT_EQ(sizeof(us.vals), 1u); + EXPECT_EQ(sizeof(ui.vals), 1u); + EXPECT_EQ(sizeof(ul.vals), 1u); + EXPECT_EQ(sizeof(ull.vals), 1u); + + EXPECT_EQ(s.vals[0], 0x7F); // 1.0 + EXPECT_EQ(i.vals[0], 0x80); // 2.0 + EXPECT_EQ(l.vals[0], 0x81); // 3.0 -> upward to 4.0 + EXPECT_EQ(ll.vals[0], 0x81); // 4.0 + EXPECT_EQ(us.vals[0], 0x7F); + EXPECT_EQ(ui.vals[0], 0x80); + EXPECT_EQ(ul.vals[0], 0x81); + EXPECT_EQ(ull.vals[0], 0x81); } TEST(FP8E8M0Test, AssignmentOperatorsAllTypes) { - fp8_e8m0<1> a(1.0f); - EXPECT_EQ(sizeof(a.vals), 1u); + fp8_e8m0 a(1.0f); + EXPECT_EQ(sizeof(a.vals), 1u); - a = sycl::half(1.0f); - EXPECT_EQ(a.vals[0], 0x7F); + a = sycl::half(1.0f); + EXPECT_EQ(a.vals[0], 0x7F); - a = sycl::ext::oneapi::bfloat16(2.0f); - EXPECT_EQ(a.vals[0], 0x80); + a = sycl::ext::oneapi::bfloat16(2.0f); + EXPECT_EQ(a.vals[0], 0x80); - a = 3.0f; - EXPECT_EQ(a.vals[0], 0x81); + a = 3.0f; + EXPECT_EQ(a.vals[0], 0x81); - a = 4.0; - EXPECT_EQ(a.vals[0], 0x81); + a = 4.0; + EXPECT_EQ(a.vals[0], 0x81); - a = static_cast(1); - EXPECT_EQ(a.vals[0], 0x7F); + a = static_cast(1); + EXPECT_EQ(a.vals[0], 0x7F); - a = static_cast(2); - EXPECT_EQ(a.vals[0], 0x80); + a = static_cast(2); + EXPECT_EQ(a.vals[0], 0x80); - a = static_cast(3); - EXPECT_EQ(a.vals[0], 0x81); + a = static_cast(3); + EXPECT_EQ(a.vals[0], 0x81); - a = static_cast(4); - EXPECT_EQ(a.vals[0], 0x81); + a = static_cast(4); + EXPECT_EQ(a.vals[0], 0x81); - a = static_cast(1); - EXPECT_EQ(a.vals[0], 0x7F); + a = static_cast(1); + EXPECT_EQ(a.vals[0], 0x7F); - a = static_cast(2); - EXPECT_EQ(a.vals[0], 0x80); + a = static_cast(2); + EXPECT_EQ(a.vals[0], 0x80); - a = static_cast(3); - EXPECT_EQ(a.vals[0], 0x81); + a = static_cast(3); + EXPECT_EQ(a.vals[0], 0x81); - a = static_cast(4); - EXPECT_EQ(a.vals[0], 0x81); + a = static_cast(4); + EXPECT_EQ(a.vals[0], 0x81); } TEST(FP8E8M0Test, FloatingPointConversionOperators) { - fp8_e8m0<1> one(1.0f); - fp8_e8m0<1> max(std::ldexp(1.0f, 127)); - fp8_e8m0<1> min(std::ldexp(1.0f, -127)); - - EXPECT_EQ(sizeof(one.vals), 1u); - EXPECT_EQ(one.vals[0], 0x7F); - EXPECT_EQ(max.vals[0], 0xFE); - EXPECT_EQ(min.vals[0], 0x00); - - float fo = static_cast(one); - double doo = static_cast(one); - sycl::half ho = static_cast(one); - sycl::ext::oneapi::bfloat16 bo = static_cast(one); - - EXPECT_EQ(fo, 1.0f); - EXPECT_EQ(doo, 1.0); - EXPECT_EQ(static_cast(ho), 1.0f); - EXPECT_EQ(static_cast(bo), 1.0f); - - sycl::half hmax = static_cast(max); - EXPECT_TRUE(std::isinf(static_cast(hmax))); - EXPECT_FALSE(std::signbit(static_cast(hmax))); - - EXPECT_EQ(static_cast(min), std::ldexp(1.0f, -127)); -} - -TEST(FP8E8M0Test, UnsignedConversionOperatorsTowardZero) { - fp8_e8m0<1> a(3.0f); // upward to 4.0 - - EXPECT_EQ(sizeof(a.vals), 1u); - EXPECT_EQ(a.vals[0], 0x81); - - EXPECT_EQ(static_cast(a), 4u); - EXPECT_EQ(static_cast(a), 4u); - EXPECT_EQ(static_cast(a), 4u); - EXPECT_EQ(static_cast(a), 4u); - EXPECT_EQ(static_cast(a), 4u); + fp8_e8m0 one(1.0f); + fp8_e8m0 max(std::ldexp(1.0f, 127)); + fp8_e8m0 min(std::ldexp(1.0f, -127)); + + EXPECT_EQ(sizeof(one.vals), 1u); + EXPECT_EQ(one.vals[0], 0x7F); + EXPECT_EQ(max.vals[0], 0xFE); + EXPECT_EQ(min.vals[0], 0x00); + + float fo = static_cast(one); + double doo = static_cast(one); + sycl::half ho = static_cast(one); + sycl::ext::oneapi::bfloat16 bo = + static_cast(one); + + EXPECT_EQ(fo, 1.0f); + EXPECT_EQ(doo, 1.0); + EXPECT_EQ(static_cast(ho), 1.0f); + EXPECT_EQ(static_cast(bo), 1.0f); + + sycl::half hmax = static_cast(max); + EXPECT_TRUE(std::isinf(static_cast(hmax))); + EXPECT_FALSE(std::signbit(static_cast(hmax))); + + EXPECT_EQ(static_cast(min), std::ldexp(1.0f, -127)); } TEST(FP8E8M0Test, BoolOperatorAlwaysTrue) { - fp8_e8m0<1> min(std::ldexp(1.0f, -127)); - fp8_e8m0<1> nanv(std::numeric_limits::quiet_NaN()); + fp8_e8m0 min(std::ldexp(1.0f, -127)); + fp8_e8m0 nanv(std::numeric_limits::quiet_NaN()); - EXPECT_TRUE(static_cast(min)); - EXPECT_TRUE(static_cast(nanv)); + EXPECT_TRUE(static_cast(min)); + EXPECT_TRUE(static_cast(nanv)); } TEST(FP8E8M0Test, MarrayConversionOperators) { - fp8_e8m0<3> a(1.0f, 3.0f, std::ldexp(1.0f, 127)); + fp8_e8m0_x2 a(1.0f, 3.0f); - sycl::marray ho = static_cast>(a); - sycl::marray bo = - static_cast>(a); - sycl::marray fo = static_cast>(a); + sycl::marray ho = static_cast>(a); + sycl::marray bo = + static_cast>(a); + sycl::marray fo = static_cast>(a); - EXPECT_EQ(static_cast(ho[0]), 1.0f); - EXPECT_EQ(static_cast(ho[1]), 4.0f); - EXPECT_TRUE(std::isinf(static_cast(ho[2]))); + EXPECT_EQ(static_cast(ho[0]), 1.0f); + EXPECT_EQ(static_cast(ho[1]), 4.0f); - EXPECT_EQ(static_cast(bo[0]), 1.0f); - EXPECT_EQ(static_cast(bo[1]), 4.0f); - EXPECT_EQ(static_cast(bo[2]), std::ldexp(1.0f, 127)); + EXPECT_EQ(static_cast(bo[0]), 1.0f); + EXPECT_EQ(static_cast(bo[1]), 4.0f); - EXPECT_EQ(fo[0], 1.0f); - EXPECT_EQ(fo[1], 4.0f); - EXPECT_EQ(fo[2], std::ldexp(1.0f, 127)); + EXPECT_EQ(fo[0], 1.0f); + EXPECT_EQ(fo[1], 4.0f); } - From 2573ebeb1ce8ab508dbbd572f77bd7995f8dc1f9 Mon Sep 17 00:00:00 2001 From: "Klochkov, Denis" Date: Thu, 19 Feb 2026 14:51:48 +0100 Subject: [PATCH 04/89] [SYCL] remove extra types --- .../oneapi/experimental/float_8bit/types.hpp | 340 ------------ sycl/unittests/Extensions/fp8/CMakeLists.txt | 1 - sycl/unittests/Extensions/fp8/fp8_e5m3.cpp | 495 ------------------ 3 files changed, 836 deletions(-) delete mode 100644 sycl/unittests/Extensions/fp8/fp8_e5m3.cpp diff --git a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp index b5c22ee582cdb..3ea979c46f91c 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp @@ -188,9 +188,6 @@ static inline ToT ConvertFromFP8_CPU(uint8_t b, if constexpr (Ebits == 8 && Mbits == 0) { sign_bit = 0u; exp = b; - } else if constexpr (Ebits == 5 && Mbits == 3) { - // E5M3 is unsigned: MSB belongs to exponent, no sign bit. - sign_bit = 0u; } auto make_nan = [&]() -> ToT { @@ -210,11 +207,6 @@ static inline ToT ConvertFromFP8_CPU(uint8_t b, if (frac != 0) return make_nan(); // frac==00 -> normal finite - } else if constexpr (Ebits == 5 && Mbits == 3) { - // E5M3: only frac==111 -> NaN, otherwise normal. - if (frac == MaxFrac) - return make_nan(); - // treat as normal finite } else // E8M0: exp all ones -> NaN return make_nan(); } @@ -375,12 +367,6 @@ ConvertToFP8_CPU(T h, rounding R = rounding::to_even) noexcept { sign | ((ExpAllOnes << Mbits) | MaxFracMask)); // S.1111.111 -> NaN uint8_t sign_bit = sign ? 1u : 0u; float ax = std::fabs(x); - if constexpr (Ebits == 5 && Mbits == 3) { - // E5M3 is unsigned: ignore sign and treat input as magnitude. - sign = 0x00; - sign_bit = 0u; - } - const float max_finite = (2.0f - std::ldexp(1.0f, 1 - Mbits)) * std::ldexp(1.0f, emax); const float min_sub = std::ldexp(1.0f, emin - Mbits); @@ -1720,338 +1706,12 @@ template class fp8_e8m0_x { uint8_t vals[N]; }; -template class fp8_e5m3_x { -private: - template uint8_t ConvertToFP8(T h, rounding r) { - if constexpr (std::is_integral_v) { - sycl::half hi = static_cast(h); - return ConvertToFP8_CPU<5, 3, sycl::half>(hi, r); - } - return ConvertToFP8_CPU<5, 3, T>(h, r); - } - - template - T ConvertFromFP8(uint8_t v, rounding r = rounding::to_even) const { - return ConvertFromFP8_CPU<5, 3, T>(v, r); - } - - bfloat16 ConvertBF16FromFP8(uint8_t v) const { - return ConvertFromFP8_CPU<5, 3, bfloat16>(v); - } - - void CheckConstraints(rounding r) const { - static_assert(N == 1 || N == 2, - "fp8_e5m3_x: Template argument N must be 1 or 2"); - if (r != rounding::to_even) - throw std::invalid_argument( - "fp8_e5m3_x: only rounding::to_even is supported"); - } - -public: - fp8_e5m3_x() = default; - fp8_e5m3_x(const fp8_e5m3_x &) = default; - ~fp8_e5m3_x() = default; - fp8_e5m3_x &operator=(const fp8_e5m3_x &) = default; - - // Construct from pack of half, bfloat16, float, double. - // Available only when the size of the pack is equal to N. - - template , half> || - std::is_same_v, bfloat16> || - std::is_same_v, float> || - std::is_same_v, double>) && - ...))>> - explicit fp8_e5m3_x(Types... v) { - static_assert(N == 1 || N == 2, - "fp8_e5m3_x: Template argument N must be 1 or 2"); - /*if constexpr (((std::is_same_v, bfloat16>) && ...)) { - const bfloat16 in[N] = {static_cast(v)...}; - for (size_t i = 0; i < N; ++i) - vals[i] = ConvertBF16ToFP8(in[i], rounding::to_even); - return; - }*/ - const sycl::half in[N] = {v...}; - for (size_t i = 0; i < N; ++i) - vals[i] = ConvertToFP8(in[i], rounding::to_even); - } - - // Construct from an array of half, bfloat16, float, double. - - explicit fp8_e5m3_x(half const (&in)[N], rounding r = rounding::to_even) { - CheckConstraints(r); - for (size_t i = 0; i < N; ++i) - vals[i] = ConvertToFP8(in[i], r); - } - explicit fp8_e5m3_x(bfloat16 const (&in)[N], rounding r = rounding::to_even) { - CheckConstraints(r); - for (size_t i = 0; i < N; ++i) - vals[i] = ConvertToFP8(in[i], r); - } - explicit fp8_e5m3_x(float const (&in)[N], rounding r = rounding::to_even) { - CheckConstraints(r); - for (size_t i = 0; i < N; ++i) - vals[i] = ConvertToFP8(in[i], r); - } - explicit fp8_e5m3_x(double const (&in)[N]) { - for (size_t i = 0; i < N; ++i) - vals[i] = ConvertToFP8(in[i], rounding::to_even); - } - - // Construct from an marray of half, bfloat16, float, double. - - explicit fp8_e5m3_x(const marray &in, - rounding r = rounding::to_even) { - CheckConstraints(r); - for (size_t i = 0; i < N; ++i) - vals[i] = ConvertToFP8(in[i], r); - } - explicit fp8_e5m3_x(const marray &in, - rounding r = rounding::to_even) { - CheckConstraints(r); - for (size_t i = 0; i < N; ++i) - vals[i] = ConvertToFP8(in[i], r); - } - explicit fp8_e5m3_x(const marray &in, - rounding r = rounding::to_even) { - CheckConstraints(r); - for (size_t i = 0; i < N; ++i) - vals[i] = ConvertToFP8(in[i], r); - } - explicit fp8_e5m3_x(const marray &in) { - for (size_t i = 0; i < N; ++i) - vals[i] = ConvertToFP8(in[i], rounding::to_even); - } - - // Construct from integer types. - // Available only when N==1. - - explicit fp8_e5m3_x(short val) { - assert(N == 1 && "fp8_e5m3_x: N must be 1 for short constructor"); - vals[0] = ConvertToFP8(val, rounding::to_even); - } - explicit fp8_e5m3_x(int val) { - assert(N == 1 && "fp8_e5m3_x: N must be 1 for int constructor"); - vals[0] = ConvertToFP8(val, rounding::to_even); - } - explicit fp8_e5m3_x(long val) { - assert(N == 1 && "fp8_e5m3_x: N must be 1 for long constructor"); - vals[0] = ConvertToFP8(val, rounding::to_even); - } - explicit fp8_e5m3_x(long long val) { - assert(N == 1 && "fp8_e5m3_x: N must be 1 for long long constructor"); - vals[0] = ConvertToFP8(val, rounding::to_even); - } - explicit fp8_e5m3_x(unsigned short val) { - assert(N == 1 && "fp8_e5m3_x: N must be 1 for unsigned short constructor"); - vals[0] = ConvertToFP8(val, rounding::to_even); - } - explicit fp8_e5m3_x(unsigned int val) { - assert(N == 1 && "fp8_e5m3_x: N must be 1 for unsigned int constructor"); - vals[0] = ConvertToFP8(val, rounding::to_even); - } - explicit fp8_e5m3_x(unsigned long val) { - assert(N == 1 && "fp8_e5m3_x: N must be 1 for unsigned long constructor"); - vals[0] = ConvertToFP8(val, rounding::to_even); - } - explicit fp8_e5m3_x(unsigned long long val) { - assert(N == 1 && - "fp8_e5m3_x: N must be 1 for unsigned long long constructor"); - vals[0] = ConvertToFP8(val, rounding::to_even); - } - - // Assign (operator) from half, bfloat16, float, double, and integer types. - // Available only when N==1. - - fp8_e5m3_x &operator=(half val) { - assert(N == 1 && "fp8_e5m3_x: N must be 1 for half assignment operator"); - vals[0] = ConvertToFP8(val, rounding::to_even); - return *this; - } - fp8_e5m3_x &operator=(bfloat16 val) { - assert(N == 1 && - "fp8_e5m3_x: N must be 1 for bfloat16 assignment operator"); - vals[0] = ConvertToFP8(val, rounding::to_even); - return *this; - } - fp8_e5m3_x &operator=(float val) { - assert(N == 1 && "fp8_e5m3_x: N must be 1 for float assignment operator"); - vals[0] = ConvertToFP8(val, rounding::to_even); - return *this; - } - fp8_e5m3_x &operator=(double val) { - assert(N == 1 && "fp8_e5m3_x: N must be 1 for double assignment operator"); - vals[0] = ConvertToFP8(val, rounding::to_even); - return *this; - } - fp8_e5m3_x &operator=(short val) { - assert(N == 1 && "fp8_e5m3_x: N must be 1 for short assignment operator"); - vals[0] = ConvertToFP8(val, rounding::to_even); - return *this; - } - fp8_e5m3_x &operator=(int val) { - assert(N == 1 && "fp8_e5m3_x: N must be 1 for int assignment operator"); - vals[0] = ConvertToFP8(val, rounding::to_even); - return *this; - } - fp8_e5m3_x &operator=(long val) { - assert(N == 1 && "fp8_e5m3_x: N must be 1 for long assignment operator"); - vals[0] = ConvertToFP8(val, rounding::to_even); - return *this; - } - fp8_e5m3_x &operator=(long long val) { - assert(N == 1 && - "fp8_e5m3_x: N must be 1 for long long assignment operator"); - vals[0] = ConvertToFP8(val, rounding::to_even); - return *this; - } - fp8_e5m3_x &operator=(unsigned short val) { - assert(N == 1 && - "fp8_e5m3_x: N must be 1 for unsigned short assignment operator"); - vals[0] = ConvertToFP8(val, rounding::to_even); - return *this; - } - fp8_e5m3_x &operator=(unsigned int val) { - assert(N == 1 && - "fp8_e5m3_x: N must be 1 for unsigned int assignment operator"); - vals[0] = ConvertToFP8(val, rounding::to_even); - return *this; - } - fp8_e5m3_x &operator=(unsigned long val) { - assert(N == 1 && - "fp8_e5m3_x: N must be 1 for unsigned long assignment operator"); - vals[0] = ConvertToFP8(val, rounding::to_even); - return *this; - } - fp8_e5m3_x &operator=(unsigned long long val) { - assert( - N == 1 && - "fp8_e5m3_x: N must be 1 for unsigned long long assignment operator"); - vals[0] = ConvertToFP8(val, rounding::to_even); - return *this; - } - - // Convert to half, bfloat16, float, double. - // Available only when N==1. - - explicit operator half() const { - assert(N == 1 && "fp8_e5m3_x: N must be 1 for half conversion operator"); - return ConvertFromFP8(vals[0]); - } - explicit operator bfloat16() const { - assert(N == 1 && - "fp8_e5m3_x: N must be 1 for bfloat16 conversion operator"); - return ConvertBF16FromFP8(vals[0]); - } - explicit operator float() const { - assert(N == 1 && "fp8_e5m3_x: N must be 1 for float conversion operator"); - return ConvertFromFP8(vals[0]); - } - explicit operator double() const { - assert(N == 1 && "fp8_e5m3_x: N must be 1 for double conversion operator"); - return ConvertFromFP8(vals[0]); - } - - // Convert to integer types. - // Available only when N==1. - - explicit operator char() const { - assert(N == 1 && "fp8_e5m3_x: N must be 1 for char conversion operator"); - return ConvertFromFP8(vals[0], rounding::toward_zero); - } - explicit operator signed char() const { - assert(N == 1 && - "fp8_e5m3_x: N must be 1 for signed char conversion operator"); - return ConvertFromFP8(vals[0], rounding::toward_zero); - } - explicit operator short() const { - assert(N == 1 && "fp8_e5m3_x: N must be 1 for short conversion operator"); - return ConvertFromFP8(vals[0], rounding::toward_zero); - } - explicit operator int() const { - assert(N == 1 && "fp8_e5m3_x: N must be 1 for int conversion operator"); - return ConvertFromFP8(vals[0], rounding::toward_zero); - } - explicit operator long() const { - assert(N == 1 && "fp8_e5m3_x: N must be 1 for long conversion operator"); - return ConvertFromFP8(vals[0], rounding::toward_zero); - } - explicit operator long long() const { - assert(N == 1 && - "fp8_e5m3_x: N must be 1 for long long conversion operator"); - return ConvertFromFP8(vals[0], rounding::toward_zero); - } - explicit operator unsigned char() const { - assert(N == 1 && - "fp8_e5m3_x: N must be 1 for unsigned char conversion operator"); - return ConvertFromFP8(vals[0], rounding::toward_zero); - } - explicit operator unsigned short() const { - assert(N == 1 && - "fp8_e5m3_x: N must be 1 for unsigned short conversion operator"); - return ConvertFromFP8(vals[0], rounding::toward_zero); - } - explicit operator unsigned int() const { - assert(N == 1 && - "fp8_e5m3_x: N must be 1 for unsigned int conversion operator"); - return ConvertFromFP8(vals[0], rounding::toward_zero); - } - explicit operator unsigned long() const { - assert(N == 1 && - "fp8_e5m3_x: N must be 1 for unsigned long conversion operator"); - return ConvertFromFP8(vals[0], rounding::toward_zero); - } - explicit operator unsigned long long() const { - assert( - N == 1 && - "fp8_e5m3_x: N must be 1 for unsigned long long conversion operator"); - return ConvertFromFP8(vals[0], rounding::toward_zero); - } - - // Convert to bool - // Available only when N==1. - - explicit operator bool() const { - static_assert(N == 1, "fp8_e5m3_x: operator() requires size N=1"); - return vals[0] != 0x00 && vals[0] != 0x80; - } - - // Convert to marray of half, bfloat16, float - - explicit operator marray() const { - marray out; - for (size_t i = 0; i < N; ++i) - out[i] = ConvertFromFP8(vals[i]); - return out; - } - explicit operator marray() const { - marray out; - for (size_t i = 0; i < N; ++i) - out[i] = ConvertBF16FromFP8(vals[i]); - return out; - } - explicit operator marray() const { - marray out; - for (size_t i = 0; i < N; ++i) - out[i] = ConvertFromFP8(vals[i]); - return out; - } - - // Intentionally public to allow access to the raw values. - - uint8_t vals[N]; -}; - using fp8_e4m3 = fp8_e4m3_x<1>; using fp8_e4m3_x2 = fp8_e4m3_x<2>; using fp8_e5m2 = fp8_e5m2_x<1>; using fp8_e5m2_x2 = fp8_e5m2_x<2>; using fp8_e8m0 = fp8_e8m0_x<1>; using fp8_e8m0_x2 = fp8_e8m0_x<2>; -using fp8_e5m3 = fp8_e5m3_x<1>; -using fp8_e5m3_x2 = fp8_e5m3_x<2>; #endif // __SYCL_TARGET_INTEL_GPU_CRI__ diff --git a/sycl/unittests/Extensions/fp8/CMakeLists.txt b/sycl/unittests/Extensions/fp8/CMakeLists.txt index 45778104df248..2d0c53daf4268 100644 --- a/sycl/unittests/Extensions/fp8/CMakeLists.txt +++ b/sycl/unittests/Extensions/fp8/CMakeLists.txt @@ -2,7 +2,6 @@ add_sycl_unittest(FP8TypesTests OBJECT fp8_e4m3.cpp fp8_e5m2.cpp fp8_e8m0.cpp - fp8_e5m3.cpp ) target_compile_options(FP8TypesTests_Preview_Tests PUBLIC -D__SYCL_TARGET_INTEL_GPU_CRI__) diff --git a/sycl/unittests/Extensions/fp8/fp8_e5m3.cpp b/sycl/unittests/Extensions/fp8/fp8_e5m3.cpp deleted file mode 100644 index f2a69c446d6a5..0000000000000 --- a/sycl/unittests/Extensions/fp8/fp8_e5m3.cpp +++ /dev/null @@ -1,495 +0,0 @@ -#include -#include - -#include -#include -#include - -using namespace sycl::ext::oneapi::experimental; -using sycl::ext::oneapi::bfloat16; - -TEST(FP8E5M3ArrayCtor, HalfDefaultRounding) { - const sycl::half vals[2] = {sycl::half(1.0f), - sycl::half(std::ldexp(1.0f, -14))}; - fp8_e5m3_x2 v(vals); - - EXPECT_EQ(sizeof(v.vals), 2u); - EXPECT_EQ(v.vals[0], 0x78); - EXPECT_EQ(v.vals[1], 0x08); -} - -TEST(FP8E5M3ArrayCtor, HalfExplicitRounding) { - const sycl::half vals[2] = {sycl::half(1.5f), sycl::half(0.75f)}; - fp8_e5m3_x2 v(vals, rounding::to_even); - - EXPECT_EQ(sizeof(v.vals), 2u); - EXPECT_EQ(v.vals[0], 0x7C); - EXPECT_EQ(v.vals[1], 0x74); -} - -TEST(FP8E5M3ArrayCtor, Bfloat16DefaultRounding) { - const bfloat16 vals[2] = {bfloat16(2.0f), bfloat16(std::ldexp(1.0f, -13))}; - fp8_e5m3_x2 v(vals); - - EXPECT_EQ(sizeof(v.vals), 2u); - EXPECT_EQ(v.vals[0], 0x80); - EXPECT_EQ(v.vals[1], 0x10); -} - -TEST(FP8E5M3ArrayCtor, Bfloat16ExplicitRounding) { - const bfloat16 vals[2] = {bfloat16(3.0f), bfloat16(0.5f)}; - fp8_e5m3_x2 v(vals, rounding::to_even); - - EXPECT_EQ(sizeof(v.vals), 2u); - EXPECT_EQ(v.vals[0], 0x84); - EXPECT_EQ(v.vals[1], 0x70); -} - -TEST(FP8E5M3ArrayCtor, FloatDefaultRounding) { - const float vals[2] = {114688.0f, std::ldexp(1.0f, -17)}; - fp8_e5m3_x2 v(vals); - - EXPECT_EQ(sizeof(v.vals), 2u); - EXPECT_EQ(v.vals[0], 0xFE); - EXPECT_EQ(v.vals[1], 0x01); -} - -TEST(FP8E5M3ArrayCtor, FloatExplicitRounding) { - const float vals[2] = {1.25f, 0.875f * std::ldexp(1.0f, -14)}; - fp8_e5m3_x2 v(vals, rounding::to_even); - - EXPECT_EQ(sizeof(v.vals), 2u); - EXPECT_EQ(v.vals[0], 0x7A); - EXPECT_EQ(v.vals[1], 0x07); -} - -TEST(FP8E5M3ArrayCtor, DoubleDefaultRounding) { - const double vals[2] = {8.0, 0.125}; - fp8_e5m3_x2 v(vals); - - EXPECT_EQ(sizeof(v.vals), 2u); - EXPECT_EQ(v.vals[0], 0x90); - EXPECT_EQ(v.vals[1], 0x60); -} - -TEST(FP8E5M3MarrayCtor, HalfDefaultRounding) { - const sycl::marray vals{sycl::half(1.75f), sycl::half(0.5f)}; - fp8_e5m3_x2 v(vals); - - EXPECT_EQ(sizeof(v.vals), 2u); - EXPECT_EQ(v.vals[0], 0x7E); - EXPECT_EQ(v.vals[1], 0x70); -} - -TEST(FP8E5M3MarrayCtor, HalfExplicitRounding) { - const sycl::marray vals{sycl::half(2.0f), sycl::half(0.125f)}; - fp8_e5m3_x2 v(vals, rounding::to_even); - - EXPECT_EQ(sizeof(v.vals), 2u); - EXPECT_EQ(v.vals[0], 0x80); - EXPECT_EQ(v.vals[1], 0x60); -} - -TEST(FP8E5M3MarrayCtor, Bfloat16DefaultRounding) { - const sycl::marray vals{bfloat16(3.0f), bfloat16(0.25f)}; - fp8_e5m3_x2 v(vals); - - EXPECT_EQ(sizeof(v.vals), 2u); - EXPECT_EQ(v.vals[0], 0x84); - EXPECT_EQ(v.vals[1], 0x68); -} - -TEST(FP8E5M3MarrayCtor, Bfloat16ExplicitRounding) { - const sycl::marray vals{bfloat16(6.0f), bfloat16(0.75f)}; - fp8_e5m3_x2 v(vals, rounding::to_even); - - EXPECT_EQ(sizeof(v.vals), 2u); - EXPECT_EQ(v.vals[0], 0x8C); - EXPECT_EQ(v.vals[1], 0x74); -} - -TEST(FP8E5M3MarrayCtor, FloatDefaultRounding) { - const sycl::marray vals{12.0f, 0.03125f}; - fp8_e5m3_x2 v(vals); - - EXPECT_EQ(sizeof(v.vals), 2u); - EXPECT_EQ(v.vals[0], 0x94); - EXPECT_EQ(v.vals[1], 0x50); -} - -TEST(FP8E5M3MarrayCtor, FloatExplicitRounding) { - const sycl::marray vals{1.25f, std::ldexp(0.375f, -14)}; - fp8_e5m3_x2 v(vals, rounding::to_even); - - EXPECT_EQ(sizeof(v.vals), 2u); - EXPECT_EQ(v.vals[0], 0x7A); - EXPECT_EQ(v.vals[1], 0x03); -} - -TEST(FP8E5M3MarrayCtor, DoubleDefaultRounding) { - const sycl::marray vals{16.0, 0.0625}; - fp8_e5m3_x2 v(vals); - - EXPECT_EQ(sizeof(v.vals), 2u); - EXPECT_EQ(v.vals[0], 0x98); - EXPECT_EQ(v.vals[1], 0x58); -} - -TEST(FP8E5M3ScalarIntCtor, ShortValue) { - const short val = 5; - fp8_e5m3 v(val); - - EXPECT_EQ(sizeof(v.vals), 1u); - EXPECT_EQ(v.vals[0], 0x8A); -} - -TEST(FP8E5M3ScalarIntCtor, IntValue) { - const int val = 7; - fp8_e5m3 v(val); - - EXPECT_EQ(sizeof(v.vals), 1u); - EXPECT_EQ(v.vals[0], 0x8E); -} - -TEST(FP8E5M3ScalarIntCtor, LongValue) { - const long val = 9; - fp8_e5m3 v(val); - - EXPECT_EQ(sizeof(v.vals), 1u); - EXPECT_EQ(v.vals[0], 0x91); -} - -TEST(FP8E5M3ScalarIntCtor, LongLongValue) { - const long long val = 10; - fp8_e5m3 v(val); - - EXPECT_EQ(sizeof(v.vals), 1u); - EXPECT_EQ(v.vals[0], 0x92); -} - -TEST(FP8E5M3ScalarIntCtor, UnsignedShortValue) { - const unsigned short val = 14; - fp8_e5m3 v(val); - - EXPECT_EQ(sizeof(v.vals), 1u); - EXPECT_EQ(v.vals[0], 0x96); -} - -TEST(FP8E5M3ScalarIntCtor, UnsignedIntValue) { - const unsigned int val = 15; - fp8_e5m3 v(val); - - EXPECT_EQ(sizeof(v.vals), 1u); - EXPECT_EQ(v.vals[0], 0x97); -} - -TEST(FP8E5M3ScalarIntCtor, UnsignedLongValue) { - const unsigned long val = 18; - fp8_e5m3 v(val); - - EXPECT_EQ(sizeof(v.vals), 1u); - EXPECT_EQ(v.vals[0], 0x99); -} - -TEST(FP8E5M3ScalarIntCtor, UnsignedLongLongValue) { - const unsigned long long val = 20; - fp8_e5m3 v(val); - - EXPECT_EQ(sizeof(v.vals), 1u); - EXPECT_EQ(v.vals[0], 0x9A); -} - -TEST(FP8E5M3ScalarIntCtor, UnsignedLimitsSaturate) { - const unsigned short usmax = std::numeric_limits::max(); - const unsigned int uimax = std::numeric_limits::max(); - const unsigned long ulmax = std::numeric_limits::max(); - const unsigned long long ullmax = - std::numeric_limits::max(); - - fp8_e5m3 vus(usmax); - fp8_e5m3 vui(uimax); - fp8_e5m3 vul(ulmax); - fp8_e5m3 vull(ullmax); - - EXPECT_EQ(sizeof(vus.vals), 1u); - EXPECT_EQ(vus.vals[0], 0xFE); - EXPECT_EQ(vui.vals[0], 0xFE); - EXPECT_EQ(vul.vals[0], 0xFE); - EXPECT_EQ(vull.vals[0], 0xFE); -} - -TEST(FP8E5M3AssignOp, HalfValue) { - fp8_e5m3 v(sycl::half(1.0f)); - v = sycl::half(1.125f); - - EXPECT_EQ(sizeof(v.vals), 1u); - EXPECT_EQ(v.vals[0], 0x79); -} - -TEST(FP8E5M3AssignOp, Bfloat16Value) { - fp8_e5m3 v(bfloat16(1.0f)); - v = bfloat16(1.875f); - - EXPECT_EQ(sizeof(v.vals), 1u); - EXPECT_EQ(v.vals[0], 0x7F); -} - -TEST(FP8E5M3AssignOp, FloatValue) { - fp8_e5m3 v(1.0f); - v = 11.0f; - - EXPECT_EQ(sizeof(v.vals), 1u); - EXPECT_EQ(v.vals[0], 0x93); -} - -TEST(FP8E5M3AssignOp, DoubleValue) { - fp8_e5m3 v(1.0); - v = 5.5; - - EXPECT_EQ(sizeof(v.vals), 1u); - EXPECT_EQ(v.vals[0], 0x8B); -} - -TEST(FP8E5M3AssignOp, ShortValue) { - fp8_e5m3 v(1); - v = static_cast(6); - - EXPECT_EQ(sizeof(v.vals), 1u); - EXPECT_EQ(v.vals[0], 0x8C); -} - -TEST(FP8E5M3AssignOp, IntValue) { - fp8_e5m3 v(1); - v = 12; - - EXPECT_EQ(sizeof(v.vals), 1u); - EXPECT_EQ(v.vals[0], 0x94); -} - -TEST(FP8E5M3AssignOp, LongValue) { - fp8_e5m3 v(1); - v = static_cast(13); - - EXPECT_EQ(sizeof(v.vals), 1u); - EXPECT_EQ(v.vals[0], 0x95); -} - -TEST(FP8E5M3AssignOp, LongLongValue) { - fp8_e5m3 v(1); - v = static_cast(17); - - EXPECT_EQ(sizeof(v.vals), 1u); - EXPECT_EQ(v.vals[0], 0x98); -} - -TEST(FP8E5M3AssignOp, UnsignedShortValue) { - fp8_e5m3 v(1); - v = static_cast(21); - - EXPECT_EQ(sizeof(v.vals), 1u); - EXPECT_EQ(v.vals[0], 0x9A); -} - -TEST(FP8E5M3AssignOp, UnsignedIntValue) { - fp8_e5m3 v(1); - v = static_cast(22); - - EXPECT_EQ(sizeof(v.vals), 1u); - EXPECT_EQ(v.vals[0], 0x9B); -} - -TEST(FP8E5M3AssignOp, UnsignedLongValue) { - fp8_e5m3 v(1); - v = static_cast(24); - - EXPECT_EQ(sizeof(v.vals), 1u); - EXPECT_EQ(v.vals[0], 0x9C); -} - -TEST(FP8E5M3AssignOp, UnsignedLongLongValue) { - fp8_e5m3 v(1); - v = static_cast(26); - - EXPECT_EQ(sizeof(v.vals), 1u); - EXPECT_EQ(v.vals[0], 0x9D); -} - -TEST(FP8E5M3ConvertOp, HalfValue) { - fp8_e5m3 v(2.5f); - auto out = static_cast(v); - - EXPECT_EQ(sizeof(v.vals), 1u); - EXPECT_EQ(v.vals[0], 0x82); - EXPECT_EQ(static_cast(out), 2.5f); -} - -TEST(FP8E5M3ConvertOp, Bfloat16Value) { - fp8_e5m3 v(0.375f); - auto out = static_cast(v); - - EXPECT_EQ(sizeof(v.vals), 1u); - EXPECT_EQ(v.vals[0], 0x6C); - EXPECT_EQ(static_cast(out), 0.375f); -} - -TEST(FP8E5M3ConvertOp, FloatValue) { - fp8_e5m3 v(2.25f); - float out = static_cast(v); - - EXPECT_EQ(sizeof(v.vals), 1u); - EXPECT_EQ(v.vals[0], 0x81); - EXPECT_EQ(out, 2.25f); -} - -TEST(FP8E5M3ConvertOp, DoubleValue) { - fp8_e5m3 v(4.0f); - double out = static_cast(v); - - EXPECT_EQ(sizeof(v.vals), 1u); - EXPECT_EQ(v.vals[0], 0x88); - EXPECT_EQ(out, 4.0); -} - -TEST(FP8E5M3ConvertIntOp, CharValue) { - fp8_e5m3 v(3.5f); - char out = static_cast(v); - - EXPECT_EQ(sizeof(v.vals), 1u); - EXPECT_EQ(v.vals[0], 0x86); - EXPECT_EQ(out, static_cast(3)); -} - -TEST(FP8E5M3ConvertIntOp, SignedCharValue) { - fp8_e5m3 v(6.5f); - signed char out = static_cast(v); - - EXPECT_EQ(sizeof(v.vals), 1u); - EXPECT_EQ(v.vals[0], 0x8D); - EXPECT_EQ(out, static_cast(6)); -} - -TEST(FP8E5M3ConvertIntOp, ShortValue) { - fp8_e5m3 v(7.5f); - short out = static_cast(v); - - EXPECT_EQ(sizeof(v.vals), 1u); - EXPECT_EQ(v.vals[0], 0x8F); - EXPECT_EQ(out, static_cast(7)); -} - -TEST(FP8E5M3ConvertIntOp, IntValue) { - fp8_e5m3 v(8.0f); - int out = static_cast(v); - - EXPECT_EQ(sizeof(v.vals), 1u); - EXPECT_EQ(v.vals[0], 0x90); - EXPECT_EQ(out, 8); -} - -TEST(FP8E5M3ConvertIntOp, LongValue) { - fp8_e5m3 v(9.0f); - long out = static_cast(v); - - EXPECT_EQ(sizeof(v.vals), 1u); - EXPECT_EQ(v.vals[0], 0x91); - EXPECT_EQ(out, 9L); -} - -TEST(FP8E5M3ConvertIntOp, LongLongValue) { - fp8_e5m3 v(10.0f); - long long out = static_cast(v); - - EXPECT_EQ(sizeof(v.vals), 1u); - EXPECT_EQ(v.vals[0], 0x92); - EXPECT_EQ(out, 10LL); -} - -TEST(FP8E5M3ConvertIntOp, UnsignedCharValue) { - fp8_e5m3 v(11.0f); - unsigned char out = static_cast(v); - - EXPECT_EQ(sizeof(v.vals), 1u); - EXPECT_EQ(v.vals[0], 0x93); - EXPECT_EQ(out, static_cast(11)); -} - -TEST(FP8E5M3ConvertIntOp, UnsignedShortValue) { - fp8_e5m3 v(12.0f); - unsigned short out = static_cast(v); - - EXPECT_EQ(sizeof(v.vals), 1u); - EXPECT_EQ(v.vals[0], 0x94); - EXPECT_EQ(out, static_cast(12)); -} - -TEST(FP8E5M3ConvertIntOp, UnsignedIntValue) { - fp8_e5m3 v(13.0f); - unsigned int out = static_cast(v); - - EXPECT_EQ(sizeof(v.vals), 1u); - EXPECT_EQ(v.vals[0], 0x95); - EXPECT_EQ(out, 13u); -} - -TEST(FP8E5M3ConvertIntOp, UnsignedLongValue) { - fp8_e5m3 v(14.0f); - unsigned long out = static_cast(v); - - EXPECT_EQ(sizeof(v.vals), 1u); - EXPECT_EQ(v.vals[0], 0x96); - EXPECT_EQ(out, 14UL); -} - -TEST(FP8E5M3ConvertIntOp, UnsignedLongLongValue) { - fp8_e5m3 v(15.0f); - unsigned long long out = static_cast(v); - - EXPECT_EQ(sizeof(v.vals), 1u); - EXPECT_EQ(v.vals[0], 0x97); - EXPECT_EQ(out, 15ULL); -} - -TEST(FP8E5M3ConvertOp, BoolFalse) { - fp8_e5m3 v(0.0f); - bool out = static_cast(v); - - EXPECT_EQ(sizeof(v.vals), 1u); - EXPECT_EQ(v.vals[0], 0x00); - EXPECT_FALSE(out); -} - -TEST(FP8E5M3ConvertOp, BoolTrue) { - fp8_e5m3 v(0.25f); - bool out = static_cast(v); - - EXPECT_EQ(sizeof(v.vals), 1u); - EXPECT_EQ(v.vals[0], 0x68); - EXPECT_TRUE(out); -} - -TEST(FP8E5M3ConvertOp, MarrayHalf) { - fp8_e5m3 v(1.625f); - auto out = static_cast>(v); - - EXPECT_EQ(sizeof(v.vals), 1u); - EXPECT_EQ(v.vals[0], 0x7D); - EXPECT_EQ(static_cast(out[0]), 1.625f); -} - -TEST(FP8E5M3ConvertOp, MarrayBfloat16) { - fp8_e5m3 v(2.75f); - auto out = static_cast>(v); - - EXPECT_EQ(sizeof(v.vals), 1u); - EXPECT_EQ(v.vals[0], 0x83); - EXPECT_EQ(static_cast(out[0]), 2.75f); -} - -TEST(FP8E5M3ConvertOp, MarrayFloat) { - fp8_e5m3 v(5.0f); - auto out = static_cast>(v); - - EXPECT_EQ(sizeof(v.vals), 1u); - EXPECT_EQ(v.vals[0], 0x8A); - EXPECT_EQ(out[0], 5.0f); -} From bb0cc94c6e7b53477f64bf914a0447b0c49e7921 Mon Sep 17 00:00:00 2001 From: "Klochkov, Denis" Date: Fri, 20 Feb 2026 12:30:18 +0100 Subject: [PATCH 05/89] [SYCL][FP8] implement stochastic rounding --- .../oneapi/experimental/float_8bit/types.hpp | 224 +++++++++--------- 1 file changed, 115 insertions(+), 109 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp index 3ea979c46f91c..4a99e19aa7fdf 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp @@ -16,58 +16,6 @@ namespace ext::oneapi::experimental { #ifdef __SYCL_TARGET_INTEL_GPU_CRI__ -#ifdef __SYCL_DEVICE_ONLY__ - -// New FP8 builtins -extern __DPCPP_SYCL_EXTERNAL sycl::half -__builtin_spirv_ConvertE4M3ToFP16EXT(uint8_t) noexcept; -extern __DPCPP_SYCL_EXTERNAL sycl::half -__builtin_spirv_ConvertE5M2ToFP16EXT(uint8_t) noexcept; -extern __DPCPP_SYCL_EXTERNAL sycl::ext::oneapi::bfloat16 -__builtin_spirv_ConvertE4M3ToBF16EXT(uint8_t) noexcept; -extern __DPCPP_SYCL_EXTERNAL sycl::ext::oneapi::bfloat16 -__builtin_spirv_ConvertE5M2ToBF16EXT(uint8_t) noexcept; -extern __DPCPP_SYCL_EXTERNAL - uint8_t __builtin_spirv_ConvertFP16ToE4M3EXT(sycl::half) noexcept; -extern __DPCPP_SYCL_EXTERNAL - uint8_t __builtin_spirv_ConvertFP16ToE5M2EXT(sycl::half) noexcept; -extern __DPCPP_SYCL_EXTERNAL uint8_t - __builtin_spirv_ConvertBF16ToE4M3EXT(sycl::ext::oneapi::bfloat16) noexcept; -extern __DPCPP_SYCL_EXTERNAL uint8_t - __builtin_spirv_ConvertBF16ToE5M2EXT(sycl::ext::oneapi::bfloat16) noexcept; -extern __DPCPP_SYCL_EXTERNAL - uint8_t __builtin_spirv_ClampConvertFP16ToE4M3INTEL(sycl::half) noexcept; -extern __DPCPP_SYCL_EXTERNAL - uint8_t __builtin_spirv_ClampConvertBF16ToE4M3INTEL( - sycl::ext::oneapi::bfloat16) noexcept; -extern __DPCPP_SYCL_EXTERNAL - uint8_t __builtin_spirv_ClampConvertFP16ToE5M2INTEL(sycl::half) noexcept; -extern __DPCPP_SYCL_EXTERNAL - uint8_t __builtin_spirv_ClampConvertBF16ToE5M2INTEL( - sycl::ext::oneapi::bfloat16) noexcept; -extern __DPCPP_SYCL_EXTERNAL - uint8_t __builtin_spirv_StochasticRoundFP16ToE5M2INTEL(sycl::half) noexcept; -extern __DPCPP_SYCL_EXTERNAL - uint8_t __builtin_spirv_StochasticRoundFP16ToE4M3INTEL(sycl::half) noexcept; -extern __DPCPP_SYCL_EXTERNAL - uint8_t __builtin_spirv_StochasticRoundBF16ToE5M2INTEL( - sycl::ext::oneapi::bfloat16) noexcept; -extern __DPCPP_SYCL_EXTERNAL - uint8_t __builtin_spirv_StochasticRoundBF16ToE4M3INTEL( - sycl::ext::oneapi::bfloat16) noexcept; -extern __DPCPP_SYCL_EXTERNAL uint8_t - __builtin_spirv_ClampStochasticRoundFP16ToE5M2INTEL(sycl::half) noexcept; -extern __DPCPP_SYCL_EXTERNAL uint8_t - __builtin_spirv_ClampStochasticRoundFP16ToE4M3INTEL(sycl::half) noexcept; -extern __DPCPP_SYCL_EXTERNAL - uint8_t __builtin_spirv_ClampStochasticRoundBF16ToE5M2INTEL( - sycl::ext::oneapi::bfloat16) noexcept; -extern __DPCPP_SYCL_EXTERNAL - uint8_t __builtin_spirv_ClampStochasticRoundBF16ToE4M3INTEL( - sycl::ext::oneapi::bfloat16) noexcept; - -#endif // __SYCL_DEVICE_ONLY__ - enum class saturation { none, finite }; enum class rounding { @@ -639,29 +587,6 @@ template class fp8_e4m3_x { vals[i] = ConvertToFP8(v[i], rounding::to_even); } - // Construct with stochastic rounding with user provided seed from an array of - // half, bfloat16, float. - // Should be removed once docs updated - explicit fp8_e4m3_x(half const (&vals)[N], const stochastic_seed &seed, - saturation s = saturation::finite); - explicit fp8_e4m3_x(bfloat16 const (&vals)[N], const stochastic_seed &seed, - saturation s = saturation::finite); - explicit fp8_e4m3_x(float const (&vals)[N], const stochastic_seed &seed, - saturation s = saturation::finite); - - // Construct with stochastic rounding with user provided seed from an marray - // of half, bfloat16, float. - - // Should be removed once docs updated - explicit fp8_e4m3_x(const sycl::marray &vals, - const stochastic_seed &seed, - saturation s = saturation::finite); - explicit fp8_e4m3_x(const sycl::marray &vals, - const stochastic_seed &seed, - saturation s = saturation::finite); - explicit fp8_e4m3_x(const sycl::marray &vals, - const stochastic_seed &seed, - saturation s = saturation::finite); // Construct from integer types. // Available only when N==1. @@ -1015,7 +940,7 @@ template class fp8_e5m2_x { ...))>> explicit fp8_e5m2_x(Types... v) { static_assert(N == 1 || N == 2, - "fp8_e5m2_x: Template argument N must be 1 or 2 on device"); + "fp8_e5m2_x: Template argument N must be 1 or 2"); if constexpr (((std::is_same_v, bfloat16>) && ...)) { const bfloat16 in[N] = {static_cast(v)...}; for (size_t i = 0; i < N; ++i) @@ -1091,27 +1016,127 @@ template class fp8_e5m2_x { // Construct with stochastic rounding with user provided seed from an array of // half, bfloat16, float. - // should be removed once docs updated - explicit fp8_e5m2_x(half const (&vals)[N], const stochastic_seed &seed, - saturation s = saturation::finite); - explicit fp8_e5m2_x(bfloat16 const (&vals)[N], const stochastic_seed &seed, - saturation s = saturation::finite); - explicit fp8_e5m2_x(double const (&vals)[N], const stochastic_seed &seed, - saturation s = saturation::finite); + explicit fp8_e5m2_x(half const (&in)[N], const stochastic_seed &seed, + saturation s = saturation::finite) { + static_assert(N == 1 || N == 2, + "fp8_e5m2_x: Template argument N must be 1 or 2"); +#ifdef __SYCL_DEVICE_ONLY__ + uint32_t current_seed = *seed.pseed; + for (size_t i = 0; i < N; ++i) { + if (s == saturation::finite) { + vals[i] = __builtin_spirv_ClampStochasticRoundFP16ToE5M2INTEL( + in[i], static_cast(current_seed), seed.pseed); + } else { + vals[i] = __builtin_spirv_StochasticRoundFP16ToE5M2INTEL( + in[i], static_cast(current_seed), seed.pseed); + } + current_seed = *seed.pseed; + } +#endif + } + + explicit fp8_e5m2_x(bfloat16 const (&in)[N], const stochastic_seed &seed, + saturation s = saturation::finite) { + static_assert(N == 1 || N == 2, + "fp8_e5m2_x: Template argument N must be 1 or 2"); +#ifdef __SYCL_DEVICE_ONLY__ + uint32_t current_seed = *seed.pseed; + for (size_t i = 0; i < N; ++i) { + if (s == saturation::finite) { + vals[i] = __builtin_spirv_ClampStochasticRoundBF16ToE5M2INTEL( + in[i], static_cast(current_seed), seed.pseed); + } else { + vals[i] = __builtin_spirv_StochasticRoundBF16ToE5M2INTEL( + in[i], static_cast(current_seed), seed.pseed); + } + current_seed = *seed.pseed; + } +#endif + } + + explicit fp8_e5m2_x(double const (&in)[N], const stochastic_seed &seed, + saturation s = saturation::finite) { + static_assert(N == 1 || N == 2, + "fp8_e5m2_x: Template argument N must be 1 or 2"); +#ifdef __SYCL_DEVICE_ONLY__ + uint32_t current_seed = *seed.pseed; + for (size_t i = 0; i < N; ++i) { + sycl::half h = static_cast(in[i]); + if (s == saturation::finite) { + vals[i] = __builtin_spirv_ClampStochasticRoundFP16ToE5M2INTEL( + h, static_cast(current_seed), seed.pseed); + } else { + vals[i] = __builtin_spirv_StochasticRoundFP16ToE5M2INTEL( + h, static_cast(current_seed), seed.pseed); + } + current_seed = *seed.pseed; + } +#endif + } // Construct with stochastic rounding with user provided seed from an marray // of half, bfloat16, float. - // should be removed once docs updated - explicit fp8_e5m2_x(const sycl::marray &vals, + explicit fp8_e5m2_x(const sycl::marray &in, const stochastic_seed &seed, - saturation s = saturation::finite); - explicit fp8_e5m2_x(const sycl::marray &vals, + saturation s = saturation::finite) { + static_assert(N == 1 || N == 2, + "fp8_e5m2_x: Template argument N must be 1 or 2"); +#ifdef __SYCL_DEVICE_ONLY__ + uint32_t current_seed = *seed.pseed; + for (size_t i = 0; i < N; ++i) { + if (s == saturation::finite) { + vals[i] = __builtin_spirv_ClampStochasticRoundFP16ToE5M2INTEL( + in[i], static_cast(current_seed), seed.pseed); + } else { + vals[i] = __builtin_spirv_StochasticRoundFP16ToE5M2INTEL( + in[i], static_cast(current_seed), seed.pseed); + } + current_seed = *seed.pseed; + } +#endif + } + + explicit fp8_e5m2_x(const sycl::marray &in, const stochastic_seed &seed, - saturation s = saturation::finite); - explicit fp8_e5m2_x(const sycl::marray &vals, + saturation s = saturation::finite) { + static_assert(N == 1 || N == 2, + "fp8_e5m2_x: Template argument N must be 1 or 2"); +#ifdef __SYCL_DEVICE_ONLY__ + uint32_t current_seed = *seed.pseed; + for (size_t i = 0; i < N; ++i) { + if (s == saturation::finite) { + vals[i] = __builtin_spirv_ClampStochasticRoundBF16ToE5M2INTEL( + in[i], static_cast(current_seed), seed.pseed); + } else { + vals[i] = __builtin_spirv_StochasticRoundBF16ToE5M2INTEL( + in[i], static_cast(current_seed), seed.pseed); + } + current_seed = *seed.pseed; + } +#endif + } + + explicit fp8_e5m2_x(const sycl::marray &in, const stochastic_seed &seed, - saturation s = saturation::finite); + saturation s = saturation::finite) { + static_assert(N == 1 || N == 2, + "fp8_e5m2_x: Template argument N must be 1 or 2"); +#ifdef __SYCL_DEVICE_ONLY__ + uint32_t current_seed = *seed.pseed; + for (size_t i = 0; i < N; ++i) { + sycl::half h = static_cast(in[i]); + if (s == saturation::finite) { + vals[i] = __builtin_spirv_ClampStochasticRoundFP16ToE5M2INTEL( + h, static_cast(current_seed), seed.pseed); + } else { + vals[i] = __builtin_spirv_StochasticRoundFP16ToE5M2INTEL( + h, static_cast(current_seed), seed.pseed); + } + current_seed = *seed.pseed; + } +#endif + } // Construct from integer types. // Available only when N==1. @@ -1548,25 +1573,6 @@ template class fp8_e8m0_x { saturation::finite); } - // Construct with stochastic rounding with user provided seed from an array of - // half, bfloat16, float. - - // should be removed once docs updated - explicit fp8_e8m0_x(half const (&vals)[N], const stochastic_seed &seed); - explicit fp8_e8m0_x(bfloat16 const (&vals)[N], const stochastic_seed &seed); - explicit fp8_e8m0_x(double const (&vals)[N], const stochastic_seed &seed); - - // Construct with stochastic rounding with user provided seed from an marray - // of half, bfloat16, float. - - // should be removed once docs updated - explicit fp8_e8m0_x(const sycl::marray &vals, - const stochastic_seed &seed); - explicit fp8_e8m0_x(const sycl::marray &vals, - const stochastic_seed &seed); - explicit fp8_e8m0_x(const sycl::marray &vals, - const stochastic_seed &seed); - // Construct from integer types. // Available only when N==1. From 6e0a2e482e27a7a539a365c1ffa7403033ac232e Mon Sep 17 00:00:00 2001 From: "Klochkov, Denis" Date: Thu, 19 Mar 2026 14:32:59 +0100 Subject: [PATCH 06/89] [SYCL] update fp8 implemetation --- .../oneapi/experimental/float_8bit/types.hpp | 414 +++++++++++------- sycl/unittests/Extensions/fp8/CMakeLists.txt | 5 +- .../Extensions/fp8/builtin_call_tests.cpp | 123 ++++++ .../Extensions/fp8/builtin_mocks.hpp | 155 +++++++ sycl/unittests/Extensions/fp8/fp8_e4m3.cpp | 5 + sycl/unittests/Extensions/fp8/fp8_e5m2.cpp | 17 +- sycl/unittests/Extensions/fp8/fp8_e8m0.cpp | 5 + 7 files changed, 562 insertions(+), 162 deletions(-) create mode 100644 sycl/unittests/Extensions/fp8/builtin_call_tests.cpp create mode 100644 sycl/unittests/Extensions/fp8/builtin_mocks.hpp diff --git a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp index 4a99e19aa7fdf..10d5c51a51aa3 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp @@ -1,6 +1,13 @@ +//==----------- types.hpp - sycl_ext_oneapi_fp8 ------------==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + #pragma once -#include #include #include @@ -8,14 +15,64 @@ #include #include #include +#include #include +#ifdef __SYCL_DEVICE_ONLY__ +// New FP8 builtins +extern __DPCPP_SYCL_EXTERNAL sycl::half +__builtin_spirv_ClampConvertE4M3ToFP16INTEL(uint8_t) noexcept; +extern __DPCPP_SYCL_EXTERNAL sycl::half +__builtin_spirv_ConvertE4M3ToFP16INTEL(uint8_t) noexcept; +extern __DPCPP_SYCL_EXTERNAL sycl::half +__builtin_spirv_ConvertE5M2ToFP16INTEL(uint8_t) noexcept; +extern __DPCPP_SYCL_EXTERNAL sycl::ext::oneapi::bfloat16 +__builtin_spirv_ConvertE4M3ToBF16INTEL(uint8_t) noexcept; +extern __DPCPP_SYCL_EXTERNAL sycl::ext::oneapi::bfloat16 +__builtin_spirv_ConvertE5M2ToBF16INTEL(uint8_t) noexcept; +extern __DPCPP_SYCL_EXTERNAL + uint8_t __builtin_spirv_ConvertFP16ToE4M3INTEL(sycl::half) noexcept; +extern __DPCPP_SYCL_EXTERNAL uint8_t __builtin_spirv_ConvertBF16ToE5M2INTEL( + sycl::ext::oneapi::bfloat16) noexcept; +extern __DPCPP_SYCL_EXTERNAL + uint8_t __builtin_spirv_ClampConvertFP16ToE4M3INTEL(sycl::half) noexcept; +extern __DPCPP_SYCL_EXTERNAL + uint8_t __builtin_spirv_ClampConvertBF16ToE4M3INTEL( + sycl::ext::oneapi::bfloat16) noexcept; +extern __DPCPP_SYCL_EXTERNAL + uint8_t __builtin_spirv_ClampConvertFP16ToE5M2INTEL(sycl::half) noexcept; +extern __DPCPP_SYCL_EXTERNAL + uint8_t __builtin_spirv_ClampConvertBF16ToE5M2INTEL( + sycl::ext::oneapi::bfloat16) noexcept; +extern __DPCPP_SYCL_EXTERNAL uint8_t +__builtin_spirv_StochasticRoundFP16ToE5M2INTEL(sycl::half, uint32_t, + uint32_t *) noexcept; +extern __DPCPP_SYCL_EXTERNAL + uint8_t __builtin_spirv_StochasticRoundFP16ToE4M3INTEL(sycl::half) noexcept; +extern __DPCPP_SYCL_EXTERNAL uint8_t +__builtin_spirv_StochasticRoundBF16ToE5M2INTEL(sycl::ext::oneapi::bfloat16, + uint32_t, uint32_t *) noexcept; +extern __DPCPP_SYCL_EXTERNAL + uint8_t __builtin_spirv_StochasticRoundBF16ToE4M3INTEL( + sycl::ext::oneapi::bfloat16) noexcept; +extern __DPCPP_SYCL_EXTERNAL uint8_t +__builtin_spirv_ClampStochasticRoundFP16ToE5M2INTEL(sycl::half, uint32_t, + uint32_t *) noexcept; +extern __DPCPP_SYCL_EXTERNAL uint8_t + __builtin_spirv_ClampStochasticRoundFP16ToE4M3INTEL(sycl::half) noexcept; +extern __DPCPP_SYCL_EXTERNAL uint8_t +__builtin_spirv_ClampStochasticRoundBF16ToE5M2INTEL(sycl::ext::oneapi::bfloat16, + uint32_t, + uint32_t *) noexcept; +extern __DPCPP_SYCL_EXTERNAL + uint8_t __builtin_spirv_ClampStochasticRoundBF16ToE4M3INTEL( + sycl::ext::oneapi::bfloat16) noexcept; +#endif // __SYCL_DEVICE_ONLY__ + namespace sycl { inline namespace _V1 { namespace ext::oneapi::experimental { -#ifdef __SYCL_TARGET_INTEL_GPU_CRI__ - enum class saturation { none, finite }; enum class rounding { @@ -423,7 +480,7 @@ template class fp8_e4m3_x { static constexpr size_t NExpBits = 4; static constexpr size_t NFracBits = 3; static constexpr float MaxNormal = 448.0f; - static constexpr float MinSubnormal = 0.001953125f; // 2^-9 + static constexpr float MinSubnormal = 0.00000762939453125f; // 2^-17 static constexpr uint8_t NaNCode = 0xFF; static constexpr uint8_t MaxFiniteCode = 0x7E; // 0.1111.110 (positive max normal) @@ -433,7 +490,7 @@ template class fp8_e4m3_x { #ifdef __SYCL_DEVICE_ONLY__ // TODO: optimize with vectorized builtin calls const uint8_t sign = std::signbit(hi) ? 0x80u : 0x00u; - const float ax = sycl::fabs(hi); + const float ax = std::fabs(hi); if (ax > MaxNormal) return static_cast(sign | MaxFiniteCode); @@ -441,11 +498,11 @@ template class fp8_e4m3_x { if (ax < MinSubnormal) return sign; - uint8_t b = __builtin_spirv_ConvertFP16ToE4M3EXT(h); + uint8_t b = __builtin_spirv_ClampConvertFP16ToE4M3INTEL(h); if (r == rounding::to_even) return b; - const sycl::half yi = __builtin_spirv_ConvertE4M3ToFP16EXT(b); + const sycl::half yi = __builtin_spirv_ConvertE4M3ToFP16INTEL(b); return round(r, b, yi, hi); #else @@ -456,7 +513,7 @@ template class fp8_e4m3_x { uint8_t ConvertBF16ToFP8(bfloat16 h, rounding r) { #ifdef __SYCL_DEVICE_ONLY__ const uint8_t sign = std::signbit(h) ? 0x80u : 0x00u; - const float ax = sycl::fabs(h); + const float ax = std::fabs(h); if (ax > MaxNormal) return static_cast(sign | MaxFiniteCode); @@ -464,11 +521,10 @@ template class fp8_e4m3_x { if (ax < MinSubnormal) return sign; - uint8_t b = __builtin_spirv_ConvertBF16ToE4M3EXT(h); + uint8_t b = __builtin_spirv_ClampConvertBF16ToE4M3INTEL(h); if (r == rounding::to_even) return b; - - const half yi = __builtin_spirv_ConvertBF16ToE4M3EXT(b); + const half yi = __builtin_spirv_ConvertE4M3ToFP16INTEL(b); return round(r, b, yi, h); #else return ConvertToFP8_CPU<4, 3, bfloat16>(h, r); @@ -477,7 +533,8 @@ template class fp8_e4m3_x { template T ConvertFromFP8(uint8_t v) const { #ifdef __SYCL_DEVICE_ONLY__ - sycl::half hi = __builtin_spirv_ConvertE4M3ToFP16EXT(v); + sycl::half hi = __builtin_spirv_ClampConvertE4M3ToFP16INTEL( + v); // sycl_fp8_ClampConvertE4M3ToFP16INTEL(v); return static_cast(hi); #else return ConvertFromFP8_CPU<4, 3, T>(v); @@ -486,7 +543,7 @@ template class fp8_e4m3_x { bfloat16 ConvertBF16FromFP8(uint8_t v) const { #ifdef __SYCL_DEVICE_ONLY__ - return __builtin_spirv_ConvertE4M3ToBF16EXT(v); + return __builtin_spirv_ConvertE4M3ToBF16INTEL(v); #else return ConvertFromFP8_CPU<4, 3, bfloat16>(v); #endif @@ -587,48 +644,51 @@ template class fp8_e4m3_x { vals[i] = ConvertToFP8(v[i], rounding::to_even); } - // Construct from integer types. // Available only when N==1. explicit fp8_e4m3_x(short val) { - assert(N == 1 && "fp8_e4m3_x: N must be 1 for short constructor"); + static_assert(N == 1 && "fp8_e4m3_x: N must be 1 for short constructor"); vals[0] = ConvertToFP8(val, rounding::to_even); } explicit fp8_e4m3_x(int val) { - assert(N == 1 && "fp8_e4m3_x: N must be 1 for int constructor"); + static_assert(N == 1 && "fp8_e4m3_x: N must be 1 for int constructor"); vals[0] = ConvertToFP8(val, rounding::to_even); } explicit fp8_e4m3_x(long val) { - assert(N == 1 && "fp8_e4m3_x: N must be 1 for long constructor"); + static_assert(N == 1 && "fp8_e4m3_x: N must be 1 for long constructor"); vals[0] = ConvertToFP8(val, rounding::to_even); } explicit fp8_e4m3_x(long long val) { - assert(N == 1 && "fp8_e4m3_x: N must be 1 for long long constructor"); + static_assert(N == 1 && + "fp8_e4m3_x: N must be 1 for long long constructor"); vals[0] = ConvertToFP8(val, rounding::to_even); } explicit fp8_e4m3_x(unsigned short val) { - assert(N == 1 && "fp8_e4m3_x: N must be 1 for unsigned short constructor"); + static_assert(N == 1 && + "fp8_e4m3_x: N must be 1 for unsigned short constructor"); vals[0] = ConvertToFP8(val, rounding::to_even); } explicit fp8_e4m3_x(unsigned int val) { - assert(N == 1 && "fp8_e4m3_x: N must be 1 for unsigned int constructor"); + static_assert(N == 1 && + "fp8_e4m3_x: N must be 1 for unsigned int constructor"); vals[0] = ConvertToFP8(val, rounding::to_even); } explicit fp8_e4m3_x(unsigned long val) { - assert(N == 1 && "fp8_e4m3_x: N must be 1 for unsigned long constructor"); + static_assert(N == 1 && + "fp8_e4m3_x: N must be 1 for unsigned long constructor"); vals[0] = ConvertToFP8(val, rounding::to_even); } explicit fp8_e4m3_x(unsigned long long val) { - assert(N == 1 && - "fp8_e4m3_x: N must be 1 for unsigned long long constructor"); + static_assert(N == 1 && + "fp8_e4m3_x: N must be 1 for unsigned long long constructor"); vals[0] = ConvertToFP8(val, rounding::to_even); } @@ -636,78 +696,87 @@ template class fp8_e4m3_x { // Available only when N==1. fp8_e4m3_x &operator=(sycl::half val) { - assert(N == 1 && "fp8_e4m3_x: N must be 1 for half assignment operator"); + static_assert(N == 1 && + "fp8_e4m3_x: N must be 1 for half assignment operator"); vals[0] = ConvertToFP8(val, rounding::to_even); return *this; } fp8_e4m3_x &operator=(bfloat16 val) { - assert(N == 1 && - "fp8_e4m3_x: N must be 1 for bfloat16 assignment operator"); + static_assert(N == 1 && + "fp8_e4m3_x: N must be 1 for bfloat16 assignment operator"); vals[0] = ConvertToFP8(val, rounding::to_even); return *this; } fp8_e4m3_x &operator=(float val) { - assert(N == 1 && "fp8_e4m3_x: N must be 1 for float assignment operator"); + static_assert(N == 1 && + "fp8_e4m3_x: N must be 1 for float assignment operator"); vals[0] = ConvertToFP8(val, rounding::to_even); return *this; } fp8_e4m3_x &operator=(double val) { - assert(N == 1 && "fp8_e4m3_x: N must be 1 for double assignment operator"); + static_assert(N == 1 && + "fp8_e4m3_x: N must be 1 for double assignment operator"); vals[0] = ConvertBF16ToFP8(val, rounding::to_even); return *this; } fp8_e4m3_x &operator=(short val) { - assert(N == 1 && "fp8_e4m3_x: N must be 1 for short assignment operator"); + static_assert(N == 1 && + "fp8_e4m3_x: N must be 1 for short assignment operator"); vals[0] = ConvertToFP8(val, rounding::to_even); return *this; } fp8_e4m3_x &operator=(int val) { - assert(N == 1 && "fp8_e4m3_x: N must be 1 for int assignment operator"); + static_assert(N == 1 && + "fp8_e4m3_x: N must be 1 for int assignment operator"); vals[0] = ConvertToFP8(val, rounding::to_even); return *this; } fp8_e4m3_x &operator=(long val) { - assert(N == 1 && "fp8_e4m3_x: N must be 1 for long assignment operator"); + static_assert(N == 1 && + "fp8_e4m3_x: N must be 1 for long assignment operator"); vals[0] = ConvertToFP8(val, rounding::to_even); return *this; } fp8_e4m3_x &operator=(long long val) { - assert(N == 1 && - "fp8_e4m3_x: N must be 1 for long long assignment operator"); + static_assert(N == 1 && + "fp8_e4m3_x: N must be 1 for long long assignment operator"); vals[0] = ConvertToFP8(val, rounding::to_even); return *this; } fp8_e4m3_x &operator=(unsigned short val) { - assert(N == 1 && - "fp8_e4m3_x: N must be 1 for unsigned short assignment operator"); + static_assert( + N == 1 && + "fp8_e4m3_x: N must be 1 for unsigned short assignment operator"); vals[0] = ConvertToFP8(val, rounding::to_even); return *this; } fp8_e4m3_x &operator=(unsigned int val) { - assert(N == 1 && - "fp8_e4m3_x: N must be 1 for unsigned int assignment operator"); + static_assert( + N == 1 && + "fp8_e4m3_x: N must be 1 for unsigned int assignment operator"); vals[0] = ConvertToFP8(val, rounding::to_even); return *this; } fp8_e4m3_x &operator=(unsigned long val) { - assert(N == 1 && - "fp8_e4m3_x: N must be 1 for unsigned long assignment operator"); + static_assert( + N == 1 && + "fp8_e4m3_x: N must be 1 for unsigned long assignment operator"); vals[0] = ConvertToFP8(val, rounding::to_even); return *this; } fp8_e4m3_x &operator=(unsigned long long val) { - assert( + static_assert( N == 1 && "fp8_e4m3_x: N must be 1 for unsigned long long assignment operator"); vals[0] = ConvertToFP8(val, rounding::to_even); @@ -718,21 +787,24 @@ template class fp8_e4m3_x { // Available only when N==1. explicit operator half() const { - assert(N == 1 && "fp8_e4m3_x: N must be 1 for half conversion operator"); + static_assert(N == 1 && + "fp8_e4m3_x: N must be 1 for half conversion operator"); return ConvertFromFP8(vals[0]); } explicit operator bfloat16() const { - assert(N == 1 && - "fp8_e4m3_x: N must be 1 for bfloat16 conversion operator"); + static_assert(N == 1 && + "fp8_e4m3_x: N must be 1 for bfloat16 conversion operator"); return ConvertBF16FromFP8(vals[0]); } explicit operator float() const { - assert(N == 1 && "fp8_e4m3_x: N must be 1 for float conversion operator"); + static_assert(N == 1 && + "fp8_e4m3_x: N must be 1 for float conversion operator"); return ConvertFromFP8(vals[0]); } explicit operator double() const { - assert(N == 1 && "fp8_e4m3_x: N must be 1 for double conversion operator"); + static_assert(N == 1 && + "fp8_e4m3_x: N must be 1 for double conversion operator"); return ConvertFromFP8(vals[0]); } @@ -740,62 +812,71 @@ template class fp8_e4m3_x { // Available only when N==1. explicit operator char() const { - assert(N == 1 && "fp8_e4m3_x: N must be 1 for char conversion operator"); + static_assert(N == 1 && + "fp8_e4m3_x: N must be 1 for char conversion operator"); return ConvertFromFP8(vals[0]); } explicit operator signed char() const { - assert(N == 1 && - "fp8_e4m3_x: N must be 1 for signed char conversion operator"); + static_assert( + N == 1 && + "fp8_e4m3_x: N must be 1 for signed char conversion operator"); return ConvertFromFP8(vals[0]); } explicit operator short() const { - assert(N == 1 && "fp8_e4m3_x: N must be 1 for short conversion operator"); + static_assert(N == 1 && + "fp8_e4m3_x: N must be 1 for short conversion operator"); return ConvertFromFP8(vals[0]); } explicit operator int() const { - assert(N == 1 && "fp8_e4m3_x: N must be 1 for int conversion operator"); + static_assert(N == 1 && + "fp8_e4m3_x: N must be 1 for int conversion operator"); return ConvertFromFP8(vals[0]); } explicit operator long() const { - assert(N == 1 && "fp8_e4m3_x: N must be 1 for long conversion operator"); + static_assert(N == 1 && + "fp8_e4m3_x: N must be 1 for long conversion operator"); return ConvertFromFP8(vals[0]); } explicit operator long long() const { - assert(N == 1 && - "fp8_e4m3_x: N must be 1 for long long conversion operator"); + static_assert(N == 1 && + "fp8_e4m3_x: N must be 1 for long long conversion operator"); return ConvertFromFP8(vals[0]); } explicit operator unsigned char() const { - assert(N == 1 && - "fp8_e4m3_x: N must be 1 for unsigned char conversion operator"); + static_assert( + N == 1 && + "fp8_e4m3_x: N must be 1 for unsigned char conversion operator"); return ConvertFromFP8(vals[0]); } explicit operator unsigned short() const { - assert(N == 1 && - "fp8_e4m3_x: N must be 1 for unsigned short conversion operator"); + static_assert( + N == 1 && + "fp8_e4m3_x: N must be 1 for unsigned short conversion operator"); return ConvertFromFP8(vals[0]); } explicit operator unsigned int() const { - assert(N == 1 && - "fp8_e4m3_x: N must be 1 for unsigned int conversion operator"); + static_assert( + N == 1 && + "fp8_e4m3_x: N must be 1 for unsigned int conversion operator"); return ConvertFromFP8(vals[0]); } explicit operator unsigned long() const { - assert(N == 1 && - "fp8_e4m3_x: N must be 1 for unsigned long conversion operator"); + static_assert( + N == 1 && + "fp8_e4m3_x: N must be 1 for unsigned long conversion operator"); return ConvertFromFP8(vals[0]); } explicit operator unsigned long long() const { - assert( + static_assert( N == 1 && "fp8_e4m3_x: N must be 1 for unsigned long long conversion operator"); return ConvertFromFP8(vals[0]); @@ -808,7 +889,7 @@ template class fp8_e4m3_x { static_assert(N == 1, "fp8_e4m3_x: operator() requires size N=1"); #ifdef __SYCL_DEVICE_ONLY__ // detect +0 / -0 - sycl::half h = __builtin_spirv_ConvertE4M3ToFP16EXT(vals[0]); + sycl::half h = __builtin_spirv_ConvertE4M3ToFP16INTEL(vals[0]); return h != 0; #else // no need to convert, just check sign bit amd 0s @@ -844,12 +925,17 @@ template class fp8_e4m3_x { }; template class fp8_e5m2_x { + static constexpr size_t NExpBits = 5; + static constexpr size_t NFracBits = 2; + static constexpr float MaxNormal = 114688.0f; // 1.75 * 2^16 + static constexpr float MinSubnormal = 0.0000152587890625f; // 2^-16 + static constexpr uint8_t MaxFiniteCode = 0x7C; // 0.11111.00 uint8_t ConvertToFP8(sycl::half h, rounding r) { #ifdef __SYCL_DEVICE_ONLY__ // TODO: optimize with vectorized builtin calls const uint8_t sign = std::signbit(h) ? 0x80u : 0x00u; - const float ax = sycl::fabs(h); + const float ax = std::fabs(h); if (ax > MaxNormal) return static_cast(sign | MaxFiniteCode); @@ -857,11 +943,10 @@ template class fp8_e5m2_x { if (ax < MinSubnormal) return sign; - uint8_t b = __builtin_spirv_ConvertFP16ToE5M2EXT(h); + uint8_t b = __builtin_spirv_ClampConvertFP16ToE5M2INTEL(h); if (r == rounding::to_even) return b; - - const sycl::half yi = __builtin_spirv_ConvertFP16ToE5M2EXT(b); + const sycl::half yi = __builtin_spirv_ConvertE5M2ToFP16INTEL(b); return round(r, b, yi, h); #else @@ -872,7 +957,7 @@ template class fp8_e5m2_x { uint8_t ConvertBF16ToFP8(bfloat16 h, rounding r) { #ifdef __SYCL_DEVICE_ONLY__ const uint8_t sign = std::signbit(h) ? 0x80u : 0x00u; - const float ax = sycl::fabs(h); + const bfloat16 ax = std::fabs(h); if (ax > MaxNormal) return static_cast(sign | MaxFiniteCode); @@ -880,11 +965,10 @@ template class fp8_e5m2_x { if (ax < MinSubnormal) return sign; - uint8_t b = __builtin_spirv_ConvertBF16ToE5M2EXT(h); + uint8_t b = __builtin_spirv_ClampConvertBF16ToE5M2INTEL(h); if (r == rounding::to_even) return b; - - const half yi = __builtin_spirv_ConvertBF16ToE5M2EXT(b); + const sycl::half yi = __builtin_spirv_ConvertE5M2ToFP16INTEL(b); return round(r, b, yi, h); #else return ConvertToFP8_CPU<5, 2, bfloat16>(h, r); @@ -893,7 +977,7 @@ template class fp8_e5m2_x { template T ConvertFromFP8(uint8_t v) const { #ifdef __SYCL_DEVICE_ONLY__ - sycl::half hi = __builtin_spirv_ConvertE5M2ToFP16EXT(v); + sycl::half hi = __builtin_spirv_ConvertE5M2ToFP16INTEL(v); return static_cast(hi); #else return ConvertFromFP8_CPU<5, 2, T>(v); @@ -902,7 +986,7 @@ template class fp8_e5m2_x { bfloat16 ConvertFP16FromFP8(uint8_t v) const { #ifdef __SYCL_DEVICE_ONLY__ - return __builtin_spirv_ConvertE5M2ToBF16EXT(v); + return __builtin_spirv_ConvertE5M2ToBF16INTEL(v); #else return ConvertFromFP8_CPU<5, 2, bfloat16>(v); #endif @@ -1016,8 +1100,9 @@ template class fp8_e5m2_x { // Construct with stochastic rounding with user provided seed from an array of // half, bfloat16, float. - explicit fp8_e5m2_x(half const (&in)[N], const stochastic_seed &seed, - saturation s = saturation::finite) { + explicit fp8_e5m2_x([[maybe_unused]] half const (&in)[N], + [[maybe_unused]] const stochastic_seed &seed, + [[maybe_unused]] saturation s = saturation::finite) { static_assert(N == 1 || N == 2, "fp8_e5m2_x: Template argument N must be 1 or 2"); #ifdef __SYCL_DEVICE_ONLY__ @@ -1025,18 +1110,19 @@ template class fp8_e5m2_x { for (size_t i = 0; i < N; ++i) { if (s == saturation::finite) { vals[i] = __builtin_spirv_ClampStochasticRoundFP16ToE5M2INTEL( - in[i], static_cast(current_seed), seed.pseed); + in[i], static_cast(current_seed), seed.pseed); } else { vals[i] = __builtin_spirv_StochasticRoundFP16ToE5M2INTEL( - in[i], static_cast(current_seed), seed.pseed); + in[i], static_cast(current_seed), seed.pseed); } current_seed = *seed.pseed; } #endif } - explicit fp8_e5m2_x(bfloat16 const (&in)[N], const stochastic_seed &seed, - saturation s = saturation::finite) { + explicit fp8_e5m2_x([[maybe_unused]] bfloat16 const (&in)[N], + [[maybe_unused]] const stochastic_seed &seed, + [[maybe_unused]] saturation s = saturation::finite) { static_assert(N == 1 || N == 2, "fp8_e5m2_x: Template argument N must be 1 or 2"); #ifdef __SYCL_DEVICE_ONLY__ @@ -1044,18 +1130,19 @@ template class fp8_e5m2_x { for (size_t i = 0; i < N; ++i) { if (s == saturation::finite) { vals[i] = __builtin_spirv_ClampStochasticRoundBF16ToE5M2INTEL( - in[i], static_cast(current_seed), seed.pseed); + in[i], static_cast(current_seed), seed.pseed); } else { vals[i] = __builtin_spirv_StochasticRoundBF16ToE5M2INTEL( - in[i], static_cast(current_seed), seed.pseed); + in[i], static_cast(current_seed), seed.pseed); } current_seed = *seed.pseed; } #endif } - explicit fp8_e5m2_x(double const (&in)[N], const stochastic_seed &seed, - saturation s = saturation::finite) { + explicit fp8_e5m2_x([[maybe_unused]] double const (&in)[N], + [[maybe_unused]] const stochastic_seed &seed, + [[maybe_unused]] saturation s = saturation::finite) { static_assert(N == 1 || N == 2, "fp8_e5m2_x: Template argument N must be 1 or 2"); #ifdef __SYCL_DEVICE_ONLY__ @@ -1064,10 +1151,10 @@ template class fp8_e5m2_x { sycl::half h = static_cast(in[i]); if (s == saturation::finite) { vals[i] = __builtin_spirv_ClampStochasticRoundFP16ToE5M2INTEL( - h, static_cast(current_seed), seed.pseed); + h, static_cast(current_seed), seed.pseed); } else { vals[i] = __builtin_spirv_StochasticRoundFP16ToE5M2INTEL( - h, static_cast(current_seed), seed.pseed); + h, static_cast(current_seed), seed.pseed); } current_seed = *seed.pseed; } @@ -1077,9 +1164,9 @@ template class fp8_e5m2_x { // Construct with stochastic rounding with user provided seed from an marray // of half, bfloat16, float. - explicit fp8_e5m2_x(const sycl::marray &in, - const stochastic_seed &seed, - saturation s = saturation::finite) { + explicit fp8_e5m2_x([[maybe_unused]] const sycl::marray &in, + [[maybe_unused]] const stochastic_seed &seed, + [[maybe_unused]] saturation s = saturation::finite) { static_assert(N == 1 || N == 2, "fp8_e5m2_x: Template argument N must be 1 or 2"); #ifdef __SYCL_DEVICE_ONLY__ @@ -1087,19 +1174,19 @@ template class fp8_e5m2_x { for (size_t i = 0; i < N; ++i) { if (s == saturation::finite) { vals[i] = __builtin_spirv_ClampStochasticRoundFP16ToE5M2INTEL( - in[i], static_cast(current_seed), seed.pseed); + in[i], static_cast(current_seed), seed.pseed); } else { vals[i] = __builtin_spirv_StochasticRoundFP16ToE5M2INTEL( - in[i], static_cast(current_seed), seed.pseed); + in[i], static_cast(current_seed), seed.pseed); } current_seed = *seed.pseed; } #endif } - explicit fp8_e5m2_x(const sycl::marray &in, - const stochastic_seed &seed, - saturation s = saturation::finite) { + explicit fp8_e5m2_x([[maybe_unused]] const sycl::marray &in, + [[maybe_unused]] const stochastic_seed &seed, + [[maybe_unused]] saturation s = saturation::finite) { static_assert(N == 1 || N == 2, "fp8_e5m2_x: Template argument N must be 1 or 2"); #ifdef __SYCL_DEVICE_ONLY__ @@ -1107,19 +1194,19 @@ template class fp8_e5m2_x { for (size_t i = 0; i < N; ++i) { if (s == saturation::finite) { vals[i] = __builtin_spirv_ClampStochasticRoundBF16ToE5M2INTEL( - in[i], static_cast(current_seed), seed.pseed); + in[i], static_cast(current_seed), seed.pseed); } else { vals[i] = __builtin_spirv_StochasticRoundBF16ToE5M2INTEL( - in[i], static_cast(current_seed), seed.pseed); + in[i], static_cast(current_seed), seed.pseed); } current_seed = *seed.pseed; } #endif } - explicit fp8_e5m2_x(const sycl::marray &in, - const stochastic_seed &seed, - saturation s = saturation::finite) { + explicit fp8_e5m2_x([[maybe_unused]] const sycl::marray &in, + [[maybe_unused]] const stochastic_seed &seed, + [[maybe_unused]] saturation s = saturation::finite) { static_assert(N == 1 || N == 2, "fp8_e5m2_x: Template argument N must be 1 or 2"); #ifdef __SYCL_DEVICE_ONLY__ @@ -1128,10 +1215,10 @@ template class fp8_e5m2_x { sycl::half h = static_cast(in[i]); if (s == saturation::finite) { vals[i] = __builtin_spirv_ClampStochasticRoundFP16ToE5M2INTEL( - h, static_cast(current_seed), seed.pseed); + h, static_cast(current_seed), seed.pseed); } else { vals[i] = __builtin_spirv_StochasticRoundFP16ToE5M2INTEL( - h, static_cast(current_seed), seed.pseed); + h, static_cast(current_seed), seed.pseed); } current_seed = *seed.pseed; } @@ -1142,43 +1229,47 @@ template class fp8_e5m2_x { // Available only when N==1. explicit fp8_e5m2_x(short val) { - assert(N == 1 && "fp8_e5m2_x: N must be 1 for short constructor"); + static_assert(N == 1 && "fp8_e5m2_x: N must be 1 for short constructor"); vals[0] = ConvertToFP8(val, rounding::to_even); } explicit fp8_e5m2_x(int val) { - assert(N == 1 && "fp8_e5m2_x: N must be 1 for int constructor"); + static_assert(N == 1 && "fp8_e5m2_x: N must be 1 for int constructor"); vals[0] = ConvertToFP8(val, rounding::to_even); } explicit fp8_e5m2_x(long val) { - assert(N == 1 && "fp8_e5m2_x: N must be 1 for long constructor"); + static_assert(N == 1 && "fp8_e5m2_x: N must be 1 for long constructor"); vals[0] = ConvertToFP8(val, rounding::to_even); } explicit fp8_e5m2_x(long long val) { - assert(N == 1 && "fp8_e5m2_x: N must be 1 for long long constructor"); + static_assert(N == 1 && + "fp8_e5m2_x: N must be 1 for long long constructor"); vals[0] = ConvertToFP8(val, rounding::to_even); } explicit fp8_e5m2_x(unsigned short val) { - assert(N == 1 && "fp8_e5m2_x: N must be 1 for unsigned short constructor"); + static_assert(N == 1 && + "fp8_e5m2_x: N must be 1 for unsigned short constructor"); vals[0] = ConvertToFP8(val, rounding::to_even); } explicit fp8_e5m2_x(unsigned int val) { - assert(N == 1 && "fp8_e5m2_x: N must be 1 for unsigned int constructor"); + static_assert(N == 1 && + "fp8_e5m2_x: N must be 1 for unsigned int constructor"); vals[0] = ConvertToFP8(val, rounding::to_even); } explicit fp8_e5m2_x(unsigned long val) { - assert(N == 1 && "fp8_e5m2_x: N must be 1 for unsigned long constructor"); + static_assert(N == 1 && + "fp8_e5m2_x: N must be 1 for unsigned long constructor"); vals[0] = ConvertToFP8(val, rounding::to_even); } explicit fp8_e5m2_x(unsigned long long val) { - assert(N == 1 && - "fp8_e5m2_x: N must be 1 for unsigned long long constructor"); + static_assert(N == 1 && + "fp8_e5m2_x: N must be 1 for unsigned long long constructor"); vals[0] = ConvertToFP8(val, rounding::to_even); } @@ -1186,77 +1277,87 @@ template class fp8_e5m2_x { // Available only when N==1. fp8_e5m2_x &operator=(sycl::half val) { - assert(N == 1 && "fp8_e5m2_x: N must be 1 for half assignment operator"); + static_assert(N == 1 && + "fp8_e5m2_x: N must be 1 for half assignment operator"); vals[0] = ConvertToFP8(val, rounding::to_even); return *this; } fp8_e5m2_x &operator=(bfloat16 val) { - assert(N == 1 && "fp8_e5m2_x: N must be 1 for half bfloat16 operator"); + static_assert(N == 1 && + "fp8_e5m2_x: N must be 1 for half bfloat16 operator"); vals[0] = ConvertBF16ToFP8(val, rounding::to_even); return *this; } fp8_e5m2_x &operator=(float val) { - assert(N == 1 && "fp8_e5m2_x: N must be 1 for float assignment operator"); + static_assert(N == 1 && + "fp8_e5m2_x: N must be 1 for float assignment operator"); vals[0] = ConvertToFP8(val, rounding::to_even); return *this; } fp8_e5m2_x &operator=(double val) { - assert(N == 1 && "fp8_e5m2_x: N must be 1 for double assignment operator"); + static_assert(N == 1 && + "fp8_e5m2_x: N must be 1 for double assignment operator"); vals[0] = ConvertToFP8(val, rounding::to_even); return *this; } fp8_e5m2_x &operator=(short val) { - assert(N == 1 && "fp8_e5m2_x: N must be 1 for short assignment operator"); + static_assert(N == 1 && + "fp8_e5m2_x: N must be 1 for short assignment operator"); vals[0] = ConvertToFP8(val, rounding::to_even); return *this; } fp8_e5m2_x &operator=(int val) { - assert(N == 1 && "fp8_e5m2_x: N must be 1 for int assignment operator"); + static_assert(N == 1 && + "fp8_e5m2_x: N must be 1 for int assignment operator"); vals[0] = ConvertToFP8(val, rounding::to_even); return *this; } fp8_e5m2_x &operator=(long val) { - assert(N == 1 && "fp8_e5m2_x: N must be 1 for long assignment operator"); + static_assert(N == 1 && + "fp8_e5m2_x: N must be 1 for long assignment operator"); vals[0] = ConvertToFP8(val, rounding::to_even); return *this; } fp8_e5m2_x &operator=(long long val) { - assert(N == 1 && - "fp8_e5m2_x: N must be 1 for long long assignment operator"); + static_assert(N == 1 && + "fp8_e5m2_x: N must be 1 for long long assignment operator"); vals[0] = ConvertToFP8(val, rounding::to_even); return *this; } fp8_e5m2_x &operator=(unsigned short val) { - assert(N == 1 && - "fp8_e5m2_x: N must be 1 for unsigned short assignment operator"); + static_assert( + N == 1 && + "fp8_e5m2_x: N must be 1 for unsigned short assignment operator"); vals[0] = ConvertToFP8(val, rounding::to_even); return *this; } fp8_e5m2_x &operator=(unsigned int val) { - assert(N == 1 && - "fp8_e5m2_x: N must be 1 for unsigned int assignment operator"); + static_assert( + N == 1 && + "fp8_e5m2_x: N must be 1 for unsigned int assignment operator"); vals[0] = ConvertToFP8(val, rounding::to_even); return *this; } fp8_e5m2_x &operator=(unsigned long val) { - assert(N == 1 && - "fp8_e5m2_x: N must be 1 for unsigned long assignment operator"); + static_assert( + N == 1 && + "fp8_e5m2_x: N must be 1 for unsigned long assignment operator"); vals[0] = ConvertToFP8(val, rounding::to_even); return *this; } fp8_e5m2_x &operator=(unsigned long long val) { - assert( + static_assert( N == 1 && "fp8_e5m2_x: N must be 1 for unsigned long long assignment operator"); vals[0] = ConvertToFP8(val, rounding::to_even); @@ -1267,23 +1368,26 @@ template class fp8_e5m2_x { // Available only when N==1. explicit operator half() const { - assert(N == 1 && "fp8_e5m2_x: N must be 1 for half conversion operator"); + static_assert(N == 1 && + "fp8_e5m2_x: N must be 1 for half conversion operator"); return ConvertFromFP8(vals[0]); } explicit operator bfloat16() const { - assert(N == 1 && - "fp8_e5m2_x: N must be 1 for bfloat16 conversion operator"); + static_assert(N == 1 && + "fp8_e5m2_x: N must be 1 for bfloat16 conversion operator"); return ConvertFP16FromFP8(vals[0]); } explicit operator float() const { - assert(N == 1 && "fp8_e5m2_x: N must be 1 for float conversion operator"); + static_assert(N == 1 && + "fp8_e5m2_x: N must be 1 for float conversion operator"); return ConvertFromFP8(vals[0]); } explicit operator double() const { - assert(N == 1 && "fp8_e5m2_x: N must be 1 for double conversion operator"); + static_assert(N == 1 && + "fp8_e5m2_x: N must be 1 for double conversion operator"); return ConvertFromFP8(vals[0]); } @@ -1291,63 +1395,72 @@ template class fp8_e5m2_x { // Available only when N==1. explicit operator char() const { - assert(N == 1 && "fp8_e5m2_x: N must be 1 for char conversion operator"); + static_assert(N == 1 && + "fp8_e5m2_x: N must be 1 for char conversion operator"); return ConvertFromFP8(vals[0]); } explicit operator signed char() const { - assert(N == 1 && - "fp8_e5m2_x: N must be 1 for signed char conversion operator"); + static_assert( + N == 1 && + "fp8_e5m2_x: N must be 1 for signed char conversion operator"); return ConvertFromFP8(vals[0]); } explicit operator short() const { - assert(N == 1 && "fp8_e5m2_x: N must be 1 for short conversion operator"); + static_assert(N == 1 && + "fp8_e5m2_x: N must be 1 for short conversion operator"); return ConvertFromFP8(vals[0]); } explicit operator int() const { - assert(N == 1 && "fp8_e5m2_x: N must be 1 for int conversion operator"); + static_assert(N == 1 && + "fp8_e5m2_x: N must be 1 for int conversion operator"); return ConvertFromFP8(vals[0]); } explicit operator long() const { - assert(N == 1 && "fp8_e5m2_x: N must be 1 for long conversion operator"); + static_assert(N == 1 && + "fp8_e5m2_x: N must be 1 for long conversion operator"); return ConvertFromFP8(vals[0]); } explicit operator long long() const { - assert(N == 1 && - "fp8_e5m2_x: N must be 1 for long long conversion operator"); + static_assert(N == 1 && + "fp8_e5m2_x: N must be 1 for long long conversion operator"); return ConvertFromFP8(vals[0]); } explicit operator unsigned char() const { - assert(N == 1 && - "fp8_e5m2_x: N must be 1 for unsigned char conversion operator"); + static_assert( + N == 1 && + "fp8_e5m2_x: N must be 1 for unsigned char conversion operator"); return ConvertFromFP8(vals[0]); } explicit operator unsigned short() const { - assert(N == 1 && - "fp8_e5m2_x: N must be 1 for unsigned short conversion operator"); + static_assert( + N == 1 && + "fp8_e5m2_x: N must be 1 for unsigned short conversion operator"); return ConvertFromFP8(vals[0]); } explicit operator unsigned int() const { - assert(N == 1 && - "fp8_e5m2_x: N must be 1 for unsigned int conversion operator"); + static_assert( + N == 1 && + "fp8_e5m2_x: N must be 1 for unsigned int conversion operator"); return ConvertFromFP8(vals[0]); } explicit operator unsigned long() const { - assert(N == 1 && - "fp8_e5m2_x: N must be 1 for unsigned long conversion operator"); + static_assert( + N == 1 && + "fp8_e5m2_x: N must be 1 for unsigned long conversion operator"); return ConvertFromFP8(vals[0]); } explicit operator unsigned long long() const { - assert( + static_assert( N == 1 && "fp8_e5m2_x: N must be 1 for unsigned long long conversion operator"); return ConvertFromFP8(vals[0]); @@ -1479,7 +1592,6 @@ static inline ToT ConvertFromE8M0_CPU(uint8_t code) noexcept { } template class fp8_e8m0_x { - void CheckConstraints(rounding r) const { static_assert(N == 1 || N == 2, "fp8_e8m0_x: Template argument N must be 1 or 2"); @@ -1577,7 +1689,7 @@ template class fp8_e8m0_x { // Available only when N==1. explicit fp8_e8m0_x(short val) { - assert(N == 1 && "fp8_e8m0_x: N must be 1 for short constructor"); + static_assert(N == 1 && "fp8_e8m0_x: N must be 1 for short constructor"); vals[0] = ConvertToE8M0_CPU(static_cast(val), rounding::upward, saturation::finite); } @@ -1719,8 +1831,6 @@ using fp8_e5m2_x2 = fp8_e5m2_x<2>; using fp8_e8m0 = fp8_e8m0_x<1>; using fp8_e8m0_x2 = fp8_e8m0_x<2>; -#endif // __SYCL_TARGET_INTEL_GPU_CRI__ - } // namespace ext::oneapi::experimental } // namespace _V1 -} // namespace sycl \ No newline at end of file +} // namespace sycl diff --git a/sycl/unittests/Extensions/fp8/CMakeLists.txt b/sycl/unittests/Extensions/fp8/CMakeLists.txt index 2d0c53daf4268..9b7c7677f9c6a 100644 --- a/sycl/unittests/Extensions/fp8/CMakeLists.txt +++ b/sycl/unittests/Extensions/fp8/CMakeLists.txt @@ -2,8 +2,5 @@ add_sycl_unittest(FP8TypesTests OBJECT fp8_e4m3.cpp fp8_e5m2.cpp fp8_e8m0.cpp + builtin_call_tests.cpp ) - -target_compile_options(FP8TypesTests_Preview_Tests PUBLIC -D__SYCL_TARGET_INTEL_GPU_CRI__) -target_compile_options(FP8TypesTests_Non_Preview_Tests PUBLIC -D__SYCL_TARGET_INTEL_GPU_CRI__) - diff --git a/sycl/unittests/Extensions/fp8/builtin_call_tests.cpp b/sycl/unittests/Extensions/fp8/builtin_call_tests.cpp new file mode 100644 index 0000000000000..c0551f3b5d746 --- /dev/null +++ b/sycl/unittests/Extensions/fp8/builtin_call_tests.cpp @@ -0,0 +1,123 @@ +#include "builtin_mocks.hpp" +#include +#include + +namespace { + +using namespace sycl::ext::oneapi::experimental; + +class Fp8BuiltinCallTest : public ::testing::Test { +protected: + void SetUp() override { fp8_builtin_mock::resetCounters(); } +}; + +TEST_F(Fp8BuiltinCallTest, E4M3CtorFromHalfCallsClampConvertFP16ToE4M3) { + fp8_e4m3 Value(static_cast(1.25f)); + (void)Value; + EXPECT_EQ(fp8_builtin_mock::getCounters().ClampConvertFP16ToE4M3INTEL, 1); +} + +TEST_F(Fp8BuiltinCallTest, E4M3CtorFromBf16CallsClampConvertBF16ToE4M3) { + fp8_e4m3 Value(static_cast(1.25f)); + (void)Value; + EXPECT_EQ(fp8_builtin_mock::getCounters().ClampConvertBF16ToE4M3INTEL, 1); +} + +TEST_F(Fp8BuiltinCallTest, E4M3CastToHalfCallsClampConvertE4M3ToFP16) { + fp8_e4m3 Value(static_cast(1.0f)); + fp8_builtin_mock::resetCounters(); + (void)static_cast(Value); + EXPECT_EQ(fp8_builtin_mock::getCounters().ClampConvertE4M3ToFP16INTEL, 1); +} + +TEST_F(Fp8BuiltinCallTest, E4M3CastToBf16CallsConvertE4M3ToBF16) { + fp8_e4m3 Value(static_cast(1.0f)); + fp8_builtin_mock::resetCounters(); + (void)static_cast(Value); + EXPECT_EQ(fp8_builtin_mock::getCounters().ConvertE4M3ToBF16INTEL, 1); +} + +TEST_F(Fp8BuiltinCallTest, E4M3CastToBoolCallsConvertE4M3ToFP16) { + fp8_e4m3 Value(static_cast(1.0f)); + fp8_builtin_mock::resetCounters(); + (void)static_cast(Value); + EXPECT_EQ(fp8_builtin_mock::getCounters().ConvertE4M3ToFP16INTEL, 1); +} + +TEST_F(Fp8BuiltinCallTest, E5M2CtorFromHalfCallsClampConvertFP16ToE5M2) { + fp8_e5m2 Value(static_cast(2.0f)); + (void)Value; + EXPECT_EQ(fp8_builtin_mock::getCounters().ClampConvertFP16ToE5M2INTEL, 1); +} + +TEST_F(Fp8BuiltinCallTest, E5M2CtorFromBf16CallsClampConvertBF16ToE5M2) { + fp8_e5m2 Value(static_cast(2.0f)); + (void)Value; + EXPECT_EQ(fp8_builtin_mock::getCounters().ClampConvertBF16ToE5M2INTEL, 1); +} + +TEST_F(Fp8BuiltinCallTest, E5M2CastToHalfCallsConvertE5M2ToFP16) { + fp8_e5m2 Value(static_cast(2.0f)); + fp8_builtin_mock::resetCounters(); + (void)static_cast(Value); + EXPECT_EQ(fp8_builtin_mock::getCounters().ConvertE5M2ToFP16INTEL, 1); +} + +TEST_F(Fp8BuiltinCallTest, E5M2CastToBf16CallsConvertE5M2ToBF16) { + fp8_e5m2 Value(static_cast(2.0f)); + fp8_builtin_mock::resetCounters(); + (void)static_cast(Value); + EXPECT_EQ(fp8_builtin_mock::getCounters().ConvertE5M2ToBF16INTEL, 1); +} + +TEST_F(Fp8BuiltinCallTest, E5M2StochasticHalfFiniteCallsClampStochastic) { + sycl::half Input[1] = {static_cast(3.0f)}; + uint32_t SeedValue = 10; + stochastic_seed Seed(&SeedValue); + + fp8_e5m2 Value(Input, Seed, saturation::finite); + (void)Value; + + EXPECT_EQ(fp8_builtin_mock::getCounters().ClampStochasticRoundFP16ToE5M2INTEL, + 1); + EXPECT_EQ(SeedValue, 11u); +} + +TEST_F(Fp8BuiltinCallTest, E5M2StochasticHalfNoneCallsNonClampStochastic) { + sycl::half Input[1] = {static_cast(3.0f)}; + uint32_t SeedValue = 20; + stochastic_seed Seed(&SeedValue); + + fp8_e5m2 Value(Input, Seed, saturation::none); + (void)Value; + + EXPECT_EQ(fp8_builtin_mock::getCounters().StochasticRoundFP16ToE5M2INTEL, 1); + EXPECT_EQ(SeedValue, 21u); +} + +TEST_F(Fp8BuiltinCallTest, E5M2StochasticBf16FiniteCallsClampStochastic) { + sycl::ext::oneapi::bfloat16 Input[1] = { + static_cast(3.0f)}; + uint32_t SeedValue = 30; + stochastic_seed Seed(&SeedValue); + + fp8_e5m2 Value(Input, Seed, saturation::finite); + (void)Value; + + EXPECT_EQ(fp8_builtin_mock::getCounters().ClampStochasticRoundBF16ToE5M2INTEL, + 1); +} + +TEST_F(Fp8BuiltinCallTest, E5M2StochasticBf16NoneCallsNonClampStochastic) { + sycl::ext::oneapi::bfloat16 Input[1] = { + static_cast(3.0f)}; + uint32_t SeedValue = 40; + stochastic_seed Seed(&SeedValue); + + fp8_e5m2 Value(Input, Seed, saturation::none); + (void)Value; + + EXPECT_EQ(fp8_builtin_mock::getCounters().StochasticRoundBF16ToE5M2INTEL, 1); +} + +} // namespace diff --git a/sycl/unittests/Extensions/fp8/builtin_mocks.hpp b/sycl/unittests/Extensions/fp8/builtin_mocks.hpp new file mode 100644 index 0000000000000..7a4aa8180b57b --- /dev/null +++ b/sycl/unittests/Extensions/fp8/builtin_mocks.hpp @@ -0,0 +1,155 @@ +//===-- FP8 builtin helpers, mocks and stubs for float_8bit/types.hpp +//---------*- C++ -*-===// + +#pragma once + +#include +#include +#include + +// Force code path that uses helpers.hpp wrappers. +#ifndef __SYCL_DEVICE_ONLY__ +#define __SYCL_DEVICE_ONLY__ 1 +#endif + +namespace fp8_builtin_mock { + +struct Counters { + int ClampConvertE4M3ToFP16INTEL = 0; + int ConvertE4M3ToFP16INTEL = 0; + int ConvertE5M2ToFP16INTEL = 0; + int ConvertE4M3ToBF16INTEL = 0; + int ConvertE5M2ToBF16INTEL = 0; + int ClampConvertFP16ToE4M3INTEL = 0; + int ClampConvertBF16ToE4M3INTEL = 0; + int ClampConvertFP16ToE5M2INTEL = 0; + int ClampConvertBF16ToE5M2INTEL = 0; + int StochasticRoundFP16ToE5M2INTEL = 0; + int StochasticRoundBF16ToE5M2INTEL = 0; + int ClampStochasticRoundFP16ToE5M2INTEL = 0; + int ClampStochasticRoundBF16ToE5M2INTEL = 0; +}; + +inline Counters &getCounters() { + static Counters Value; + return Value; +} + +inline void resetCounters() { getCounters() = Counters{}; } + +} // namespace fp8_builtin_mock + +// Builtin mocks (do not replace helpers.hpp; provide symbols here). +inline sycl::half +__builtin_spirv_ClampConvertE4M3ToFP16INTEL(uint8_t) noexcept { + ++fp8_builtin_mock::getCounters().ClampConvertE4M3ToFP16INTEL; + return static_cast(2.0f); +} + +inline sycl::half __builtin_spirv_ConvertE4M3ToFP16INTEL(uint8_t) noexcept { + ++fp8_builtin_mock::getCounters().ConvertE4M3ToFP16INTEL; + return static_cast(1.0f); +} + +inline sycl::half __builtin_spirv_ConvertE5M2ToFP16INTEL(uint8_t) noexcept { + ++fp8_builtin_mock::getCounters().ConvertE5M2ToFP16INTEL; + return static_cast(3.0f); +} + +inline sycl::ext::oneapi::bfloat16 +__builtin_spirv_ConvertE4M3ToBF16INTEL(uint8_t) noexcept { + ++fp8_builtin_mock::getCounters().ConvertE4M3ToBF16INTEL; + return static_cast(4.0f); +} + +inline sycl::ext::oneapi::bfloat16 +__builtin_spirv_ConvertE5M2ToBF16INTEL(uint8_t) noexcept { + ++fp8_builtin_mock::getCounters().ConvertE5M2ToBF16INTEL; + return static_cast(5.0f); +} + +inline uint8_t __builtin_spirv_ConvertFP16ToE4M3INTEL(sycl::half) noexcept { + return 0x00; +} + +inline uint8_t +__builtin_spirv_ConvertBF16ToE5M2INTEL(sycl::ext::oneapi::bfloat16) noexcept { + return 0x00; +} + +inline uint8_t +__builtin_spirv_ClampConvertFP16ToE4M3INTEL(sycl::half) noexcept { + ++fp8_builtin_mock::getCounters().ClampConvertFP16ToE4M3INTEL; + return 0x11; +} + +inline uint8_t __builtin_spirv_ClampConvertBF16ToE4M3INTEL( + sycl::ext::oneapi::bfloat16) noexcept { + ++fp8_builtin_mock::getCounters().ClampConvertBF16ToE4M3INTEL; + return 0x12; +} + +inline uint8_t +__builtin_spirv_ClampConvertFP16ToE5M2INTEL(sycl::half) noexcept { + ++fp8_builtin_mock::getCounters().ClampConvertFP16ToE5M2INTEL; + return 0x21; +} + +inline uint8_t __builtin_spirv_ClampConvertBF16ToE5M2INTEL( + sycl::ext::oneapi::bfloat16) noexcept { + ++fp8_builtin_mock::getCounters().ClampConvertBF16ToE5M2INTEL; + return 0x22; +} + +inline uint8_t +__builtin_spirv_StochasticRoundFP16ToE5M2INTEL(sycl::half, uint32_t Seed, + uint32_t *NextSeed) noexcept { + ++fp8_builtin_mock::getCounters().StochasticRoundFP16ToE5M2INTEL; + if (NextSeed) + *NextSeed = Seed + 1; + return 0x31; +} + +inline uint8_t +__builtin_spirv_StochasticRoundFP16ToE4M3INTEL(sycl::half) noexcept { + return 0x00; +} + +inline uint8_t __builtin_spirv_StochasticRoundBF16ToE5M2INTEL( + sycl::ext::oneapi::bfloat16, uint32_t Seed, uint32_t *NextSeed) noexcept { + ++fp8_builtin_mock::getCounters().StochasticRoundBF16ToE5M2INTEL; + if (NextSeed) + *NextSeed = Seed + 1; + return 0x32; +} + +inline uint8_t __builtin_spirv_StochasticRoundBF16ToE4M3INTEL( + sycl::ext::oneapi::bfloat16) noexcept { + return 0x00; +} + +inline uint8_t __builtin_spirv_ClampStochasticRoundFP16ToE5M2INTEL( + sycl::half, uint32_t Seed, uint32_t *NextSeed) noexcept { + ++fp8_builtin_mock::getCounters().ClampStochasticRoundFP16ToE5M2INTEL; + if (NextSeed) + *NextSeed = Seed + 1; + return 0x41; +} + +inline uint8_t +__builtin_spirv_ClampStochasticRoundFP16ToE4M3INTEL(sycl::half) noexcept { + return 0x00; +} + +inline uint8_t __builtin_spirv_ClampStochasticRoundBF16ToE5M2INTEL( + sycl::ext::oneapi::bfloat16, uint32_t Seed, uint32_t *NextSeed) noexcept { + ++fp8_builtin_mock::getCounters().ClampStochasticRoundBF16ToE5M2INTEL; + if (NextSeed) + *NextSeed = Seed + 1; + return 0x42; +} + +inline uint8_t __builtin_spirv_ClampStochasticRoundBF16ToE4M3INTEL( + sycl::ext::oneapi::bfloat16) noexcept { + return 0x00; +} diff --git a/sycl/unittests/Extensions/fp8/fp8_e4m3.cpp b/sycl/unittests/Extensions/fp8/fp8_e4m3.cpp index 809efd179c5e7..6b4b558936e34 100644 --- a/sycl/unittests/Extensions/fp8/fp8_e4m3.cpp +++ b/sycl/unittests/Extensions/fp8/fp8_e4m3.cpp @@ -5,6 +5,11 @@ #include #include +/* +Unit tests check only CPU versions. Most of the constraints related to device +code thus unit tests check only API +*/ + using namespace sycl::ext::oneapi::experimental; TEST(FP8E4M3Test, VariadicConstructorHalf) { diff --git a/sycl/unittests/Extensions/fp8/fp8_e5m2.cpp b/sycl/unittests/Extensions/fp8/fp8_e5m2.cpp index c1ea19ea3fa47..e73e4f1a24624 100644 --- a/sycl/unittests/Extensions/fp8/fp8_e5m2.cpp +++ b/sycl/unittests/Extensions/fp8/fp8_e5m2.cpp @@ -5,6 +5,11 @@ #include #include +/* +Unit tests check only CPU versions. Most of the constraints related to device +code thus unit tests check only API +*/ + using namespace sycl::ext::oneapi::experimental; TEST(FP8E5M2Test, VariadicConstructorHalf) { @@ -176,12 +181,12 @@ TEST(FP8E5M2Test, VariadicConstructorSaturatesFinite) { ); fp8_e5m2_x2 a1(-100000.0f, // clamp to -57344 - -0.0f); + -0.0f); EXPECT_EQ(sizeof(a.vals), 2u); EXPECT_EQ(sizeof(a1.vals), 2u); EXPECT_EQ(a.vals[0], 0x3C); - EXPECT_EQ(a.vals[1], 0x7B); // +max normal + EXPECT_EQ(a.vals[1], 0x7B); // +max normal EXPECT_EQ(a1.vals[0], 0xFB); // -max normal EXPECT_EQ(a1.vals[1], 0x80); // -0 } @@ -205,8 +210,8 @@ TEST(FP8E5M2Test, CArrayConstructorFloatHostToEvenFinite) { EXPECT_EQ(sizeof(a.vals), 2u); EXPECT_EQ(sizeof(a1.vals), 2u); - EXPECT_EQ(a.vals[0], 0x3C); // 1.0 - EXPECT_EQ(a.vals[1], 0x3C); // 1.1 -> 1.0 + EXPECT_EQ(a.vals[0], 0x3C); // 1.0 + EXPECT_EQ(a.vals[1], 0x3C); // 1.1 -> 1.0 EXPECT_EQ(a1.vals[0], 0x3C); // tie -> to_even => 1.0 EXPECT_EQ(a1.vals[1], 0x7B); // finite saturation => +57344 } @@ -224,8 +229,8 @@ TEST(FP8E5M2Test, CArrayConstructorDoubleToEvenFinite) { EXPECT_EQ(sizeof(a.vals), 2u); EXPECT_EQ(sizeof(a1.vals), 2u); EXPECT_EQ(sizeof(a2.vals), 2u); - EXPECT_EQ(a.vals[0], 0x7B); // +57344 - EXPECT_EQ(a.vals[1], 0x7B); // 60000 -> clamp to +57344 + EXPECT_EQ(a.vals[0], 0x7B); // +57344 + EXPECT_EQ(a.vals[1], 0x7B); // 60000 -> clamp to +57344 EXPECT_EQ(a1.vals[0], 0x04); // min normal EXPECT_EQ(a1.vals[1], 0x03); // max subnormal EXPECT_EQ(a2.vals[0], 0x01); // min subnormal diff --git a/sycl/unittests/Extensions/fp8/fp8_e8m0.cpp b/sycl/unittests/Extensions/fp8/fp8_e8m0.cpp index e8478bf9447db..74576226d6173 100644 --- a/sycl/unittests/Extensions/fp8/fp8_e8m0.cpp +++ b/sycl/unittests/Extensions/fp8/fp8_e8m0.cpp @@ -5,6 +5,11 @@ #include #include +/* +Unit tests check only CPU versions. Most of the constraints related to device +code thus unit tests check only API +*/ + using namespace sycl::ext::oneapi::experimental; TEST(FP8E8M0Test, VariadicConstructorFloat) { From 6d03f08e45f2cf06961b391e86deb822e590a57a Mon Sep 17 00:00:00 2001 From: "Klochkov, Denis" Date: Thu, 19 Mar 2026 14:59:32 +0100 Subject: [PATCH 07/89] [SYCL] fix formatting --- sycl/unittests/Extensions/fp8/fp8_e4m3.cpp | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/sycl/unittests/Extensions/fp8/fp8_e4m3.cpp b/sycl/unittests/Extensions/fp8/fp8_e4m3.cpp index 6b4b558936e34..4aecd26ff1eb2 100644 --- a/sycl/unittests/Extensions/fp8/fp8_e4m3.cpp +++ b/sycl/unittests/Extensions/fp8/fp8_e4m3.cpp @@ -25,7 +25,8 @@ TEST(FP8E4M3Test, VariadicConstructorHalf) { } TEST(FP8E4M3Test, VariadicConstructorBFloat16) { - fp8_e4m3_x2 a(sycl::ext::oneapi::bfloat16(1.0f), sycl::ext::oneapi::bfloat16(2.0f)); + fp8_e4m3_x2 a(sycl::ext::oneapi::bfloat16(1.0f), + sycl::ext::oneapi::bfloat16(2.0f)); EXPECT_EQ(sizeof(a.vals), 2u); EXPECT_EQ(a.vals[0], 0x38); @@ -77,7 +78,7 @@ TEST(FP8E4M3Test, VariadicBoundaryEncodingsFloat) { TEST(FP8E4M3Test, VariadicNaNEncodingFloat) { // NaN is encoded as S.1111.111; sign is permitted. fp8_e4m3_x2 a(std::numeric_limits::quiet_NaN(), - -std::numeric_limits::quiet_NaN()); + -std::numeric_limits::quiet_NaN()); EXPECT_EQ(a.vals[0], 0x7F); // +NaN -> 0b0_1111_111 EXPECT_EQ(a.vals[1], 0xFF); // -NaN -> 0b1_1111_111 @@ -95,9 +96,9 @@ TEST(FP8E4M3Test, IntegerToEvenFiniteAndSize) { EXPECT_EQ(sizeof(a2.vals), 1u); EXPECT_EQ(sizeof(an1.vals), 1u); - EXPECT_EQ(a0.vals[0], 0x00); // +0 - EXPECT_EQ(a1.vals[0], 0x38); // +1.0 -> 0b0_0111_000 - EXPECT_EQ(a2.vals[0], 0x40); // +2.0 -> 0b0_1000_000 + EXPECT_EQ(a0.vals[0], 0x00); // +0 + EXPECT_EQ(a1.vals[0], 0x38); // +1.0 -> 0b0_0111_000 + EXPECT_EQ(a2.vals[0], 0x40); // +2.0 -> 0b0_1000_000 EXPECT_EQ(an1.vals[0], 0xB8); // -1.0 -> sign set: 0b1_0111_000 } From c9386426f29d6df5ef113b97e76b2218625aa05c Mon Sep 17 00:00:00 2001 From: "Klochkov, Denis" Date: Thu, 19 Mar 2026 17:04:30 +0100 Subject: [PATCH 08/89] [SYCL] do not use extra rounding modes --- .../oneapi/experimental/float_8bit/types.hpp | 20 ++----------------- 1 file changed, 2 insertions(+), 18 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp index 10d5c51a51aa3..e943af983807c 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp @@ -78,10 +78,7 @@ enum class saturation { none, finite }; enum class rounding { to_even, upward, - downward, toward_zero, - to_away, - stochastic }; struct stochastic_seed { @@ -451,11 +448,6 @@ uint8_t round(rounding r, uint8_t b, sycl::half yi, T vi) { return nextE4M3(b, /*up=*/true); break; } - case rounding::downward: { - if (yi > vi) - return nextE4M3(b, /*up=*/false); - break; - } case rounding::toward_zero: if (vi > 0.0f && yi > vi) { return nextE4M3(b, /*up=*/false); @@ -463,13 +455,6 @@ uint8_t round(rounding r, uint8_t b, sycl::half yi, T vi) { return nextE4M3(b, /*up=*/true); } break; - case rounding::to_away: - if (vi > 0.0f && yi < vi) { - return nextE4M3(b, /*up=*/true); - } else if (vi < 0.0f && yi > vi) { - return nextE4M3(b, /*up=*/false); - } - break; default: break; } @@ -1140,7 +1125,7 @@ template class fp8_e5m2_x { #endif } - explicit fp8_e5m2_x([[maybe_unused]] double const (&in)[N], + explicit fp8_e5m2_x([[maybe_unused]] float const (&in)[N], [[maybe_unused]] const stochastic_seed &seed, [[maybe_unused]] saturation s = saturation::finite) { static_assert(N == 1 || N == 2, @@ -1204,7 +1189,7 @@ template class fp8_e5m2_x { #endif } - explicit fp8_e5m2_x([[maybe_unused]] const sycl::marray &in, + explicit fp8_e5m2_x([[maybe_unused]] const sycl::marray &in, [[maybe_unused]] const stochastic_seed &seed, [[maybe_unused]] saturation s = saturation::finite) { static_assert(N == 1 || N == 2, @@ -1545,7 +1530,6 @@ static inline uint8_t ConvertToE8M0_CPU(float x, rounding R, if (!is_exact_power_of_two && E < Emax) ++E; break; - case rounding::downward: case rounding::toward_zero: // toward -inf / toward 0: both pick the lower power for non-exact. break; From a55e2754c2382f9e1e308b1741926f020349443b Mon Sep 17 00:00:00 2001 From: "Klochkov, Denis" Date: Tue, 24 Mar 2026 13:41:32 +0100 Subject: [PATCH 09/89] [SYCL][FP8] use saturation --- .../oneapi/experimental/float_8bit/types.hpp | 171 ++++++++++-------- sycl/unittests/Extensions/fp8/fp8_e4m3.cpp | 20 ++ sycl/unittests/Extensions/fp8/fp8_e5m2.cpp | 39 ++++ sycl/unittests/Extensions/fp8/fp8_e8m0.cpp | 15 ++ 4 files changed, 169 insertions(+), 76 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp index e943af983807c..d1ed476a3fbfa 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp @@ -107,6 +107,11 @@ static inline uint8_t RoundClip(float x, uint8_t max, rounding R, return 1u; return 0u; } + if (R == rounding::toward_zero) { + if (std::isnan(x) || x <= 0.0f) + return 0u; + return static_cast(std::floor(x)); + } // Default / to_even if (std::isnan(x)) return 0u; @@ -133,6 +138,14 @@ static inline uint8_t RoundClip(float x, uint8_t max, rounding R, return static_cast(fi); } } + if (R == rounding::toward_zero) { + if (std::isnan(x) || x <= 0.0f) + return 0u; + uint32_t truncated = static_cast(std::floor(x)); + if (truncated > max) + truncated = max; + return static_cast(truncated); + } // default: round-to-nearest-even return RneClip(x, max); } @@ -256,7 +269,8 @@ static inline ToT ConvertFromFP8_CPU(uint8_t b, /// Ebits bits mantissa, Mbits bits exponent. template static inline uint8_t -ConvertToFP8_CPU(T h, rounding R = rounding::to_even) noexcept { +ConvertToFP8_CPU(T h, rounding R = rounding::to_even, + saturation S = saturation::finite) noexcept { // Specialized implementation for fp8_e8m0_x (Ebits=8, Mbits=0) if constexpr (Ebits == 8 && Mbits == 0) { // Format characteristics (finite-only, no zero, no infinity): @@ -288,17 +302,16 @@ ConvertToFP8_CPU(T h, rounding R = rounding::to_even) noexcept { if (std::isnan(x)) return NaNCode; - uint8_t sign = std::signbit(x) ? 0x80 : 0x00; float ax = std::fabs(x); - // Handle underflow (|x| < min_normal) and x == 0: encode smallest normal - // with sign. + // Handle underflow (|x| < min_normal) and x == 0: encode smallest normal. if (ax == 0.0f || ax < min_normal) - return sign; // exp field = 0 -> E = -127 + return 0x00; // exp field = 0 -> E = -127 - // Handle overflow (|x| >= max_normal * (anything beyond representable)): - if (ax >= max_normal) - return static_cast(sign | (MaxExpField)); // E = +127 + // Handle overflow (|x| > max_normal): clamp or return NaN depending on + // saturation. E8M0 has no sign bit and no infinity representation. + if (ax > max_normal) + return (S == saturation::finite) ? MaxExpField : NaNCode; // Determine exponent E such that 2^E <= ax < 2^{E+1} int e2; @@ -316,7 +329,7 @@ ConvertToFP8_CPU(T h, rounding R = rounding::to_even) noexcept { rounding effR = (R == rounding::upward) ? R : rounding::upward; if (effR == rounding::upward) { - if (sign == 0x00) { + if (x >= 0.0f) { if (!is_exact_power_of_two) { // Round up (increase exponent) if possible. if (E < Emax) @@ -337,7 +350,7 @@ ConvertToFP8_CPU(T h, rounding R = rounding::to_even) noexcept { uint8_t ecode = static_cast(E + Bias); // 0 .. 254 // ecode must never be 255 here. - return static_cast(sign | ecode); + return ecode; } constexpr int bias = (1 << (Ebits - 1)) - 1; @@ -374,16 +387,17 @@ ConvertToFP8_CPU(T h, rounding R = rounding::to_even) noexcept { const float min_sub = std::ldexp(1.0f, emin - Mbits); if (ax > max_finite) { - return static_cast( - sign | ((MaxExpForMaxNormal << Mbits) | MaxFracForMaxNormal)); - } - if (ax >= max_finite) { + if (S == saturation::none) { + if constexpr (Ebits == 5 && Mbits == 2) + return static_cast(sign | (ExpAllOnes << Mbits)); + return static_cast(sign | ((ExpAllOnes << Mbits) | MaxFracMask)); + } return static_cast( sign | ((MaxExpForMaxNormal << Mbits) | MaxFracForMaxNormal)); } - if (ax < min_sub) - return sign; // underflow + if (ax == 0.0f) + return sign; int e2; float m = std::frexp(ax, &e2); @@ -405,6 +419,11 @@ ConvertToFP8_CPU(T h, rounding R = rounding::to_even) noexcept { ++E; } if (E > emax) { + if (S == saturation::none) { + if constexpr (Ebits == 5 && Mbits == 2) + return static_cast(sign | (ExpAllOnes << Mbits)); + return static_cast(sign | ((ExpAllOnes << Mbits) | MaxFracMask)); + } auto ret = static_cast( sign | ((MaxExpForMaxNormal << Mbits) | MaxFracForMaxNormal)); return ret; @@ -470,18 +489,20 @@ template class fp8_e4m3_x { static constexpr uint8_t MaxFiniteCode = 0x7E; // 0.1111.110 (positive max normal) - template uint8_t ConvertToFP8(T h, rounding r) { + + template + uint8_t ConvertToFP8(T h, rounding r, saturation s = saturation::finite) { sycl::half hi = static_cast(h); #ifdef __SYCL_DEVICE_ONLY__ // TODO: optimize with vectorized builtin calls const uint8_t sign = std::signbit(hi) ? 0x80u : 0x00u; const float ax = std::fabs(hi); - if (ax > MaxNormal) - return static_cast(sign | MaxFiniteCode); - - if (ax < MinSubnormal) - return sign; + if (ax > MaxNormal) { + if (s == saturation::finite) + return static_cast(sign | MaxFiniteCode); + return static_cast(sign | NaNCode); + } uint8_t b = __builtin_spirv_ClampConvertFP16ToE4M3INTEL(h); if (r == rounding::to_even) @@ -491,7 +512,7 @@ template class fp8_e4m3_x { return round(r, b, yi, hi); #else - return ConvertToFP8_CPU<4, 3, sycl::half>(hi, r); + return ConvertToFP8_CPU<4, 3, sycl::half>(hi, r, s); #endif } @@ -503,23 +524,19 @@ template class fp8_e4m3_x { if (ax > MaxNormal) return static_cast(sign | MaxFiniteCode); - if (ax < MinSubnormal) - return sign; - uint8_t b = __builtin_spirv_ClampConvertBF16ToE4M3INTEL(h); if (r == rounding::to_even) return b; const half yi = __builtin_spirv_ConvertE4M3ToFP16INTEL(b); return round(r, b, yi, h); #else - return ConvertToFP8_CPU<4, 3, bfloat16>(h, r); + return ConvertToFP8_CPU<4, 3, bfloat16>(h, r, saturation::finite); #endif } template T ConvertFromFP8(uint8_t v) const { #ifdef __SYCL_DEVICE_ONLY__ - sycl::half hi = __builtin_spirv_ClampConvertE4M3ToFP16INTEL( - v); // sycl_fp8_ClampConvertE4M3ToFP16INTEL(v); + sycl::half hi = __builtin_spirv_ClampConvertE4M3ToFP16INTEL(v); return static_cast(hi); #else return ConvertFromFP8_CPU<4, 3, T>(v); @@ -914,19 +931,20 @@ template class fp8_e5m2_x { static constexpr size_t NFracBits = 2; static constexpr float MaxNormal = 114688.0f; // 1.75 * 2^16 static constexpr float MinSubnormal = 0.0000152587890625f; // 2^-16 - static constexpr uint8_t MaxFiniteCode = 0x7C; // 0.11111.00 + static constexpr uint8_t MaxFiniteCode = 0x7B; // 0.11110.11 + static constexpr uint8_t InfinityCode = 0x7C; // 0.11111.00 - uint8_t ConvertToFP8(sycl::half h, rounding r) { + uint8_t ConvertToFP8(sycl::half h, rounding r, saturation s) { #ifdef __SYCL_DEVICE_ONLY__ // TODO: optimize with vectorized builtin calls const uint8_t sign = std::signbit(h) ? 0x80u : 0x00u; const float ax = std::fabs(h); - if (ax > MaxNormal) - return static_cast(sign | MaxFiniteCode); - - if (ax < MinSubnormal) - return sign; + if (ax > MaxNormal || std::isinf(ax)) { + if (s == saturation::finite) + return static_cast(sign | MaxFiniteCode); + return static_cast(sign | InfinityCode); + } uint8_t b = __builtin_spirv_ClampConvertFP16ToE5M2INTEL(h); if (r == rounding::to_even) @@ -935,20 +953,20 @@ template class fp8_e5m2_x { return round(r, b, yi, h); #else - return ConvertToFP8_CPU<5, 2, sycl::half>(h, r); + return ConvertToFP8_CPU<5, 2, sycl::half>(h, r, s); #endif } - uint8_t ConvertBF16ToFP8(bfloat16 h, rounding r) { + uint8_t ConvertBF16ToFP8(bfloat16 h, rounding r, saturation s) { #ifdef __SYCL_DEVICE_ONLY__ const uint8_t sign = std::signbit(h) ? 0x80u : 0x00u; const bfloat16 ax = std::fabs(h); - if (ax > MaxNormal) - return static_cast(sign | MaxFiniteCode); - - if (ax < MinSubnormal) - return sign; + if (ax > MaxNormal || std::isinf(ax)) { + if (s == saturation::finite) + return static_cast(sign | MaxFiniteCode); + return static_cast(sign | InfinityCode); + } uint8_t b = __builtin_spirv_ClampConvertBF16ToE5M2INTEL(h); if (r == rounding::to_even) @@ -956,7 +974,7 @@ template class fp8_e5m2_x { const sycl::half yi = __builtin_spirv_ConvertE5M2ToFP16INTEL(b); return round(r, b, yi, h); #else - return ConvertToFP8_CPU<5, 2, bfloat16>(h, r); + return ConvertToFP8_CPU<5, 2, bfloat16>(h, r, s); #endif } @@ -983,9 +1001,9 @@ template class fp8_e5m2_x { if (r != rounding::to_even) throw std::invalid_argument( "fp8_e5m2_x: only rounding::to_even is supported"); - if (s != saturation::finite) + if (s != saturation::finite && s != saturation::none) throw std::invalid_argument( - "fp8_e5m2_x: only saturation::finite is supported"); + "fp8_e5m2_x: unsupported saturation mode"); } public: @@ -1013,12 +1031,13 @@ template class fp8_e5m2_x { if constexpr (((std::is_same_v, bfloat16>) && ...)) { const bfloat16 in[N] = {static_cast(v)...}; for (size_t i = 0; i < N; ++i) - vals[i] = ConvertBF16ToFP8(in[i], rounding::to_even); + vals[i] = ConvertBF16ToFP8(in[i], rounding::to_even, + saturation::finite); return; } const sycl::half in[N] = {v...}; for (size_t i = 0; i < N; ++i) - vals[i] = ConvertToFP8(in[i], rounding::to_even); + vals[i] = ConvertToFP8(in[i], rounding::to_even, saturation::finite); } // Construct from an array of half, bfloat16, float, double. @@ -1028,7 +1047,7 @@ template class fp8_e5m2_x { CheckConstraints(r, s); // TODO: optimize with vectorized builtin calls for (size_t i = 0; i < N; ++i) - vals[i] = ConvertToFP8(v[i], r); + vals[i] = ConvertToFP8(v[i], r, s); } explicit fp8_e5m2_x(bfloat16 const (&v)[N], rounding r = rounding::to_even, @@ -1036,19 +1055,19 @@ template class fp8_e5m2_x { CheckConstraints(r, s); // TODO: optimize with vectorized builtin calls for (size_t i = 0; i < N; ++i) - vals[i] = ConvertBF16ToFP8(v[i], r); + vals[i] = ConvertBF16ToFP8(v[i], r, s); } explicit fp8_e5m2_x(float const (&v)[N], rounding r = rounding::to_even, saturation s = saturation::finite) { CheckConstraints(r, s); for (size_t i = 0; i < N; ++i) - vals[i] = ConvertToFP8(v[i], r); + vals[i] = ConvertToFP8(v[i], r, s); } explicit fp8_e5m2_x(double const (&v)[N]) { for (size_t i = 0; i < N; ++i) - vals[i] = ConvertToFP8(v[i], rounding::to_even); + vals[i] = ConvertToFP8(v[i], rounding::to_even, saturation::finite); } // Construct from an marray of half, bfloat16, float, double. @@ -1058,7 +1077,7 @@ template class fp8_e5m2_x { saturation s = saturation::finite) { CheckConstraints(r, s); for (size_t i = 0; i < N; ++i) - vals[i] = ConvertToFP8(v[i], r); + vals[i] = ConvertToFP8(v[i], r, s); } explicit fp8_e5m2_x(const sycl::marray &v, @@ -1066,7 +1085,7 @@ template class fp8_e5m2_x { saturation s = saturation::finite) { CheckConstraints(r, s); for (size_t i = 0; i < N; ++i) - vals[i] = ConvertBF16ToFP8(v[i], r); + vals[i] = ConvertBF16ToFP8(v[i], r, s); } explicit fp8_e5m2_x(const sycl::marray &v, @@ -1074,12 +1093,12 @@ template class fp8_e5m2_x { saturation s = saturation::finite) { CheckConstraints(r, s); for (size_t i = 0; i < N; ++i) - vals[i] = ConvertToFP8(v[i], r); + vals[i] = ConvertToFP8(v[i], r, s); } explicit fp8_e5m2_x(const sycl::marray &v) { for (size_t i = 0; i < N; ++i) - vals[i] = ConvertToFP8(v[i], rounding::to_even); + vals[i] = ConvertToFP8(v[i], rounding::to_even, saturation::finite); } // Construct with stochastic rounding with user provided seed from an array of @@ -1215,47 +1234,47 @@ template class fp8_e5m2_x { explicit fp8_e5m2_x(short val) { static_assert(N == 1 && "fp8_e5m2_x: N must be 1 for short constructor"); - vals[0] = ConvertToFP8(val, rounding::to_even); + vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); } explicit fp8_e5m2_x(int val) { static_assert(N == 1 && "fp8_e5m2_x: N must be 1 for int constructor"); - vals[0] = ConvertToFP8(val, rounding::to_even); + vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); } explicit fp8_e5m2_x(long val) { static_assert(N == 1 && "fp8_e5m2_x: N must be 1 for long constructor"); - vals[0] = ConvertToFP8(val, rounding::to_even); + vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); } explicit fp8_e5m2_x(long long val) { static_assert(N == 1 && "fp8_e5m2_x: N must be 1 for long long constructor"); - vals[0] = ConvertToFP8(val, rounding::to_even); + vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); } explicit fp8_e5m2_x(unsigned short val) { static_assert(N == 1 && "fp8_e5m2_x: N must be 1 for unsigned short constructor"); - vals[0] = ConvertToFP8(val, rounding::to_even); + vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); } explicit fp8_e5m2_x(unsigned int val) { static_assert(N == 1 && "fp8_e5m2_x: N must be 1 for unsigned int constructor"); - vals[0] = ConvertToFP8(val, rounding::to_even); + vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); } explicit fp8_e5m2_x(unsigned long val) { static_assert(N == 1 && "fp8_e5m2_x: N must be 1 for unsigned long constructor"); - vals[0] = ConvertToFP8(val, rounding::to_even); + vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); } explicit fp8_e5m2_x(unsigned long long val) { static_assert(N == 1 && "fp8_e5m2_x: N must be 1 for unsigned long long constructor"); - vals[0] = ConvertToFP8(val, rounding::to_even); + vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); } // Assign (operator) from half, bfloat16, float, double, and integer types. @@ -1264,56 +1283,56 @@ template class fp8_e5m2_x { fp8_e5m2_x &operator=(sycl::half val) { static_assert(N == 1 && "fp8_e5m2_x: N must be 1 for half assignment operator"); - vals[0] = ConvertToFP8(val, rounding::to_even); + vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); return *this; } fp8_e5m2_x &operator=(bfloat16 val) { static_assert(N == 1 && "fp8_e5m2_x: N must be 1 for half bfloat16 operator"); - vals[0] = ConvertBF16ToFP8(val, rounding::to_even); + vals[0] = ConvertBF16ToFP8(val, rounding::to_even, saturation::finite); return *this; } fp8_e5m2_x &operator=(float val) { static_assert(N == 1 && "fp8_e5m2_x: N must be 1 for float assignment operator"); - vals[0] = ConvertToFP8(val, rounding::to_even); + vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); return *this; } fp8_e5m2_x &operator=(double val) { static_assert(N == 1 && "fp8_e5m2_x: N must be 1 for double assignment operator"); - vals[0] = ConvertToFP8(val, rounding::to_even); + vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); return *this; } fp8_e5m2_x &operator=(short val) { static_assert(N == 1 && "fp8_e5m2_x: N must be 1 for short assignment operator"); - vals[0] = ConvertToFP8(val, rounding::to_even); + vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); return *this; } fp8_e5m2_x &operator=(int val) { static_assert(N == 1 && "fp8_e5m2_x: N must be 1 for int assignment operator"); - vals[0] = ConvertToFP8(val, rounding::to_even); + vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); return *this; } fp8_e5m2_x &operator=(long val) { static_assert(N == 1 && "fp8_e5m2_x: N must be 1 for long assignment operator"); - vals[0] = ConvertToFP8(val, rounding::to_even); + vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); return *this; } fp8_e5m2_x &operator=(long long val) { static_assert(N == 1 && "fp8_e5m2_x: N must be 1 for long long assignment operator"); - vals[0] = ConvertToFP8(val, rounding::to_even); + vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); return *this; } @@ -1321,7 +1340,7 @@ template class fp8_e5m2_x { static_assert( N == 1 && "fp8_e5m2_x: N must be 1 for unsigned short assignment operator"); - vals[0] = ConvertToFP8(val, rounding::to_even); + vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); return *this; } @@ -1329,7 +1348,7 @@ template class fp8_e5m2_x { static_assert( N == 1 && "fp8_e5m2_x: N must be 1 for unsigned int assignment operator"); - vals[0] = ConvertToFP8(val, rounding::to_even); + vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); return *this; } @@ -1337,7 +1356,7 @@ template class fp8_e5m2_x { static_assert( N == 1 && "fp8_e5m2_x: N must be 1 for unsigned long assignment operator"); - vals[0] = ConvertToFP8(val, rounding::to_even); + vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); return *this; } @@ -1345,7 +1364,7 @@ template class fp8_e5m2_x { static_assert( N == 1 && "fp8_e5m2_x: N must be 1 for unsigned long long assignment operator"); - vals[0] = ConvertToFP8(val, rounding::to_even); + vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); return *this; } diff --git a/sycl/unittests/Extensions/fp8/fp8_e4m3.cpp b/sycl/unittests/Extensions/fp8/fp8_e4m3.cpp index 4aecd26ff1eb2..dee5a8ce13069 100644 --- a/sycl/unittests/Extensions/fp8/fp8_e4m3.cpp +++ b/sycl/unittests/Extensions/fp8/fp8_e4m3.cpp @@ -75,6 +75,26 @@ TEST(FP8E4M3Test, VariadicBoundaryEncodingsFloat) { EXPECT_EQ(c.vals[1], 0x80); // -0 -> 0b1_0000_000 } +TEST(FP8E4M3Test, FiniteOverflowClampsToMaxNormal) { + fp8_e4m3_x2 a(std::numeric_limits::infinity(), + -std::numeric_limits::infinity()); + + EXPECT_EQ(a.vals[0], 0x7E); // +max normal + EXPECT_EQ(a.vals[1], 0xFE); // -max normal +} + +TEST(FP8E4M3Test, FiniteUnderflowRoundsUsingToEven) { + constexpr float MinSubnormal = 0.001953125f; // 2^-9 + + fp8_e4m3_x2 tie(0.5f * MinSubnormal, -0.5f * MinSubnormal); + fp8_e4m3_x2 up(0.75f * MinSubnormal, -0.75f * MinSubnormal); + + EXPECT_EQ(tie.vals[0], 0x00); // tie -> even => +0 + EXPECT_EQ(tie.vals[1], 0x80); // tie -> even => -0 + EXPECT_EQ(up.vals[0], 0x01); // +min subnormal + EXPECT_EQ(up.vals[1], 0x81); // -min subnormal +} + TEST(FP8E4M3Test, VariadicNaNEncodingFloat) { // NaN is encoded as S.1111.111; sign is permitted. fp8_e4m3_x2 a(std::numeric_limits::quiet_NaN(), diff --git a/sycl/unittests/Extensions/fp8/fp8_e5m2.cpp b/sycl/unittests/Extensions/fp8/fp8_e5m2.cpp index e73e4f1a24624..99d98109217c7 100644 --- a/sycl/unittests/Extensions/fp8/fp8_e5m2.cpp +++ b/sycl/unittests/Extensions/fp8/fp8_e5m2.cpp @@ -75,6 +75,45 @@ TEST(FP8E5M2Test, VariadicConstructorBoundaryEncodingsFloat) { EXPECT_EQ(a2.vals[1], 0x80); // -0 -> 0b1_00000_00 } +TEST(FP8E5M2Test, FiniteOverflowClampsToMaxNormal) { + const float in[2] = {std::numeric_limits::infinity(), + -std::numeric_limits::infinity()}; + fp8_e5m2_x2 a(in, rounding::to_even, saturation::finite); + + EXPECT_EQ(a.vals[0], 0x7B); // +max normal + EXPECT_EQ(a.vals[1], 0xFB); // -max normal +} + +TEST(FP8E5M2Test, NoneOverflowProducesInfinity) { + const float in[2] = {std::numeric_limits::infinity(), + -std::numeric_limits::infinity()}; + fp8_e5m2_x2 a(in, rounding::to_even, saturation::none); + + EXPECT_EQ(a.vals[0], 0x7C); // +infinity + EXPECT_EQ(a.vals[1], 0xFC); // -infinity +} + +TEST(FP8E5M2Test, UnderflowRoundsSameForFiniteAndNone) { + constexpr float MinSubnormal = 0.0000152587890625f; // 2^-16 + const float tie[2] = {0.5f * MinSubnormal, -0.5f * MinSubnormal}; + const float up[2] = {0.75f * MinSubnormal, -0.75f * MinSubnormal}; + + fp8_e5m2_x2 tieFinite(tie, rounding::to_even, saturation::finite); + fp8_e5m2_x2 tieNone(tie, rounding::to_even, saturation::none); + fp8_e5m2_x2 upFinite(up, rounding::to_even, saturation::finite); + fp8_e5m2_x2 upNone(up, rounding::to_even, saturation::none); + + EXPECT_EQ(tieFinite.vals[0], 0x00); // tie -> even => +0 + EXPECT_EQ(tieFinite.vals[1], 0x80); // tie -> even => -0 + EXPECT_EQ(tieNone.vals[0], 0x00); + EXPECT_EQ(tieNone.vals[1], 0x80); + + EXPECT_EQ(upFinite.vals[0], 0x01); // +min subnormal + EXPECT_EQ(upFinite.vals[1], 0x81); // -min subnormal + EXPECT_EQ(upNone.vals[0], 0x01); + EXPECT_EQ(upNone.vals[1], 0x81); +} + TEST(FP8E5M2Test, VariadicConstructorNaNEncodingFloat) { fp8_e5m2_x2 a(std::numeric_limits::quiet_NaN(), -std::numeric_limits::quiet_NaN()); diff --git a/sycl/unittests/Extensions/fp8/fp8_e8m0.cpp b/sycl/unittests/Extensions/fp8/fp8_e8m0.cpp index 74576226d6173..cecd7af085494 100644 --- a/sycl/unittests/Extensions/fp8/fp8_e8m0.cpp +++ b/sycl/unittests/Extensions/fp8/fp8_e8m0.cpp @@ -58,6 +58,21 @@ TEST(FP8E8M0Test, VariadicConstructorBoundaryEncodings) { EXPECT_EQ(a.vals[1], 0xFF); // NaN } +TEST(FP8E8M0Test, FiniteOverflowClampsAndDropsSign) { + fp8_e8m0_x2 a(std::numeric_limits::infinity(), + -std::numeric_limits::infinity()); + + EXPECT_EQ(a.vals[0], 0xFE); // max normal + EXPECT_EQ(a.vals[1], 0xFE); // sign dropped +} + +TEST(FP8E8M0Test, FiniteUnderflowMapsToMinNormalAndDropsSign) { + fp8_e8m0_x2 a(std::ldexp(1.0f, -128), -std::ldexp(1.0f, -128)); + + EXPECT_EQ(a.vals[0], 0x00); // min normal + EXPECT_EQ(a.vals[1], 0x00); // sign dropped +} + TEST(FP8E8M0Test, CArrayConstructorFloatHostUpwardFinite) { const float in[2] = {1.0f, 1.1f}; const float in1[2] = {3.0f, 1000.0f}; From 28565bc7e54987e3f82cfd4ff80c6b74dfd46632 Mon Sep 17 00:00:00 2001 From: "Klochkov, Denis" Date: Tue, 24 Mar 2026 17:02:40 +0100 Subject: [PATCH 10/89] Revert "[SYCL][FP8] use saturation" This reverts commit a55e2754c2382f9e1e308b1741926f020349443b. --- .../oneapi/experimental/float_8bit/types.hpp | 171 ++++++++---------- sycl/unittests/Extensions/fp8/fp8_e4m3.cpp | 20 -- sycl/unittests/Extensions/fp8/fp8_e5m2.cpp | 39 ---- sycl/unittests/Extensions/fp8/fp8_e8m0.cpp | 15 -- 4 files changed, 76 insertions(+), 169 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp index d1ed476a3fbfa..e943af983807c 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp @@ -107,11 +107,6 @@ static inline uint8_t RoundClip(float x, uint8_t max, rounding R, return 1u; return 0u; } - if (R == rounding::toward_zero) { - if (std::isnan(x) || x <= 0.0f) - return 0u; - return static_cast(std::floor(x)); - } // Default / to_even if (std::isnan(x)) return 0u; @@ -138,14 +133,6 @@ static inline uint8_t RoundClip(float x, uint8_t max, rounding R, return static_cast(fi); } } - if (R == rounding::toward_zero) { - if (std::isnan(x) || x <= 0.0f) - return 0u; - uint32_t truncated = static_cast(std::floor(x)); - if (truncated > max) - truncated = max; - return static_cast(truncated); - } // default: round-to-nearest-even return RneClip(x, max); } @@ -269,8 +256,7 @@ static inline ToT ConvertFromFP8_CPU(uint8_t b, /// Ebits bits mantissa, Mbits bits exponent. template static inline uint8_t -ConvertToFP8_CPU(T h, rounding R = rounding::to_even, - saturation S = saturation::finite) noexcept { +ConvertToFP8_CPU(T h, rounding R = rounding::to_even) noexcept { // Specialized implementation for fp8_e8m0_x (Ebits=8, Mbits=0) if constexpr (Ebits == 8 && Mbits == 0) { // Format characteristics (finite-only, no zero, no infinity): @@ -302,16 +288,17 @@ ConvertToFP8_CPU(T h, rounding R = rounding::to_even, if (std::isnan(x)) return NaNCode; + uint8_t sign = std::signbit(x) ? 0x80 : 0x00; float ax = std::fabs(x); - // Handle underflow (|x| < min_normal) and x == 0: encode smallest normal. + // Handle underflow (|x| < min_normal) and x == 0: encode smallest normal + // with sign. if (ax == 0.0f || ax < min_normal) - return 0x00; // exp field = 0 -> E = -127 + return sign; // exp field = 0 -> E = -127 - // Handle overflow (|x| > max_normal): clamp or return NaN depending on - // saturation. E8M0 has no sign bit and no infinity representation. - if (ax > max_normal) - return (S == saturation::finite) ? MaxExpField : NaNCode; + // Handle overflow (|x| >= max_normal * (anything beyond representable)): + if (ax >= max_normal) + return static_cast(sign | (MaxExpField)); // E = +127 // Determine exponent E such that 2^E <= ax < 2^{E+1} int e2; @@ -329,7 +316,7 @@ ConvertToFP8_CPU(T h, rounding R = rounding::to_even, rounding effR = (R == rounding::upward) ? R : rounding::upward; if (effR == rounding::upward) { - if (x >= 0.0f) { + if (sign == 0x00) { if (!is_exact_power_of_two) { // Round up (increase exponent) if possible. if (E < Emax) @@ -350,7 +337,7 @@ ConvertToFP8_CPU(T h, rounding R = rounding::to_even, uint8_t ecode = static_cast(E + Bias); // 0 .. 254 // ecode must never be 255 here. - return ecode; + return static_cast(sign | ecode); } constexpr int bias = (1 << (Ebits - 1)) - 1; @@ -387,17 +374,16 @@ ConvertToFP8_CPU(T h, rounding R = rounding::to_even, const float min_sub = std::ldexp(1.0f, emin - Mbits); if (ax > max_finite) { - if (S == saturation::none) { - if constexpr (Ebits == 5 && Mbits == 2) - return static_cast(sign | (ExpAllOnes << Mbits)); - return static_cast(sign | ((ExpAllOnes << Mbits) | MaxFracMask)); - } + return static_cast( + sign | ((MaxExpForMaxNormal << Mbits) | MaxFracForMaxNormal)); + } + if (ax >= max_finite) { return static_cast( sign | ((MaxExpForMaxNormal << Mbits) | MaxFracForMaxNormal)); } - if (ax == 0.0f) - return sign; + if (ax < min_sub) + return sign; // underflow int e2; float m = std::frexp(ax, &e2); @@ -419,11 +405,6 @@ ConvertToFP8_CPU(T h, rounding R = rounding::to_even, ++E; } if (E > emax) { - if (S == saturation::none) { - if constexpr (Ebits == 5 && Mbits == 2) - return static_cast(sign | (ExpAllOnes << Mbits)); - return static_cast(sign | ((ExpAllOnes << Mbits) | MaxFracMask)); - } auto ret = static_cast( sign | ((MaxExpForMaxNormal << Mbits) | MaxFracForMaxNormal)); return ret; @@ -489,20 +470,18 @@ template class fp8_e4m3_x { static constexpr uint8_t MaxFiniteCode = 0x7E; // 0.1111.110 (positive max normal) - - template - uint8_t ConvertToFP8(T h, rounding r, saturation s = saturation::finite) { + template uint8_t ConvertToFP8(T h, rounding r) { sycl::half hi = static_cast(h); #ifdef __SYCL_DEVICE_ONLY__ // TODO: optimize with vectorized builtin calls const uint8_t sign = std::signbit(hi) ? 0x80u : 0x00u; const float ax = std::fabs(hi); - if (ax > MaxNormal) { - if (s == saturation::finite) - return static_cast(sign | MaxFiniteCode); - return static_cast(sign | NaNCode); - } + if (ax > MaxNormal) + return static_cast(sign | MaxFiniteCode); + + if (ax < MinSubnormal) + return sign; uint8_t b = __builtin_spirv_ClampConvertFP16ToE4M3INTEL(h); if (r == rounding::to_even) @@ -512,7 +491,7 @@ template class fp8_e4m3_x { return round(r, b, yi, hi); #else - return ConvertToFP8_CPU<4, 3, sycl::half>(hi, r, s); + return ConvertToFP8_CPU<4, 3, sycl::half>(hi, r); #endif } @@ -524,19 +503,23 @@ template class fp8_e4m3_x { if (ax > MaxNormal) return static_cast(sign | MaxFiniteCode); + if (ax < MinSubnormal) + return sign; + uint8_t b = __builtin_spirv_ClampConvertBF16ToE4M3INTEL(h); if (r == rounding::to_even) return b; const half yi = __builtin_spirv_ConvertE4M3ToFP16INTEL(b); return round(r, b, yi, h); #else - return ConvertToFP8_CPU<4, 3, bfloat16>(h, r, saturation::finite); + return ConvertToFP8_CPU<4, 3, bfloat16>(h, r); #endif } template T ConvertFromFP8(uint8_t v) const { #ifdef __SYCL_DEVICE_ONLY__ - sycl::half hi = __builtin_spirv_ClampConvertE4M3ToFP16INTEL(v); + sycl::half hi = __builtin_spirv_ClampConvertE4M3ToFP16INTEL( + v); // sycl_fp8_ClampConvertE4M3ToFP16INTEL(v); return static_cast(hi); #else return ConvertFromFP8_CPU<4, 3, T>(v); @@ -931,20 +914,19 @@ template class fp8_e5m2_x { static constexpr size_t NFracBits = 2; static constexpr float MaxNormal = 114688.0f; // 1.75 * 2^16 static constexpr float MinSubnormal = 0.0000152587890625f; // 2^-16 - static constexpr uint8_t MaxFiniteCode = 0x7B; // 0.11110.11 - static constexpr uint8_t InfinityCode = 0x7C; // 0.11111.00 + static constexpr uint8_t MaxFiniteCode = 0x7C; // 0.11111.00 - uint8_t ConvertToFP8(sycl::half h, rounding r, saturation s) { + uint8_t ConvertToFP8(sycl::half h, rounding r) { #ifdef __SYCL_DEVICE_ONLY__ // TODO: optimize with vectorized builtin calls const uint8_t sign = std::signbit(h) ? 0x80u : 0x00u; const float ax = std::fabs(h); - if (ax > MaxNormal || std::isinf(ax)) { - if (s == saturation::finite) - return static_cast(sign | MaxFiniteCode); - return static_cast(sign | InfinityCode); - } + if (ax > MaxNormal) + return static_cast(sign | MaxFiniteCode); + + if (ax < MinSubnormal) + return sign; uint8_t b = __builtin_spirv_ClampConvertFP16ToE5M2INTEL(h); if (r == rounding::to_even) @@ -953,20 +935,20 @@ template class fp8_e5m2_x { return round(r, b, yi, h); #else - return ConvertToFP8_CPU<5, 2, sycl::half>(h, r, s); + return ConvertToFP8_CPU<5, 2, sycl::half>(h, r); #endif } - uint8_t ConvertBF16ToFP8(bfloat16 h, rounding r, saturation s) { + uint8_t ConvertBF16ToFP8(bfloat16 h, rounding r) { #ifdef __SYCL_DEVICE_ONLY__ const uint8_t sign = std::signbit(h) ? 0x80u : 0x00u; const bfloat16 ax = std::fabs(h); - if (ax > MaxNormal || std::isinf(ax)) { - if (s == saturation::finite) - return static_cast(sign | MaxFiniteCode); - return static_cast(sign | InfinityCode); - } + if (ax > MaxNormal) + return static_cast(sign | MaxFiniteCode); + + if (ax < MinSubnormal) + return sign; uint8_t b = __builtin_spirv_ClampConvertBF16ToE5M2INTEL(h); if (r == rounding::to_even) @@ -974,7 +956,7 @@ template class fp8_e5m2_x { const sycl::half yi = __builtin_spirv_ConvertE5M2ToFP16INTEL(b); return round(r, b, yi, h); #else - return ConvertToFP8_CPU<5, 2, bfloat16>(h, r, s); + return ConvertToFP8_CPU<5, 2, bfloat16>(h, r); #endif } @@ -1001,9 +983,9 @@ template class fp8_e5m2_x { if (r != rounding::to_even) throw std::invalid_argument( "fp8_e5m2_x: only rounding::to_even is supported"); - if (s != saturation::finite && s != saturation::none) + if (s != saturation::finite) throw std::invalid_argument( - "fp8_e5m2_x: unsupported saturation mode"); + "fp8_e5m2_x: only saturation::finite is supported"); } public: @@ -1031,13 +1013,12 @@ template class fp8_e5m2_x { if constexpr (((std::is_same_v, bfloat16>) && ...)) { const bfloat16 in[N] = {static_cast(v)...}; for (size_t i = 0; i < N; ++i) - vals[i] = ConvertBF16ToFP8(in[i], rounding::to_even, - saturation::finite); + vals[i] = ConvertBF16ToFP8(in[i], rounding::to_even); return; } const sycl::half in[N] = {v...}; for (size_t i = 0; i < N; ++i) - vals[i] = ConvertToFP8(in[i], rounding::to_even, saturation::finite); + vals[i] = ConvertToFP8(in[i], rounding::to_even); } // Construct from an array of half, bfloat16, float, double. @@ -1047,7 +1028,7 @@ template class fp8_e5m2_x { CheckConstraints(r, s); // TODO: optimize with vectorized builtin calls for (size_t i = 0; i < N; ++i) - vals[i] = ConvertToFP8(v[i], r, s); + vals[i] = ConvertToFP8(v[i], r); } explicit fp8_e5m2_x(bfloat16 const (&v)[N], rounding r = rounding::to_even, @@ -1055,19 +1036,19 @@ template class fp8_e5m2_x { CheckConstraints(r, s); // TODO: optimize with vectorized builtin calls for (size_t i = 0; i < N; ++i) - vals[i] = ConvertBF16ToFP8(v[i], r, s); + vals[i] = ConvertBF16ToFP8(v[i], r); } explicit fp8_e5m2_x(float const (&v)[N], rounding r = rounding::to_even, saturation s = saturation::finite) { CheckConstraints(r, s); for (size_t i = 0; i < N; ++i) - vals[i] = ConvertToFP8(v[i], r, s); + vals[i] = ConvertToFP8(v[i], r); } explicit fp8_e5m2_x(double const (&v)[N]) { for (size_t i = 0; i < N; ++i) - vals[i] = ConvertToFP8(v[i], rounding::to_even, saturation::finite); + vals[i] = ConvertToFP8(v[i], rounding::to_even); } // Construct from an marray of half, bfloat16, float, double. @@ -1077,7 +1058,7 @@ template class fp8_e5m2_x { saturation s = saturation::finite) { CheckConstraints(r, s); for (size_t i = 0; i < N; ++i) - vals[i] = ConvertToFP8(v[i], r, s); + vals[i] = ConvertToFP8(v[i], r); } explicit fp8_e5m2_x(const sycl::marray &v, @@ -1085,7 +1066,7 @@ template class fp8_e5m2_x { saturation s = saturation::finite) { CheckConstraints(r, s); for (size_t i = 0; i < N; ++i) - vals[i] = ConvertBF16ToFP8(v[i], r, s); + vals[i] = ConvertBF16ToFP8(v[i], r); } explicit fp8_e5m2_x(const sycl::marray &v, @@ -1093,12 +1074,12 @@ template class fp8_e5m2_x { saturation s = saturation::finite) { CheckConstraints(r, s); for (size_t i = 0; i < N; ++i) - vals[i] = ConvertToFP8(v[i], r, s); + vals[i] = ConvertToFP8(v[i], r); } explicit fp8_e5m2_x(const sycl::marray &v) { for (size_t i = 0; i < N; ++i) - vals[i] = ConvertToFP8(v[i], rounding::to_even, saturation::finite); + vals[i] = ConvertToFP8(v[i], rounding::to_even); } // Construct with stochastic rounding with user provided seed from an array of @@ -1234,47 +1215,47 @@ template class fp8_e5m2_x { explicit fp8_e5m2_x(short val) { static_assert(N == 1 && "fp8_e5m2_x: N must be 1 for short constructor"); - vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); + vals[0] = ConvertToFP8(val, rounding::to_even); } explicit fp8_e5m2_x(int val) { static_assert(N == 1 && "fp8_e5m2_x: N must be 1 for int constructor"); - vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); + vals[0] = ConvertToFP8(val, rounding::to_even); } explicit fp8_e5m2_x(long val) { static_assert(N == 1 && "fp8_e5m2_x: N must be 1 for long constructor"); - vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); + vals[0] = ConvertToFP8(val, rounding::to_even); } explicit fp8_e5m2_x(long long val) { static_assert(N == 1 && "fp8_e5m2_x: N must be 1 for long long constructor"); - vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); + vals[0] = ConvertToFP8(val, rounding::to_even); } explicit fp8_e5m2_x(unsigned short val) { static_assert(N == 1 && "fp8_e5m2_x: N must be 1 for unsigned short constructor"); - vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); + vals[0] = ConvertToFP8(val, rounding::to_even); } explicit fp8_e5m2_x(unsigned int val) { static_assert(N == 1 && "fp8_e5m2_x: N must be 1 for unsigned int constructor"); - vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); + vals[0] = ConvertToFP8(val, rounding::to_even); } explicit fp8_e5m2_x(unsigned long val) { static_assert(N == 1 && "fp8_e5m2_x: N must be 1 for unsigned long constructor"); - vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); + vals[0] = ConvertToFP8(val, rounding::to_even); } explicit fp8_e5m2_x(unsigned long long val) { static_assert(N == 1 && "fp8_e5m2_x: N must be 1 for unsigned long long constructor"); - vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); + vals[0] = ConvertToFP8(val, rounding::to_even); } // Assign (operator) from half, bfloat16, float, double, and integer types. @@ -1283,56 +1264,56 @@ template class fp8_e5m2_x { fp8_e5m2_x &operator=(sycl::half val) { static_assert(N == 1 && "fp8_e5m2_x: N must be 1 for half assignment operator"); - vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); + vals[0] = ConvertToFP8(val, rounding::to_even); return *this; } fp8_e5m2_x &operator=(bfloat16 val) { static_assert(N == 1 && "fp8_e5m2_x: N must be 1 for half bfloat16 operator"); - vals[0] = ConvertBF16ToFP8(val, rounding::to_even, saturation::finite); + vals[0] = ConvertBF16ToFP8(val, rounding::to_even); return *this; } fp8_e5m2_x &operator=(float val) { static_assert(N == 1 && "fp8_e5m2_x: N must be 1 for float assignment operator"); - vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); + vals[0] = ConvertToFP8(val, rounding::to_even); return *this; } fp8_e5m2_x &operator=(double val) { static_assert(N == 1 && "fp8_e5m2_x: N must be 1 for double assignment operator"); - vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); + vals[0] = ConvertToFP8(val, rounding::to_even); return *this; } fp8_e5m2_x &operator=(short val) { static_assert(N == 1 && "fp8_e5m2_x: N must be 1 for short assignment operator"); - vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); + vals[0] = ConvertToFP8(val, rounding::to_even); return *this; } fp8_e5m2_x &operator=(int val) { static_assert(N == 1 && "fp8_e5m2_x: N must be 1 for int assignment operator"); - vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); + vals[0] = ConvertToFP8(val, rounding::to_even); return *this; } fp8_e5m2_x &operator=(long val) { static_assert(N == 1 && "fp8_e5m2_x: N must be 1 for long assignment operator"); - vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); + vals[0] = ConvertToFP8(val, rounding::to_even); return *this; } fp8_e5m2_x &operator=(long long val) { static_assert(N == 1 && "fp8_e5m2_x: N must be 1 for long long assignment operator"); - vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); + vals[0] = ConvertToFP8(val, rounding::to_even); return *this; } @@ -1340,7 +1321,7 @@ template class fp8_e5m2_x { static_assert( N == 1 && "fp8_e5m2_x: N must be 1 for unsigned short assignment operator"); - vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); + vals[0] = ConvertToFP8(val, rounding::to_even); return *this; } @@ -1348,7 +1329,7 @@ template class fp8_e5m2_x { static_assert( N == 1 && "fp8_e5m2_x: N must be 1 for unsigned int assignment operator"); - vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); + vals[0] = ConvertToFP8(val, rounding::to_even); return *this; } @@ -1356,7 +1337,7 @@ template class fp8_e5m2_x { static_assert( N == 1 && "fp8_e5m2_x: N must be 1 for unsigned long assignment operator"); - vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); + vals[0] = ConvertToFP8(val, rounding::to_even); return *this; } @@ -1364,7 +1345,7 @@ template class fp8_e5m2_x { static_assert( N == 1 && "fp8_e5m2_x: N must be 1 for unsigned long long assignment operator"); - vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); + vals[0] = ConvertToFP8(val, rounding::to_even); return *this; } diff --git a/sycl/unittests/Extensions/fp8/fp8_e4m3.cpp b/sycl/unittests/Extensions/fp8/fp8_e4m3.cpp index dee5a8ce13069..4aecd26ff1eb2 100644 --- a/sycl/unittests/Extensions/fp8/fp8_e4m3.cpp +++ b/sycl/unittests/Extensions/fp8/fp8_e4m3.cpp @@ -75,26 +75,6 @@ TEST(FP8E4M3Test, VariadicBoundaryEncodingsFloat) { EXPECT_EQ(c.vals[1], 0x80); // -0 -> 0b1_0000_000 } -TEST(FP8E4M3Test, FiniteOverflowClampsToMaxNormal) { - fp8_e4m3_x2 a(std::numeric_limits::infinity(), - -std::numeric_limits::infinity()); - - EXPECT_EQ(a.vals[0], 0x7E); // +max normal - EXPECT_EQ(a.vals[1], 0xFE); // -max normal -} - -TEST(FP8E4M3Test, FiniteUnderflowRoundsUsingToEven) { - constexpr float MinSubnormal = 0.001953125f; // 2^-9 - - fp8_e4m3_x2 tie(0.5f * MinSubnormal, -0.5f * MinSubnormal); - fp8_e4m3_x2 up(0.75f * MinSubnormal, -0.75f * MinSubnormal); - - EXPECT_EQ(tie.vals[0], 0x00); // tie -> even => +0 - EXPECT_EQ(tie.vals[1], 0x80); // tie -> even => -0 - EXPECT_EQ(up.vals[0], 0x01); // +min subnormal - EXPECT_EQ(up.vals[1], 0x81); // -min subnormal -} - TEST(FP8E4M3Test, VariadicNaNEncodingFloat) { // NaN is encoded as S.1111.111; sign is permitted. fp8_e4m3_x2 a(std::numeric_limits::quiet_NaN(), diff --git a/sycl/unittests/Extensions/fp8/fp8_e5m2.cpp b/sycl/unittests/Extensions/fp8/fp8_e5m2.cpp index 99d98109217c7..e73e4f1a24624 100644 --- a/sycl/unittests/Extensions/fp8/fp8_e5m2.cpp +++ b/sycl/unittests/Extensions/fp8/fp8_e5m2.cpp @@ -75,45 +75,6 @@ TEST(FP8E5M2Test, VariadicConstructorBoundaryEncodingsFloat) { EXPECT_EQ(a2.vals[1], 0x80); // -0 -> 0b1_00000_00 } -TEST(FP8E5M2Test, FiniteOverflowClampsToMaxNormal) { - const float in[2] = {std::numeric_limits::infinity(), - -std::numeric_limits::infinity()}; - fp8_e5m2_x2 a(in, rounding::to_even, saturation::finite); - - EXPECT_EQ(a.vals[0], 0x7B); // +max normal - EXPECT_EQ(a.vals[1], 0xFB); // -max normal -} - -TEST(FP8E5M2Test, NoneOverflowProducesInfinity) { - const float in[2] = {std::numeric_limits::infinity(), - -std::numeric_limits::infinity()}; - fp8_e5m2_x2 a(in, rounding::to_even, saturation::none); - - EXPECT_EQ(a.vals[0], 0x7C); // +infinity - EXPECT_EQ(a.vals[1], 0xFC); // -infinity -} - -TEST(FP8E5M2Test, UnderflowRoundsSameForFiniteAndNone) { - constexpr float MinSubnormal = 0.0000152587890625f; // 2^-16 - const float tie[2] = {0.5f * MinSubnormal, -0.5f * MinSubnormal}; - const float up[2] = {0.75f * MinSubnormal, -0.75f * MinSubnormal}; - - fp8_e5m2_x2 tieFinite(tie, rounding::to_even, saturation::finite); - fp8_e5m2_x2 tieNone(tie, rounding::to_even, saturation::none); - fp8_e5m2_x2 upFinite(up, rounding::to_even, saturation::finite); - fp8_e5m2_x2 upNone(up, rounding::to_even, saturation::none); - - EXPECT_EQ(tieFinite.vals[0], 0x00); // tie -> even => +0 - EXPECT_EQ(tieFinite.vals[1], 0x80); // tie -> even => -0 - EXPECT_EQ(tieNone.vals[0], 0x00); - EXPECT_EQ(tieNone.vals[1], 0x80); - - EXPECT_EQ(upFinite.vals[0], 0x01); // +min subnormal - EXPECT_EQ(upFinite.vals[1], 0x81); // -min subnormal - EXPECT_EQ(upNone.vals[0], 0x01); - EXPECT_EQ(upNone.vals[1], 0x81); -} - TEST(FP8E5M2Test, VariadicConstructorNaNEncodingFloat) { fp8_e5m2_x2 a(std::numeric_limits::quiet_NaN(), -std::numeric_limits::quiet_NaN()); diff --git a/sycl/unittests/Extensions/fp8/fp8_e8m0.cpp b/sycl/unittests/Extensions/fp8/fp8_e8m0.cpp index cecd7af085494..74576226d6173 100644 --- a/sycl/unittests/Extensions/fp8/fp8_e8m0.cpp +++ b/sycl/unittests/Extensions/fp8/fp8_e8m0.cpp @@ -58,21 +58,6 @@ TEST(FP8E8M0Test, VariadicConstructorBoundaryEncodings) { EXPECT_EQ(a.vals[1], 0xFF); // NaN } -TEST(FP8E8M0Test, FiniteOverflowClampsAndDropsSign) { - fp8_e8m0_x2 a(std::numeric_limits::infinity(), - -std::numeric_limits::infinity()); - - EXPECT_EQ(a.vals[0], 0xFE); // max normal - EXPECT_EQ(a.vals[1], 0xFE); // sign dropped -} - -TEST(FP8E8M0Test, FiniteUnderflowMapsToMinNormalAndDropsSign) { - fp8_e8m0_x2 a(std::ldexp(1.0f, -128), -std::ldexp(1.0f, -128)); - - EXPECT_EQ(a.vals[0], 0x00); // min normal - EXPECT_EQ(a.vals[1], 0x00); // sign dropped -} - TEST(FP8E8M0Test, CArrayConstructorFloatHostUpwardFinite) { const float in[2] = {1.0f, 1.1f}; const float in1[2] = {3.0f, 1000.0f}; From 9712f3456eb052873a964785ff189a97740920ad Mon Sep 17 00:00:00 2001 From: "Klochkov, Denis" Date: Tue, 24 Mar 2026 18:39:00 +0100 Subject: [PATCH 11/89] [SYCL] update list of builtins used in fp8 types --- .../oneapi/experimental/float_8bit/types.hpp | 264 ++++++++---------- .../Extensions/fp8/builtin_call_tests.cpp | 29 +- .../Extensions/fp8/builtin_mocks.hpp | 55 ++-- 3 files changed, 173 insertions(+), 175 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp index e943af983807c..abd9d82a7d4d1 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp @@ -19,54 +19,46 @@ #include #ifdef __SYCL_DEVICE_ONLY__ -// New FP8 builtins -extern __DPCPP_SYCL_EXTERNAL sycl::half -__builtin_spirv_ClampConvertE4M3ToFP16INTEL(uint8_t) noexcept; -extern __DPCPP_SYCL_EXTERNAL sycl::half -__builtin_spirv_ConvertE4M3ToFP16INTEL(uint8_t) noexcept; -extern __DPCPP_SYCL_EXTERNAL sycl::half -__builtin_spirv_ConvertE5M2ToFP16INTEL(uint8_t) noexcept; -extern __DPCPP_SYCL_EXTERNAL sycl::ext::oneapi::bfloat16 -__builtin_spirv_ConvertE4M3ToBF16INTEL(uint8_t) noexcept; -extern __DPCPP_SYCL_EXTERNAL sycl::ext::oneapi::bfloat16 -__builtin_spirv_ConvertE5M2ToBF16INTEL(uint8_t) noexcept; -extern __DPCPP_SYCL_EXTERNAL - uint8_t __builtin_spirv_ConvertFP16ToE4M3INTEL(sycl::half) noexcept; -extern __DPCPP_SYCL_EXTERNAL uint8_t __builtin_spirv_ConvertBF16ToE5M2INTEL( - sycl::ext::oneapi::bfloat16) noexcept; +// FP8 builtins extern __DPCPP_SYCL_EXTERNAL uint8_t __builtin_spirv_ClampConvertFP16ToE4M3INTEL(sycl::half) noexcept; +extern __DPCPP_SYCL_EXTERNAL + uint8_t __builtin_spirv_ConvertFP16ToE4M3EXT(sycl::half) noexcept; +extern __DPCPP_SYCL_EXTERNAL sycl::half +__builtin_spirv_ConvertE4M3ToFP16EXT(uint8_t) noexcept; extern __DPCPP_SYCL_EXTERNAL uint8_t __builtin_spirv_ClampConvertBF16ToE4M3INTEL( sycl::ext::oneapi::bfloat16) noexcept; +extern __DPCPP_SYCL_EXTERNAL uint8_t + __builtin_spirv_ConvertBF16ToE4M3EXT(sycl::ext::oneapi::bfloat16) noexcept; +extern __DPCPP_SYCL_EXTERNAL sycl::ext::oneapi::bfloat16 +__builtin_spirv_ConvertE4M3ToBF16EXT(uint8_t) noexcept; extern __DPCPP_SYCL_EXTERNAL uint8_t __builtin_spirv_ClampConvertFP16ToE5M2INTEL(sycl::half) noexcept; +extern __DPCPP_SYCL_EXTERNAL + uint8_t __builtin_spirv_ConvertFP16ToE5M2EXT(sycl::half) noexcept; +extern __DPCPP_SYCL_EXTERNAL sycl::half +__builtin_spirv_ConvertE5M2ToFP16EXT(uint8_t) noexcept; extern __DPCPP_SYCL_EXTERNAL uint8_t __builtin_spirv_ClampConvertBF16ToE5M2INTEL( sycl::ext::oneapi::bfloat16) noexcept; extern __DPCPP_SYCL_EXTERNAL uint8_t -__builtin_spirv_StochasticRoundFP16ToE5M2INTEL(sycl::half, uint32_t, - uint32_t *) noexcept; -extern __DPCPP_SYCL_EXTERNAL - uint8_t __builtin_spirv_StochasticRoundFP16ToE4M3INTEL(sycl::half) noexcept; -extern __DPCPP_SYCL_EXTERNAL uint8_t -__builtin_spirv_StochasticRoundBF16ToE5M2INTEL(sycl::ext::oneapi::bfloat16, - uint32_t, uint32_t *) noexcept; -extern __DPCPP_SYCL_EXTERNAL - uint8_t __builtin_spirv_StochasticRoundBF16ToE4M3INTEL( - sycl::ext::oneapi::bfloat16) noexcept; + __builtin_spirv_ConvertBF16ToE5M2EXT(sycl::ext::oneapi::bfloat16) noexcept; +extern __DPCPP_SYCL_EXTERNAL sycl::ext::oneapi::bfloat16 +__builtin_spirv_ConvertE5M2ToBF16EXT(uint8_t) noexcept; extern __DPCPP_SYCL_EXTERNAL uint8_t __builtin_spirv_ClampStochasticRoundFP16ToE5M2INTEL(sycl::half, uint32_t, uint32_t *) noexcept; extern __DPCPP_SYCL_EXTERNAL uint8_t - __builtin_spirv_ClampStochasticRoundFP16ToE4M3INTEL(sycl::half) noexcept; +__builtin_spirv_StochasticRoundFP16ToE5M2INTEL(sycl::half, uint32_t, + uint32_t *) noexcept; extern __DPCPP_SYCL_EXTERNAL uint8_t __builtin_spirv_ClampStochasticRoundBF16ToE5M2INTEL(sycl::ext::oneapi::bfloat16, uint32_t, uint32_t *) noexcept; -extern __DPCPP_SYCL_EXTERNAL - uint8_t __builtin_spirv_ClampStochasticRoundBF16ToE4M3INTEL( - sycl::ext::oneapi::bfloat16) noexcept; +extern __DPCPP_SYCL_EXTERNAL uint8_t +__builtin_spirv_StochasticRoundBF16ToE5M2INTEL(sycl::ext::oneapi::bfloat16, + uint32_t, uint32_t *) noexcept; #endif // __SYCL_DEVICE_ONLY__ namespace sycl { @@ -470,24 +462,19 @@ template class fp8_e4m3_x { static constexpr uint8_t MaxFiniteCode = 0x7E; // 0.1111.110 (positive max normal) - template uint8_t ConvertToFP8(T h, rounding r) { + template uint8_t ConvertToFP8(T h, rounding r, saturation s) { sycl::half hi = static_cast(h); #ifdef __SYCL_DEVICE_ONLY__ // TODO: optimize with vectorized builtin calls - const uint8_t sign = std::signbit(hi) ? 0x80u : 0x00u; - const float ax = std::fabs(hi); - - if (ax > MaxNormal) - return static_cast(sign | MaxFiniteCode); - - if (ax < MinSubnormal) - return sign; - - uint8_t b = __builtin_spirv_ClampConvertFP16ToE4M3INTEL(h); + uint8_t b = 0; + if (s == saturation::finite) + b = __builtin_spirv_ClampConvertFP16ToE4M3INTEL(h); + else + b = __builtin_spirv_ConvertFP16ToE4M3EXT(h); if (r == rounding::to_even) return b; - const sycl::half yi = __builtin_spirv_ConvertE4M3ToFP16INTEL(b); + const sycl::half yi = __builtin_spirv_ConvertE4M3ToFP16EXT(b); return round(r, b, yi, hi); #else @@ -495,21 +482,16 @@ template class fp8_e4m3_x { #endif } - uint8_t ConvertBF16ToFP8(bfloat16 h, rounding r) { + uint8_t ConvertBF16ToFP8(bfloat16 h, rounding r, saturation s) { #ifdef __SYCL_DEVICE_ONLY__ - const uint8_t sign = std::signbit(h) ? 0x80u : 0x00u; - const float ax = std::fabs(h); - - if (ax > MaxNormal) - return static_cast(sign | MaxFiniteCode); - - if (ax < MinSubnormal) - return sign; - - uint8_t b = __builtin_spirv_ClampConvertBF16ToE4M3INTEL(h); + uint8_t b = 0; + if (s == saturation::finite) + b = __builtin_spirv_ClampConvertBF16ToE4M3INTEL(h); + else + b = __builtin_spirv_ConvertBF16ToE4M3EXT(h); if (r == rounding::to_even) return b; - const half yi = __builtin_spirv_ConvertE4M3ToFP16INTEL(b); + const half yi = __builtin_spirv_ConvertE4M3ToFP16EXT(b); return round(r, b, yi, h); #else return ConvertToFP8_CPU<4, 3, bfloat16>(h, r); @@ -518,8 +500,7 @@ template class fp8_e4m3_x { template T ConvertFromFP8(uint8_t v) const { #ifdef __SYCL_DEVICE_ONLY__ - sycl::half hi = __builtin_spirv_ClampConvertE4M3ToFP16INTEL( - v); // sycl_fp8_ClampConvertE4M3ToFP16INTEL(v); + sycl::half hi = __builtin_spirv_ConvertE4M3ToFP16EXT(v); return static_cast(hi); #else return ConvertFromFP8_CPU<4, 3, T>(v); @@ -528,7 +509,7 @@ template class fp8_e4m3_x { bfloat16 ConvertBF16FromFP8(uint8_t v) const { #ifdef __SYCL_DEVICE_ONLY__ - return __builtin_spirv_ConvertE4M3ToBF16INTEL(v); + return __builtin_spirv_ConvertE4M3ToBF16EXT(v); #else return ConvertFromFP8_CPU<4, 3, bfloat16>(v); #endif @@ -566,40 +547,40 @@ template class fp8_e4m3_x { if constexpr (((std::is_same_v, bfloat16>) && ...)) { const bfloat16 in[N] = {static_cast(v)...}; for (size_t i = 0; i < N; ++i) - vals[i] = ConvertBF16ToFP8(in[i], rounding::to_even); + vals[i] = + ConvertBF16ToFP8(in[i], rounding::to_even, saturation::finite); return; } const sycl::half in[N] = {v...}; for (size_t i = 0; i < N; ++i) - vals[i] = ConvertToFP8(in[i], rounding::to_even); + vals[i] = ConvertToFP8(in[i], rounding::to_even, saturation::finite); } // Construct from an array of half, bfloat16, float, double. explicit fp8_e4m3_x(sycl::half const (&v)[N], rounding r = rounding::to_even) { CheckConstraints(r); - // TODO: optimize with vectorized builtin calls for (size_t i = 0; i < N; ++i) - vals[i] = ConvertToFP8(v[i], r); + vals[i] = ConvertToFP8(v[i], r, saturation::finite); } explicit fp8_e4m3_x(bfloat16 const (&v)[N], rounding r = rounding::to_even) { CheckConstraints(r); for (size_t i = 0; i < N; ++i) - vals[i] = ConvertBF16ToFP8(v[i], r); + vals[i] = ConvertBF16ToFP8(v[i], r, saturation::finite); } explicit fp8_e4m3_x(float const (&v)[N], rounding r = rounding::to_even) { CheckConstraints(r); for (size_t i = 0; i < N; ++i) - vals[i] = ConvertToFP8(v[i], r); + vals[i] = ConvertToFP8(v[i], r, saturation::finite); } explicit fp8_e4m3_x(double const (&v)[N]) { static_assert(N == 1 || N == 2, "fp8_e4m3_x: Template argument N must be 1 or 2"); for (size_t i = 0; i < N; ++i) - vals[i] = ConvertToFP8(v[i], rounding::to_even); + vals[i] = ConvertToFP8(v[i], rounding::to_even, saturation::finite); } // Construct from an marray of half, bfloat16, float, double. @@ -607,26 +588,26 @@ template class fp8_e4m3_x { rounding r = rounding::to_even) { CheckConstraints(r); for (size_t i = 0; i < N; ++i) - vals[i] = ConvertToFP8(v[i], r); + vals[i] = ConvertToFP8(v[i], r, saturation::finite); } explicit fp8_e4m3_x(const sycl::marray &v, rounding r = rounding::to_even) { CheckConstraints(r); for (size_t i = 0; i < N; ++i) - vals[i] = ConvertBF16ToFP8(v[i], r); + vals[i] = ConvertBF16ToFP8(v[i], r, saturation::finite); } explicit fp8_e4m3_x(const sycl::marray &v, rounding r = rounding::to_even) { CheckConstraints(r); for (size_t i = 0; i < N; ++i) - vals[i] = ConvertToFP8(v[i], r); + vals[i] = ConvertToFP8(v[i], r, saturation::finite); } explicit fp8_e4m3_x(const sycl::marray &v) { for (size_t i = 0; i < N; ++i) - vals[i] = ConvertToFP8(v[i], rounding::to_even); + vals[i] = ConvertToFP8(v[i], rounding::to_even, saturation::finite); } // Construct from integer types. @@ -634,47 +615,47 @@ template class fp8_e4m3_x { explicit fp8_e4m3_x(short val) { static_assert(N == 1 && "fp8_e4m3_x: N must be 1 for short constructor"); - vals[0] = ConvertToFP8(val, rounding::to_even); + vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); } explicit fp8_e4m3_x(int val) { static_assert(N == 1 && "fp8_e4m3_x: N must be 1 for int constructor"); - vals[0] = ConvertToFP8(val, rounding::to_even); + vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); } explicit fp8_e4m3_x(long val) { static_assert(N == 1 && "fp8_e4m3_x: N must be 1 for long constructor"); - vals[0] = ConvertToFP8(val, rounding::to_even); + vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); } explicit fp8_e4m3_x(long long val) { static_assert(N == 1 && "fp8_e4m3_x: N must be 1 for long long constructor"); - vals[0] = ConvertToFP8(val, rounding::to_even); + vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); } explicit fp8_e4m3_x(unsigned short val) { static_assert(N == 1 && "fp8_e4m3_x: N must be 1 for unsigned short constructor"); - vals[0] = ConvertToFP8(val, rounding::to_even); + vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); } explicit fp8_e4m3_x(unsigned int val) { static_assert(N == 1 && "fp8_e4m3_x: N must be 1 for unsigned int constructor"); - vals[0] = ConvertToFP8(val, rounding::to_even); + vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); } explicit fp8_e4m3_x(unsigned long val) { static_assert(N == 1 && "fp8_e4m3_x: N must be 1 for unsigned long constructor"); - vals[0] = ConvertToFP8(val, rounding::to_even); + vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); } explicit fp8_e4m3_x(unsigned long long val) { static_assert(N == 1 && "fp8_e4m3_x: N must be 1 for unsigned long long constructor"); - vals[0] = ConvertToFP8(val, rounding::to_even); + vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); } // Assign (operator) from half, bfloat16, float, double, and integer types. @@ -683,56 +664,56 @@ template class fp8_e4m3_x { fp8_e4m3_x &operator=(sycl::half val) { static_assert(N == 1 && "fp8_e4m3_x: N must be 1 for half assignment operator"); - vals[0] = ConvertToFP8(val, rounding::to_even); + vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); return *this; } fp8_e4m3_x &operator=(bfloat16 val) { static_assert(N == 1 && "fp8_e4m3_x: N must be 1 for bfloat16 assignment operator"); - vals[0] = ConvertToFP8(val, rounding::to_even); + vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); return *this; } fp8_e4m3_x &operator=(float val) { static_assert(N == 1 && "fp8_e4m3_x: N must be 1 for float assignment operator"); - vals[0] = ConvertToFP8(val, rounding::to_even); + vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); return *this; } fp8_e4m3_x &operator=(double val) { static_assert(N == 1 && "fp8_e4m3_x: N must be 1 for double assignment operator"); - vals[0] = ConvertBF16ToFP8(val, rounding::to_even); + vals[0] = ConvertBF16ToFP8(val, rounding::to_even, saturation::finite); return *this; } fp8_e4m3_x &operator=(short val) { static_assert(N == 1 && "fp8_e4m3_x: N must be 1 for short assignment operator"); - vals[0] = ConvertToFP8(val, rounding::to_even); + vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); return *this; } fp8_e4m3_x &operator=(int val) { static_assert(N == 1 && "fp8_e4m3_x: N must be 1 for int assignment operator"); - vals[0] = ConvertToFP8(val, rounding::to_even); + vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); return *this; } fp8_e4m3_x &operator=(long val) { static_assert(N == 1 && "fp8_e4m3_x: N must be 1 for long assignment operator"); - vals[0] = ConvertToFP8(val, rounding::to_even); + vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); return *this; } fp8_e4m3_x &operator=(long long val) { static_assert(N == 1 && "fp8_e4m3_x: N must be 1 for long long assignment operator"); - vals[0] = ConvertToFP8(val, rounding::to_even); + vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); return *this; } @@ -740,7 +721,7 @@ template class fp8_e4m3_x { static_assert( N == 1 && "fp8_e4m3_x: N must be 1 for unsigned short assignment operator"); - vals[0] = ConvertToFP8(val, rounding::to_even); + vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); return *this; } @@ -748,7 +729,7 @@ template class fp8_e4m3_x { static_assert( N == 1 && "fp8_e4m3_x: N must be 1 for unsigned int assignment operator"); - vals[0] = ConvertToFP8(val, rounding::to_even); + vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); return *this; } @@ -756,7 +737,7 @@ template class fp8_e4m3_x { static_assert( N == 1 && "fp8_e4m3_x: N must be 1 for unsigned long assignment operator"); - vals[0] = ConvertToFP8(val, rounding::to_even); + vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); return *this; } @@ -764,7 +745,7 @@ template class fp8_e4m3_x { static_assert( N == 1 && "fp8_e4m3_x: N must be 1 for unsigned long long assignment operator"); - vals[0] = ConvertToFP8(val, rounding::to_even); + vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); return *this; } @@ -874,7 +855,7 @@ template class fp8_e4m3_x { static_assert(N == 1, "fp8_e4m3_x: operator() requires size N=1"); #ifdef __SYCL_DEVICE_ONLY__ // detect +0 / -0 - sycl::half h = __builtin_spirv_ConvertE4M3ToFP16INTEL(vals[0]); + sycl::half h = __builtin_spirv_ConvertE4M3ToFP16EXT(vals[0]); return h != 0; #else // no need to convert, just check sign bit amd 0s @@ -916,22 +897,16 @@ template class fp8_e5m2_x { static constexpr float MinSubnormal = 0.0000152587890625f; // 2^-16 static constexpr uint8_t MaxFiniteCode = 0x7C; // 0.11111.00 - uint8_t ConvertToFP8(sycl::half h, rounding r) { + uint8_t ConvertToFP8(sycl::half h, rounding r, saturation s) { #ifdef __SYCL_DEVICE_ONLY__ - // TODO: optimize with vectorized builtin calls - const uint8_t sign = std::signbit(h) ? 0x80u : 0x00u; - const float ax = std::fabs(h); - - if (ax > MaxNormal) - return static_cast(sign | MaxFiniteCode); - - if (ax < MinSubnormal) - return sign; - - uint8_t b = __builtin_spirv_ClampConvertFP16ToE5M2INTEL(h); + uint8_t b = 0; + if (s == saturation::finite) + b = __builtin_spirv_ClampConvertFP16ToE5M2INTEL(h); + else + b = __builtin_spirv_ConvertFP16ToE5M2EXT(h); if (r == rounding::to_even) return b; - const sycl::half yi = __builtin_spirv_ConvertE5M2ToFP16INTEL(b); + const sycl::half yi = __builtin_spirv_ConvertE5M2ToFP16EXT(b); return round(r, b, yi, h); #else @@ -939,21 +914,16 @@ template class fp8_e5m2_x { #endif } - uint8_t ConvertBF16ToFP8(bfloat16 h, rounding r) { + uint8_t ConvertBF16ToFP8(bfloat16 h, rounding r, saturation s) { #ifdef __SYCL_DEVICE_ONLY__ - const uint8_t sign = std::signbit(h) ? 0x80u : 0x00u; - const bfloat16 ax = std::fabs(h); - - if (ax > MaxNormal) - return static_cast(sign | MaxFiniteCode); - - if (ax < MinSubnormal) - return sign; - - uint8_t b = __builtin_spirv_ClampConvertBF16ToE5M2INTEL(h); + uint8_t b = 0; + if (s == saturation::finite) + b = __builtin_spirv_ClampConvertBF16ToE5M2INTEL(h); + else + b = __builtin_spirv_ConvertBF16ToE5M2EXT(h); if (r == rounding::to_even) return b; - const sycl::half yi = __builtin_spirv_ConvertE5M2ToFP16INTEL(b); + const sycl::half yi = __builtin_spirv_ConvertE5M2ToBF16EXT(b); return round(r, b, yi, h); #else return ConvertToFP8_CPU<5, 2, bfloat16>(h, r); @@ -962,7 +932,7 @@ template class fp8_e5m2_x { template T ConvertFromFP8(uint8_t v) const { #ifdef __SYCL_DEVICE_ONLY__ - sycl::half hi = __builtin_spirv_ConvertE5M2ToFP16INTEL(v); + sycl::half hi = __builtin_spirv_ConvertE5M2ToFP16EXT(v); return static_cast(hi); #else return ConvertFromFP8_CPU<5, 2, T>(v); @@ -971,7 +941,7 @@ template class fp8_e5m2_x { bfloat16 ConvertFP16FromFP8(uint8_t v) const { #ifdef __SYCL_DEVICE_ONLY__ - return __builtin_spirv_ConvertE5M2ToBF16INTEL(v); + return __builtin_spirv_ConvertE5M2ToBF16EXT(v); #else return ConvertFromFP8_CPU<5, 2, bfloat16>(v); #endif @@ -983,9 +953,6 @@ template class fp8_e5m2_x { if (r != rounding::to_even) throw std::invalid_argument( "fp8_e5m2_x: only rounding::to_even is supported"); - if (s != saturation::finite) - throw std::invalid_argument( - "fp8_e5m2_x: only saturation::finite is supported"); } public: @@ -1013,12 +980,13 @@ template class fp8_e5m2_x { if constexpr (((std::is_same_v, bfloat16>) && ...)) { const bfloat16 in[N] = {static_cast(v)...}; for (size_t i = 0; i < N; ++i) - vals[i] = ConvertBF16ToFP8(in[i], rounding::to_even); + vals[i] = + ConvertBF16ToFP8(in[i], rounding::to_even, saturation::finite); return; } const sycl::half in[N] = {v...}; for (size_t i = 0; i < N; ++i) - vals[i] = ConvertToFP8(in[i], rounding::to_even); + vals[i] = ConvertToFP8(in[i], rounding::to_even, saturation::finite); } // Construct from an array of half, bfloat16, float, double. @@ -1028,7 +996,7 @@ template class fp8_e5m2_x { CheckConstraints(r, s); // TODO: optimize with vectorized builtin calls for (size_t i = 0; i < N; ++i) - vals[i] = ConvertToFP8(v[i], r); + vals[i] = ConvertToFP8(v[i], r, s); } explicit fp8_e5m2_x(bfloat16 const (&v)[N], rounding r = rounding::to_even, @@ -1036,19 +1004,19 @@ template class fp8_e5m2_x { CheckConstraints(r, s); // TODO: optimize with vectorized builtin calls for (size_t i = 0; i < N; ++i) - vals[i] = ConvertBF16ToFP8(v[i], r); + vals[i] = ConvertBF16ToFP8(v[i], r, s); } explicit fp8_e5m2_x(float const (&v)[N], rounding r = rounding::to_even, saturation s = saturation::finite) { CheckConstraints(r, s); for (size_t i = 0; i < N; ++i) - vals[i] = ConvertToFP8(v[i], r); + vals[i] = ConvertToFP8(v[i], r, s); } explicit fp8_e5m2_x(double const (&v)[N]) { for (size_t i = 0; i < N; ++i) - vals[i] = ConvertToFP8(v[i], rounding::to_even); + vals[i] = ConvertToFP8(v[i], rounding::to_even, saturation::finite); } // Construct from an marray of half, bfloat16, float, double. @@ -1058,7 +1026,7 @@ template class fp8_e5m2_x { saturation s = saturation::finite) { CheckConstraints(r, s); for (size_t i = 0; i < N; ++i) - vals[i] = ConvertToFP8(v[i], r); + vals[i] = ConvertToFP8(v[i], r, s); } explicit fp8_e5m2_x(const sycl::marray &v, @@ -1066,7 +1034,7 @@ template class fp8_e5m2_x { saturation s = saturation::finite) { CheckConstraints(r, s); for (size_t i = 0; i < N; ++i) - vals[i] = ConvertBF16ToFP8(v[i], r); + vals[i] = ConvertBF16ToFP8(v[i], r, s); } explicit fp8_e5m2_x(const sycl::marray &v, @@ -1074,12 +1042,12 @@ template class fp8_e5m2_x { saturation s = saturation::finite) { CheckConstraints(r, s); for (size_t i = 0; i < N; ++i) - vals[i] = ConvertToFP8(v[i], r); + vals[i] = ConvertToFP8(v[i], r, s); } explicit fp8_e5m2_x(const sycl::marray &v) { for (size_t i = 0; i < N; ++i) - vals[i] = ConvertToFP8(v[i], rounding::to_even); + vals[i] = ConvertToFP8(v[i], rounding::to_even, saturation::finite); } // Construct with stochastic rounding with user provided seed from an array of @@ -1215,47 +1183,47 @@ template class fp8_e5m2_x { explicit fp8_e5m2_x(short val) { static_assert(N == 1 && "fp8_e5m2_x: N must be 1 for short constructor"); - vals[0] = ConvertToFP8(val, rounding::to_even); + vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); } explicit fp8_e5m2_x(int val) { static_assert(N == 1 && "fp8_e5m2_x: N must be 1 for int constructor"); - vals[0] = ConvertToFP8(val, rounding::to_even); + vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); } explicit fp8_e5m2_x(long val) { static_assert(N == 1 && "fp8_e5m2_x: N must be 1 for long constructor"); - vals[0] = ConvertToFP8(val, rounding::to_even); + vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); } explicit fp8_e5m2_x(long long val) { static_assert(N == 1 && "fp8_e5m2_x: N must be 1 for long long constructor"); - vals[0] = ConvertToFP8(val, rounding::to_even); + vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); } explicit fp8_e5m2_x(unsigned short val) { static_assert(N == 1 && "fp8_e5m2_x: N must be 1 for unsigned short constructor"); - vals[0] = ConvertToFP8(val, rounding::to_even); + vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); } explicit fp8_e5m2_x(unsigned int val) { static_assert(N == 1 && "fp8_e5m2_x: N must be 1 for unsigned int constructor"); - vals[0] = ConvertToFP8(val, rounding::to_even); + vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); } explicit fp8_e5m2_x(unsigned long val) { static_assert(N == 1 && "fp8_e5m2_x: N must be 1 for unsigned long constructor"); - vals[0] = ConvertToFP8(val, rounding::to_even); + vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); } explicit fp8_e5m2_x(unsigned long long val) { static_assert(N == 1 && "fp8_e5m2_x: N must be 1 for unsigned long long constructor"); - vals[0] = ConvertToFP8(val, rounding::to_even); + vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); } // Assign (operator) from half, bfloat16, float, double, and integer types. @@ -1264,56 +1232,56 @@ template class fp8_e5m2_x { fp8_e5m2_x &operator=(sycl::half val) { static_assert(N == 1 && "fp8_e5m2_x: N must be 1 for half assignment operator"); - vals[0] = ConvertToFP8(val, rounding::to_even); + vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); return *this; } fp8_e5m2_x &operator=(bfloat16 val) { static_assert(N == 1 && "fp8_e5m2_x: N must be 1 for half bfloat16 operator"); - vals[0] = ConvertBF16ToFP8(val, rounding::to_even); + vals[0] = ConvertBF16ToFP8(val, rounding::to_even, saturation::finite); return *this; } fp8_e5m2_x &operator=(float val) { static_assert(N == 1 && "fp8_e5m2_x: N must be 1 for float assignment operator"); - vals[0] = ConvertToFP8(val, rounding::to_even); + vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); return *this; } fp8_e5m2_x &operator=(double val) { static_assert(N == 1 && "fp8_e5m2_x: N must be 1 for double assignment operator"); - vals[0] = ConvertToFP8(val, rounding::to_even); + vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); return *this; } fp8_e5m2_x &operator=(short val) { static_assert(N == 1 && "fp8_e5m2_x: N must be 1 for short assignment operator"); - vals[0] = ConvertToFP8(val, rounding::to_even); + vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); return *this; } fp8_e5m2_x &operator=(int val) { static_assert(N == 1 && "fp8_e5m2_x: N must be 1 for int assignment operator"); - vals[0] = ConvertToFP8(val, rounding::to_even); + vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); return *this; } fp8_e5m2_x &operator=(long val) { static_assert(N == 1 && "fp8_e5m2_x: N must be 1 for long assignment operator"); - vals[0] = ConvertToFP8(val, rounding::to_even); + vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); return *this; } fp8_e5m2_x &operator=(long long val) { static_assert(N == 1 && "fp8_e5m2_x: N must be 1 for long long assignment operator"); - vals[0] = ConvertToFP8(val, rounding::to_even); + vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); return *this; } @@ -1321,7 +1289,7 @@ template class fp8_e5m2_x { static_assert( N == 1 && "fp8_e5m2_x: N must be 1 for unsigned short assignment operator"); - vals[0] = ConvertToFP8(val, rounding::to_even); + vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); return *this; } @@ -1329,7 +1297,7 @@ template class fp8_e5m2_x { static_assert( N == 1 && "fp8_e5m2_x: N must be 1 for unsigned int assignment operator"); - vals[0] = ConvertToFP8(val, rounding::to_even); + vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); return *this; } @@ -1337,7 +1305,7 @@ template class fp8_e5m2_x { static_assert( N == 1 && "fp8_e5m2_x: N must be 1 for unsigned long assignment operator"); - vals[0] = ConvertToFP8(val, rounding::to_even); + vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); return *this; } @@ -1345,7 +1313,7 @@ template class fp8_e5m2_x { static_assert( N == 1 && "fp8_e5m2_x: N must be 1 for unsigned long long assignment operator"); - vals[0] = ConvertToFP8(val, rounding::to_even); + vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); return *this; } diff --git a/sycl/unittests/Extensions/fp8/builtin_call_tests.cpp b/sycl/unittests/Extensions/fp8/builtin_call_tests.cpp index c0551f3b5d746..99c0ef0236731 100644 --- a/sycl/unittests/Extensions/fp8/builtin_call_tests.cpp +++ b/sycl/unittests/Extensions/fp8/builtin_call_tests.cpp @@ -27,21 +27,21 @@ TEST_F(Fp8BuiltinCallTest, E4M3CastToHalfCallsClampConvertE4M3ToFP16) { fp8_e4m3 Value(static_cast(1.0f)); fp8_builtin_mock::resetCounters(); (void)static_cast(Value); - EXPECT_EQ(fp8_builtin_mock::getCounters().ClampConvertE4M3ToFP16INTEL, 1); + EXPECT_EQ(fp8_builtin_mock::getCounters().ConvertE4M3ToFP16EXT, 1); } TEST_F(Fp8BuiltinCallTest, E4M3CastToBf16CallsConvertE4M3ToBF16) { fp8_e4m3 Value(static_cast(1.0f)); fp8_builtin_mock::resetCounters(); (void)static_cast(Value); - EXPECT_EQ(fp8_builtin_mock::getCounters().ConvertE4M3ToBF16INTEL, 1); + EXPECT_EQ(fp8_builtin_mock::getCounters().ConvertE4M3ToBF16EXT, 1); } TEST_F(Fp8BuiltinCallTest, E4M3CastToBoolCallsConvertE4M3ToFP16) { fp8_e4m3 Value(static_cast(1.0f)); fp8_builtin_mock::resetCounters(); (void)static_cast(Value); - EXPECT_EQ(fp8_builtin_mock::getCounters().ConvertE4M3ToFP16INTEL, 1); + EXPECT_EQ(fp8_builtin_mock::getCounters().ConvertE4M3ToFP16EXT, 1); } TEST_F(Fp8BuiltinCallTest, E5M2CtorFromHalfCallsClampConvertFP16ToE5M2) { @@ -60,14 +60,33 @@ TEST_F(Fp8BuiltinCallTest, E5M2CastToHalfCallsConvertE5M2ToFP16) { fp8_e5m2 Value(static_cast(2.0f)); fp8_builtin_mock::resetCounters(); (void)static_cast(Value); - EXPECT_EQ(fp8_builtin_mock::getCounters().ConvertE5M2ToFP16INTEL, 1); + EXPECT_EQ(fp8_builtin_mock::getCounters().ConvertE5M2ToFP16EXT, 1); } TEST_F(Fp8BuiltinCallTest, E5M2CastToBf16CallsConvertE5M2ToBF16) { fp8_e5m2 Value(static_cast(2.0f)); fp8_builtin_mock::resetCounters(); (void)static_cast(Value); - EXPECT_EQ(fp8_builtin_mock::getCounters().ConvertE5M2ToBF16INTEL, 1); + EXPECT_EQ(fp8_builtin_mock::getCounters().ConvertE5M2ToBF16EXT, 1); +} + +TEST_F(Fp8BuiltinCallTest, E5M2CtorFromHalfWithNoSaturationCallsConvertFP16ToE5M2) { + sycl::half Input[1] = {static_cast(2.0f)}; + + fp8_e5m2 Value(Input, rounding::to_even, saturation::none); + (void)Value; + + EXPECT_EQ(fp8_builtin_mock::getCounters().ConvertFP16ToE5M2EXT, 1); +} + +TEST_F(Fp8BuiltinCallTest, E5M2CtorFromBf16WithNoSaturationCallsConvertBF16ToE5M2) { + sycl::ext::oneapi::bfloat16 Input[1] = { + static_cast(2.0f)}; + + fp8_e5m2 Value(Input, rounding::to_even, saturation::none); + (void)Value; + + EXPECT_EQ(fp8_builtin_mock::getCounters().ConvertBF16ToE5M2EXT, 1); } TEST_F(Fp8BuiltinCallTest, E5M2StochasticHalfFiniteCallsClampStochastic) { diff --git a/sycl/unittests/Extensions/fp8/builtin_mocks.hpp b/sycl/unittests/Extensions/fp8/builtin_mocks.hpp index 7a4aa8180b57b..6c4ebb9cc8189 100644 --- a/sycl/unittests/Extensions/fp8/builtin_mocks.hpp +++ b/sycl/unittests/Extensions/fp8/builtin_mocks.hpp @@ -15,15 +15,18 @@ namespace fp8_builtin_mock { struct Counters { - int ClampConvertE4M3ToFP16INTEL = 0; - int ConvertE4M3ToFP16INTEL = 0; - int ConvertE5M2ToFP16INTEL = 0; - int ConvertE4M3ToBF16INTEL = 0; - int ConvertE5M2ToBF16INTEL = 0; + int ConvertE4M3ToFP16EXT = 0; + int ConvertE5M2ToFP16EXT = 0; + int ConvertE4M3ToBF16EXT = 0; + int ConvertE5M2ToBF16EXT = 0; int ClampConvertFP16ToE4M3INTEL = 0; int ClampConvertBF16ToE4M3INTEL = 0; + int ConvertFP16ToE4M3EXT = 0; + int ConvertBF16ToE4M3EXT = 0; int ClampConvertFP16ToE5M2INTEL = 0; int ClampConvertBF16ToE5M2INTEL = 0; + int ConvertFP16ToE5M2EXT = 0; + int ConvertBF16ToE5M2EXT = 0; int StochasticRoundFP16ToE5M2INTEL = 0; int StochasticRoundBF16ToE5M2INTEL = 0; int ClampStochasticRoundFP16ToE5M2INTEL = 0; @@ -41,40 +44,37 @@ inline void resetCounters() { getCounters() = Counters{}; } // Builtin mocks (do not replace helpers.hpp; provide symbols here). inline sycl::half -__builtin_spirv_ClampConvertE4M3ToFP16INTEL(uint8_t) noexcept { - ++fp8_builtin_mock::getCounters().ClampConvertE4M3ToFP16INTEL; +__builtin_spirv_ConvertE4M3ToFP16EXT(uint8_t) noexcept { + ++fp8_builtin_mock::getCounters().ConvertE4M3ToFP16EXT; return static_cast(2.0f); } -inline sycl::half __builtin_spirv_ConvertE4M3ToFP16INTEL(uint8_t) noexcept { - ++fp8_builtin_mock::getCounters().ConvertE4M3ToFP16INTEL; - return static_cast(1.0f); -} - -inline sycl::half __builtin_spirv_ConvertE5M2ToFP16INTEL(uint8_t) noexcept { - ++fp8_builtin_mock::getCounters().ConvertE5M2ToFP16INTEL; +inline sycl::half __builtin_spirv_ConvertE5M2ToFP16EXT(uint8_t) noexcept { + ++fp8_builtin_mock::getCounters().ConvertE5M2ToFP16EXT; return static_cast(3.0f); } inline sycl::ext::oneapi::bfloat16 -__builtin_spirv_ConvertE4M3ToBF16INTEL(uint8_t) noexcept { - ++fp8_builtin_mock::getCounters().ConvertE4M3ToBF16INTEL; +__builtin_spirv_ConvertE4M3ToBF16EXT(uint8_t) noexcept { + ++fp8_builtin_mock::getCounters().ConvertE4M3ToBF16EXT; return static_cast(4.0f); } inline sycl::ext::oneapi::bfloat16 -__builtin_spirv_ConvertE5M2ToBF16INTEL(uint8_t) noexcept { - ++fp8_builtin_mock::getCounters().ConvertE5M2ToBF16INTEL; +__builtin_spirv_ConvertE5M2ToBF16EXT(uint8_t) noexcept { + ++fp8_builtin_mock::getCounters().ConvertE5M2ToBF16EXT; return static_cast(5.0f); } -inline uint8_t __builtin_spirv_ConvertFP16ToE4M3INTEL(sycl::half) noexcept { - return 0x00; +inline uint8_t __builtin_spirv_ConvertFP16ToE4M3EXT(sycl::half) noexcept { + ++fp8_builtin_mock::getCounters().ConvertFP16ToE4M3EXT; + return 0x01; } inline uint8_t -__builtin_spirv_ConvertBF16ToE5M2INTEL(sycl::ext::oneapi::bfloat16) noexcept { - return 0x00; +__builtin_spirv_ConvertBF16ToE4M3EXT(sycl::ext::oneapi::bfloat16) noexcept { + ++fp8_builtin_mock::getCounters().ConvertBF16ToE4M3EXT; + return 0x02; } inline uint8_t @@ -89,12 +89,23 @@ inline uint8_t __builtin_spirv_ClampConvertBF16ToE4M3INTEL( return 0x12; } +inline uint8_t __builtin_spirv_ConvertFP16ToE5M2EXT(sycl::half) noexcept { + ++fp8_builtin_mock::getCounters().ConvertFP16ToE5M2EXT; + return 0x03; +} + inline uint8_t __builtin_spirv_ClampConvertFP16ToE5M2INTEL(sycl::half) noexcept { ++fp8_builtin_mock::getCounters().ClampConvertFP16ToE5M2INTEL; return 0x21; } +inline uint8_t +__builtin_spirv_ConvertBF16ToE5M2EXT(sycl::ext::oneapi::bfloat16) noexcept { + ++fp8_builtin_mock::getCounters().ConvertBF16ToE5M2EXT; + return 0x04; +} + inline uint8_t __builtin_spirv_ClampConvertBF16ToE5M2INTEL( sycl::ext::oneapi::bfloat16) noexcept { ++fp8_builtin_mock::getCounters().ClampConvertBF16ToE5M2INTEL; From 046affd72a654e200164741ceaa125d98b0e8fe3 Mon Sep 17 00:00:00 2001 From: "Klochkov, Denis" Date: Tue, 24 Mar 2026 18:51:25 +0100 Subject: [PATCH 12/89] [SYCL] add more tests of builtin calls --- .../Extensions/fp8/builtin_call_tests.cpp | 135 +++++++++++++++++- .../Extensions/fp8/builtin_mocks.hpp | 3 +- 2 files changed, 134 insertions(+), 4 deletions(-) diff --git a/sycl/unittests/Extensions/fp8/builtin_call_tests.cpp b/sycl/unittests/Extensions/fp8/builtin_call_tests.cpp index 99c0ef0236731..461b0dfd0b644 100644 --- a/sycl/unittests/Extensions/fp8/builtin_call_tests.cpp +++ b/sycl/unittests/Extensions/fp8/builtin_call_tests.cpp @@ -23,6 +23,26 @@ TEST_F(Fp8BuiltinCallTest, E4M3CtorFromBf16CallsClampConvertBF16ToE4M3) { EXPECT_EQ(fp8_builtin_mock::getCounters().ClampConvertBF16ToE4M3INTEL, 1); } +TEST_F(Fp8BuiltinCallTest, E4M3ArrayCtorFromFloatCallsClampConvertFP16ToE4M3) { + float Input[2] = {1.25f, 2.5f}; + + fp8_e4m3_x2 Value(Input); + (void)Value; + + EXPECT_EQ(fp8_builtin_mock::getCounters().ClampConvertFP16ToE4M3INTEL, 2); +} + +TEST_F(Fp8BuiltinCallTest, E4M3MarrayCtorFromBf16CallsClampConvertBF16ToE4M3) { + sycl::marray Input = { + static_cast(1.25f), + static_cast(2.5f)}; + + fp8_e4m3_x2 Value(Input); + (void)Value; + + EXPECT_EQ(fp8_builtin_mock::getCounters().ClampConvertBF16ToE4M3INTEL, 2); +} + TEST_F(Fp8BuiltinCallTest, E4M3CastToHalfCallsClampConvertE4M3ToFP16) { fp8_e4m3 Value(static_cast(1.0f)); fp8_builtin_mock::resetCounters(); @@ -44,6 +64,37 @@ TEST_F(Fp8BuiltinCallTest, E4M3CastToBoolCallsConvertE4M3ToFP16) { EXPECT_EQ(fp8_builtin_mock::getCounters().ConvertE4M3ToFP16EXT, 1); } +TEST_F(Fp8BuiltinCallTest, E4M3MarrayCastToHalfCallsConvertE4M3ToFP16) { + sycl::half Input[2] = {static_cast(1.0f), + static_cast(2.0f)}; + fp8_e4m3_x2 Value(Input); + + fp8_builtin_mock::resetCounters(); + (void)static_cast>(Value); + + EXPECT_EQ(fp8_builtin_mock::getCounters().ConvertE4M3ToFP16EXT, 2); +} + +TEST_F(Fp8BuiltinCallTest, E4M3MarrayCastToBf16CallsConvertE4M3ToBF16) { + sycl::half Input[2] = {static_cast(1.0f), + static_cast(2.0f)}; + fp8_e4m3_x2 Value(Input); + + fp8_builtin_mock::resetCounters(); + (void)static_cast>(Value); + + EXPECT_EQ(fp8_builtin_mock::getCounters().ConvertE4M3ToBF16EXT, 2); +} + +TEST_F(Fp8BuiltinCallTest, E4M3AssignmentFromFloatCallsClampConvertFP16ToE4M3) { + fp8_e4m3 Value(static_cast(1.0f)); + + fp8_builtin_mock::resetCounters(); + Value = 1.25f; + + EXPECT_EQ(fp8_builtin_mock::getCounters().ClampConvertFP16ToE4M3INTEL, 1); +} + TEST_F(Fp8BuiltinCallTest, E5M2CtorFromHalfCallsClampConvertFP16ToE5M2) { fp8_e5m2 Value(static_cast(2.0f)); (void)Value; @@ -56,6 +107,27 @@ TEST_F(Fp8BuiltinCallTest, E5M2CtorFromBf16CallsClampConvertBF16ToE5M2) { EXPECT_EQ(fp8_builtin_mock::getCounters().ClampConvertBF16ToE5M2INTEL, 1); } +TEST_F(Fp8BuiltinCallTest, + E5M2ArrayCtorFromFloatFiniteCallsClampConvertFP16ToE5M2) { + float Input[2] = {2.0f, 4.0f}; + + fp8_e5m2_x2 Value(Input, rounding::to_even, saturation::finite); + (void)Value; + + EXPECT_EQ(fp8_builtin_mock::getCounters().ClampConvertFP16ToE5M2INTEL, 2); +} + +TEST_F(Fp8BuiltinCallTest, E5M2MarrayCtorFromBf16NoneCallsConvertBF16ToE5M2) { + sycl::marray Input = { + static_cast(2.0f), + static_cast(4.0f)}; + + fp8_e5m2_x2 Value(Input, rounding::to_even, saturation::none); + (void)Value; + + EXPECT_EQ(fp8_builtin_mock::getCounters().ConvertBF16ToE5M2EXT, 2); +} + TEST_F(Fp8BuiltinCallTest, E5M2CastToHalfCallsConvertE5M2ToFP16) { fp8_e5m2 Value(static_cast(2.0f)); fp8_builtin_mock::resetCounters(); @@ -70,7 +142,39 @@ TEST_F(Fp8BuiltinCallTest, E5M2CastToBf16CallsConvertE5M2ToBF16) { EXPECT_EQ(fp8_builtin_mock::getCounters().ConvertE5M2ToBF16EXT, 1); } -TEST_F(Fp8BuiltinCallTest, E5M2CtorFromHalfWithNoSaturationCallsConvertFP16ToE5M2) { +TEST_F(Fp8BuiltinCallTest, E5M2MarrayCastToHalfCallsConvertE5M2ToFP16) { + sycl::half Input[2] = {static_cast(2.0f), + static_cast(4.0f)}; + fp8_e5m2_x2 Value(Input); + + fp8_builtin_mock::resetCounters(); + (void)static_cast>(Value); + + EXPECT_EQ(fp8_builtin_mock::getCounters().ConvertE5M2ToFP16EXT, 2); +} + +TEST_F(Fp8BuiltinCallTest, E5M2MarrayCastToBf16CallsConvertE5M2ToBF16) { + sycl::half Input[2] = {static_cast(2.0f), + static_cast(4.0f)}; + fp8_e5m2_x2 Value(Input); + + fp8_builtin_mock::resetCounters(); + (void)static_cast>(Value); + + EXPECT_EQ(fp8_builtin_mock::getCounters().ConvertE5M2ToBF16EXT, 2); +} + +TEST_F(Fp8BuiltinCallTest, E5M2AssignmentFromFloatCallsClampConvertFP16ToE5M2) { + fp8_e5m2 Value(static_cast(2.0f)); + + fp8_builtin_mock::resetCounters(); + Value = 4.0f; + + EXPECT_EQ(fp8_builtin_mock::getCounters().ClampConvertFP16ToE5M2INTEL, 1); +} + +TEST_F(Fp8BuiltinCallTest, + E5M2CtorFromHalfWithNoSaturationCallsConvertFP16ToE5M2) { sycl::half Input[1] = {static_cast(2.0f)}; fp8_e5m2 Value(Input, rounding::to_even, saturation::none); @@ -79,7 +183,8 @@ TEST_F(Fp8BuiltinCallTest, E5M2CtorFromHalfWithNoSaturationCallsConvertFP16ToE5M EXPECT_EQ(fp8_builtin_mock::getCounters().ConvertFP16ToE5M2EXT, 1); } -TEST_F(Fp8BuiltinCallTest, E5M2CtorFromBf16WithNoSaturationCallsConvertBF16ToE5M2) { +TEST_F(Fp8BuiltinCallTest, + E5M2CtorFromBf16WithNoSaturationCallsConvertBF16ToE5M2) { sycl::ext::oneapi::bfloat16 Input[1] = { static_cast(2.0f)}; @@ -139,4 +244,30 @@ TEST_F(Fp8BuiltinCallTest, E5M2StochasticBf16NoneCallsNonClampStochastic) { EXPECT_EQ(fp8_builtin_mock::getCounters().StochasticRoundBF16ToE5M2INTEL, 1); } +TEST_F(Fp8BuiltinCallTest, E5M2StochasticFloatFiniteCallsClampStochastic) { + float Input[2] = {3.0f, 4.0f}; + uint32_t SeedValue = 50; + stochastic_seed Seed(&SeedValue); + + fp8_e5m2_x2 Value(Input, Seed, saturation::finite); + (void)Value; + + EXPECT_EQ(fp8_builtin_mock::getCounters().ClampStochasticRoundFP16ToE5M2INTEL, + 2); + EXPECT_EQ(SeedValue, 52u); +} + +TEST_F(Fp8BuiltinCallTest, + E5M2StochasticMarrayFloatNoneCallsNonClampStochastic) { + sycl::marray Input = {3.0f, 4.0f}; + uint32_t SeedValue = 60; + stochastic_seed Seed(&SeedValue); + + fp8_e5m2_x2 Value(Input, Seed, saturation::none); + (void)Value; + + EXPECT_EQ(fp8_builtin_mock::getCounters().StochasticRoundFP16ToE5M2INTEL, 2); + EXPECT_EQ(SeedValue, 62u); +} + } // namespace diff --git a/sycl/unittests/Extensions/fp8/builtin_mocks.hpp b/sycl/unittests/Extensions/fp8/builtin_mocks.hpp index 6c4ebb9cc8189..ac89f27cfe614 100644 --- a/sycl/unittests/Extensions/fp8/builtin_mocks.hpp +++ b/sycl/unittests/Extensions/fp8/builtin_mocks.hpp @@ -43,8 +43,7 @@ inline void resetCounters() { getCounters() = Counters{}; } } // namespace fp8_builtin_mock // Builtin mocks (do not replace helpers.hpp; provide symbols here). -inline sycl::half -__builtin_spirv_ConvertE4M3ToFP16EXT(uint8_t) noexcept { +inline sycl::half __builtin_spirv_ConvertE4M3ToFP16EXT(uint8_t) noexcept { ++fp8_builtin_mock::getCounters().ConvertE4M3ToFP16EXT; return static_cast(2.0f); } From a697eb8f7867af24c05974df23ef680f224d4a2f Mon Sep 17 00:00:00 2001 From: "Klochkov, Denis" Date: Wed, 1 Apr 2026 11:50:52 +0200 Subject: [PATCH 13/89] [SYCL] fix PR issues --- .../oneapi/experimental/float_8bit/types.hpp | 20 ++++++++----------- 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp index abd9d82a7d4d1..a822f94a33e1d 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp @@ -245,7 +245,7 @@ static inline ToT ConvertFromFP8_CPU(uint8_t b, /// \param h The input value to be converted. /// \param R The rounding mode to be used during conversion. /// \return uint8_t The converted 8-bit floating point value, MSB is sign bit, -/// Ebits bits mantissa, Mbits bits exponent. +/// Ebits bits exponent, Mbits bits mantissa. template static inline uint8_t ConvertToFP8_CPU(T h, rounding R = rounding::to_even) noexcept { @@ -369,10 +369,6 @@ ConvertToFP8_CPU(T h, rounding R = rounding::to_even) noexcept { return static_cast( sign | ((MaxExpForMaxNormal << Mbits) | MaxFracForMaxNormal)); } - if (ax >= max_finite) { - return static_cast( - sign | ((MaxExpForMaxNormal << Mbits) | MaxFracForMaxNormal)); - } if (ax < min_sub) return sign; // underflow @@ -671,7 +667,7 @@ template class fp8_e4m3_x { fp8_e4m3_x &operator=(bfloat16 val) { static_assert(N == 1 && "fp8_e4m3_x: N must be 1 for bfloat16 assignment operator"); - vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); + vals[0] = ConvertBF16ToFP8(val, rounding::to_even, saturation::finite); return *this; } @@ -685,7 +681,7 @@ template class fp8_e4m3_x { fp8_e4m3_x &operator=(double val) { static_assert(N == 1 && "fp8_e4m3_x: N must be 1 for double assignment operator"); - vals[0] = ConvertBF16ToFP8(val, rounding::to_even, saturation::finite); + vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); return *this; } @@ -858,7 +854,7 @@ template class fp8_e4m3_x { sycl::half h = __builtin_spirv_ConvertE4M3ToFP16EXT(vals[0]); return h != 0; #else - // no need to convert, just check sign bit amd 0s + // no need to convert, just check sign bit and 0s return vals[0] != 0 && vals[0] != 0x80; #endif } @@ -923,7 +919,7 @@ template class fp8_e5m2_x { b = __builtin_spirv_ConvertBF16ToE5M2EXT(h); if (r == rounding::to_even) return b; - const sycl::half yi = __builtin_spirv_ConvertE5M2ToBF16EXT(b); + const bfloat16 yi = __builtin_spirv_ConvertE5M2ToBF16EXT(b); return round(r, b, yi, h); #else return ConvertToFP8_CPU<5, 2, bfloat16>(h, r); @@ -939,7 +935,7 @@ template class fp8_e5m2_x { #endif } - bfloat16 ConvertFP16FromFP8(uint8_t v) const { + bfloat16 ConvertBF16FromFP8(uint8_t v) const { #ifdef __SYCL_DEVICE_ONLY__ return __builtin_spirv_ConvertE5M2ToBF16EXT(v); #else @@ -1329,7 +1325,7 @@ template class fp8_e5m2_x { explicit operator bfloat16() const { static_assert(N == 1 && "fp8_e5m2_x: N must be 1 for bfloat16 conversion operator"); - return ConvertFP16FromFP8(vals[0]); + return ConvertBF16FromFP8(vals[0]); } explicit operator float() const { @@ -1437,7 +1433,7 @@ template class fp8_e5m2_x { explicit operator sycl::marray() const { sycl::marray out; for (size_t i = 0; i < N; ++i) - out[i] = ConvertFP16FromFP8(vals[i]); + out[i] = ConvertBF16FromFP8(vals[i]); return out; } explicit operator sycl::marray() const { From 1f0808a566c4209a97b55ed13682102c5b7dbf59 Mon Sep 17 00:00:00 2001 From: "Klochkov, Denis" Date: Wed, 1 Apr 2026 13:50:26 +0200 Subject: [PATCH 14/89] [SYCL] do not use extra check for e8m0 --- .../oneapi/experimental/float_8bit/types.hpp | 19 ++++------ sycl/unittests/Extensions/fp8/fp8_e8m0.cpp | 38 +++++++++++++++++++ 2 files changed, 46 insertions(+), 11 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp index a822f94a33e1d..852764262630b 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp @@ -94,13 +94,14 @@ static inline uint8_t RoundClip(float x, uint8_t max, rounding R, if (max == 0) { // No fraction bits (E8M0 path) if (R == rounding::upward) { - // Any positive residual causes a carry; NaN / non-positive → 0 - if (!std::isnan(x) && x > 0.0f) + // For sign-preserving formats, roundTowardPositive increments only for + // positive values with a non-zero residual. Negative values stay at the + // lower-magnitude encoding. + if (!std::isnan(x) && sign_bit == 0u && x > 0.0f) return 1u; return 0u; } - // Default / to_even - if (std::isnan(x)) + if (R == rounding::toward_zero || std::isnan(x)) return 0u; if (x > 0.5f) return 1u; @@ -305,20 +306,16 @@ ConvertToFP8_CPU(T h, rounding R = rounding::to_even) noexcept { // Exact power-of-two: m == 0.5 (since frexp gives m in [0.5,1)) bool is_exact_power_of_two = (m == 0.5f); - rounding effR = (R == rounding::upward) ? R : rounding::upward; + //rounding effR = (R == rounding::upward) ? R : rounding::upward; - if (effR == rounding::upward) { + if (R == rounding::upward) { if (sign == 0x00) { - if (!is_exact_power_of_two) { // Round up (increase exponent) if possible. if (E < Emax) ++E; else E = Emax; - } - } else { - // Negative: leave E as-is (toward +inf reduces magnitude). - } + } } // Clamp exponent just in case. diff --git a/sycl/unittests/Extensions/fp8/fp8_e8m0.cpp b/sycl/unittests/Extensions/fp8/fp8_e8m0.cpp index 74576226d6173..fdb6a09018edf 100644 --- a/sycl/unittests/Extensions/fp8/fp8_e8m0.cpp +++ b/sycl/unittests/Extensions/fp8/fp8_e8m0.cpp @@ -12,6 +12,16 @@ code thus unit tests check only API using namespace sycl::ext::oneapi::experimental; +namespace { + +bool checkCode(float Input, rounding Mode, uint8_t Expected) { + const float Values[1] = {Input}; + const fp8_e8m0 Encoded(Values, Mode); + return Encoded.vals[0] == Expected; +} + +} // namespace + TEST(FP8E8M0Test, VariadicConstructorFloat) { fp8_e8m0_x2 a(1.0f, 2.0f); fp8_e8m0_x2 a1(1.1f, 0.0f); @@ -72,6 +82,34 @@ TEST(FP8E8M0Test, CArrayConstructorFloatHostUpwardFinite) { EXPECT_EQ(a1.vals[1], 0x89); // upward to 2^10 = 1024 } +TEST(FP8E8M0Test, CArrayConstructorFloatRoundingModes) { + EXPECT_TRUE(checkCode(3.0f, rounding::upward, 0x81)); + EXPECT_TRUE(checkCode(3.0f, rounding::toward_zero, 0x80)); + + // E8M0 drops sign per the extension specification, so negative inputs are + // rounded using their magnitude. + EXPECT_TRUE(checkCode(-3.0f, rounding::upward, 0x81)); + EXPECT_TRUE(checkCode(-3.0f, rounding::toward_zero, 0x80)); + EXPECT_TRUE(checkCode(-1.5f, rounding::upward, 0x80)); + EXPECT_TRUE(checkCode(-1.5f, rounding::toward_zero, 0x7F)); + EXPECT_TRUE(checkCode(-0.5f, rounding::upward, 0x7E)); + EXPECT_TRUE(checkCode(-0.5f, rounding::toward_zero, 0x7E)); + + EXPECT_TRUE(checkCode(1.0f, rounding::upward, 0x7F)); + EXPECT_TRUE(checkCode(0.5f, rounding::upward, 0x7E)); + EXPECT_TRUE(checkCode(0.5f, rounding::toward_zero, 0x7E)); + EXPECT_TRUE(checkCode(0.0f, rounding::toward_zero, 0x00)); + EXPECT_TRUE(checkCode(std::numeric_limits::quiet_NaN(), + rounding::upward, 0xFF)); +} + +TEST(FP8E8M0Test, RoundClipZeroFractionNegativeAndTieCases) { + EXPECT_EQ(RoundClip(0.25f, 0, rounding::upward, 0u), 1u); + EXPECT_EQ(RoundClip(0.25f, 0, rounding::upward, 1u), 0u); + EXPECT_EQ(RoundClip(0.5f, 0, rounding::to_even, 0u), 0u); + EXPECT_EQ(RoundClip(0.75f, 0, rounding::to_even, 0u), 1u); +} + TEST(FP8E8M0Test, CArrayConstructorHalfHostUpwardFinite) { const sycl::half in[2] = {sycl::half(1.0f), sycl::half(1.1f)}; const sycl::half in1[2] = {sycl::half(3.0f), sycl::half(0.0f)}; From ef5f6700bd0e765edaf0751af92df2abb422e25b Mon Sep 17 00:00:00 2001 From: "Klochkov, Denis" Date: Wed, 1 Apr 2026 14:44:42 +0200 Subject: [PATCH 15/89] [SYCL] fix formatting --- .../ext/oneapi/experimental/float_8bit/types.hpp | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp index 852764262630b..a14d782baca84 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp @@ -306,16 +306,16 @@ ConvertToFP8_CPU(T h, rounding R = rounding::to_even) noexcept { // Exact power-of-two: m == 0.5 (since frexp gives m in [0.5,1)) bool is_exact_power_of_two = (m == 0.5f); - //rounding effR = (R == rounding::upward) ? R : rounding::upward; + // rounding effR = (R == rounding::upward) ? R : rounding::upward; if (R == rounding::upward) { if (sign == 0x00) { - // Round up (increase exponent) if possible. - if (E < Emax) - ++E; - else - E = Emax; - } + // Round up (increase exponent) if possible. + if (E < Emax) + ++E; + else + E = Emax; + } } // Clamp exponent just in case. From 8d9cc9f832c7e084760d27207b9a3c53939e7141 Mon Sep 17 00:00:00 2001 From: "Klochkov, Denis" Date: Wed, 1 Apr 2026 15:57:07 +0200 Subject: [PATCH 16/89] [SYCL] remove unused variable --- .../include/sycl/ext/oneapi/experimental/float_8bit/types.hpp | 4 ---- 1 file changed, 4 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp index a14d782baca84..e8ba3fa8ff736 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp @@ -303,10 +303,6 @@ ConvertToFP8_CPU(T h, rounding R = rounding::to_even) noexcept { // power (E+1) if within range. // - For negative numbers: rounding toward +inf moves value toward zero, so // keep current E. - // Exact power-of-two: m == 0.5 (since frexp gives m in [0.5,1)) - bool is_exact_power_of_two = (m == 0.5f); - - // rounding effR = (R == rounding::upward) ? R : rounding::upward; if (R == rounding::upward) { if (sign == 0x00) { From e5fd6c402d0da667a65a5e6f0f126b2f70efc49d Mon Sep 17 00:00:00 2001 From: "Klochkov, Denis" Date: Fri, 3 Apr 2026 15:23:54 +0200 Subject: [PATCH 17/89] [SYCL] remove unused variable --- sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp | 1 - 1 file changed, 1 deletion(-) diff --git a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp index e8ba3fa8ff736..02956f223c4cf 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp @@ -295,7 +295,6 @@ ConvertToFP8_CPU(T h, rounding R = rounding::to_even) noexcept { // Determine exponent E such that 2^E <= ax < 2^{E+1} int e2; - float m = std::frexp(ax, &e2); // ax = m * 2^{e2}, m in [0.5,1) int E = e2 - 1; // Now 2^E <= ax < 2^{E+1} // Upward rounding semantics: From ae426a944aeba0226b0d17afe5a36aa57c365f4d Mon Sep 17 00:00:00 2001 From: "Klochkov, Denis" Date: Fri, 3 Apr 2026 15:41:45 +0200 Subject: [PATCH 18/89] [SYCL] fix formatting --- sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp index 02956f223c4cf..1a43c5b38654a 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp @@ -295,7 +295,7 @@ ConvertToFP8_CPU(T h, rounding R = rounding::to_even) noexcept { // Determine exponent E such that 2^E <= ax < 2^{E+1} int e2; - int E = e2 - 1; // Now 2^E <= ax < 2^{E+1} + int E = e2 - 1; // Upward rounding semantics: // - For positive numbers: if not exact power-of-two, round up to next From 5b6da23d44bff357fd74f55c1f9ab19d8f2801be Mon Sep 17 00:00:00 2001 From: "Klochkov, Denis" Date: Tue, 7 Apr 2026 16:10:57 +0200 Subject: [PATCH 19/89] [SYCL] do not construct fp8 with mixture of parameters in pack --- .../oneapi/experimental/float_8bit/types.hpp | 276 +++++++++--------- sycl/unittests/Extensions/fp8/fp8_e4m3.cpp | 5 + sycl/unittests/Extensions/fp8/fp8_e5m2.cpp | 6 + sycl/unittests/Extensions/fp8/fp8_e8m0.cpp | 14 +- 4 files changed, 159 insertions(+), 142 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp index 1a43c5b38654a..1a63e91594618 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp @@ -78,6 +78,7 @@ struct stochastic_seed { uint32_t *const pseed; }; +namespace detail { static inline uint8_t RneClip(float x, uint8_t max) noexcept { float f = std::floor(x); float frac = x - f; @@ -441,6 +442,99 @@ uint8_t round(rounding r, uint8_t b, sycl::half yi, T vi) { return b; } +static inline uint8_t ConvertToE8M0_CPU(float x, rounding R, + saturation S) noexcept { + // E8M0: unsigned 8-bit exponent code, bias 127. + // Code 0xFF reserved for NaN. No Inf, no subnormals, no signed zero. + constexpr int Bias = 127; + constexpr int Emin = -127; + constexpr int Emax = 127; + constexpr uint8_t NaNCode = 0xFF; + constexpr uint8_t MaxFiniteCode = 0xFE; + + if (std::isnan(x)) + return NaNCode; + + // No sign bit: negative inputs are treated as their magnitude. + float ax = std::fabs(x); + + // Infinity handling: depends on saturation. + if (std::isinf(ax)) + return (S == saturation::finite) ? MaxFiniteCode : NaNCode; + + // Zero and underflow: map to min normal (code 0). + // Min normal = 2^-127. + const float min_normal = std::ldexp(1.0f, Emin); + if (ax == 0.0f || ax < min_normal) + return 0x00; + + // Overflow and "too large": clamp or NaN depending on saturation. + const float max_normal = std::ldexp(1.0f, Emax); // 2^127 + if (ax >= max_normal) + return (S == saturation::finite) ? MaxFiniteCode : NaNCode; + + // Determine E such that 2^E <= ax < 2^(E+1). + int e2 = 0; + float m = std::frexp(ax, &e2); // ax = m * 2^e2, m in [0.5, 1) + int E = e2 - 1; + + // With no mantissa, representables are exact powers of two. + // Choose between 2^E and 2^(E+1) based on rounding mode. + const bool is_exact_power_of_two = (m == 0.5f); + + switch (R) { + case rounding::upward: + // toward +inf; with no sign, this is "ceil in magnitude". + if (!is_exact_power_of_two && E < Emax) + ++E; + break; + case rounding::toward_zero: + // toward -inf / toward 0: both pick the lower power for non-exact. + break; + case rounding::to_even: + default: { + if (!is_exact_power_of_two) { + // Nearest of {2^E, 2^(E+1)} w/ ties-to-even (even exponent on tie). + float lo = std::ldexp(1.0f, E); + float hi = std::ldexp(1.0f, E + 1); + float dlo = ax - lo; + float dhi = hi - ax; + if (dhi < dlo) { + if (E < Emax) + ++E; + } else if (dhi == dlo) { + // tie -> even exponent + if ((E & 1) != 0 && E < Emax) + ++E; + } + } + break; + } + } + + if (E < Emin) + E = Emin; + if (E > Emax) + E = Emax; + + uint8_t code = static_cast(E + Bias); // 0..254 + return code; +} + +template +static inline ToT ConvertFromE8M0_CPU(uint8_t code) noexcept { + constexpr int Bias = 127; + if (code == 0xFF) { + float qn = std::numeric_limits::quiet_NaN(); + return static_cast(qn); + } + int E = static_cast(code) - Bias; // includes code==0 -> -127 + float v = std::ldexp(1.0f, E); + return ConvertFloatToTarget(v, rounding::to_even); +} + +} // namespace detail + template class fp8_e4m3_x { static constexpr size_t NExpBits = 4; static constexpr size_t NFracBits = 3; @@ -463,10 +557,10 @@ template class fp8_e4m3_x { return b; const sycl::half yi = __builtin_spirv_ConvertE4M3ToFP16EXT(b); - return round(r, b, yi, hi); + return detail::round(r, b, yi, hi); #else - return ConvertToFP8_CPU<4, 3, sycl::half>(hi, r); + return detail::ConvertToFP8_CPU<4, 3, sycl::half>(hi, r); #endif } @@ -480,9 +574,9 @@ template class fp8_e4m3_x { if (r == rounding::to_even) return b; const half yi = __builtin_spirv_ConvertE4M3ToFP16EXT(b); - return round(r, b, yi, h); + return detail::round(r, b, yi, h); #else - return ConvertToFP8_CPU<4, 3, bfloat16>(h, r); + return detail::ConvertToFP8_CPU<4, 3, bfloat16>(h, r); #endif } @@ -491,7 +585,7 @@ template class fp8_e4m3_x { sycl::half hi = __builtin_spirv_ConvertE4M3ToFP16EXT(v); return static_cast(hi); #else - return ConvertFromFP8_CPU<4, 3, T>(v); + return detail::ConvertFromFP8_CPU<4, 3, T>(v); #endif } @@ -499,7 +593,7 @@ template class fp8_e4m3_x { #ifdef __SYCL_DEVICE_ONLY__ return __builtin_spirv_ConvertE4M3ToBF16EXT(v); #else - return ConvertFromFP8_CPU<4, 3, bfloat16>(v); + return detail::ConvertFromFP8_CPU<4, 3, bfloat16>(v); #endif } @@ -524,11 +618,10 @@ template class fp8_e4m3_x { template , half> || - std::is_same_v, bfloat16> || - std::is_same_v, float> || - std::is_same_v, double>) && - ...))>> + (((std::is_same_v, half>) && ...) || + ((std::is_same_v, bfloat16>) && ...) || + ((std::is_same_v, float>) && ...) || + ((std::is_same_v, double>) && ...))>> explicit fp8_e4m3_x(Types... v) { static_assert(N == 1 || N == 2, "fp8_e4m3_x: Template argument N must be 1 or 2"); @@ -895,10 +988,10 @@ template class fp8_e5m2_x { if (r == rounding::to_even) return b; const sycl::half yi = __builtin_spirv_ConvertE5M2ToFP16EXT(b); - return round(r, b, yi, h); + return detail::round(r, b, yi, h); #else - return ConvertToFP8_CPU<5, 2, sycl::half>(h, r); + return detail::ConvertToFP8_CPU<5, 2, sycl::half>(h, r); #endif } @@ -912,9 +1005,9 @@ template class fp8_e5m2_x { if (r == rounding::to_even) return b; const bfloat16 yi = __builtin_spirv_ConvertE5M2ToBF16EXT(b); - return round(r, b, yi, h); + return detail::round(r, b, yi, h); #else - return ConvertToFP8_CPU<5, 2, bfloat16>(h, r); + return detail::ConvertToFP8_CPU<5, 2, bfloat16>(h, r); #endif } @@ -923,7 +1016,7 @@ template class fp8_e5m2_x { sycl::half hi = __builtin_spirv_ConvertE5M2ToFP16EXT(v); return static_cast(hi); #else - return ConvertFromFP8_CPU<5, 2, T>(v); + return detail::ConvertFromFP8_CPU<5, 2, T>(v); #endif } @@ -931,7 +1024,7 @@ template class fp8_e5m2_x { #ifdef __SYCL_DEVICE_ONLY__ return __builtin_spirv_ConvertE5M2ToBF16EXT(v); #else - return ConvertFromFP8_CPU<5, 2, bfloat16>(v); + return detail::ConvertFromFP8_CPU<5, 2, bfloat16>(v); #endif } @@ -957,11 +1050,10 @@ template class fp8_e5m2_x { template , half> || - std::is_same_v, bfloat16> || - std::is_same_v, float> || - std::is_same_v, double>) && - ...))>> + (((std::is_same_v, half>) && ...) || + ((std::is_same_v, bfloat16>) && ...) || + ((std::is_same_v, float>) && ...) || + ((std::is_same_v, double>) && ...))>> explicit fp8_e5m2_x(Types... v) { static_assert(N == 1 || N == 2, "fp8_e5m2_x: Template argument N must be 1 or 2"); @@ -1440,97 +1532,6 @@ template class fp8_e5m2_x { uint8_t vals[N]; }; -static inline uint8_t ConvertToE8M0_CPU(float x, rounding R, - saturation S) noexcept { - // E8M0: unsigned 8-bit exponent code, bias 127. - // Code 0xFF reserved for NaN. No Inf, no subnormals, no signed zero. - constexpr int Bias = 127; - constexpr int Emin = -127; - constexpr int Emax = 127; - constexpr uint8_t NaNCode = 0xFF; - constexpr uint8_t MaxFiniteCode = 0xFE; - - if (std::isnan(x)) - return NaNCode; - - // No sign bit: negative inputs are treated as their magnitude. - float ax = std::fabs(x); - - // Infinity handling: depends on saturation. - if (std::isinf(ax)) - return (S == saturation::finite) ? MaxFiniteCode : NaNCode; - - // Zero and underflow: map to min normal (code 0). - // Min normal = 2^-127. - const float min_normal = std::ldexp(1.0f, Emin); - if (ax == 0.0f || ax < min_normal) - return 0x00; - - // Overflow and "too large": clamp or NaN depending on saturation. - const float max_normal = std::ldexp(1.0f, Emax); // 2^127 - if (ax >= max_normal) - return (S == saturation::finite) ? MaxFiniteCode : NaNCode; - - // Determine E such that 2^E <= ax < 2^(E+1). - int e2 = 0; - float m = std::frexp(ax, &e2); // ax = m * 2^e2, m in [0.5, 1) - int E = e2 - 1; - - // With no mantissa, representables are exact powers of two. - // Choose between 2^E and 2^(E+1) based on rounding mode. - const bool is_exact_power_of_two = (m == 0.5f); - - switch (R) { - case rounding::upward: - // toward +inf; with no sign, this is "ceil in magnitude". - if (!is_exact_power_of_two && E < Emax) - ++E; - break; - case rounding::toward_zero: - // toward -inf / toward 0: both pick the lower power for non-exact. - break; - case rounding::to_even: - default: { - if (!is_exact_power_of_two) { - // Nearest of {2^E, 2^(E+1)} w/ ties-to-even (even exponent on tie). - float lo = std::ldexp(1.0f, E); - float hi = std::ldexp(1.0f, E + 1); - float dlo = ax - lo; - float dhi = hi - ax; - if (dhi < dlo) { - if (E < Emax) - ++E; - } else if (dhi == dlo) { - // tie -> even exponent - if ((E & 1) != 0 && E < Emax) - ++E; - } - } - break; - } - } - - if (E < Emin) - E = Emin; - if (E > Emax) - E = Emax; - - uint8_t code = static_cast(E + Bias); // 0..254 - return code; -} - -template -static inline ToT ConvertFromE8M0_CPU(uint8_t code) noexcept { - constexpr int Bias = 127; - if (code == 0xFF) { - float qn = std::numeric_limits::quiet_NaN(); - return static_cast(qn); - } - int E = static_cast(code) - Bias; // includes code==0 -> -127 - float v = std::ldexp(1.0f, E); - return ConvertFloatToTarget(v, rounding::to_even); -} - template class fp8_e8m0_x { void CheckConstraints(rounding r) const { static_assert(N == 1 || N == 2, @@ -1549,11 +1550,10 @@ template class fp8_e8m0_x { template , half> || - std::is_same_v, bfloat16> || - std::is_same_v, float> || - std::is_same_v, double>) && - ...))>> + (((std::is_same_v, half>) && ...) || + ((std::is_same_v, bfloat16>) && ...) || + ((std::is_same_v, float>) && ...) || + ((std::is_same_v, double>) && ...))>> explicit fp8_e8m0_x(Types... v) { #ifdef __SYCL_DEVICE_ONLY__ static_assert(N == 1 || N == 2, @@ -1562,7 +1562,7 @@ template class fp8_e8m0_x { using InT = std::common_type_t...>; const InT in[N] = {v...}; for (size_t i = 0; i < N; ++i) - vals[i] = ConvertToE8M0_CPU(static_cast(in[i]), rounding::upward, + vals[i] = detail::ConvertToE8M0_CPU(static_cast(in[i]), rounding::upward, saturation::finite); } @@ -1570,27 +1570,27 @@ template class fp8_e8m0_x { CheckConstraints(r); for (size_t i = 0; i < N; ++i) vals[i] = - ConvertToE8M0_CPU(static_cast(in[i]), r, saturation::finite); + detail::ConvertToE8M0_CPU(static_cast(in[i]), r, saturation::finite); } explicit fp8_e8m0_x(bfloat16 const (&in)[N], rounding r = rounding::upward) { CheckConstraints(r); for (size_t i = 0; i < N; ++i) vals[i] = - ConvertToE8M0_CPU(static_cast(in[i]), r, saturation::finite); + detail::ConvertToE8M0_CPU(static_cast(in[i]), r, saturation::finite); } explicit fp8_e8m0_x(float const (&in)[N], rounding r = rounding::upward) { CheckConstraints(r); for (size_t i = 0; i < N; ++i) - vals[i] = ConvertToE8M0_CPU(in[i], r, saturation::finite); + vals[i] = detail::ConvertToE8M0_CPU(in[i], r, saturation::finite); } explicit fp8_e8m0_x(double const (&in)[N]) { static_assert(N == 1 || N == 2, "fp8_e8m0_x: Template argument N must be 1 or 2 on device"); for (size_t i = 0; i < N; ++i) - vals[i] = ConvertToE8M0_CPU(static_cast(in[i]), rounding::upward, + vals[i] = detail::ConvertToE8M0_CPU(static_cast(in[i]), rounding::upward, saturation::finite); } @@ -1599,7 +1599,7 @@ template class fp8_e8m0_x { CheckConstraints(r); for (size_t i = 0; i < N; ++i) vals[i] = - ConvertToE8M0_CPU(static_cast(in[i]), r, saturation::finite); + detail::ConvertToE8M0_CPU(static_cast(in[i]), r, saturation::finite); } explicit fp8_e8m0_x(const marray &in, @@ -1607,21 +1607,21 @@ template class fp8_e8m0_x { CheckConstraints(r); for (size_t i = 0; i < N; ++i) vals[i] = - ConvertToE8M0_CPU(static_cast(in[i]), r, saturation::finite); + detail::ConvertToE8M0_CPU(static_cast(in[i]), r, saturation::finite); } explicit fp8_e8m0_x(const marray &in, rounding r = rounding::upward) { CheckConstraints(r); for (size_t i = 0; i < N; ++i) - vals[i] = ConvertToE8M0_CPU(in[i], r, saturation::finite); + vals[i] = detail::ConvertToE8M0_CPU(in[i], r, saturation::finite); } explicit fp8_e8m0_x(const marray &in) { static_assert(N == 1 || N == 2, "fp8_e8m0_x: Template argument N must be 1 or 2 on device"); for (size_t i = 0; i < N; ++i) - vals[i] = ConvertToE8M0_CPU(static_cast(in[i]), rounding::upward, + vals[i] = detail::ConvertToE8M0_CPU(static_cast(in[i]), rounding::upward, saturation::finite); } @@ -1630,7 +1630,7 @@ template class fp8_e8m0_x { explicit fp8_e8m0_x(short val) { static_assert(N == 1 && "fp8_e8m0_x: N must be 1 for short constructor"); - vals[0] = ConvertToE8M0_CPU(static_cast(val), rounding::upward, + vals[0] = detail::ConvertToE8M0_CPU(static_cast(val), rounding::upward, saturation::finite); } explicit fp8_e8m0_x(int val) : fp8_e8m0_x(static_cast(val)) {} @@ -1646,19 +1646,19 @@ template class fp8_e8m0_x { fp8_e8m0_x &operator=(half val) { static_assert(N == 1, "fp8_e8m0_x: N must be 1 for scalar assignment"); - vals[0] = ConvertToE8M0_CPU(static_cast(val), rounding::upward, + vals[0] = detail::ConvertToE8M0_CPU(static_cast(val), rounding::upward, saturation::finite); return *this; } fp8_e8m0_x &operator=(bfloat16 val) { static_assert(N == 1, "fp8_e8m0_x: N must be 1 for scalar assignment"); - vals[0] = ConvertToE8M0_CPU(static_cast(val), rounding::upward, + vals[0] = detail::ConvertToE8M0_CPU(static_cast(val), rounding::upward, saturation::finite); return *this; } fp8_e8m0_x &operator=(float val) { static_assert(N == 1, "fp8_e8m0_x: N must be 1 for scalar assignment"); - vals[0] = ConvertToE8M0_CPU(val, rounding::upward, saturation::finite); + vals[0] = detail::ConvertToE8M0_CPU(val, rounding::upward, saturation::finite); return *this; } fp8_e8m0_x &operator=(double val) { @@ -1685,19 +1685,19 @@ template class fp8_e8m0_x { explicit operator half() const { static_assert(N == 1, "fp8_e8m0_x: N must be 1 for scalar conversion"); - return ConvertFromE8M0_CPU(vals[0]); + return detail::ConvertFromE8M0_CPU(vals[0]); } explicit operator bfloat16() const { static_assert(N == 1, "fp8_e8m0_x: N must be 1 for scalar conversion"); - return ConvertFromE8M0_CPU(vals[0]); + return detail::ConvertFromE8M0_CPU(vals[0]); } explicit operator float() const { static_assert(N == 1, "fp8_e8m0_x: N must be 1 for scalar conversion"); - return ConvertFromE8M0_CPU(vals[0]); + return detail::ConvertFromE8M0_CPU(vals[0]); } explicit operator double() const { static_assert(N == 1, "fp8_e8m0_x: N must be 1 for scalar conversion"); - return ConvertFromE8M0_CPU(vals[0]); + return detail::ConvertFromE8M0_CPU(vals[0]); } explicit operator char() const { @@ -1743,19 +1743,19 @@ template class fp8_e8m0_x { explicit operator sycl::marray() const { sycl::marray out; for (size_t i = 0; i < N; ++i) - out[i] = ConvertFromE8M0_CPU(vals[i]); + out[i] = detail::ConvertFromE8M0_CPU(vals[i]); return out; } explicit operator sycl::marray() const { sycl::marray out; for (size_t i = 0; i < N; ++i) - out[i] = ConvertFromE8M0_CPU(vals[i]); + out[i] = detail::ConvertFromE8M0_CPU(vals[i]); return out; } explicit operator sycl::marray() const { sycl::marray out; for (size_t i = 0; i < N; ++i) - out[i] = ConvertFromE8M0_CPU(vals[i]); + out[i] = detail::ConvertFromE8M0_CPU(vals[i]); return out; } diff --git a/sycl/unittests/Extensions/fp8/fp8_e4m3.cpp b/sycl/unittests/Extensions/fp8/fp8_e4m3.cpp index 4aecd26ff1eb2..ed90cead2d43f 100644 --- a/sycl/unittests/Extensions/fp8/fp8_e4m3.cpp +++ b/sycl/unittests/Extensions/fp8/fp8_e4m3.cpp @@ -405,3 +405,8 @@ TEST(FP8E4M3Test, MarrayDoubleToEven) { EXPECT_EQ(a.vals[0], 0x06); EXPECT_EQ(a.vals[1], 0x38); } + +TEST(FP8E4M3Test, VariadicRejectsMixedTypes) { + EXPECT_FALSE((std::is_constructible_v)); + EXPECT_FALSE((std::is_constructible_v)); +} diff --git a/sycl/unittests/Extensions/fp8/fp8_e5m2.cpp b/sycl/unittests/Extensions/fp8/fp8_e5m2.cpp index e73e4f1a24624..6a5123b5f3cc6 100644 --- a/sycl/unittests/Extensions/fp8/fp8_e5m2.cpp +++ b/sycl/unittests/Extensions/fp8/fp8_e5m2.cpp @@ -446,3 +446,9 @@ TEST(FP8E5M2Test, BoolOperatorWithNaN) { EXPECT_TRUE(static_cast(nanv)); // not +0 or -0 EXPECT_EQ(nanv.vals[0], 0x7F); // NaN encoding remains S.11111.11 } + +TEST(FP8E5M2Test, VariadicRejectsMixedTypes) { + EXPECT_FALSE((std::is_constructible_v)); + EXPECT_FALSE( + (std::is_constructible_v)); +} \ No newline at end of file diff --git a/sycl/unittests/Extensions/fp8/fp8_e8m0.cpp b/sycl/unittests/Extensions/fp8/fp8_e8m0.cpp index fdb6a09018edf..b64970a8e6522 100644 --- a/sycl/unittests/Extensions/fp8/fp8_e8m0.cpp +++ b/sycl/unittests/Extensions/fp8/fp8_e8m0.cpp @@ -104,10 +104,10 @@ TEST(FP8E8M0Test, CArrayConstructorFloatRoundingModes) { } TEST(FP8E8M0Test, RoundClipZeroFractionNegativeAndTieCases) { - EXPECT_EQ(RoundClip(0.25f, 0, rounding::upward, 0u), 1u); - EXPECT_EQ(RoundClip(0.25f, 0, rounding::upward, 1u), 0u); - EXPECT_EQ(RoundClip(0.5f, 0, rounding::to_even, 0u), 0u); - EXPECT_EQ(RoundClip(0.75f, 0, rounding::to_even, 0u), 1u); + EXPECT_EQ(detail::RoundClip(0.25f, 0, rounding::upward, 0u), 1u); + EXPECT_EQ(detail::RoundClip(0.25f, 0, rounding::upward, 1u), 0u); + EXPECT_EQ(detail::RoundClip(0.5f, 0, rounding::to_even, 0u), 0u); + EXPECT_EQ(detail::RoundClip(0.75f, 0, rounding::to_even, 0u), 1u); } TEST(FP8E8M0Test, CArrayConstructorHalfHostUpwardFinite) { @@ -311,3 +311,9 @@ TEST(FP8E8M0Test, MarrayConversionOperators) { EXPECT_EQ(fo[0], 1.0f); EXPECT_EQ(fo[1], 4.0f); } + +TEST(FP8E8M0Test, VariadicRejectsMixedTypes) { + EXPECT_FALSE((std::is_constructible_v)); + EXPECT_FALSE((std::is_constructible_v)); +} \ No newline at end of file From b8cf8b0b673ea6128ca79a866037e3b7d957a887 Mon Sep 17 00:00:00 2001 From: "Klochkov, Denis" Date: Wed, 8 Apr 2026 12:02:13 +0200 Subject: [PATCH 20/89] [SYCL] limit operators with SFINAE, do not use asserts --- .../oneapi/experimental/float_8bit/types.hpp | 454 ++++++++---------- 1 file changed, 206 insertions(+), 248 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp index 1a63e91594618..2cc9ee8342414 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp @@ -442,7 +442,8 @@ uint8_t round(rounding r, uint8_t b, sycl::half yi, T vi) { return b; } -static inline uint8_t ConvertToE8M0_CPU(float x, rounding R, +template +static inline uint8_t ConvertToE8M0_CPU(T x, rounding R, saturation S) noexcept { // E8M0: unsigned 8-bit exponent code, bias 127. // Code 0xFF reserved for NaN. No Inf, no subnormals, no signed zero. @@ -452,30 +453,44 @@ static inline uint8_t ConvertToE8M0_CPU(float x, rounding R, constexpr uint8_t NaNCode = 0xFF; constexpr uint8_t MaxFiniteCode = 0xFE; - if (std::isnan(x)) - return NaNCode; + // NaN and Inf checks only apply to non-integral types. + if constexpr (!std::is_integral_v) { + if (std::isnan(static_cast(x))) + return NaNCode; + if (std::isinf(static_cast(x))) + return (S == saturation::finite) ? MaxFiniteCode : NaNCode; + } - // No sign bit: negative inputs are treated as their magnitude. - float ax = std::fabs(x); + // Compute absolute value in the natural type T. + T ax; + if constexpr (std::is_unsigned_v) + ax = x; + else if constexpr (std::is_signed_v && std::is_integral_v) + ax = x < T(0) ? static_cast(-x) : x; + else + ax = static_cast(std::fabs(static_cast(x))); - // Infinity handling: depends on saturation. - if (std::isinf(ax)) - return (S == saturation::finite) ? MaxFiniteCode : NaNCode; + // Zero check in natural type. + if (ax == T(0)) + return 0x00; + + // Convert to float for frexp/ldexp-based exponent extraction. + float fax = static_cast(ax); - // Zero and underflow: map to min normal (code 0). + // Underflow: map to min normal (code 0). // Min normal = 2^-127. const float min_normal = std::ldexp(1.0f, Emin); - if (ax == 0.0f || ax < min_normal) + if (fax < min_normal) return 0x00; // Overflow and "too large": clamp or NaN depending on saturation. const float max_normal = std::ldexp(1.0f, Emax); // 2^127 - if (ax >= max_normal) + if (fax >= max_normal) return (S == saturation::finite) ? MaxFiniteCode : NaNCode; - // Determine E such that 2^E <= ax < 2^(E+1). + // Determine E such that 2^E <= fax < 2^(E+1). int e2 = 0; - float m = std::frexp(ax, &e2); // ax = m * 2^e2, m in [0.5, 1) + float m = std::frexp(fax, &e2); // fax = m * 2^e2, m in [0.5, 1) int E = e2 - 1; // With no mantissa, representables are exact powers of two. @@ -497,8 +512,8 @@ static inline uint8_t ConvertToE8M0_CPU(float x, rounding R, // Nearest of {2^E, 2^(E+1)} w/ ties-to-even (even exponent on tie). float lo = std::ldexp(1.0f, E); float hi = std::ldexp(1.0f, E + 1); - float dlo = ax - lo; - float dhi = hi - ax; + float dlo = fax - lo; + float dhi = hi - fax; if (dhi < dlo) { if (E < Emax) ++E; @@ -544,6 +559,9 @@ template class fp8_e4m3_x { static constexpr uint8_t MaxFiniteCode = 0x7E; // 0.1111.110 (positive max normal) + static_assert(N == 1 || N == 2, + "fp8_e4m3_x: Template argument N must be 1 or 2"); + template uint8_t ConvertToFP8(T h, rounding r, saturation s) { sycl::half hi = static_cast(h); #ifdef __SYCL_DEVICE_ONLY__ @@ -598,8 +616,6 @@ template class fp8_e4m3_x { } void CheckConstraints(rounding r) const { - static_assert(N == 1 || N == 2, - "fp8_e4m3_x: Template argument N must be 1 or 2"); if (r != rounding::to_even) throw std::invalid_argument( "fp8_e4m3_x: only rounding::to_even is supported"); @@ -623,8 +639,6 @@ template class fp8_e4m3_x { ((std::is_same_v, float>) && ...) || ((std::is_same_v, double>) && ...))>> explicit fp8_e4m3_x(Types... v) { - static_assert(N == 1 || N == 2, - "fp8_e4m3_x: Template argument N must be 1 or 2"); if constexpr (((std::is_same_v, bfloat16>) && ...)) { const bfloat16 in[N] = {static_cast(v)...}; for (size_t i = 0; i < N; ++i) @@ -658,8 +672,6 @@ template class fp8_e4m3_x { } explicit fp8_e4m3_x(double const (&v)[N]) { - static_assert(N == 1 || N == 2, - "fp8_e4m3_x: Template argument N must be 1 or 2"); for (size_t i = 0; i < N; ++i) vals[i] = ConvertToFP8(v[i], rounding::to_even, saturation::finite); } @@ -694,138 +706,117 @@ template class fp8_e4m3_x { // Construct from integer types. // Available only when N==1. + template > explicit fp8_e4m3_x(short val) { - static_assert(N == 1 && "fp8_e4m3_x: N must be 1 for short constructor"); vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); } + template > explicit fp8_e4m3_x(int val) { - static_assert(N == 1 && "fp8_e4m3_x: N must be 1 for int constructor"); vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); } + template > explicit fp8_e4m3_x(long val) { - static_assert(N == 1 && "fp8_e4m3_x: N must be 1 for long constructor"); vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); } + template > explicit fp8_e4m3_x(long long val) { - static_assert(N == 1 && - "fp8_e4m3_x: N must be 1 for long long constructor"); vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); } + template > explicit fp8_e4m3_x(unsigned short val) { - static_assert(N == 1 && - "fp8_e4m3_x: N must be 1 for unsigned short constructor"); vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); } + template > explicit fp8_e4m3_x(unsigned int val) { - static_assert(N == 1 && - "fp8_e4m3_x: N must be 1 for unsigned int constructor"); vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); } + template > explicit fp8_e4m3_x(unsigned long val) { - static_assert(N == 1 && - "fp8_e4m3_x: N must be 1 for unsigned long constructor"); vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); } + template > explicit fp8_e4m3_x(unsigned long long val) { - static_assert(N == 1 && - "fp8_e4m3_x: N must be 1 for unsigned long long constructor"); vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); } // Assign (operator) from half, bfloat16, float, double, and integer types. // Available only when N==1. + template > fp8_e4m3_x &operator=(sycl::half val) { - static_assert(N == 1 && - "fp8_e4m3_x: N must be 1 for half assignment operator"); vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); return *this; } + template > fp8_e4m3_x &operator=(bfloat16 val) { - static_assert(N == 1 && - "fp8_e4m3_x: N must be 1 for bfloat16 assignment operator"); vals[0] = ConvertBF16ToFP8(val, rounding::to_even, saturation::finite); return *this; } + template > fp8_e4m3_x &operator=(float val) { - static_assert(N == 1 && - "fp8_e4m3_x: N must be 1 for float assignment operator"); vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); return *this; } + template > fp8_e4m3_x &operator=(double val) { - static_assert(N == 1 && - "fp8_e4m3_x: N must be 1 for double assignment operator"); vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); return *this; } + template > fp8_e4m3_x &operator=(short val) { - static_assert(N == 1 && - "fp8_e4m3_x: N must be 1 for short assignment operator"); vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); return *this; } + template > fp8_e4m3_x &operator=(int val) { - static_assert(N == 1 && - "fp8_e4m3_x: N must be 1 for int assignment operator"); vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); return *this; } + template > fp8_e4m3_x &operator=(long val) { - static_assert(N == 1 && - "fp8_e4m3_x: N must be 1 for long assignment operator"); vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); return *this; } + template > fp8_e4m3_x &operator=(long long val) { - static_assert(N == 1 && - "fp8_e4m3_x: N must be 1 for long long assignment operator"); vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); return *this; } + template > fp8_e4m3_x &operator=(unsigned short val) { - static_assert( - N == 1 && - "fp8_e4m3_x: N must be 1 for unsigned short assignment operator"); vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); return *this; } + template > fp8_e4m3_x &operator=(unsigned int val) { - static_assert( - N == 1 && - "fp8_e4m3_x: N must be 1 for unsigned int assignment operator"); vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); return *this; } + template > fp8_e4m3_x &operator=(unsigned long val) { - static_assert( - N == 1 && - "fp8_e4m3_x: N must be 1 for unsigned long assignment operator"); vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); return *this; } + template > fp8_e4m3_x &operator=(unsigned long long val) { - static_assert( - N == 1 && - "fp8_e4m3_x: N must be 1 for unsigned long long assignment operator"); vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); return *this; } @@ -833,107 +824,86 @@ template class fp8_e4m3_x { // Convert to half, bfloat16, float, double. // Available only when N==1. + template > explicit operator half() const { - static_assert(N == 1 && - "fp8_e4m3_x: N must be 1 for half conversion operator"); return ConvertFromFP8(vals[0]); } + template > explicit operator bfloat16() const { - static_assert(N == 1 && - "fp8_e4m3_x: N must be 1 for bfloat16 conversion operator"); return ConvertBF16FromFP8(vals[0]); } + template > explicit operator float() const { - static_assert(N == 1 && - "fp8_e4m3_x: N must be 1 for float conversion operator"); return ConvertFromFP8(vals[0]); } + template > explicit operator double() const { - static_assert(N == 1 && - "fp8_e4m3_x: N must be 1 for double conversion operator"); return ConvertFromFP8(vals[0]); } // Convert to integer types. // Available only when N==1. + template > explicit operator char() const { - static_assert(N == 1 && - "fp8_e4m3_x: N must be 1 for char conversion operator"); return ConvertFromFP8(vals[0]); } + template > explicit operator signed char() const { - static_assert( - N == 1 && - "fp8_e4m3_x: N must be 1 for signed char conversion operator"); return ConvertFromFP8(vals[0]); } + template > explicit operator short() const { - static_assert(N == 1 && - "fp8_e4m3_x: N must be 1 for short conversion operator"); return ConvertFromFP8(vals[0]); } + template > explicit operator int() const { - static_assert(N == 1 && - "fp8_e4m3_x: N must be 1 for int conversion operator"); return ConvertFromFP8(vals[0]); } + template > explicit operator long() const { - static_assert(N == 1 && - "fp8_e4m3_x: N must be 1 for long conversion operator"); return ConvertFromFP8(vals[0]); } + template > explicit operator long long() const { - static_assert(N == 1 && - "fp8_e4m3_x: N must be 1 for long long conversion operator"); return ConvertFromFP8(vals[0]); } + template > explicit operator unsigned char() const { - static_assert( - N == 1 && - "fp8_e4m3_x: N must be 1 for unsigned char conversion operator"); return ConvertFromFP8(vals[0]); } + template > explicit operator unsigned short() const { - static_assert( - N == 1 && - "fp8_e4m3_x: N must be 1 for unsigned short conversion operator"); return ConvertFromFP8(vals[0]); } + template > explicit operator unsigned int() const { - static_assert( - N == 1 && - "fp8_e4m3_x: N must be 1 for unsigned int conversion operator"); return ConvertFromFP8(vals[0]); } + template > explicit operator unsigned long() const { - static_assert( - N == 1 && - "fp8_e4m3_x: N must be 1 for unsigned long conversion operator"); return ConvertFromFP8(vals[0]); } + template > explicit operator unsigned long long() const { - static_assert( - N == 1 && - "fp8_e4m3_x: N must be 1 for unsigned long long conversion operator"); return ConvertFromFP8(vals[0]); } // Convert to bool // Available only when N==1. + template > explicit operator bool() const { - static_assert(N == 1, "fp8_e4m3_x: operator() requires size N=1"); #ifdef __SYCL_DEVICE_ONLY__ // detect +0 / -0 sycl::half h = __builtin_spirv_ConvertE4M3ToFP16EXT(vals[0]); @@ -978,6 +948,9 @@ template class fp8_e5m2_x { static constexpr float MinSubnormal = 0.0000152587890625f; // 2^-16 static constexpr uint8_t MaxFiniteCode = 0x7C; // 0.11111.00 + static_assert(N == 1 || N == 2, + "fp8_e5m2_x: Template argument N must be 1 or 2"); + uint8_t ConvertToFP8(sycl::half h, rounding r, saturation s) { #ifdef __SYCL_DEVICE_ONLY__ uint8_t b = 0; @@ -1029,8 +1002,6 @@ template class fp8_e5m2_x { } void CheckConstraints(rounding r, saturation s) const { - static_assert(N == 1 || N == 2, - "fp8_e5m2_x: Template argument N must be 1 or 2"); if (r != rounding::to_even) throw std::invalid_argument( "fp8_e5m2_x: only rounding::to_even is supported"); @@ -1055,8 +1026,6 @@ template class fp8_e5m2_x { ((std::is_same_v, float>) && ...) || ((std::is_same_v, double>) && ...))>> explicit fp8_e5m2_x(Types... v) { - static_assert(N == 1 || N == 2, - "fp8_e5m2_x: Template argument N must be 1 or 2"); if constexpr (((std::is_same_v, bfloat16>) && ...)) { const bfloat16 in[N] = {static_cast(v)...}; for (size_t i = 0; i < N; ++i) @@ -1136,8 +1105,6 @@ template class fp8_e5m2_x { explicit fp8_e5m2_x([[maybe_unused]] half const (&in)[N], [[maybe_unused]] const stochastic_seed &seed, [[maybe_unused]] saturation s = saturation::finite) { - static_assert(N == 1 || N == 2, - "fp8_e5m2_x: Template argument N must be 1 or 2"); #ifdef __SYCL_DEVICE_ONLY__ uint32_t current_seed = *seed.pseed; for (size_t i = 0; i < N; ++i) { @@ -1156,8 +1123,6 @@ template class fp8_e5m2_x { explicit fp8_e5m2_x([[maybe_unused]] bfloat16 const (&in)[N], [[maybe_unused]] const stochastic_seed &seed, [[maybe_unused]] saturation s = saturation::finite) { - static_assert(N == 1 || N == 2, - "fp8_e5m2_x: Template argument N must be 1 or 2"); #ifdef __SYCL_DEVICE_ONLY__ uint32_t current_seed = *seed.pseed; for (size_t i = 0; i < N; ++i) { @@ -1176,8 +1141,6 @@ template class fp8_e5m2_x { explicit fp8_e5m2_x([[maybe_unused]] float const (&in)[N], [[maybe_unused]] const stochastic_seed &seed, [[maybe_unused]] saturation s = saturation::finite) { - static_assert(N == 1 || N == 2, - "fp8_e5m2_x: Template argument N must be 1 or 2"); #ifdef __SYCL_DEVICE_ONLY__ uint32_t current_seed = *seed.pseed; for (size_t i = 0; i < N; ++i) { @@ -1200,8 +1163,6 @@ template class fp8_e5m2_x { explicit fp8_e5m2_x([[maybe_unused]] const sycl::marray &in, [[maybe_unused]] const stochastic_seed &seed, [[maybe_unused]] saturation s = saturation::finite) { - static_assert(N == 1 || N == 2, - "fp8_e5m2_x: Template argument N must be 1 or 2"); #ifdef __SYCL_DEVICE_ONLY__ uint32_t current_seed = *seed.pseed; for (size_t i = 0; i < N; ++i) { @@ -1220,8 +1181,6 @@ template class fp8_e5m2_x { explicit fp8_e5m2_x([[maybe_unused]] const sycl::marray &in, [[maybe_unused]] const stochastic_seed &seed, [[maybe_unused]] saturation s = saturation::finite) { - static_assert(N == 1 || N == 2, - "fp8_e5m2_x: Template argument N must be 1 or 2"); #ifdef __SYCL_DEVICE_ONLY__ uint32_t current_seed = *seed.pseed; for (size_t i = 0; i < N; ++i) { @@ -1240,8 +1199,6 @@ template class fp8_e5m2_x { explicit fp8_e5m2_x([[maybe_unused]] const sycl::marray &in, [[maybe_unused]] const stochastic_seed &seed, [[maybe_unused]] saturation s = saturation::finite) { - static_assert(N == 1 || N == 2, - "fp8_e5m2_x: Template argument N must be 1 or 2"); #ifdef __SYCL_DEVICE_ONLY__ uint32_t current_seed = *seed.pseed; for (size_t i = 0; i < N; ++i) { @@ -1261,138 +1218,117 @@ template class fp8_e5m2_x { // Construct from integer types. // Available only when N==1. + template > explicit fp8_e5m2_x(short val) { - static_assert(N == 1 && "fp8_e5m2_x: N must be 1 for short constructor"); vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); } + template > explicit fp8_e5m2_x(int val) { - static_assert(N == 1 && "fp8_e5m2_x: N must be 1 for int constructor"); vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); } + template > explicit fp8_e5m2_x(long val) { - static_assert(N == 1 && "fp8_e5m2_x: N must be 1 for long constructor"); vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); } + template > explicit fp8_e5m2_x(long long val) { - static_assert(N == 1 && - "fp8_e5m2_x: N must be 1 for long long constructor"); vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); } + template > explicit fp8_e5m2_x(unsigned short val) { - static_assert(N == 1 && - "fp8_e5m2_x: N must be 1 for unsigned short constructor"); vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); } + template > explicit fp8_e5m2_x(unsigned int val) { - static_assert(N == 1 && - "fp8_e5m2_x: N must be 1 for unsigned int constructor"); vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); } + template > explicit fp8_e5m2_x(unsigned long val) { - static_assert(N == 1 && - "fp8_e5m2_x: N must be 1 for unsigned long constructor"); vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); } + template > explicit fp8_e5m2_x(unsigned long long val) { - static_assert(N == 1 && - "fp8_e5m2_x: N must be 1 for unsigned long long constructor"); vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); } // Assign (operator) from half, bfloat16, float, double, and integer types. // Available only when N==1. + template > fp8_e5m2_x &operator=(sycl::half val) { - static_assert(N == 1 && - "fp8_e5m2_x: N must be 1 for half assignment operator"); vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); return *this; } + template > fp8_e5m2_x &operator=(bfloat16 val) { - static_assert(N == 1 && - "fp8_e5m2_x: N must be 1 for half bfloat16 operator"); vals[0] = ConvertBF16ToFP8(val, rounding::to_even, saturation::finite); return *this; } + template > fp8_e5m2_x &operator=(float val) { - static_assert(N == 1 && - "fp8_e5m2_x: N must be 1 for float assignment operator"); vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); return *this; } + template > fp8_e5m2_x &operator=(double val) { - static_assert(N == 1 && - "fp8_e5m2_x: N must be 1 for double assignment operator"); vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); return *this; } + template > fp8_e5m2_x &operator=(short val) { - static_assert(N == 1 && - "fp8_e5m2_x: N must be 1 for short assignment operator"); vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); return *this; } + template > fp8_e5m2_x &operator=(int val) { - static_assert(N == 1 && - "fp8_e5m2_x: N must be 1 for int assignment operator"); vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); return *this; } + template > fp8_e5m2_x &operator=(long val) { - static_assert(N == 1 && - "fp8_e5m2_x: N must be 1 for long assignment operator"); vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); return *this; } + template > fp8_e5m2_x &operator=(long long val) { - static_assert(N == 1 && - "fp8_e5m2_x: N must be 1 for long long assignment operator"); vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); return *this; } + template > fp8_e5m2_x &operator=(unsigned short val) { - static_assert( - N == 1 && - "fp8_e5m2_x: N must be 1 for unsigned short assignment operator"); vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); return *this; } + template > fp8_e5m2_x &operator=(unsigned int val) { - static_assert( - N == 1 && - "fp8_e5m2_x: N must be 1 for unsigned int assignment operator"); vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); return *this; } + template > fp8_e5m2_x &operator=(unsigned long val) { - static_assert( - N == 1 && - "fp8_e5m2_x: N must be 1 for unsigned long assignment operator"); vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); return *this; } + template > fp8_e5m2_x &operator=(unsigned long long val) { - static_assert( - N == 1 && - "fp8_e5m2_x: N must be 1 for unsigned long long assignment operator"); vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); return *this; } @@ -1400,110 +1336,89 @@ template class fp8_e5m2_x { // Convert to half, bfloat16, float, double. // Available only when N==1. + template > explicit operator half() const { - static_assert(N == 1 && - "fp8_e5m2_x: N must be 1 for half conversion operator"); return ConvertFromFP8(vals[0]); } + template > explicit operator bfloat16() const { - static_assert(N == 1 && - "fp8_e5m2_x: N must be 1 for bfloat16 conversion operator"); return ConvertBF16FromFP8(vals[0]); } + template > explicit operator float() const { - static_assert(N == 1 && - "fp8_e5m2_x: N must be 1 for float conversion operator"); return ConvertFromFP8(vals[0]); } + template > explicit operator double() const { - static_assert(N == 1 && - "fp8_e5m2_x: N must be 1 for double conversion operator"); return ConvertFromFP8(vals[0]); } // Convert to integer types. // Available only when N==1. + template > explicit operator char() const { - static_assert(N == 1 && - "fp8_e5m2_x: N must be 1 for char conversion operator"); return ConvertFromFP8(vals[0]); } + template > explicit operator signed char() const { - static_assert( - N == 1 && - "fp8_e5m2_x: N must be 1 for signed char conversion operator"); return ConvertFromFP8(vals[0]); } + template > explicit operator short() const { - static_assert(N == 1 && - "fp8_e5m2_x: N must be 1 for short conversion operator"); return ConvertFromFP8(vals[0]); } + template > explicit operator int() const { - static_assert(N == 1 && - "fp8_e5m2_x: N must be 1 for int conversion operator"); return ConvertFromFP8(vals[0]); } + template > explicit operator long() const { - static_assert(N == 1 && - "fp8_e5m2_x: N must be 1 for long conversion operator"); return ConvertFromFP8(vals[0]); } + template > explicit operator long long() const { - static_assert(N == 1 && - "fp8_e5m2_x: N must be 1 for long long conversion operator"); return ConvertFromFP8(vals[0]); } + template > explicit operator unsigned char() const { - static_assert( - N == 1 && - "fp8_e5m2_x: N must be 1 for unsigned char conversion operator"); return ConvertFromFP8(vals[0]); } + template > explicit operator unsigned short() const { - static_assert( - N == 1 && - "fp8_e5m2_x: N must be 1 for unsigned short conversion operator"); return ConvertFromFP8(vals[0]); } + template > explicit operator unsigned int() const { - static_assert( - N == 1 && - "fp8_e5m2_x: N must be 1 for unsigned int conversion operator"); return ConvertFromFP8(vals[0]); } + template > explicit operator unsigned long() const { - static_assert( - N == 1 && - "fp8_e5m2_x: N must be 1 for unsigned long conversion operator"); return ConvertFromFP8(vals[0]); } + template > explicit operator unsigned long long() const { - static_assert( - N == 1 && - "fp8_e5m2_x: N must be 1 for unsigned long long conversion operator"); return ConvertFromFP8(vals[0]); } // Convert to bool // Available only when N==1. + template > explicit operator bool() const { - static_assert(N == 1, "fp8_e5m2_x: operator() requires size N=1"); // false iff +0 or -0; otherwise true. return vals[0] != 0x00 && vals[0] != 0x80; } @@ -1533,9 +1448,11 @@ template class fp8_e5m2_x { }; template class fp8_e8m0_x { + static_assert(N == 1 || N == 2, + "fp8_e8m0_x: Template argument N must be 1 or 2"); + void CheckConstraints(rounding r) const { - static_assert(N == 1 || N == 2, - "fp8_e8m0_x: Template argument N must be 1 or 2"); + if (r != rounding::upward && r != rounding::toward_zero) throw std::invalid_argument("fp8_e8m0_x: only rounding::upward and " "rounding::toward_zero are supported"); @@ -1555,29 +1472,23 @@ template class fp8_e8m0_x { ((std::is_same_v, float>) && ...) || ((std::is_same_v, double>) && ...))>> explicit fp8_e8m0_x(Types... v) { -#ifdef __SYCL_DEVICE_ONLY__ - static_assert(N == 1 || N == 2, - "fp8_e8m0_x: Template argument N must be 1 or 2 on device"); -#endif using InT = std::common_type_t...>; const InT in[N] = {v...}; for (size_t i = 0; i < N; ++i) - vals[i] = detail::ConvertToE8M0_CPU(static_cast(in[i]), rounding::upward, - saturation::finite); + vals[i] = detail::ConvertToE8M0_CPU(in[i], rounding::upward, + saturation::finite); } explicit fp8_e8m0_x(half const (&in)[N], rounding r = rounding::upward) { CheckConstraints(r); for (size_t i = 0; i < N; ++i) - vals[i] = - detail::ConvertToE8M0_CPU(static_cast(in[i]), r, saturation::finite); + vals[i] = detail::ConvertToE8M0_CPU(in[i], r, saturation::finite); } explicit fp8_e8m0_x(bfloat16 const (&in)[N], rounding r = rounding::upward) { CheckConstraints(r); for (size_t i = 0; i < N; ++i) - vals[i] = - detail::ConvertToE8M0_CPU(static_cast(in[i]), r, saturation::finite); + vals[i] = detail::ConvertToE8M0_CPU(in[i], r, saturation::finite); } explicit fp8_e8m0_x(float const (&in)[N], rounding r = rounding::upward) { @@ -1587,27 +1498,23 @@ template class fp8_e8m0_x { } explicit fp8_e8m0_x(double const (&in)[N]) { - static_assert(N == 1 || N == 2, - "fp8_e8m0_x: Template argument N must be 1 or 2 on device"); for (size_t i = 0; i < N; ++i) - vals[i] = detail::ConvertToE8M0_CPU(static_cast(in[i]), rounding::upward, - saturation::finite); + vals[i] = detail::ConvertToE8M0_CPU(in[i], rounding::upward, + saturation::finite); } explicit fp8_e8m0_x(const marray &in, rounding r = rounding::upward) { CheckConstraints(r); for (size_t i = 0; i < N; ++i) - vals[i] = - detail::ConvertToE8M0_CPU(static_cast(in[i]), r, saturation::finite); + vals[i] = detail::ConvertToE8M0_CPU(in[i], r, saturation::finite); } explicit fp8_e8m0_x(const marray &in, rounding r = rounding::upward) { CheckConstraints(r); for (size_t i = 0; i < N; ++i) - vals[i] = - detail::ConvertToE8M0_CPU(static_cast(in[i]), r, saturation::finite); + vals[i] = detail::ConvertToE8M0_CPU(in[i], r, saturation::finite); } explicit fp8_e8m0_x(const marray &in, @@ -1618,125 +1525,176 @@ template class fp8_e8m0_x { } explicit fp8_e8m0_x(const marray &in) { - static_assert(N == 1 || N == 2, - "fp8_e8m0_x: Template argument N must be 1 or 2 on device"); for (size_t i = 0; i < N; ++i) - vals[i] = detail::ConvertToE8M0_CPU(static_cast(in[i]), rounding::upward, - saturation::finite); + vals[i] = detail::ConvertToE8M0_CPU(in[i], rounding::upward, + saturation::finite); } // Construct from integer types. // Available only when N==1. + template > explicit fp8_e8m0_x(short val) { - static_assert(N == 1 && "fp8_e8m0_x: N must be 1 for short constructor"); - vals[0] = detail::ConvertToE8M0_CPU(static_cast(val), rounding::upward, - saturation::finite); - } - explicit fp8_e8m0_x(int val) : fp8_e8m0_x(static_cast(val)) {} - explicit fp8_e8m0_x(long val) : fp8_e8m0_x(static_cast(val)) {} - explicit fp8_e8m0_x(long long val) : fp8_e8m0_x(static_cast(val)) {} - explicit fp8_e8m0_x(unsigned short val) - : fp8_e8m0_x(static_cast(val)) {} - explicit fp8_e8m0_x(unsigned int val) : fp8_e8m0_x(static_cast(val)) {} - explicit fp8_e8m0_x(unsigned long val) - : fp8_e8m0_x(static_cast(val)) {} - explicit fp8_e8m0_x(unsigned long long val) - : fp8_e8m0_x(static_cast(val)) {} + vals[0] = + detail::ConvertToE8M0_CPU(val, rounding::upward, saturation::finite); + } + + template > + explicit fp8_e8m0_x(int val) { + vals[0] = + detail::ConvertToE8M0_CPU(val, rounding::upward, saturation::finite); + } + template > + explicit fp8_e8m0_x(long val) { + vals[0] = + detail::ConvertToE8M0_CPU(val, rounding::upward, saturation::finite); + } + template > + explicit fp8_e8m0_x(long long val) { + vals[0] = + detail::ConvertToE8M0_CPU(val, rounding::upward, saturation::finite); + } + template > + explicit fp8_e8m0_x(unsigned short val) { + vals[0] = + detail::ConvertToE8M0_CPU(val, rounding::upward, saturation::finite); + } + template > + explicit fp8_e8m0_x(unsigned int val) { + vals[0] = + detail::ConvertToE8M0_CPU(val, rounding::upward, saturation::finite); + } + template > + explicit fp8_e8m0_x(unsigned long val) { + vals[0] = + detail::ConvertToE8M0_CPU(val, rounding::upward, saturation::finite); + } + template > + explicit fp8_e8m0_x(unsigned long long val) { + vals[0] = + detail::ConvertToE8M0_CPU(val, rounding::upward, saturation::finite); + } + + template > fp8_e8m0_x &operator=(half val) { - static_assert(N == 1, "fp8_e8m0_x: N must be 1 for scalar assignment"); - vals[0] = detail::ConvertToE8M0_CPU(static_cast(val), rounding::upward, - saturation::finite); + vals[0] = + detail::ConvertToE8M0_CPU(val, rounding::upward, saturation::finite); return *this; } + template > fp8_e8m0_x &operator=(bfloat16 val) { - static_assert(N == 1, "fp8_e8m0_x: N must be 1 for scalar assignment"); - vals[0] = detail::ConvertToE8M0_CPU(static_cast(val), rounding::upward, - saturation::finite); + vals[0] = + detail::ConvertToE8M0_CPU(val, rounding::upward, saturation::finite); return *this; } + template > fp8_e8m0_x &operator=(float val) { - static_assert(N == 1, "fp8_e8m0_x: N must be 1 for scalar assignment"); - vals[0] = detail::ConvertToE8M0_CPU(val, rounding::upward, saturation::finite); + vals[0] = + detail::ConvertToE8M0_CPU(val, rounding::upward, saturation::finite); return *this; } + template > fp8_e8m0_x &operator=(double val) { return (*this = static_cast(val)); } - fp8_e8m0_x &operator=(short val) { return (*this = static_cast(val)); } - fp8_e8m0_x &operator=(int val) { return (*this = static_cast(val)); } - fp8_e8m0_x &operator=(long val) { return (*this = static_cast(val)); } + template > + fp8_e8m0_x &operator=(short val) { + return (*this = static_cast(val)); + } + template > + fp8_e8m0_x &operator=(int val) { + return (*this = static_cast(val)); + } + template > + fp8_e8m0_x &operator=(long val) { + return (*this = static_cast(val)); + } + template > fp8_e8m0_x &operator=(long long val) { return (*this = static_cast(val)); } + template > fp8_e8m0_x &operator=(unsigned short val) { return (*this = static_cast(val)); } + template > fp8_e8m0_x &operator=(unsigned int val) { return (*this = static_cast(val)); } + template > fp8_e8m0_x &operator=(unsigned long val) { return (*this = static_cast(val)); } + template > fp8_e8m0_x &operator=(unsigned long long val) { return (*this = static_cast(val)); } + template > explicit operator half() const { - static_assert(N == 1, "fp8_e8m0_x: N must be 1 for scalar conversion"); return detail::ConvertFromE8M0_CPU(vals[0]); } + template > explicit operator bfloat16() const { - static_assert(N == 1, "fp8_e8m0_x: N must be 1 for scalar conversion"); return detail::ConvertFromE8M0_CPU(vals[0]); } + template > explicit operator float() const { - static_assert(N == 1, "fp8_e8m0_x: N must be 1 for scalar conversion"); return detail::ConvertFromE8M0_CPU(vals[0]); } + template > explicit operator double() const { - static_assert(N == 1, "fp8_e8m0_x: N must be 1 for scalar conversion"); return detail::ConvertFromE8M0_CPU(vals[0]); } + template > explicit operator char() const { - static_assert(N == 1, "fp8_e8m0_x: N must be 1 for scalar conversion"); return static_cast(static_cast(*this)); } + template > explicit operator signed char() const { return static_cast(static_cast(*this)); } + template > explicit operator short() const { return static_cast(static_cast(*this)); } + template > explicit operator int() const { return static_cast(static_cast(*this)); } + template > explicit operator long() const { return static_cast(static_cast(*this)); } + template > explicit operator long long() const { return static_cast(static_cast(*this)); } + template > explicit operator unsigned char() const { return static_cast(static_cast(*this)); } + template > explicit operator unsigned short() const { return static_cast(static_cast(*this)); } + template > explicit operator unsigned int() const { return static_cast(static_cast(*this)); } + template > explicit operator unsigned long() const { return static_cast(static_cast(*this)); } + template > explicit operator unsigned long long() const { return static_cast(static_cast(*this)); } + template > explicit operator bool() const { - static_assert(N == 1, "fp8_e8m0_x: operator bool requires size N=1"); return true; } From a77259184bf5666531f000dec7ed316331b891ec Mon Sep 17 00:00:00 2001 From: "Klochkov, Denis" Date: Wed, 8 Apr 2026 14:17:42 +0200 Subject: [PATCH 21/89] [SYCL] do not cast to float --- .../oneapi/experimental/float_8bit/types.hpp | 56 +++++++++++++------ 1 file changed, 38 insertions(+), 18 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp index 2cc9ee8342414..1b1e8d3dc30a4 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp @@ -537,7 +537,8 @@ static inline uint8_t ConvertToE8M0_CPU(T x, rounding R, } template -static inline ToT ConvertFromE8M0_CPU(uint8_t code) noexcept { +static inline ToT ConvertFromE8M0_CPU(uint8_t code, + rounding R) noexcept { constexpr int Bias = 127; if (code == 0xFF) { float qn = std::numeric_limits::quiet_NaN(); @@ -545,7 +546,7 @@ static inline ToT ConvertFromE8M0_CPU(uint8_t code) noexcept { } int E = static_cast(code) - Bias; // includes code==0 -> -127 float v = std::ldexp(1.0f, E); - return ConvertFloatToTarget(v, rounding::to_even); + return ConvertFloatToTarget(v, R); } } // namespace detail @@ -1596,56 +1597,74 @@ template class fp8_e8m0_x { } template > fp8_e8m0_x &operator=(double val) { - return (*this = static_cast(val)); + vals[0] = + detail::ConvertToE8M0_CPU(val, rounding::upward, saturation::finite); + return *this; } template > fp8_e8m0_x &operator=(short val) { - return (*this = static_cast(val)); + vals[0] = + detail::ConvertToE8M0_CPU(val, rounding::upward, saturation::finite); + return *this; } template > fp8_e8m0_x &operator=(int val) { - return (*this = static_cast(val)); + vals[0] = + detail::ConvertToE8M0_CPU(val, rounding::upward, saturation::finite); + return *this; } template > fp8_e8m0_x &operator=(long val) { - return (*this = static_cast(val)); + vals[0] = + detail::ConvertToE8M0_CPU(val, rounding::upward, saturation::finite); + return *this; } template > fp8_e8m0_x &operator=(long long val) { - return (*this = static_cast(val)); + vals[0] = + detail::ConvertToE8M0_CPU(val, rounding::upward, saturation::finite); + return *this; } template > fp8_e8m0_x &operator=(unsigned short val) { - return (*this = static_cast(val)); + vals[0] = + detail::ConvertToE8M0_CPU(val, rounding::upward, saturation::finite); + return *this; } template > fp8_e8m0_x &operator=(unsigned int val) { - return (*this = static_cast(val)); + vals[0] = + detail::ConvertToE8M0_CPU(val, rounding::upward, saturation::finite); + return *this; } template > fp8_e8m0_x &operator=(unsigned long val) { - return (*this = static_cast(val)); + vals[0] = + detail::ConvertToE8M0_CPU(val, rounding::upward, saturation::finite); + return *this; } template > fp8_e8m0_x &operator=(unsigned long long val) { - return (*this = static_cast(val)); + vals[0] = + detail::ConvertToE8M0_CPU(val, rounding::upward, saturation::finite); + return *this; } template > explicit operator half() const { - return detail::ConvertFromE8M0_CPU(vals[0]); + return detail::ConvertFromE8M0_CPU(vals[0], rounding::to_even); } template > explicit operator bfloat16() const { - return detail::ConvertFromE8M0_CPU(vals[0]); + return detail::ConvertFromE8M0_CPU(vals[0], rounding::to_even); } template > explicit operator float() const { - return detail::ConvertFromE8M0_CPU(vals[0]); + return detail::ConvertFromE8M0_CPU(vals[0], rounding::to_even); } template > explicit operator double() const { - return detail::ConvertFromE8M0_CPU(vals[0]); + return detail::ConvertFromE8M0_CPU(vals[0], rounding::to_even); } template > @@ -1701,19 +1720,20 @@ template class fp8_e8m0_x { explicit operator sycl::marray() const { sycl::marray out; for (size_t i = 0; i < N; ++i) - out[i] = detail::ConvertFromE8M0_CPU(vals[i]); + out[i] = detail::ConvertFromE8M0_CPU(vals[i], rounding::to_even); return out; } explicit operator sycl::marray() const { sycl::marray out; for (size_t i = 0; i < N; ++i) - out[i] = detail::ConvertFromE8M0_CPU(vals[i]); + out[i] = + detail::ConvertFromE8M0_CPU(vals[i], rounding::to_even); return out; } explicit operator sycl::marray() const { sycl::marray out; for (size_t i = 0; i < N; ++i) - out[i] = detail::ConvertFromE8M0_CPU(vals[i]); + out[i] = detail::ConvertFromE8M0_CPU(vals[i], rounding::to_even); return out; } From 2d57fd00119bbbeb0a17dfd0c0a331b106941016 Mon Sep 17 00:00:00 2001 From: "Klochkov, Denis" Date: Wed, 8 Apr 2026 16:30:36 +0200 Subject: [PATCH 22/89] [SYCL] rework fp8 to avoid casts to float --- .../oneapi/experimental/float_8bit/types.hpp | 202 ++++++++++++------ 1 file changed, 135 insertions(+), 67 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp index 1b1e8d3dc30a4..61bc73a612cb9 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp @@ -131,36 +131,114 @@ static inline uint8_t RoundClip(float x, uint8_t max, rounding R, return RneClip(x, max); } +static inline int BitWidth(uint32_t x) noexcept { + int width = 0; + while (x != 0u) { + ++width; + x >>= 1; + } + return width; +} + +template struct DirectBinary16Traits; + +template <> struct DirectBinary16Traits { + static constexpr uint16_t SignMask = 0x8000u; + static constexpr uint16_t FracMask = 0x03FFu; + static constexpr uint16_t InfBits = 0x7C00u; + static constexpr uint16_t MaxFiniteBits = 0x7BFFu; + static constexpr uint16_t QuietNaNBits = 0x7E00u; + static constexpr int FracBits = 10; + static constexpr int Bias = 15; + static constexpr int Emin = -14; + static constexpr int Emax = 15; +}; + +template <> struct DirectBinary16Traits { + static constexpr uint16_t SignMask = 0x8000u; + static constexpr uint16_t FracMask = 0x007Fu; + static constexpr uint16_t InfBits = 0x7F80u; + static constexpr uint16_t MaxFiniteBits = 0x7F7Fu; + static constexpr uint16_t QuietNaNBits = 0x7FC0u; + static constexpr int FracBits = 7; + static constexpr int Bias = 127; + static constexpr int Emin = -126; + static constexpr int Emax = 127; +}; + +template static inline ToT MakeDirectNaN() noexcept { + if constexpr (std::is_same_v || + std::is_same_v) { + return sycl::bit_cast(DirectBinary16Traits::QuietNaNBits); + } else if constexpr (std::numeric_limits::has_quiet_NaN) { + return std::numeric_limits::quiet_NaN(); + } else { + return ToT{}; + } +} + template -static inline ToT ConvertFloatToTarget(float v, rounding R) noexcept { +static inline ToT ConvertFloatToTarget(bool negative, uint32_t significand, + int exp2, int srcFracBits, + rounding R) noexcept { + if (significand == 0u) + return negative ? -ToT{0} : ToT{0}; + if constexpr (std::is_same_v || std::is_same_v) { - ToT cand = static_cast(v); - if (R == rounding::toward_zero) { - // If cast increased magnitude, step the 16-bit encoding toward zero. - float fcand = static_cast(cand); - if (std::fabs(fcand) > std::fabs(v)) { - uint16_t bits = sycl::bit_cast(cand); - // Order-preserving transform: sign-bit mapped to MSB ordering - uint16_t ord = (bits & 0x8000u) ? static_cast(~bits) - : static_cast(bits ^ 0x8000u); - if (v >= 0.0f) { - if (ord != 0u) - --ord; // step toward smaller positive - } else { - if (ord != 0xFFFFu) - ++ord; // step toward smaller magnitude for negative numbers - } - uint16_t newbits = (ord & 0x8000u) - ? static_cast(~ord) - : static_cast(ord ^ 0x8000u); - cand = sycl::bit_cast(newbits); - } + using Traits = DirectBinary16Traits; + const uint16_t sign = negative ? Traits::SignMask : 0u; + const int sigBits = BitWidth(significand); + const int unbiasedExp = exp2 + sigBits - 1 - srcFracBits; + + if (unbiasedExp > Traits::Emax) { + return sycl::bit_cast(static_cast( + sign | (R == rounding::toward_zero ? Traits::MaxFiniteBits + : Traits::InfBits))); } - return cand; + + if (unbiasedExp >= Traits::Emin) { + const int shift = Traits::FracBits - (sigBits - 1); + const uint32_t aligned = significand << shift; + const uint16_t expField = + static_cast(unbiasedExp + Traits::Bias) << Traits::FracBits; + const uint16_t fracField = + static_cast(aligned & Traits::FracMask); + return sycl::bit_cast( + static_cast(sign | expField | fracField)); + } + + const int subShift = exp2 - srcFracBits - Traits::Emin + Traits::FracBits; + if (subShift < 0) + return sycl::bit_cast(sign); + + const uint32_t fracField = significand << subShift; + if (fracField == 0u || fracField > Traits::FracMask) + return sycl::bit_cast(sign); + + return sycl::bit_cast( + static_cast(sign | static_cast(fracField))); + } else if constexpr (std::is_floating_point_v) { + ToT magnitude = + std::ldexp(static_cast(significand), exp2 - srcFracBits); + return negative ? -magnitude : magnitude; + } else if constexpr (std::is_integral_v) { + const int shift = exp2 - srcFracBits; + uint64_t magnitude = significand; + if (shift >= 0) + magnitude <<= shift; + else if (-shift < 64) + magnitude >>= -shift; + else + magnitude = 0u; + + if constexpr (std::is_signed_v) { + int64_t signedMagnitude = static_cast(magnitude); + return static_cast(negative ? -signedMagnitude : signedMagnitude); + } else + return static_cast(magnitude); } else - // For float/double/integral targets just use normal cast - return static_cast(v); + return ToT{}; } template @@ -186,10 +264,7 @@ static inline ToT ConvertFromFP8_CPU(uint8_t b, exp = b; } - auto make_nan = [&]() -> ToT { - float qn = std::numeric_limits::quiet_NaN(); - return static_cast(qn); - }; + auto make_nan = [&]() -> ToT { return MakeDirectNaN(); }; // Handle exp = all ones (custom finite-only rules). if (exp == ExpMaskAll) { @@ -211,35 +286,21 @@ static inline ToT ConvertFromFP8_CPU(uint8_t b, if (exp == 0) { if constexpr (Mbits == 0) { // E8M0: exp==0 is the smallest normal (no subnormals) - int E = -Bias; - float v = std::ldexp(1.0f, E); - return ConvertFloatToTarget(v, R); + return ConvertFloatToTarget(false, 1u, -Bias, 0, R); } else { if (frac == 0) { - float zf = std::copysign(0.0f, sign_bit ? -1.0f : 1.0f); - if constexpr (std::is_same_v || - std::is_same_v) - return ConvertFloatToTarget(zf, R); - else - return static_cast(zf); + return ConvertFloatToTarget(sign_bit != 0u, 0u, 0, 0, R); } // Subnormal: value = sign * (frac / 2^Mbits) * 2^(Emin) - float m = static_cast(frac) / static_cast(FracDen); - float v = std::ldexp(m, Emin); - return ConvertFloatToTarget((sign_bit ? -v : v), R); + return ConvertFloatToTarget(sign_bit != 0u, frac, Emin, Mbits, R); } } // Normal number. int E = static_cast(exp) - Bias; - float m; - if constexpr (Mbits == 0) - // E8M0: mantissa == 1 always - m = 1.0f; - else - m = 1.0f + static_cast(frac) / static_cast(FracDen); - float v = std::ldexp(m, E); - return ConvertFloatToTarget((sign_bit ? -v : v), R); + const uint32_t significand = + (Mbits == 0) ? 1u : (static_cast(FracDen) + frac); + return ConvertFloatToTarget(sign_bit != 0u, significand, E, Mbits, R); } /// \brief Converts a given value to fp8 floating point with a rounding @@ -296,6 +357,7 @@ ConvertToFP8_CPU(T h, rounding R = rounding::to_even) noexcept { // Determine exponent E such that 2^E <= ax < 2^{E+1} int e2; + float m = std::frexp(ax, &e2); int E = e2 - 1; // Upward rounding semantics: @@ -537,16 +599,13 @@ static inline uint8_t ConvertToE8M0_CPU(T x, rounding R, } template -static inline ToT ConvertFromE8M0_CPU(uint8_t code, - rounding R) noexcept { +static inline ToT ConvertFromE8M0_CPU(uint8_t code, rounding R) noexcept { constexpr int Bias = 127; if (code == 0xFF) { - float qn = std::numeric_limits::quiet_NaN(); - return static_cast(qn); + return MakeDirectNaN(); } - int E = static_cast(code) - Bias; // includes code==0 -> -127 - float v = std::ldexp(1.0f, E); - return ConvertFloatToTarget(v, R); + return ConvertFloatToTarget(false, 1u, static_cast(code) - Bias, 0, + R); } } // namespace detail @@ -1669,47 +1728,56 @@ template class fp8_e8m0_x { template > explicit operator char() const { - return static_cast(static_cast(*this)); + return detail::ConvertFromE8M0_CPU(vals[0], rounding::toward_zero); } template > explicit operator signed char() const { - return static_cast(static_cast(*this)); + return detail::ConvertFromE8M0_CPU(vals[0], + rounding::toward_zero); } template > explicit operator short() const { - return static_cast(static_cast(*this)); + return detail::ConvertFromE8M0_CPU(vals[0], rounding::toward_zero); } template > explicit operator int() const { - return static_cast(static_cast(*this)); + return detail::ConvertFromE8M0_CPU(vals[0], rounding::toward_zero); } template > explicit operator long() const { - return static_cast(static_cast(*this)); + return detail::ConvertFromE8M0_CPU(vals[0], rounding::toward_zero); } template > explicit operator long long() const { - return static_cast(static_cast(*this)); + return detail::ConvertFromE8M0_CPU(vals[0], + rounding::toward_zero); } template > explicit operator unsigned char() const { - return static_cast(static_cast(*this)); + return detail::ConvertFromE8M0_CPU(vals[0], + rounding::toward_zero); } template > explicit operator unsigned short() const { - return static_cast(static_cast(*this)); + return detail::ConvertFromE8M0_CPU(vals[0], + rounding::toward_zero); } template > explicit operator unsigned int() const { - return static_cast(static_cast(*this)); + return detail::ConvertFromE8M0_CPU(vals[0], + rounding::toward_zero); } + template > explicit operator unsigned long() const { - return static_cast(static_cast(*this)); + return detail::ConvertFromE8M0_CPU(vals[0], + rounding::toward_zero); } + template > explicit operator unsigned long long() const { - return static_cast(static_cast(*this)); + return detail::ConvertFromE8M0_CPU( + vals[0], rounding::toward_zero); } template > From 479f0115f380bf08a7f7147f79bba29f73e1b4c3 Mon Sep 17 00:00:00 2001 From: "Klochkov, Denis" Date: Wed, 8 Apr 2026 16:50:11 +0200 Subject: [PATCH 23/89] [SYCL] remove extra check from assert --- sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp index 61bc73a612cb9..1549687de726d 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp @@ -245,7 +245,7 @@ template static inline ToT ConvertFromFP8_CPU(uint8_t b, rounding R = rounding::to_even) noexcept { static_assert((Ebits == 4 && Mbits == 3) || (Ebits == 5 && Mbits == 2) || - (Ebits == 5 && Mbits == 3) || (Ebits == 8 && Mbits == 0), + (Ebits == 8 && Mbits == 0), "Unsupported FP8 (Ebits,Mbits) combination"); constexpr int Bias = (1 << (Ebits - 1)) - 1; From 181d92cb6ae86f1c0cba317a100da299115d6ea2 Mon Sep 17 00:00:00 2001 From: "Klochkov, Denis" Date: Wed, 8 Apr 2026 17:43:21 +0200 Subject: [PATCH 24/89] [SYCL] do not cast to half during convertion --- .../sycl/ext/oneapi/experimental/float_8bit/types.hpp | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp index 1549687de726d..14bf59840c199 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp @@ -483,8 +483,8 @@ inline uint8_t nextE4M3(uint8_t b, bool up) { : static_cast(~ord); } -template -uint8_t round(rounding r, uint8_t b, sycl::half yi, T vi) { +template +uint8_t round(rounding r, uint8_t b, YiT yi, T vi) { switch (r) { case rounding::upward: { if (yi < vi) @@ -623,7 +623,6 @@ template class fp8_e4m3_x { "fp8_e4m3_x: Template argument N must be 1 or 2"); template uint8_t ConvertToFP8(T h, rounding r, saturation s) { - sycl::half hi = static_cast(h); #ifdef __SYCL_DEVICE_ONLY__ // TODO: optimize with vectorized builtin calls uint8_t b = 0; @@ -635,10 +634,10 @@ template class fp8_e4m3_x { return b; const sycl::half yi = __builtin_spirv_ConvertE4M3ToFP16EXT(b); - return detail::round(r, b, yi, hi); + return detail::round(r, b, yi, h); #else - return detail::ConvertToFP8_CPU<4, 3, sycl::half>(hi, r); + return detail::ConvertToFP8_CPU<4, 3, sycl::half>(h, r); #endif } From e4051c6c44f08d78821ff94c964c20a3a2ea55c1 Mon Sep 17 00:00:00 2001 From: "Klochkov, Denis" Date: Wed, 8 Apr 2026 18:18:00 +0200 Subject: [PATCH 25/89] [SYCL] do not use extra checks of saturation and rounding for e5m2 and e4m3 types --- .../oneapi/experimental/float_8bit/types.hpp | 186 +++++++----------- 1 file changed, 76 insertions(+), 110 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp index 14bf59840c199..03161682d5adb 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp @@ -622,38 +622,19 @@ template class fp8_e4m3_x { static_assert(N == 1 || N == 2, "fp8_e4m3_x: Template argument N must be 1 or 2"); - template uint8_t ConvertToFP8(T h, rounding r, saturation s) { + template uint8_t ConvertToFP8(T h) { #ifdef __SYCL_DEVICE_ONLY__ - // TODO: optimize with vectorized builtin calls - uint8_t b = 0; - if (s == saturation::finite) - b = __builtin_spirv_ClampConvertFP16ToE4M3INTEL(h); - else - b = __builtin_spirv_ConvertFP16ToE4M3EXT(h); - if (r == rounding::to_even) - return b; - - const sycl::half yi = __builtin_spirv_ConvertE4M3ToFP16EXT(b); - return detail::round(r, b, yi, h); - + return __builtin_spirv_ClampConvertFP16ToE4M3INTEL(h); #else - return detail::ConvertToFP8_CPU<4, 3, sycl::half>(h, r); + return detail::ConvertToFP8_CPU<4, 3, sycl::half>(h, rounding::to_even); #endif } - uint8_t ConvertBF16ToFP8(bfloat16 h, rounding r, saturation s) { + uint8_t ConvertBF16ToFP8(bfloat16 h) { #ifdef __SYCL_DEVICE_ONLY__ - uint8_t b = 0; - if (s == saturation::finite) - b = __builtin_spirv_ClampConvertBF16ToE4M3INTEL(h); - else - b = __builtin_spirv_ConvertBF16ToE4M3EXT(h); - if (r == rounding::to_even) - return b; - const half yi = __builtin_spirv_ConvertE4M3ToFP16EXT(b); - return detail::round(r, b, yi, h); + return __builtin_spirv_ClampConvertBF16ToE4M3INTEL(h); #else - return detail::ConvertToFP8_CPU<4, 3, bfloat16>(h, r); + return detail::ConvertToFP8_CPU<4, 3, bfloat16>(h, rounding::to_even); #endif } @@ -701,13 +682,12 @@ template class fp8_e4m3_x { if constexpr (((std::is_same_v, bfloat16>) && ...)) { const bfloat16 in[N] = {static_cast(v)...}; for (size_t i = 0; i < N; ++i) - vals[i] = - ConvertBF16ToFP8(in[i], rounding::to_even, saturation::finite); + vals[i] = ConvertBF16ToFP8(in[i]); return; } const sycl::half in[N] = {v...}; for (size_t i = 0; i < N; ++i) - vals[i] = ConvertToFP8(in[i], rounding::to_even, saturation::finite); + vals[i] = ConvertToFP8(in[i]); } // Construct from an array of half, bfloat16, float, double. @@ -715,24 +695,24 @@ template class fp8_e4m3_x { rounding r = rounding::to_even) { CheckConstraints(r); for (size_t i = 0; i < N; ++i) - vals[i] = ConvertToFP8(v[i], r, saturation::finite); + vals[i] = ConvertToFP8(v[i]); } explicit fp8_e4m3_x(bfloat16 const (&v)[N], rounding r = rounding::to_even) { CheckConstraints(r); for (size_t i = 0; i < N; ++i) - vals[i] = ConvertBF16ToFP8(v[i], r, saturation::finite); + vals[i] = ConvertBF16ToFP8(v[i]); } explicit fp8_e4m3_x(float const (&v)[N], rounding r = rounding::to_even) { CheckConstraints(r); for (size_t i = 0; i < N; ++i) - vals[i] = ConvertToFP8(v[i], r, saturation::finite); + vals[i] = ConvertToFP8(v[i]); } explicit fp8_e4m3_x(double const (&v)[N]) { for (size_t i = 0; i < N; ++i) - vals[i] = ConvertToFP8(v[i], rounding::to_even, saturation::finite); + vals[i] = ConvertToFP8(v[i]); } // Construct from an marray of half, bfloat16, float, double. @@ -740,26 +720,26 @@ template class fp8_e4m3_x { rounding r = rounding::to_even) { CheckConstraints(r); for (size_t i = 0; i < N; ++i) - vals[i] = ConvertToFP8(v[i], r, saturation::finite); + vals[i] = ConvertToFP8(v[i]); } explicit fp8_e4m3_x(const sycl::marray &v, rounding r = rounding::to_even) { CheckConstraints(r); for (size_t i = 0; i < N; ++i) - vals[i] = ConvertBF16ToFP8(v[i], r, saturation::finite); + vals[i] = ConvertBF16ToFP8(v[i]); } explicit fp8_e4m3_x(const sycl::marray &v, rounding r = rounding::to_even) { CheckConstraints(r); for (size_t i = 0; i < N; ++i) - vals[i] = ConvertToFP8(v[i], r, saturation::finite); + vals[i] = ConvertToFP8(v[i]); } explicit fp8_e4m3_x(const sycl::marray &v) { for (size_t i = 0; i < N; ++i) - vals[i] = ConvertToFP8(v[i], rounding::to_even, saturation::finite); + vals[i] = ConvertToFP8(v[i]); } // Construct from integer types. @@ -767,42 +747,42 @@ template class fp8_e4m3_x { template > explicit fp8_e4m3_x(short val) { - vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); + vals[0] = ConvertToFP8(val); } template > explicit fp8_e4m3_x(int val) { - vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); + vals[0] = ConvertToFP8(val); } template > explicit fp8_e4m3_x(long val) { - vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); + vals[0] = ConvertToFP8(val); } template > explicit fp8_e4m3_x(long long val) { - vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); + vals[0] = ConvertToFP8(val); } template > explicit fp8_e4m3_x(unsigned short val) { - vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); + vals[0] = ConvertToFP8(val); } template > explicit fp8_e4m3_x(unsigned int val) { - vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); + vals[0] = ConvertToFP8(val); } template > explicit fp8_e4m3_x(unsigned long val) { - vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); + vals[0] = ConvertToFP8(val); } template > explicit fp8_e4m3_x(unsigned long long val) { - vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); + vals[0] = ConvertToFP8(val); } // Assign (operator) from half, bfloat16, float, double, and integer types. @@ -810,73 +790,73 @@ template class fp8_e4m3_x { template > fp8_e4m3_x &operator=(sycl::half val) { - vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); + vals[0] = ConvertToFP8(val); return *this; } template > fp8_e4m3_x &operator=(bfloat16 val) { - vals[0] = ConvertBF16ToFP8(val, rounding::to_even, saturation::finite); + vals[0] = ConvertBF16ToFP8(val); return *this; } template > fp8_e4m3_x &operator=(float val) { - vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); + vals[0] = ConvertToFP8(val); return *this; } template > fp8_e4m3_x &operator=(double val) { - vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); + vals[0] = ConvertToFP8(val); return *this; } template > fp8_e4m3_x &operator=(short val) { - vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); + vals[0] = ConvertToFP8(val); return *this; } template > fp8_e4m3_x &operator=(int val) { - vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); + vals[0] = ConvertToFP8(val); return *this; } template > fp8_e4m3_x &operator=(long val) { - vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); + vals[0] = ConvertToFP8(val); return *this; } template > fp8_e4m3_x &operator=(long long val) { - vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); + vals[0] = ConvertToFP8(val); return *this; } template > fp8_e4m3_x &operator=(unsigned short val) { - vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); + vals[0] = ConvertToFP8(val); return *this; } template > fp8_e4m3_x &operator=(unsigned int val) { - vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); + vals[0] = ConvertToFP8(val); return *this; } template > fp8_e4m3_x &operator=(unsigned long val) { - vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); + vals[0] = ConvertToFP8(val); return *this; } template > fp8_e4m3_x &operator=(unsigned long long val) { - vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); + vals[0] = ConvertToFP8(val); return *this; } @@ -1010,36 +990,23 @@ template class fp8_e5m2_x { static_assert(N == 1 || N == 2, "fp8_e5m2_x: Template argument N must be 1 or 2"); - uint8_t ConvertToFP8(sycl::half h, rounding r, saturation s) { + uint8_t ConvertToFP8(sycl::half h, saturation s) { #ifdef __SYCL_DEVICE_ONLY__ - uint8_t b = 0; - if (s == saturation::finite) - b = __builtin_spirv_ClampConvertFP16ToE5M2INTEL(h); - else - b = __builtin_spirv_ConvertFP16ToE5M2EXT(h); - if (r == rounding::to_even) - return b; - const sycl::half yi = __builtin_spirv_ConvertE5M2ToFP16EXT(b); - return detail::round(r, b, yi, h); - + return s == saturation::finite + ? __builtin_spirv_ClampConvertFP16ToE5M2INTEL(h) + : __builtin_spirv_ConvertFP16ToE5M2EXT(h); #else - return detail::ConvertToFP8_CPU<5, 2, sycl::half>(h, r); + return detail::ConvertToFP8_CPU<5, 2, sycl::half>(h, rounding::to_even); #endif } - uint8_t ConvertBF16ToFP8(bfloat16 h, rounding r, saturation s) { + uint8_t ConvertBF16ToFP8(bfloat16 h, saturation s) { #ifdef __SYCL_DEVICE_ONLY__ - uint8_t b = 0; - if (s == saturation::finite) - b = __builtin_spirv_ClampConvertBF16ToE5M2INTEL(h); - else - b = __builtin_spirv_ConvertBF16ToE5M2EXT(h); - if (r == rounding::to_even) - return b; - const bfloat16 yi = __builtin_spirv_ConvertE5M2ToBF16EXT(b); - return detail::round(r, b, yi, h); + return s == saturation::finite + ? __builtin_spirv_ClampConvertBF16ToE5M2INTEL(h) + : __builtin_spirv_ConvertBF16ToE5M2EXT(h); #else - return detail::ConvertToFP8_CPU<5, 2, bfloat16>(h, r); + return detail::ConvertToFP8_CPU<5, 2, bfloat16>(h, rounding::to_even); #endif } @@ -1088,13 +1055,12 @@ template class fp8_e5m2_x { if constexpr (((std::is_same_v, bfloat16>) && ...)) { const bfloat16 in[N] = {static_cast(v)...}; for (size_t i = 0; i < N; ++i) - vals[i] = - ConvertBF16ToFP8(in[i], rounding::to_even, saturation::finite); + vals[i] = ConvertBF16ToFP8(in[i], saturation::finite); return; } const sycl::half in[N] = {v...}; for (size_t i = 0; i < N; ++i) - vals[i] = ConvertToFP8(in[i], rounding::to_even, saturation::finite); + vals[i] = ConvertToFP8(in[i], saturation::finite); } // Construct from an array of half, bfloat16, float, double. @@ -1104,7 +1070,7 @@ template class fp8_e5m2_x { CheckConstraints(r, s); // TODO: optimize with vectorized builtin calls for (size_t i = 0; i < N; ++i) - vals[i] = ConvertToFP8(v[i], r, s); + vals[i] = ConvertToFP8(v[i], s); } explicit fp8_e5m2_x(bfloat16 const (&v)[N], rounding r = rounding::to_even, @@ -1112,19 +1078,19 @@ template class fp8_e5m2_x { CheckConstraints(r, s); // TODO: optimize with vectorized builtin calls for (size_t i = 0; i < N; ++i) - vals[i] = ConvertBF16ToFP8(v[i], r, s); + vals[i] = ConvertBF16ToFP8(v[i], s); } explicit fp8_e5m2_x(float const (&v)[N], rounding r = rounding::to_even, saturation s = saturation::finite) { CheckConstraints(r, s); for (size_t i = 0; i < N; ++i) - vals[i] = ConvertToFP8(v[i], r, s); + vals[i] = ConvertToFP8(v[i], s); } explicit fp8_e5m2_x(double const (&v)[N]) { for (size_t i = 0; i < N; ++i) - vals[i] = ConvertToFP8(v[i], rounding::to_even, saturation::finite); + vals[i] = ConvertToFP8(v[i], saturation::finite); } // Construct from an marray of half, bfloat16, float, double. @@ -1134,7 +1100,7 @@ template class fp8_e5m2_x { saturation s = saturation::finite) { CheckConstraints(r, s); for (size_t i = 0; i < N; ++i) - vals[i] = ConvertToFP8(v[i], r, s); + vals[i] = ConvertToFP8(v[i], s); } explicit fp8_e5m2_x(const sycl::marray &v, @@ -1142,7 +1108,7 @@ template class fp8_e5m2_x { saturation s = saturation::finite) { CheckConstraints(r, s); for (size_t i = 0; i < N; ++i) - vals[i] = ConvertBF16ToFP8(v[i], r, s); + vals[i] = ConvertBF16ToFP8(v[i], s); } explicit fp8_e5m2_x(const sycl::marray &v, @@ -1150,12 +1116,12 @@ template class fp8_e5m2_x { saturation s = saturation::finite) { CheckConstraints(r, s); for (size_t i = 0; i < N; ++i) - vals[i] = ConvertToFP8(v[i], r, s); + vals[i] = ConvertToFP8(v[i], s); } explicit fp8_e5m2_x(const sycl::marray &v) { for (size_t i = 0; i < N; ++i) - vals[i] = ConvertToFP8(v[i], rounding::to_even, saturation::finite); + vals[i] = ConvertToFP8(v[i], saturation::finite); } // Construct with stochastic rounding with user provided seed from an array of @@ -1279,42 +1245,42 @@ template class fp8_e5m2_x { template > explicit fp8_e5m2_x(short val) { - vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); + vals[0] = ConvertToFP8(val, saturation::finite); } template > explicit fp8_e5m2_x(int val) { - vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); + vals[0] = ConvertToFP8(val, saturation::finite); } template > explicit fp8_e5m2_x(long val) { - vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); + vals[0] = ConvertToFP8(val, saturation::finite); } template > explicit fp8_e5m2_x(long long val) { - vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); + vals[0] = ConvertToFP8(val, saturation::finite); } template > explicit fp8_e5m2_x(unsigned short val) { - vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); + vals[0] = ConvertToFP8(val, saturation::finite); } template > explicit fp8_e5m2_x(unsigned int val) { - vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); + vals[0] = ConvertToFP8(val, saturation::finite); } template > explicit fp8_e5m2_x(unsigned long val) { - vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); + vals[0] = ConvertToFP8(val, saturation::finite); } template > explicit fp8_e5m2_x(unsigned long long val) { - vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); + vals[0] = ConvertToFP8(val, saturation::finite); } // Assign (operator) from half, bfloat16, float, double, and integer types. @@ -1322,73 +1288,73 @@ template class fp8_e5m2_x { template > fp8_e5m2_x &operator=(sycl::half val) { - vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); + vals[0] = ConvertToFP8(val, saturation::finite); return *this; } template > fp8_e5m2_x &operator=(bfloat16 val) { - vals[0] = ConvertBF16ToFP8(val, rounding::to_even, saturation::finite); + vals[0] = ConvertBF16ToFP8(val, saturation::finite); return *this; } template > fp8_e5m2_x &operator=(float val) { - vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); + vals[0] = ConvertToFP8(val, saturation::finite); return *this; } template > fp8_e5m2_x &operator=(double val) { - vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); + vals[0] = ConvertToFP8(val, saturation::finite); return *this; } template > fp8_e5m2_x &operator=(short val) { - vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); + vals[0] = ConvertToFP8(val, saturation::finite); return *this; } template > fp8_e5m2_x &operator=(int val) { - vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); + vals[0] = ConvertToFP8(val, saturation::finite); return *this; } template > fp8_e5m2_x &operator=(long val) { - vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); + vals[0] = ConvertToFP8(val, saturation::finite); return *this; } template > fp8_e5m2_x &operator=(long long val) { - vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); + vals[0] = ConvertToFP8(val, saturation::finite); return *this; } template > fp8_e5m2_x &operator=(unsigned short val) { - vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); + vals[0] = ConvertToFP8(val, saturation::finite); return *this; } template > fp8_e5m2_x &operator=(unsigned int val) { - vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); + vals[0] = ConvertToFP8(val, saturation::finite); return *this; } template > fp8_e5m2_x &operator=(unsigned long val) { - vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); + vals[0] = ConvertToFP8(val, saturation::finite); return *this; } template > fp8_e5m2_x &operator=(unsigned long long val) { - vals[0] = ConvertToFP8(val, rounding::to_even, saturation::finite); + vals[0] = ConvertToFP8(val, saturation::finite); return *this; } From f8aa6fb4a5f86d527fac20220e22f01fd1cba52c Mon Sep 17 00:00:00 2001 From: "Klochkov, Denis" Date: Wed, 8 Apr 2026 18:41:57 +0200 Subject: [PATCH 26/89] [SYCL] replace exceptions with asserts --- .../oneapi/experimental/float_8bit/types.hpp | 33 +++++++++---------- 1 file changed, 16 insertions(+), 17 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp index 03161682d5adb..2566de71d21cd 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp @@ -656,9 +656,8 @@ template class fp8_e4m3_x { } void CheckConstraints(rounding r) const { - if (r != rounding::to_even) - throw std::invalid_argument( - "fp8_e4m3_x: only rounding::to_even is supported"); + assert(r == rounding::to_even && + "fp8_e4m3_x: only rounding::to_even is supported"); } public: @@ -1027,10 +1026,9 @@ template class fp8_e5m2_x { #endif } - void CheckConstraints(rounding r, saturation s) const { - if (r != rounding::to_even) - throw std::invalid_argument( - "fp8_e5m2_x: only rounding::to_even is supported"); + void CheckConstraints(rounding r) const { + assert(r == rounding::to_even && + "fp8_e5m2_x: only rounding::to_even is supported"); } public: @@ -1067,7 +1065,7 @@ template class fp8_e5m2_x { explicit fp8_e5m2_x(half const (&v)[N], rounding r = rounding::to_even, saturation s = saturation::finite) { - CheckConstraints(r, s); + CheckConstraints(r); // TODO: optimize with vectorized builtin calls for (size_t i = 0; i < N; ++i) vals[i] = ConvertToFP8(v[i], s); @@ -1075,7 +1073,7 @@ template class fp8_e5m2_x { explicit fp8_e5m2_x(bfloat16 const (&v)[N], rounding r = rounding::to_even, saturation s = saturation::finite) { - CheckConstraints(r, s); + CheckConstraints(r); // TODO: optimize with vectorized builtin calls for (size_t i = 0; i < N; ++i) vals[i] = ConvertBF16ToFP8(v[i], s); @@ -1083,7 +1081,7 @@ template class fp8_e5m2_x { explicit fp8_e5m2_x(float const (&v)[N], rounding r = rounding::to_even, saturation s = saturation::finite) { - CheckConstraints(r, s); + CheckConstraints(r); for (size_t i = 0; i < N; ++i) vals[i] = ConvertToFP8(v[i], s); } @@ -1098,7 +1096,7 @@ template class fp8_e5m2_x { explicit fp8_e5m2_x(const sycl::marray &v, rounding r = rounding::to_even, saturation s = saturation::finite) { - CheckConstraints(r, s); + CheckConstraints(r); for (size_t i = 0; i < N; ++i) vals[i] = ConvertToFP8(v[i], s); } @@ -1106,7 +1104,7 @@ template class fp8_e5m2_x { explicit fp8_e5m2_x(const sycl::marray &v, rounding r = rounding::to_even, saturation s = saturation::finite) { - CheckConstraints(r, s); + CheckConstraints(r); for (size_t i = 0; i < N; ++i) vals[i] = ConvertBF16ToFP8(v[i], s); } @@ -1114,7 +1112,7 @@ template class fp8_e5m2_x { explicit fp8_e5m2_x(const sycl::marray &v, rounding r = rounding::to_even, saturation s = saturation::finite) { - CheckConstraints(r, s); + CheckConstraints(r); for (size_t i = 0; i < N; ++i) vals[i] = ConvertToFP8(v[i], s); } @@ -1477,10 +1475,11 @@ template class fp8_e8m0_x { "fp8_e8m0_x: Template argument N must be 1 or 2"); void CheckConstraints(rounding r) const { - - if (r != rounding::upward && r != rounding::toward_zero) - throw std::invalid_argument("fp8_e8m0_x: only rounding::upward and " - "rounding::toward_zero are supported"); + assert( + r == rounding::upward || + r == rounding::toward_zero && + "fp8_e8m0_x: only rounding::upward and rounding::toward_zero are " + "supported"); } public: From a92fff87cf955a436f785aec5c8faacb047d163e Mon Sep 17 00:00:00 2001 From: "Klochkov, Denis" Date: Thu, 9 Apr 2026 11:12:38 +0200 Subject: [PATCH 27/89] [SYCL] remove unused functions --- .../oneapi/experimental/float_8bit/types.hpp | 50 +------------------ 1 file changed, 2 insertions(+), 48 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp index 2566de71d21cd..53c4ed25ee604 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp @@ -458,52 +458,6 @@ ConvertToFP8_CPU(T h, rounding R = rounding::to_even) noexcept { return ret; } -// Map E4M3 byte to integer -// then "nextUp" in that order, and map back. -// E4M3 finite-only: exp=0xF & frac!=0 => NaN (no Inf). -inline uint8_t nextE4M3(uint8_t b, bool up) { - uint8_t exp = (b >> 3) & 0x0F; - uint8_t frac = b & 0x07; - // NaN -> NaN - if (exp == 0x0F && frac) - return b; - uint8_t ord = - (b & 0x80) ? static_cast(~b) : static_cast(b ^ 0x80); - - if (up) { - if (ord == 0xFF) - return b; - ++ord; - } else { - if (ord == 0x00) - return b; - --ord; - } - return (ord & 0x80) ? static_cast(ord ^ 0x80) - : static_cast(~ord); -} - -template -uint8_t round(rounding r, uint8_t b, YiT yi, T vi) { - switch (r) { - case rounding::upward: { - if (yi < vi) - return nextE4M3(b, /*up=*/true); - break; - } - case rounding::toward_zero: - if (vi > 0.0f && yi > vi) { - return nextE4M3(b, /*up=*/false); - } else if (vi < 0.0f && yi < vi) { - return nextE4M3(b, /*up=*/true); - } - break; - default: - break; - } - return b; -} - template static inline uint8_t ConvertToE8M0_CPU(T x, rounding R, saturation S) noexcept { @@ -1476,8 +1430,8 @@ template class fp8_e8m0_x { void CheckConstraints(rounding r) const { assert( - r == rounding::upward || - r == rounding::toward_zero && + (r == rounding::upward || + r == rounding::toward_zero) && "fp8_e8m0_x: only rounding::upward and rounding::toward_zero are " "supported"); } From d99b83f6a0c5511e90c8cd2cd14ff55b1fab0bc3 Mon Sep 17 00:00:00 2001 From: "Klochkov, Denis" Date: Thu, 9 Apr 2026 16:54:05 +0200 Subject: [PATCH 28/89] [SYCL] add tests to check rouning constraints --- .../oneapi/experimental/float_8bit/types.hpp | 8 +- sycl/unittests/Extensions/fp8/fp8_e4m3.cpp | 253 +++++++++++++++- sycl/unittests/Extensions/fp8/fp8_e5m2.cpp | 273 +++++++++++++++++- sycl/unittests/Extensions/fp8/fp8_e8m0.cpp | 213 +++++++++++++- 4 files changed, 710 insertions(+), 37 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp index 53c4ed25ee604..e0ae4fd34d5aa 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp @@ -1429,11 +1429,9 @@ template class fp8_e8m0_x { "fp8_e8m0_x: Template argument N must be 1 or 2"); void CheckConstraints(rounding r) const { - assert( - (r == rounding::upward || - r == rounding::toward_zero) && - "fp8_e8m0_x: only rounding::upward and rounding::toward_zero are " - "supported"); + assert((r == rounding::upward || r == rounding::toward_zero) && + "fp8_e8m0_x: only rounding::upward and rounding::toward_zero are " + "supported"); } public: diff --git a/sycl/unittests/Extensions/fp8/fp8_e4m3.cpp b/sycl/unittests/Extensions/fp8/fp8_e4m3.cpp index ed90cead2d43f..341e039df8547 100644 --- a/sycl/unittests/Extensions/fp8/fp8_e4m3.cpp +++ b/sycl/unittests/Extensions/fp8/fp8_e4m3.cpp @@ -4,6 +4,7 @@ #include #include #include +#include /* Unit tests check only CPU versions. Most of the constraints related to device @@ -12,7 +13,7 @@ code thus unit tests check only API using namespace sycl::ext::oneapi::experimental; -TEST(FP8E4M3Test, VariadicConstructorHalf) { +TEST(FP8E4M3Test, VariadicHalf) { fp8_e4m3_x2 a(sycl::half(1.0f), sycl::half(2.0f)); EXPECT_EQ(sizeof(a.vals), 2u); @@ -24,7 +25,7 @@ TEST(FP8E4M3Test, VariadicConstructorHalf) { EXPECT_EQ(b.vals[0], 0x39); // 1.1 rounds to 1.125 -> frac=1 } -TEST(FP8E4M3Test, VariadicConstructorBFloat16) { +TEST(FP8E4M3Test, VariadicBFloat16) { fp8_e4m3_x2 a(sycl::ext::oneapi::bfloat16(1.0f), sycl::ext::oneapi::bfloat16(2.0f)); @@ -37,7 +38,7 @@ TEST(FP8E4M3Test, VariadicConstructorBFloat16) { EXPECT_EQ(b.vals[0], 0x39); } -TEST(FP8E4M3Test, VariadicConstructorFloat) { +TEST(FP8E4M3Test, VariadicFloat) { fp8_e4m3_x2 a(1.0f, 2.0f); EXPECT_EQ(sizeof(a.vals), 2u); @@ -410,3 +411,249 @@ TEST(FP8E4M3Test, VariadicRejectsMixedTypes) { EXPECT_FALSE((std::is_constructible_v)); EXPECT_FALSE((std::is_constructible_v)); } + +TEST(FP8E4M3Test, X2NotConstructibleFromSingleShort) { + EXPECT_FALSE((std::is_constructible_v)); +} + +TEST(FP8E4M3Test, X2NotConstructibleFromSingleInt) { + EXPECT_FALSE((std::is_constructible_v)); +} + +TEST(FP8E4M3Test, X2NotConstructibleFromSingleLong) { + EXPECT_FALSE((std::is_constructible_v)); +} + +TEST(FP8E4M3Test, X2NotConstructibleFromSingleLL) { + EXPECT_FALSE((std::is_constructible_v)); +} + +TEST(FP8E4M3Test, X2NotConstructibleFromSingleUShort) { + EXPECT_FALSE((std::is_constructible_v)); +} + +TEST(FP8E4M3Test, X2NotConstructibleFromSingleUInt) { + EXPECT_FALSE((std::is_constructible_v)); +} + +TEST(FP8E4M3Test, X2NotConstructibleFromSingleUL) { + EXPECT_FALSE((std::is_constructible_v)); +} + +TEST(FP8E4M3Test, X2NotConstructibleFromSingleULL) { + EXPECT_FALSE((std::is_constructible_v)); +} + +TEST(FP8E4M3Test, X2NotConstructibleFromSingleFloat) { + EXPECT_FALSE((std::is_constructible_v)); +} + +TEST(FP8E4M3Test, X2NotConstructibleFromSingleDouble) { + EXPECT_FALSE((std::is_constructible_v)); +} + +TEST(FP8E4M3Test, X2NotConstructibleFromSingleBFloat16) { + EXPECT_FALSE((std::is_constructible_v)); +} + +TEST(FP8E4M3Test, X2NotConstructibleFromSingleHalf) { + EXPECT_FALSE((std::is_constructible_v)); +} + +TEST(FP8E4M3Test, X2NotConstructibleFromSingleChar) { + EXPECT_FALSE((std::is_constructible_v)); +} + +TEST(FP8E4M3Test, X2NotConstructibleFromSingleUChar) { + EXPECT_FALSE((std::is_constructible_v)); +} + +TEST(FP8E4M3Test, X2NotAssignableFromSingleHalf) { + EXPECT_FALSE((std::is_assignable_v)); +} + +TEST(FP8E4M3Test, X2NotAssignableFromSingleBFloat16) { + EXPECT_FALSE( + (std::is_assignable_v)); +} + +TEST(FP8E4M3Test, X2NotAssignableFromSingleFloat) { + EXPECT_FALSE((std::is_assignable_v)); +} + +TEST(FP8E4M3Test, X2NotAssignableFromSingleDouble) { + EXPECT_FALSE((std::is_assignable_v)); +} + +TEST(FP8E4M3Test, X2NotAssignableFromSingleChar) { + EXPECT_FALSE((std::is_assignable_v)); +} + +TEST(FP8E4M3Test, X2NotAssignableFromSingleSignedChar) { + EXPECT_FALSE((std::is_assignable_v)); +} + +TEST(FP8E4M3Test, X2NotAssignableFromSingleUChar) { + EXPECT_FALSE((std::is_assignable_v)); +} + +TEST(FP8E4M3Test, X2NotAssignableFromSingleShort) { + EXPECT_FALSE((std::is_assignable_v)); +} + +TEST(FP8E4M3Test, X2NotAssignableFromSingleInt) { + EXPECT_FALSE((std::is_assignable_v)); +} + +TEST(FP8E4M3Test, X2NotAssignableFromSingleLong) { + EXPECT_FALSE((std::is_assignable_v)); +} + +TEST(FP8E4M3Test, X2NotAssignableFromSingleLL) { + EXPECT_FALSE((std::is_assignable_v)); +} + +TEST(FP8E4M3Test, X2NotAssignableFromSingleUShort) { + EXPECT_FALSE((std::is_assignable_v)); +} + +TEST(FP8E4M3Test, X2NotAssignableFromSingleUInt) { + EXPECT_FALSE((std::is_assignable_v)); +} + +TEST(FP8E4M3Test, X2NotAssignableFromSingleUL) { + EXPECT_FALSE((std::is_assignable_v)); +} + +TEST(FP8E4M3Test, X2NotAssignableFromSingleULL) { + EXPECT_FALSE((std::is_assignable_v)); +} + +TEST(FP8E4M3Test, CArrayHalfRejectsUpwardRounding) { + const sycl::half in[2] = {sycl::half(1.0f), sycl::half(2.0f)}; + EXPECT_DEATH( + { + fp8_e4m3_x2 value(in, rounding::upward); + (void)value; + }, + "fp8_e4m3_x: only rounding::to_even is supported"); +} + +TEST(FP8E4M3Test, CArrayHalfRejectsTowardZeroRounding) { + const sycl::half in[2] = {sycl::half(1.0f), sycl::half(2.0f)}; + EXPECT_DEATH( + { + fp8_e4m3_x2 value(in, rounding::toward_zero); + (void)value; + }, + "fp8_e4m3_x: only rounding::to_even is supported"); +} + +TEST(FP8E4M3Test, CArrayBFloat16RejectsUpwardRounding) { + const sycl::ext::oneapi::bfloat16 in[2] = { + sycl::ext::oneapi::bfloat16(1.0f), + sycl::ext::oneapi::bfloat16(2.0f)}; + EXPECT_DEATH( + { + fp8_e4m3_x2 value(in, rounding::upward); + (void)value; + }, + "fp8_e4m3_x: only rounding::to_even is supported"); +} + +TEST(FP8E4M3Test, CArrayBFloat16RejectsTowardZeroRounding) { + const sycl::ext::oneapi::bfloat16 in[2] = { + sycl::ext::oneapi::bfloat16(1.0f), + sycl::ext::oneapi::bfloat16(2.0f)}; + EXPECT_DEATH( + { + fp8_e4m3_x2 value(in, rounding::toward_zero); + (void)value; + }, + "fp8_e4m3_x: only rounding::to_even is supported"); +} + +TEST(FP8E4M3Test, CArrayFloatRejectsUpwardRounding) { + const float in[2] = {1.0f, 2.0f}; + EXPECT_DEATH( + { + fp8_e4m3_x2 value(in, rounding::upward); + (void)value; + }, + "fp8_e4m3_x: only rounding::to_even is supported"); +} + +TEST(FP8E4M3Test, CArrayFloatRejectsTowardZeroRounding) { + const float in[2] = {1.0f, 2.0f}; + EXPECT_DEATH( + { + fp8_e4m3_x2 value(in, rounding::toward_zero); + (void)value; + }, + "fp8_e4m3_x: only rounding::to_even is supported"); +} + +TEST(FP8E4M3Test, MarrayHalfRejectsUpwardRounding) { + const sycl::marray in = {sycl::half(1.0f), sycl::half(2.0f)}; + EXPECT_DEATH( + { + fp8_e4m3_x2 value(in, rounding::upward); + (void)value; + }, + "fp8_e4m3_x: only rounding::to_even is supported"); +} + +TEST(FP8E4M3Test, MarrayHalfRejectsTowardZeroRounding) { + const sycl::marray in = {sycl::half(1.0f), sycl::half(2.0f)}; + EXPECT_DEATH( + { + fp8_e4m3_x2 value(in, rounding::toward_zero); + (void)value; + }, + "fp8_e4m3_x: only rounding::to_even is supported"); +} + +TEST(FP8E4M3Test, MarrayBFloat16RejectsUpwardRounding) { + const sycl::marray in = { + sycl::ext::oneapi::bfloat16(1.0f), + sycl::ext::oneapi::bfloat16(2.0f)}; + EXPECT_DEATH( + { + fp8_e4m3_x2 value(in, rounding::upward); + (void)value; + }, + "fp8_e4m3_x: only rounding::to_even is supported"); +} + +TEST(FP8E4M3Test, MarrayBFloat16RejectsTowardZeroRounding) { + const sycl::marray in = { + sycl::ext::oneapi::bfloat16(1.0f), + sycl::ext::oneapi::bfloat16(2.0f)}; + EXPECT_DEATH( + { + fp8_e4m3_x2 value(in, rounding::toward_zero); + (void)value; + }, + "fp8_e4m3_x: only rounding::to_even is supported"); +} + +TEST(FP8E4M3Test, MarrayFloatRejectsUpwardRounding) { + const sycl::marray in = {1.0f, 2.0f}; + EXPECT_DEATH( + { + fp8_e4m3_x2 value(in, rounding::upward); + (void)value; + }, + "fp8_e4m3_x: only rounding::to_even is supported"); +} + +TEST(FP8E4M3Test, MarrayFloatRejectsTowardZeroRounding) { + const sycl::marray in = {1.0f, 2.0f}; + EXPECT_DEATH( + { + fp8_e4m3_x2 value(in, rounding::toward_zero); + (void)value; + }, + "fp8_e4m3_x: only rounding::to_even is supported"); +} diff --git a/sycl/unittests/Extensions/fp8/fp8_e5m2.cpp b/sycl/unittests/Extensions/fp8/fp8_e5m2.cpp index 6a5123b5f3cc6..3ca8a3651cc49 100644 --- a/sycl/unittests/Extensions/fp8/fp8_e5m2.cpp +++ b/sycl/unittests/Extensions/fp8/fp8_e5m2.cpp @@ -4,6 +4,7 @@ #include #include #include +#include /* Unit tests check only CPU versions. Most of the constraints related to device @@ -12,7 +13,7 @@ code thus unit tests check only API using namespace sycl::ext::oneapi::experimental; -TEST(FP8E5M2Test, VariadicConstructorHalf) { +TEST(FP8E5M2Test, VariadicHalf) { fp8_e5m2_x2 a(sycl::half(1.0f), sycl::half(2.0f)); EXPECT_EQ(sizeof(a.vals), 2u); @@ -24,7 +25,7 @@ TEST(FP8E5M2Test, VariadicConstructorHalf) { EXPECT_EQ(b.vals[0], 0x3C); // 1.1 rounds to 1.0 } -TEST(FP8E5M2Test, VariadicConstructorBFloat16) { +TEST(FP8E5M2Test, VariadicBFloat16) { fp8_e5m2_x2 a(sycl::ext::oneapi::bfloat16(1.0f), sycl::ext::oneapi::bfloat16(2.0f)); @@ -37,7 +38,7 @@ TEST(FP8E5M2Test, VariadicConstructorBFloat16) { EXPECT_EQ(b.vals[0], 0x3C); } -TEST(FP8E5M2Test, VariadicConstructorFloat) { +TEST(FP8E5M2Test, VariadicFloat) { fp8_e5m2_x2 a(1.0f, 2.0f); EXPECT_EQ(sizeof(a.vals), 2u); @@ -49,7 +50,7 @@ TEST(FP8E5M2Test, VariadicConstructorFloat) { EXPECT_EQ(b.vals[0], 0x3C); } -TEST(FP8E5M2Test, VariadicConstructorBoundaryEncodingsFloat) { +TEST(FP8E5M2Test, VariadicBoundaryEncodingsFloat) { fp8_e5m2_x2 a(57344.0f, // max normal -> S.11110.11 0.00006103515625f // min normal -> S.00001.00 (2^-14) ); @@ -75,7 +76,7 @@ TEST(FP8E5M2Test, VariadicConstructorBoundaryEncodingsFloat) { EXPECT_EQ(a2.vals[1], 0x80); // -0 -> 0b1_00000_00 } -TEST(FP8E5M2Test, VariadicConstructorNaNEncodingFloat) { +TEST(FP8E5M2Test, VariadicNaNEncodingFloat) { fp8_e5m2_x2 a(std::numeric_limits::quiet_NaN(), -std::numeric_limits::quiet_NaN()); @@ -174,7 +175,7 @@ TEST(FP8E5M2Test, BoolOperatorZeroRules) { EXPECT_TRUE(static_cast(sub)); } -TEST(FP8E5M2Test, VariadicConstructorSaturatesFinite) { +TEST(FP8E5M2Test, VariadicSaturatesFinite) { // Variadic constructors: to_even + finite saturation (CPU). fp8_e5m2_x2 a(1.0f, 100000.0f // above max normal: clamp to +57344 @@ -191,7 +192,7 @@ TEST(FP8E5M2Test, VariadicConstructorSaturatesFinite) { EXPECT_EQ(a1.vals[1], 0x80); // -0 } -TEST(FP8E5M2Test, VariadicConstructorToEvenTie) { +TEST(FP8E5M2Test, VariadicToEvenTie) { // Tie case: between 1.0 (0x3C) and 1.25 (0x3D) is 1.125 exactly. // to_even => choose 1.0 because its LSB (fraction) is even (0). // Tie between 1.25 (0x3D) and 1.5 (0x3E) is 1.375 exactly => choose 1.5. @@ -201,7 +202,7 @@ TEST(FP8E5M2Test, VariadicConstructorToEvenTie) { EXPECT_EQ(a.vals[1], 0xBE); } -TEST(FP8E5M2Test, CArrayConstructorFloatHostToEvenFinite) { +TEST(FP8E5M2Test, CArrayFloatHostToEvenFinite) { // Host code supports only rounding::to_even and saturation::finite. const float in[2] = {1.0f, 1.1f}; const float in1[2] = {1.125f, 100000.0f}; @@ -216,7 +217,7 @@ TEST(FP8E5M2Test, CArrayConstructorFloatHostToEvenFinite) { EXPECT_EQ(a1.vals[1], 0x7B); // finite saturation => +57344 } -TEST(FP8E5M2Test, CArrayConstructorDoubleToEvenFinite) { +TEST(FP8E5M2Test, CArrayDoubleToEvenFinite) { // Double c-array: to_even + finite saturation. const double in[2] = {57344.0, 60000.0}; const double in1[2] = {0.00006103515625, 0.0000457763671875}; @@ -237,7 +238,7 @@ TEST(FP8E5M2Test, CArrayConstructorDoubleToEvenFinite) { EXPECT_EQ(a2.vals[1], 0x7F); // NaN } -TEST(FP8E5M2Test, CArrayConstructorHalfHostToEvenFinite) { +TEST(FP8E5M2Test, CArrayHalfHostToEvenFinite) { const sycl::half in[2] = {sycl::half(1.0f), sycl::half(2.0f)}; const sycl::half in1[2] = {sycl::half(1.125f), sycl::half(-0.0f)}; fp8_e5m2_x2 a(in); @@ -250,7 +251,7 @@ TEST(FP8E5M2Test, CArrayConstructorHalfHostToEvenFinite) { EXPECT_EQ(a1.vals[1], 0x80); } -TEST(FP8E5M2Test, CArrayConstructorBFloat16HostToEvenFinite) { +TEST(FP8E5M2Test, CArrayBFloat16HostToEvenFinite) { const sycl::ext::oneapi::bfloat16 in[2] = {sycl::ext::oneapi::bfloat16(1.0f), sycl::ext::oneapi::bfloat16(2.0f)}; const sycl::ext::oneapi::bfloat16 in1[2] = { @@ -266,7 +267,7 @@ TEST(FP8E5M2Test, CArrayConstructorBFloat16HostToEvenFinite) { EXPECT_EQ(a1.vals[1], 0x80); } -TEST(FP8E5M2Test, MarrayConstructorAndOperators) { +TEST(FP8E5M2Test, MarrayAndOperators) { sycl::marray in = {1.0f, 2.0f}; sycl::marray in1 = {0.0f, -0.0f}; sycl::marray in2 = {57344.0f, 100000.0f}; @@ -305,7 +306,7 @@ TEST(FP8E5M2Test, MarrayConstructorAndOperators) { EXPECT_EQ(out3[1], -1.5f); } -TEST(FP8E5M2Test, MarrayConstructorDouble) { +TEST(FP8E5M2Test, MarrayDouble) { sycl::marray dvals = {1.0, 2.0}; sycl::marray dvals1 = {57344.0, -0.0}; @@ -447,8 +448,250 @@ TEST(FP8E5M2Test, BoolOperatorWithNaN) { EXPECT_EQ(nanv.vals[0], 0x7F); // NaN encoding remains S.11111.11 } -TEST(FP8E5M2Test, VariadicRejectsMixedTypes) { +TEST(FP8E5M2Test, VariadicMixedScalarTypes) { EXPECT_FALSE((std::is_constructible_v)); EXPECT_FALSE( (std::is_constructible_v)); -} \ No newline at end of file +} + +TEST(FP8E5M2Test, X2NotConstructibleFromSingleShort) { + EXPECT_FALSE((std::is_constructible_v)); +} + +TEST(FP8E5M2Test, X2NotConstructibleFromSingleInt) { + EXPECT_FALSE((std::is_constructible_v)); +} + +TEST(FP8E5M2Test, X2NotConstructibleFromSingleLong) { + EXPECT_FALSE((std::is_constructible_v)); +} + +TEST(FP8E5M2Test, X2NotConstructibleFromSingleLL) { + EXPECT_FALSE((std::is_constructible_v)); +} + +TEST(FP8E5M2Test, X2NotConstructibleFromSingleUShort) { + EXPECT_FALSE((std::is_constructible_v)); +} + +TEST(FP8E5M2Test, X2NotConstructibleFromSingleUInt) { + EXPECT_FALSE((std::is_constructible_v)); +} + +TEST(FP8E5M2Test, X2NotConstructibleFromSingleUL) { + EXPECT_FALSE((std::is_constructible_v)); +} + +TEST(FP8E5M2Test, X2NotConstructibleFromSingleULL) { + EXPECT_FALSE((std::is_constructible_v)); +} + +TEST(FP8E5M2Test, X2NotConstructibleFromSingleFloat) { + EXPECT_FALSE((std::is_constructible_v)); +} + +TEST(FP8E5M2Test, X2NotConstructibleFromSingleDouble) { + EXPECT_FALSE((std::is_constructible_v)); +} + +TEST(FP8E5M2Test, X2NotConstructibleFromSingleBFloat16) { + EXPECT_FALSE( + (std::is_constructible_v)); +} + +TEST(FP8E5M2Test, X2NotConstructibleFromSingleHalf) { + EXPECT_FALSE((std::is_constructible_v)); +} + +TEST(FP8E5M2Test, X2NotConstructibleFromSingleChar) { + EXPECT_FALSE((std::is_constructible_v)); +} + +TEST(FP8E5M2Test, X2NotConstructibleFromSingleUChar) { + EXPECT_FALSE((std::is_constructible_v)); +} + +TEST(FP8E5M2Test, X2NotAssignableFromSingleHalf) { + EXPECT_FALSE((std::is_assignable_v)); +} + +TEST(FP8E5M2Test, X2NotAssignableFromSingleBFloat16) { + EXPECT_FALSE( + (std::is_assignable_v)); +} + +TEST(FP8E5M2Test, X2NotAssignableFromSingleFloat) { + EXPECT_FALSE((std::is_assignable_v)); +} + +TEST(FP8E5M2Test, X2NotAssignableFromSingleDouble) { + EXPECT_FALSE((std::is_assignable_v)); +} + +TEST(FP8E5M2Test, X2NotAssignableFromSingleChar) { + EXPECT_FALSE((std::is_assignable_v)); +} + +TEST(FP8E5M2Test, X2NotAssignableFromSingleSignedChar) { + EXPECT_FALSE((std::is_assignable_v)); +} + +TEST(FP8E5M2Test, X2NotAssignableFromSingleUChar) { + EXPECT_FALSE((std::is_assignable_v)); +} + +TEST(FP8E5M2Test, X2NotAssignableFromSingleShort) { + EXPECT_FALSE((std::is_assignable_v)); +} + +TEST(FP8E5M2Test, X2NotAssignableFromSingleInt) { + EXPECT_FALSE((std::is_assignable_v)); +} + +TEST(FP8E5M2Test, X2NotAssignableFromSingleLong) { + EXPECT_FALSE((std::is_assignable_v)); +} + +TEST(FP8E5M2Test, X2NotAssignableFromSingleLL) { + EXPECT_FALSE((std::is_assignable_v)); +} + +TEST(FP8E5M2Test, X2NotAssignableFromSingleUShort) { + EXPECT_FALSE((std::is_assignable_v)); +} + +TEST(FP8E5M2Test, X2NotAssignableFromSingleUInt) { + EXPECT_FALSE((std::is_assignable_v)); +} + +TEST(FP8E5M2Test, X2NotAssignableFromSingleUL) { + EXPECT_FALSE((std::is_assignable_v)); +} + +TEST(FP8E5M2Test, X2NotAssignableFromSingleULL) { + EXPECT_FALSE((std::is_assignable_v)); +} + +TEST(FP8E5M2Test, CArrayHalfUpwardRounding) { + const sycl::half in[2] = {sycl::half(1.0f), sycl::half(2.0f)}; + EXPECT_DEATH( + { + fp8_e5m2_x2 value(in, rounding::upward); + (void)value; + }, + "fp8_e5m2_x: only rounding::to_even is supported"); +} + +TEST(FP8E5M2Test, CArrayHalfTowardZeroRounding) { + const sycl::half in[2] = {sycl::half(1.0f), sycl::half(2.0f)}; + EXPECT_DEATH( + { + fp8_e5m2_x2 value(in, rounding::toward_zero); + (void)value; + }, + "fp8_e5m2_x: only rounding::to_even is supported"); +} + +TEST(FP8E5M2Test, CArrayBFloat16UpwardRounding) { + const sycl::ext::oneapi::bfloat16 in[2] = {sycl::ext::oneapi::bfloat16(1.0f), + sycl::ext::oneapi::bfloat16(2.0f)}; + EXPECT_DEATH( + { + fp8_e5m2_x2 value(in, rounding::upward); + (void)value; + }, + "fp8_e5m2_x: only rounding::to_even is supported"); +} + +TEST(FP8E5M2Test, CArrayBFloat16TowardZeroRounding) { + const sycl::ext::oneapi::bfloat16 in[2] = {sycl::ext::oneapi::bfloat16(1.0f), + sycl::ext::oneapi::bfloat16(2.0f)}; + EXPECT_DEATH( + { + fp8_e5m2_x2 value(in, rounding::toward_zero); + (void)value; + }, + "fp8_e5m2_x: only rounding::to_even is supported"); +} + +TEST(FP8E5M2Test, CArrayFloatUpwardRounding) { + const float in[2] = {1.0f, 2.0f}; + EXPECT_DEATH( + { + fp8_e5m2_x2 value(in, rounding::upward); + (void)value; + }, + "fp8_e5m2_x: only rounding::to_even is supported"); +} + +TEST(FP8E5M2Test, CArrayFloatTowardZeroRounding) { + const float in[2] = {1.0f, 2.0f}; + EXPECT_DEATH( + { + fp8_e5m2_x2 value(in, rounding::toward_zero); + (void)value; + }, + "fp8_e5m2_x: only rounding::to_even is supported"); +} + +TEST(FP8E5M2Test, MarrayHalfUpwardRounding) { + const sycl::marray in = {sycl::half(1.0f), sycl::half(2.0f)}; + EXPECT_DEATH( + { + fp8_e5m2_x2 value(in, rounding::upward); + (void)value; + }, + "fp8_e5m2_x: only rounding::to_even is supported"); +} + +TEST(FP8E5M2Test, MarrayHalfTowardZeroRounding) { + const sycl::marray in = {sycl::half(1.0f), sycl::half(2.0f)}; + EXPECT_DEATH( + { + fp8_e5m2_x2 value(in, rounding::toward_zero); + (void)value; + }, + "fp8_e5m2_x: only rounding::to_even is supported"); +} + +TEST(FP8E5M2Test, MarrayBFloat16UpwardRounding) { + const sycl::marray in = { + sycl::ext::oneapi::bfloat16(1.0f), sycl::ext::oneapi::bfloat16(2.0f)}; + EXPECT_DEATH( + { + fp8_e5m2_x2 value(in, rounding::upward); + (void)value; + }, + "fp8_e5m2_x: only rounding::to_even is supported"); +} + +TEST(FP8E5M2Test, MarrayBFloat16TowardZeroRounding) { + const sycl::marray in = { + sycl::ext::oneapi::bfloat16(1.0f), sycl::ext::oneapi::bfloat16(2.0f)}; + EXPECT_DEATH( + { + fp8_e5m2_x2 value(in, rounding::toward_zero); + (void)value; + }, + "fp8_e5m2_x: only rounding::to_even is supported"); +} + +TEST(FP8E5M2Test, MarrayFloatUpwardRounding) { + const sycl::marray in = {1.0f, 2.0f}; + EXPECT_DEATH( + { + fp8_e5m2_x2 value(in, rounding::upward); + (void)value; + }, + "fp8_e5m2_x: only rounding::to_even is supported"); +} + +TEST(FP8E5M2Test, MarrayFloatTowardZeroRounding) { + const sycl::marray in = {1.0f, 2.0f}; + EXPECT_DEATH( + { + fp8_e5m2_x2 value(in, rounding::toward_zero); + (void)value; + }, + "fp8_e5m2_x: only rounding::to_even is supported"); +} diff --git a/sycl/unittests/Extensions/fp8/fp8_e8m0.cpp b/sycl/unittests/Extensions/fp8/fp8_e8m0.cpp index b64970a8e6522..81a1d8d6c0db3 100644 --- a/sycl/unittests/Extensions/fp8/fp8_e8m0.cpp +++ b/sycl/unittests/Extensions/fp8/fp8_e8m0.cpp @@ -4,6 +4,7 @@ #include #include #include +#include /* Unit tests check only CPU versions. Most of the constraints related to device @@ -14,6 +15,10 @@ using namespace sycl::ext::oneapi::experimental; namespace { +constexpr const char *UnsupportedRoundingAssertRegex = + "fp8_e8m0_x: only rounding::upward and rounding::toward_zero are " + "\" \"supported"; + bool checkCode(float Input, rounding Mode, uint8_t Expected) { const float Values[1] = {Input}; const fp8_e8m0 Encoded(Values, Mode); @@ -22,7 +27,7 @@ bool checkCode(float Input, rounding Mode, uint8_t Expected) { } // namespace -TEST(FP8E8M0Test, VariadicConstructorFloat) { +TEST(FP8E8M0Test, VariadicFloat) { fp8_e8m0_x2 a(1.0f, 2.0f); fp8_e8m0_x2 a1(1.1f, 0.0f); @@ -34,7 +39,7 @@ TEST(FP8E8M0Test, VariadicConstructorFloat) { EXPECT_EQ(a1.vals[1], 0x00); // 0.0 -> min normal } -TEST(FP8E8M0Test, VariadicConstructorHalf) { +TEST(FP8E8M0Test, VariadicHalf) { fp8_e8m0_x2 a(sycl::half(1.0f), sycl::half(3.0f)); EXPECT_EQ(sizeof(a.vals), 2u); @@ -42,7 +47,7 @@ TEST(FP8E8M0Test, VariadicConstructorHalf) { EXPECT_EQ(a.vals[1], 0x81); // 3.0 -> upward to 4.0 } -TEST(FP8E8M0Test, VariadicConstructorBFloat16) { +TEST(FP8E8M0Test, VariadicBFloat16) { fp8_e8m0_x2 a(sycl::ext::oneapi::bfloat16(1.0f), sycl::ext::oneapi::bfloat16(2.0f)); @@ -51,7 +56,7 @@ TEST(FP8E8M0Test, VariadicConstructorBFloat16) { EXPECT_EQ(a.vals[1], 0x80); } -TEST(FP8E8M0Test, VariadicConstructorDouble) { +TEST(FP8E8M0Test, VariadicDouble) { fp8_e8m0_x2 a(1.0, 3.0); EXPECT_EQ(sizeof(a.vals), 2u); @@ -59,7 +64,7 @@ TEST(FP8E8M0Test, VariadicConstructorDouble) { EXPECT_EQ(a.vals[1], 0x81); } -TEST(FP8E8M0Test, VariadicConstructorBoundaryEncodings) { +TEST(FP8E8M0Test, VariadicBoundaryEncodings) { fp8_e8m0_x2 a(std::ldexp(1.0f, -127), std::numeric_limits::quiet_NaN()); @@ -68,7 +73,7 @@ TEST(FP8E8M0Test, VariadicConstructorBoundaryEncodings) { EXPECT_EQ(a.vals[1], 0xFF); // NaN } -TEST(FP8E8M0Test, CArrayConstructorFloatHostUpwardFinite) { +TEST(FP8E8M0Test, CArrayFloatHostUpwardFinite) { const float in[2] = {1.0f, 1.1f}; const float in1[2] = {3.0f, 1000.0f}; fp8_e8m0_x2 a(in, rounding::upward); @@ -82,7 +87,7 @@ TEST(FP8E8M0Test, CArrayConstructorFloatHostUpwardFinite) { EXPECT_EQ(a1.vals[1], 0x89); // upward to 2^10 = 1024 } -TEST(FP8E8M0Test, CArrayConstructorFloatRoundingModes) { +TEST(FP8E8M0Test, CArrayFloatRoundingModes) { EXPECT_TRUE(checkCode(3.0f, rounding::upward, 0x81)); EXPECT_TRUE(checkCode(3.0f, rounding::toward_zero, 0x80)); @@ -110,7 +115,7 @@ TEST(FP8E8M0Test, RoundClipZeroFractionNegativeAndTieCases) { EXPECT_EQ(detail::RoundClip(0.75f, 0, rounding::to_even, 0u), 1u); } -TEST(FP8E8M0Test, CArrayConstructorHalfHostUpwardFinite) { +TEST(FP8E8M0Test, CArrayHalfHostUpwardFinite) { const sycl::half in[2] = {sycl::half(1.0f), sycl::half(1.1f)}; const sycl::half in1[2] = {sycl::half(3.0f), sycl::half(0.0f)}; @@ -125,7 +130,7 @@ TEST(FP8E8M0Test, CArrayConstructorHalfHostUpwardFinite) { EXPECT_EQ(a1.vals[1], 0x00); } -TEST(FP8E8M0Test, CArrayConstructorBFloat16HostUpwardFinite) { +TEST(FP8E8M0Test, CArrayBFloat16HostUpwardFinite) { const sycl::ext::oneapi::bfloat16 in[2] = {sycl::ext::oneapi::bfloat16(1.0f), sycl::ext::oneapi::bfloat16(2.0f)}; fp8_e8m0_x2 a(in, rounding::upward); @@ -135,7 +140,7 @@ TEST(FP8E8M0Test, CArrayConstructorBFloat16HostUpwardFinite) { EXPECT_EQ(a.vals[1], 0x80); } -TEST(FP8E8M0Test, CArrayConstructorDoubleDefaultUpwardFinite) { +TEST(FP8E8M0Test, CArrayDoubleDefaultUpwardFinite) { const double in[2] = {1.0, 3.0}; fp8_e8m0_x2 a(in); @@ -144,7 +149,7 @@ TEST(FP8E8M0Test, CArrayConstructorDoubleDefaultUpwardFinite) { EXPECT_EQ(a.vals[1], 0x81); } -TEST(FP8E8M0Test, MarrayConstructorAndOperatorsFloat) { +TEST(FP8E8M0Test, MarrayAndOperatorsFloat) { sycl::marray in = {1.0f, 2.0f}; sycl::marray in1 = {3.0f, 0.0f}; @@ -166,7 +171,7 @@ TEST(FP8E8M0Test, MarrayConstructorAndOperatorsFloat) { EXPECT_EQ(out1[1], std::ldexp(1.0f, -127)); } -TEST(FP8E8M0Test, MarrayConstructorHalfBFloat16Double) { +TEST(FP8E8M0Test, MarrayHalfBFloat16Double) { sycl::marray hvals = {sycl::half(1.0f), sycl::half(3.0f)}; sycl::marray bvals = { sycl::ext::oneapi::bfloat16(1.0f), sycl::ext::oneapi::bfloat16(2.0f)}; @@ -312,8 +317,188 @@ TEST(FP8E8M0Test, MarrayConversionOperators) { EXPECT_EQ(fo[1], 4.0f); } -TEST(FP8E8M0Test, VariadicRejectsMixedTypes) { +TEST(FP8E8M0Test, VariadicMixedTypes) { EXPECT_FALSE((std::is_constructible_v)); EXPECT_FALSE((std::is_constructible_v)); -} \ No newline at end of file +} + +TEST(FP8E8M0Test, X2NotConstructibleFromSingleShort) { + EXPECT_FALSE((std::is_constructible_v)); +} + +TEST(FP8E8M0Test, X2NotConstructibleFromSingleInt) { + EXPECT_FALSE((std::is_constructible_v)); +} + +TEST(FP8E8M0Test, X2NotConstructibleFromSingleLong) { + EXPECT_FALSE((std::is_constructible_v)); +} + +TEST(FP8E8M0Test, X2NotConstructibleFromSingleLL) { + EXPECT_FALSE((std::is_constructible_v)); +} + +TEST(FP8E8M0Test, X2NotConstructibleFromSingleUShort) { + EXPECT_FALSE((std::is_constructible_v)); +} + +TEST(FP8E8M0Test, X2NotConstructibleFromSingleUInt) { + EXPECT_FALSE((std::is_constructible_v)); +} + +TEST(FP8E8M0Test, X2NotConstructibleFromSingleUL) { + EXPECT_FALSE((std::is_constructible_v)); +} + +TEST(FP8E8M0Test, X2NotConstructibleFromSingleULL) { + EXPECT_FALSE((std::is_constructible_v)); +} + +TEST(FP8E8M0Test, X2NotConstructibleFromSingleFloat) { + EXPECT_FALSE((std::is_constructible_v)); +} + +TEST(FP8E8M0Test, X2NotConstructibleFromSingleDouble) { + EXPECT_FALSE((std::is_constructible_v)); +} + +TEST(FP8E8M0Test, X2NotConstructibleFromSingleBFloat16) { + EXPECT_FALSE( + (std::is_constructible_v)); +} + +TEST(FP8E8M0Test, X2NotConstructibleFromSingleHalf) { + EXPECT_FALSE((std::is_constructible_v)); +} + +TEST(FP8E8M0Test, X2NotConstructibleFromSingleChar) { + EXPECT_FALSE((std::is_constructible_v)); +} + +TEST(FP8E8M0Test, X2NotConstructibleFromSingleUChar) { + EXPECT_FALSE((std::is_constructible_v)); +} + +TEST(FP8E8M0Test, X2NotAssignableFromSingleHalf) { + EXPECT_FALSE((std::is_assignable_v)); +} + +TEST(FP8E8M0Test, X2NotAssignableFromSingleBFloat16) { + EXPECT_FALSE( + (std::is_assignable_v)); +} + +TEST(FP8E8M0Test, X2NotAssignableFromSingleFloat) { + EXPECT_FALSE((std::is_assignable_v)); +} + +TEST(FP8E8M0Test, X2NotAssignableFromSingleDouble) { + EXPECT_FALSE((std::is_assignable_v)); +} + +TEST(FP8E8M0Test, X2NotAssignableFromSingleChar) { + EXPECT_FALSE((std::is_assignable_v)); +} + +TEST(FP8E8M0Test, X2NotAssignableFromSingleSignedChar) { + EXPECT_FALSE((std::is_assignable_v)); +} + +TEST(FP8E8M0Test, X2NotAssignableFromSingleUChar) { + EXPECT_FALSE((std::is_assignable_v)); +} + +TEST(FP8E8M0Test, X2NotAssignableFromSingleShort) { + EXPECT_FALSE((std::is_assignable_v)); +} + +TEST(FP8E8M0Test, X2NotAssignableFromSingleInt) { + EXPECT_FALSE((std::is_assignable_v)); +} + +TEST(FP8E8M0Test, X2NotAssignableFromSingleLong) { + EXPECT_FALSE((std::is_assignable_v)); +} + +TEST(FP8E8M0Test, X2NotAssignableFromSingleLL) { + EXPECT_FALSE((std::is_assignable_v)); +} + +TEST(FP8E8M0Test, X2NotAssignableFromSingleUShort) { + EXPECT_FALSE((std::is_assignable_v)); +} + +TEST(FP8E8M0Test, X2NotAssignableFromSingleUInt) { + EXPECT_FALSE((std::is_assignable_v)); +} + +TEST(FP8E8M0Test, X2NotAssignableFromSingleUL) { + EXPECT_FALSE((std::is_assignable_v)); +} + +TEST(FP8E8M0Test, X2NotAssignableFromSingleULL) { + EXPECT_FALSE((std::is_assignable_v)); +} + +TEST(FP8E8M0Test, CArrayHalfToEvenRounding) { + const sycl::half in[2] = {sycl::half(1.0f), sycl::half(2.0f)}; + EXPECT_DEATH( + { + fp8_e8m0_x2 value(in, rounding::to_even); + (void)value; + }, + UnsupportedRoundingAssertRegex); +} + +TEST(FP8E8M0Test, CArrayBFloat16ToEvenRounding) { + const sycl::ext::oneapi::bfloat16 in[2] = {sycl::ext::oneapi::bfloat16(1.0f), + sycl::ext::oneapi::bfloat16(2.0f)}; + EXPECT_DEATH( + { + fp8_e8m0_x2 value(in, rounding::to_even); + (void)value; + }, + UnsupportedRoundingAssertRegex); +} + +TEST(FP8E8M0Test, CArrayFloatToEvenRounding) { + const float in[2] = {1.0f, 2.0f}; + EXPECT_DEATH( + { + fp8_e8m0_x2 value(in, rounding::to_even); + (void)value; + }, + UnsupportedRoundingAssertRegex); +} + +TEST(FP8E8M0Test, MarrayHalfToEvenRounding) { + const sycl::marray in = {sycl::half(1.0f), sycl::half(2.0f)}; + EXPECT_DEATH( + { + fp8_e8m0_x2 value(in, rounding::to_even); + (void)value; + }, + UnsupportedRoundingAssertRegex); +} + +TEST(FP8E8M0Test, MarrayBFloat16ToEvenRounding) { + const sycl::marray in = { + sycl::ext::oneapi::bfloat16(1.0f), sycl::ext::oneapi::bfloat16(2.0f)}; + EXPECT_DEATH( + { + fp8_e8m0_x2 value(in, rounding::to_even); + (void)value; + }, + UnsupportedRoundingAssertRegex); +} + +TEST(FP8E8M0Test, MarrayFloatToEvenRounding) { + const sycl::marray in = {1.0f, 2.0f}; + EXPECT_DEATH( + { + fp8_e8m0_x2 value(in, rounding::to_even); + (void)value; + }, + UnsupportedRoundingAssertRegex); +} From aa7c77657c4cf78f88b1c2e006fc046a2db8afee Mon Sep 17 00:00:00 2001 From: "Klochkov, Denis" Date: Thu, 9 Apr 2026 17:00:08 +0200 Subject: [PATCH 29/89] [SYCL] fix formatting --- sycl/unittests/Extensions/fp8/fp8_e4m3.cpp | 20 ++++++++------------ 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/sycl/unittests/Extensions/fp8/fp8_e4m3.cpp b/sycl/unittests/Extensions/fp8/fp8_e4m3.cpp index 341e039df8547..d231cf93f7a92 100644 --- a/sycl/unittests/Extensions/fp8/fp8_e4m3.cpp +++ b/sycl/unittests/Extensions/fp8/fp8_e4m3.cpp @@ -453,8 +453,8 @@ TEST(FP8E4M3Test, X2NotConstructibleFromSingleDouble) { } TEST(FP8E4M3Test, X2NotConstructibleFromSingleBFloat16) { - EXPECT_FALSE((std::is_constructible_v)); + EXPECT_FALSE( + (std::is_constructible_v)); } TEST(FP8E4M3Test, X2NotConstructibleFromSingleHalf) { @@ -551,9 +551,8 @@ TEST(FP8E4M3Test, CArrayHalfRejectsTowardZeroRounding) { } TEST(FP8E4M3Test, CArrayBFloat16RejectsUpwardRounding) { - const sycl::ext::oneapi::bfloat16 in[2] = { - sycl::ext::oneapi::bfloat16(1.0f), - sycl::ext::oneapi::bfloat16(2.0f)}; + const sycl::ext::oneapi::bfloat16 in[2] = {sycl::ext::oneapi::bfloat16(1.0f), + sycl::ext::oneapi::bfloat16(2.0f)}; EXPECT_DEATH( { fp8_e4m3_x2 value(in, rounding::upward); @@ -563,9 +562,8 @@ TEST(FP8E4M3Test, CArrayBFloat16RejectsUpwardRounding) { } TEST(FP8E4M3Test, CArrayBFloat16RejectsTowardZeroRounding) { - const sycl::ext::oneapi::bfloat16 in[2] = { - sycl::ext::oneapi::bfloat16(1.0f), - sycl::ext::oneapi::bfloat16(2.0f)}; + const sycl::ext::oneapi::bfloat16 in[2] = {sycl::ext::oneapi::bfloat16(1.0f), + sycl::ext::oneapi::bfloat16(2.0f)}; EXPECT_DEATH( { fp8_e4m3_x2 value(in, rounding::toward_zero); @@ -616,8 +614,7 @@ TEST(FP8E4M3Test, MarrayHalfRejectsTowardZeroRounding) { TEST(FP8E4M3Test, MarrayBFloat16RejectsUpwardRounding) { const sycl::marray in = { - sycl::ext::oneapi::bfloat16(1.0f), - sycl::ext::oneapi::bfloat16(2.0f)}; + sycl::ext::oneapi::bfloat16(1.0f), sycl::ext::oneapi::bfloat16(2.0f)}; EXPECT_DEATH( { fp8_e4m3_x2 value(in, rounding::upward); @@ -628,8 +625,7 @@ TEST(FP8E4M3Test, MarrayBFloat16RejectsUpwardRounding) { TEST(FP8E4M3Test, MarrayBFloat16RejectsTowardZeroRounding) { const sycl::marray in = { - sycl::ext::oneapi::bfloat16(1.0f), - sycl::ext::oneapi::bfloat16(2.0f)}; + sycl::ext::oneapi::bfloat16(1.0f), sycl::ext::oneapi::bfloat16(2.0f)}; EXPECT_DEATH( { fp8_e4m3_x2 value(in, rounding::toward_zero); From 8fbe4601db702ffd47eb33cba70f7a39c0dce1de Mon Sep 17 00:00:00 2001 From: "Klochkov, Denis" Date: Thu, 9 Apr 2026 17:21:03 +0200 Subject: [PATCH 30/89] [SYCL] do not cast seed --- .../oneapi/experimental/float_8bit/types.hpp | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp index e0ae4fd34d5aa..d69a754d3b13e 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp @@ -1087,10 +1087,10 @@ template class fp8_e5m2_x { for (size_t i = 0; i < N; ++i) { if (s == saturation::finite) { vals[i] = __builtin_spirv_ClampStochasticRoundFP16ToE5M2INTEL( - in[i], static_cast(current_seed), seed.pseed); + in[i], current_seed, seed.pseed); } else { vals[i] = __builtin_spirv_StochasticRoundFP16ToE5M2INTEL( - in[i], static_cast(current_seed), seed.pseed); + in[i], current_seed, seed.pseed); } current_seed = *seed.pseed; } @@ -1105,10 +1105,10 @@ template class fp8_e5m2_x { for (size_t i = 0; i < N; ++i) { if (s == saturation::finite) { vals[i] = __builtin_spirv_ClampStochasticRoundBF16ToE5M2INTEL( - in[i], static_cast(current_seed), seed.pseed); + in[i], current_seed, seed.pseed); } else { vals[i] = __builtin_spirv_StochasticRoundBF16ToE5M2INTEL( - in[i], static_cast(current_seed), seed.pseed); + in[i], current_seed, seed.pseed); } current_seed = *seed.pseed; } @@ -1124,10 +1124,10 @@ template class fp8_e5m2_x { sycl::half h = static_cast(in[i]); if (s == saturation::finite) { vals[i] = __builtin_spirv_ClampStochasticRoundFP16ToE5M2INTEL( - h, static_cast(current_seed), seed.pseed); + h, current_seed, seed.pseed); } else { vals[i] = __builtin_spirv_StochasticRoundFP16ToE5M2INTEL( - h, static_cast(current_seed), seed.pseed); + h, current_seed, seed.pseed); } current_seed = *seed.pseed; } @@ -1145,10 +1145,10 @@ template class fp8_e5m2_x { for (size_t i = 0; i < N; ++i) { if (s == saturation::finite) { vals[i] = __builtin_spirv_ClampStochasticRoundFP16ToE5M2INTEL( - in[i], static_cast(current_seed), seed.pseed); + in[i], current_seed, seed.pseed); } else { vals[i] = __builtin_spirv_StochasticRoundFP16ToE5M2INTEL( - in[i], static_cast(current_seed), seed.pseed); + in[i], current_seed, seed.pseed); } current_seed = *seed.pseed; } @@ -1163,10 +1163,10 @@ template class fp8_e5m2_x { for (size_t i = 0; i < N; ++i) { if (s == saturation::finite) { vals[i] = __builtin_spirv_ClampStochasticRoundBF16ToE5M2INTEL( - in[i], static_cast(current_seed), seed.pseed); + in[i], current_seed, seed.pseed); } else { vals[i] = __builtin_spirv_StochasticRoundBF16ToE5M2INTEL( - in[i], static_cast(current_seed), seed.pseed); + in[i], current_seed, seed.pseed); } current_seed = *seed.pseed; } @@ -1182,10 +1182,10 @@ template class fp8_e5m2_x { sycl::half h = static_cast(in[i]); if (s == saturation::finite) { vals[i] = __builtin_spirv_ClampStochasticRoundFP16ToE5M2INTEL( - h, static_cast(current_seed), seed.pseed); + h, current_seed, seed.pseed); } else { vals[i] = __builtin_spirv_StochasticRoundFP16ToE5M2INTEL( - h, static_cast(current_seed), seed.pseed); + h, current_seed, seed.pseed); } current_seed = *seed.pseed; } From f4945f075442563efed113e2e7b86b26d0c471a2 Mon Sep 17 00:00:00 2001 From: "Klochkov, Denis" Date: Thu, 9 Apr 2026 18:04:25 +0200 Subject: [PATCH 31/89] [SYCL] remove unused variable --- .../sycl/ext/oneapi/experimental/float_8bit/types.hpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp index d69a754d3b13e..b29631ccaf68f 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp @@ -356,8 +356,8 @@ ConvertToFP8_CPU(T h, rounding R = rounding::to_even) noexcept { return static_cast(sign | (MaxExpField)); // E = +127 // Determine exponent E such that 2^E <= ax < 2^{E+1} - int e2; - float m = std::frexp(ax, &e2); + int e2 = 0; + std::frexp(ax, &e2); int E = e2 - 1; // Upward rounding semantics: @@ -428,7 +428,7 @@ ConvertToFP8_CPU(T h, rounding R = rounding::to_even) noexcept { if (ax < min_sub) return sign; // underflow - int e2; + int e2 = 0; float m = std::frexp(ax, &e2); int E = e2 - 1; From bcbe8d7eed6af47837d117f04ed4483471641363 Mon Sep 17 00:00:00 2001 From: "Klochkov, Denis" Date: Wed, 15 Apr 2026 12:22:34 +0200 Subject: [PATCH 32/89] [SYCL] use memcpy to convert to e8m0 instead of std library --- .../oneapi/experimental/float_8bit/types.hpp | 468 +++++++++++++----- 1 file changed, 340 insertions(+), 128 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp index b29631ccaf68f..b314132123da8 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp @@ -17,6 +17,7 @@ #include #include #include +#include #ifdef __SYCL_DEVICE_ONLY__ // FP8 builtins @@ -140,6 +141,15 @@ static inline int BitWidth(uint32_t x) noexcept { return width; } +static inline int BitWidth(uint64_t x) noexcept { + int width = 0; + while (x != 0u) { + ++width; + x >>= 1; + } + return width; +} + template struct DirectBinary16Traits; template <> struct DirectBinary16Traits { @@ -303,6 +313,290 @@ static inline ToT ConvertFromFP8_CPU(uint8_t b, return ConvertFloatToTarget(sign_bit != 0u, significand, E, Mbits, R); } +template struct E8M0SourceTraits; + +template <> struct E8M0SourceTraits { + using UInt = uint32_t; + static constexpr size_t ExpBits = 8; + static constexpr size_t FracBits = 23; + static constexpr int Bias = 127; +}; + +template <> struct E8M0SourceTraits { + using UInt = uint16_t; + static constexpr size_t ExpBits = 5; + static constexpr size_t FracBits = 10; + static constexpr int Bias = 15; +}; + +template <> struct E8M0SourceTraits { + using UInt = uint16_t; + static constexpr size_t ExpBits = 8; + static constexpr size_t FracBits = 7; + static constexpr int Bias = 127; +}; + +template <> struct E8M0SourceTraits { + using UInt = uint64_t; + static constexpr size_t ExpBits = 11; + static constexpr size_t FracBits = 52; + static constexpr int Bias = 1023; +}; + +template <> struct E8M0SourceTraits { + using UInt = uint16_t; + using UnsignedT = std::make_unsigned_t; + + static constexpr bool IsIntegral = true; + static constexpr bool IsSigned = true; + static constexpr int ValueBits = std::numeric_limits::digits; +}; + +template <> struct E8M0SourceTraits { + using UInt = uint32_t; + using UnsignedT = std::make_unsigned_t; + + static constexpr bool IsIntegral = true; + static constexpr bool IsSigned = true; + static constexpr int ValueBits = std::numeric_limits::digits; +}; + +template <> struct E8M0SourceTraits { + using UInt = std::make_unsigned_t; + using UnsignedT = std::make_unsigned_t; + + static constexpr bool IsIntegral = true; + static constexpr bool IsSigned = true; + static constexpr int ValueBits = std::numeric_limits::digits; +}; + +template <> struct E8M0SourceTraits { + using UInt = uint64_t; + using UnsignedT = std::make_unsigned_t; + + static constexpr bool IsIntegral = true; + static constexpr bool IsSigned = true; + static constexpr int ValueBits = std::numeric_limits::digits; +}; + +template <> struct E8M0SourceTraits { + using UInt = uint16_t; + using UnsignedT = std::make_unsigned_t; + + static constexpr bool IsIntegral = true; + static constexpr bool IsSigned = false; + static constexpr int ValueBits = std::numeric_limits::digits; +}; + +template <> struct E8M0SourceTraits { + using UInt = uint32_t; + using UnsignedT = std::make_unsigned_t; + + static constexpr bool IsIntegral = true; + static constexpr bool IsSigned = false; + static constexpr int ValueBits = std::numeric_limits::digits; +}; + +template <> struct E8M0SourceTraits { + using UInt = std::make_unsigned_t; + using UnsignedT = std::make_unsigned_t; + + static constexpr bool IsIntegral = true; + static constexpr bool IsSigned = false; + static constexpr int ValueBits = std::numeric_limits::digits; +}; + +template <> struct E8M0SourceTraits { + using UInt = uint64_t; + using UnsignedT = std::make_unsigned_t; + + static constexpr bool IsIntegral = true; + static constexpr bool IsSigned = false; + static constexpr int ValueBits = std::numeric_limits::digits; +}; + +template > +static inline uint8_t ConvertIntToE8M0_CPU(T f, rounding R, + saturation S) noexcept { + using UnsignedT = typename Traits::UnsignedT; + UnsignedT magnitude = f < 0 ? -f : f; + + if (magnitude == 0) + return 0x00u; + + int lowerExp = BitWidth(static_cast(magnitude)) - 1; + bool isExactPowerOfTwo = (magnitude & (magnitude - 1)) == 0; + + bool roundUp = false; + switch (R) { + case rounding::toward_zero: + break; + case rounding::upward: + roundUp = !isExactPowerOfTwo; + break; + case rounding::to_even: { + if (!isExactPowerOfTwo) { + const uint64_t leading = uint64_t{1} << lowerExp; + const uint64_t twice = 2ull * static_cast(magnitude); + const uint64_t midpoint = 3ull * leading; + if (twice > midpoint || (twice == midpoint && (lowerExp & 1) != 0)) + roundUp = true; + } + break; + } + } + return static_cast(127 + lowerExp + (roundUp ? 1 : 0)); +} + +template > +static inline uint8_t ConvertFloatToE8M0_CPU(T f, rounding R, + saturation S) noexcept { + using UInt = typename Traits::UInt; + constexpr UInt SignMask = UInt{1} << (Traits::ExpBits + Traits::FracBits); + constexpr UInt FracMask = (UInt{1} << Traits::FracBits) - 1; + constexpr UInt ExpMask = ((UInt{1} << Traits::ExpBits) - 1) + << Traits::FracBits; + constexpr UInt ExpAllOnes = (UInt{1} << Traits::ExpBits) - 1; + constexpr uint8_t NaNCode = 0xFF; + constexpr uint8_t MaxFiniteCode = 0xFE; + constexpr int TargetBias = 127; + constexpr int TargetEmin = -127; + constexpr int TargetEmax = 127; + + UInt h; + __builtin_memcpy(&h, &f, sizeof(h)); + h &= ~SignMask; + + UInt exp = (h & ExpMask) >> Traits::FracBits; + UInt frac = h & FracMask; + + if (exp == ExpAllOnes) { + if (frac != 0u) + return NaNCode; + return (S == saturation::finite) ? MaxFiniteCode : NaNCode; + } + + if (exp == 0u && frac == 0u) + return 0x00u; + + uint64_t significand = 0u; + int leadingBit = 0; + int lowerExp = 0; + bool isExactPowerOfTwo = false; + + if (exp != 0u) { + significand = (uint64_t{1} << Traits::FracBits) | static_cast(frac); + leadingBit = static_cast(Traits::FracBits); + lowerExp = static_cast(exp) - Traits::Bias; + isExactPowerOfTwo = frac == 0u; + } else { + significand = static_cast(frac); + leadingBit = BitWidth(significand) - 1; + lowerExp = 1 - Traits::Bias - static_cast(Traits::FracBits) + leadingBit; + isExactPowerOfTwo = (significand & (significand - 1u)) == 0u; + } + + if (lowerExp < TargetEmin) + return 0x00u; + + bool roundUp = false; + + switch (R) { + case rounding::toward_zero: + break; + case rounding::upward: + roundUp = !isExactPowerOfTwo; + break; + case rounding::to_even: { + if (!isExactPowerOfTwo) { + const uint64_t twiceSignificand = 2ull * significand; + const uint64_t midpoint = 3ull * (uint64_t{1} << leadingBit); + if (twiceSignificand > midpoint) { + roundUp = true; + } else if (twiceSignificand == midpoint && (lowerExp & 1) != 0) { + roundUp = true; + } + } + break; + } + } + + int encodedExp = lowerExp + (roundUp ? 1 : 0); + if (encodedExp > TargetEmax) + return (S == saturation::finite) ? MaxFiniteCode : NaNCode; + + return static_cast(encodedExp + TargetBias); +} + +template +struct HasE8M0FloatTraits : std::false_type {}; + +template +struct HasE8M0FloatTraits< + Traits, + std::void_t> : std::true_type {}; + +template +struct HasE8M0IntegralTraits : std::false_type {}; + +template +struct HasE8M0IntegralTraits< + Traits, std::void_t> : std::true_type {}; + +template > +static inline ToT ConvertFromE8M0ToBinaryFloat_CPU(uint8_t code) noexcept { + if constexpr (HasE8M0FloatTraits::value) { + using UInt = typename Traits::UInt; + + constexpr UInt ExpAllOnes = + ((UInt{1} << Traits::ExpBits) - UInt{1}) << Traits::FracBits; + constexpr UInt QuietNaNBit = UInt{1} << (Traits::FracBits - 1); + constexpr int MinNormalExp = 1 - Traits::Bias; + constexpr int MinSubnormalExp = + MinNormalExp - static_cast(Traits::FracBits); + constexpr int MaxNormalExp = + static_cast((UInt{1} << Traits::ExpBits) - UInt{2}) - + Traits::Bias; + + UInt bits = 0; + if (code == 0xFFu) { + bits = ExpAllOnes | QuietNaNBit; + } else { + const int unbiasedExp = static_cast(code) - 127; + if (unbiasedExp > MaxNormalExp) { + bits = ExpAllOnes; + } else if (unbiasedExp >= MinNormalExp) { + bits = static_cast(unbiasedExp + Traits::Bias) + << Traits::FracBits; + } else if (unbiasedExp >= MinSubnormalExp) { + const int fracBit = + unbiasedExp - MinNormalExp + static_cast(Traits::FracBits); + bits = UInt{1} << fracBit; + } + } + + return __builtin_bit_cast(ToT, bits); + } else if constexpr (HasE8M0IntegralTraits::value && + Traits::IsIntegral) { + using UnsignedT = typename Traits::UnsignedT; + + if (code == 0xFFu) + return ToT{}; + + const int shift = static_cast(code) - 127; + if (shift < 0 || shift >= Traits::ValueBits) + return ToT{}; + + const UnsignedT magnitude = UnsignedT{1} << shift; + return static_cast(magnitude); + } else { + return ToT{}; + } +} + /// \brief Converts a given value to fp8 floating point with a rounding /// mode to_even by default and saturation finite for host code. /// \param h The input value to be converted. @@ -330,7 +624,7 @@ ConvertToFP8_CPU(T h, rounding R = rounding::to_even) noexcept { // to the smallest magnitude normal with the input sign preserved // (consistent with prior sign-preserving underflow behavior). // - constexpr int Bias = 127; + constexpr uint32_t Bias = 127; constexpr int Emin = -127; constexpr int Emax = 127; constexpr uint8_t NaNCode = 0xFF; // 11111111 @@ -458,108 +752,22 @@ ConvertToFP8_CPU(T h, rounding R = rounding::to_even) noexcept { return ret; } -template -static inline uint8_t ConvertToE8M0_CPU(T x, rounding R, - saturation S) noexcept { - // E8M0: unsigned 8-bit exponent code, bias 127. - // Code 0xFF reserved for NaN. No Inf, no subnormals, no signed zero. - constexpr int Bias = 127; - constexpr int Emin = -127; - constexpr int Emax = 127; - constexpr uint8_t NaNCode = 0xFF; - constexpr uint8_t MaxFiniteCode = 0xFE; - - // NaN and Inf checks only apply to non-integral types. - if constexpr (!std::is_integral_v) { - if (std::isnan(static_cast(x))) - return NaNCode; - if (std::isinf(static_cast(x))) - return (S == saturation::finite) ? MaxFiniteCode : NaNCode; - } - - // Compute absolute value in the natural type T. - T ax; - if constexpr (std::is_unsigned_v) - ax = x; - else if constexpr (std::is_signed_v && std::is_integral_v) - ax = x < T(0) ? static_cast(-x) : x; - else - ax = static_cast(std::fabs(static_cast(x))); - - // Zero check in natural type. - if (ax == T(0)) - return 0x00; - - // Convert to float for frexp/ldexp-based exponent extraction. - float fax = static_cast(ax); - - // Underflow: map to min normal (code 0). - // Min normal = 2^-127. - const float min_normal = std::ldexp(1.0f, Emin); - if (fax < min_normal) - return 0x00; - - // Overflow and "too large": clamp or NaN depending on saturation. - const float max_normal = std::ldexp(1.0f, Emax); // 2^127 - if (fax >= max_normal) - return (S == saturation::finite) ? MaxFiniteCode : NaNCode; - - // Determine E such that 2^E <= fax < 2^(E+1). - int e2 = 0; - float m = std::frexp(fax, &e2); // fax = m * 2^e2, m in [0.5, 1) - int E = e2 - 1; - - // With no mantissa, representables are exact powers of two. - // Choose between 2^E and 2^(E+1) based on rounding mode. - const bool is_exact_power_of_two = (m == 0.5f); +template +static inline ToT ConvertFromE8M0_CPU(uint8_t code, rounding R) noexcept { + using Traits = E8M0SourceTraits; - switch (R) { - case rounding::upward: - // toward +inf; with no sign, this is "ceil in magnitude". - if (!is_exact_power_of_two && E < Emax) - ++E; - break; - case rounding::toward_zero: - // toward -inf / toward 0: both pick the lower power for non-exact. - break; - case rounding::to_even: - default: { - if (!is_exact_power_of_two) { - // Nearest of {2^E, 2^(E+1)} w/ ties-to-even (even exponent on tie). - float lo = std::ldexp(1.0f, E); - float hi = std::ldexp(1.0f, E + 1); - float dlo = fax - lo; - float dhi = hi - fax; - if (dhi < dlo) { - if (E < Emax) - ++E; - } else if (dhi == dlo) { - // tie -> even exponent - if ((E & 1) != 0 && E < Emax) - ++E; - } - } - break; - } + if constexpr (HasE8M0FloatTraits::value || + HasE8M0IntegralTraits::value) { + (void)R; + return ConvertFromE8M0ToBinaryFloat_CPU(code); } - if (E < Emin) - E = Emin; - if (E > Emax) - E = Emax; - - uint8_t code = static_cast(E + Bias); // 0..254 - return code; -} - -template -static inline ToT ConvertFromE8M0_CPU(uint8_t code, rounding R) noexcept { - constexpr int Bias = 127; + /* constexpr int Bias = 127; if (code == 0xFF) { return MakeDirectNaN(); } return ConvertFloatToTarget(false, 1u, static_cast(code) - Bias, 0, - R); + R);*/ } } // namespace detail @@ -1451,59 +1659,59 @@ template class fp8_e8m0_x { using InT = std::common_type_t...>; const InT in[N] = {v...}; for (size_t i = 0; i < N; ++i) - vals[i] = detail::ConvertToE8M0_CPU(in[i], rounding::upward, - saturation::finite); + vals[i] = detail::ConvertFloatToE8M0_CPU(in[i], rounding::upward, + saturation::finite); } explicit fp8_e8m0_x(half const (&in)[N], rounding r = rounding::upward) { CheckConstraints(r); for (size_t i = 0; i < N; ++i) - vals[i] = detail::ConvertToE8M0_CPU(in[i], r, saturation::finite); + vals[i] = detail::ConvertFloatToE8M0_CPU(in[i], r, saturation::finite); } explicit fp8_e8m0_x(bfloat16 const (&in)[N], rounding r = rounding::upward) { CheckConstraints(r); for (size_t i = 0; i < N; ++i) - vals[i] = detail::ConvertToE8M0_CPU(in[i], r, saturation::finite); + vals[i] = detail::ConvertFloatToE8M0_CPU(in[i], r, saturation::finite); } explicit fp8_e8m0_x(float const (&in)[N], rounding r = rounding::upward) { CheckConstraints(r); for (size_t i = 0; i < N; ++i) - vals[i] = detail::ConvertToE8M0_CPU(in[i], r, saturation::finite); + vals[i] = detail::ConvertFloatToE8M0_CPU(in[i], r, saturation::finite); } explicit fp8_e8m0_x(double const (&in)[N]) { for (size_t i = 0; i < N; ++i) - vals[i] = detail::ConvertToE8M0_CPU(in[i], rounding::upward, - saturation::finite); + vals[i] = detail::ConvertFloatToE8M0_CPU(in[i], rounding::upward, + saturation::finite); } explicit fp8_e8m0_x(const marray &in, rounding r = rounding::upward) { CheckConstraints(r); for (size_t i = 0; i < N; ++i) - vals[i] = detail::ConvertToE8M0_CPU(in[i], r, saturation::finite); + vals[i] = detail::ConvertFloatToE8M0_CPU(in[i], r, saturation::finite); } explicit fp8_e8m0_x(const marray &in, rounding r = rounding::upward) { CheckConstraints(r); for (size_t i = 0; i < N; ++i) - vals[i] = detail::ConvertToE8M0_CPU(in[i], r, saturation::finite); + vals[i] = detail::ConvertFloatToE8M0_CPU(in[i], r, saturation::finite); } explicit fp8_e8m0_x(const marray &in, rounding r = rounding::upward) { CheckConstraints(r); for (size_t i = 0; i < N; ++i) - vals[i] = detail::ConvertToE8M0_CPU(in[i], r, saturation::finite); + vals[i] = detail::ConvertFloatToE8M0_CPU(in[i], r, saturation::finite); } explicit fp8_e8m0_x(const marray &in) { for (size_t i = 0; i < N; ++i) - vals[i] = detail::ConvertToE8M0_CPU(in[i], rounding::upward, - saturation::finite); + vals[i] = detail::ConvertFloatToE8M0_CPU(in[i], rounding::upward, + saturation::finite); } // Construct from integer types. @@ -1512,116 +1720,120 @@ template class fp8_e8m0_x { template > explicit fp8_e8m0_x(short val) { vals[0] = - detail::ConvertToE8M0_CPU(val, rounding::upward, saturation::finite); + detail::ConvertIntToE8M0_CPU(val, rounding::upward, saturation::finite); } template > explicit fp8_e8m0_x(int val) { vals[0] = - detail::ConvertToE8M0_CPU(val, rounding::upward, saturation::finite); + detail::ConvertIntToE8M0_CPU(val, rounding::upward, saturation::finite); } + template > explicit fp8_e8m0_x(long val) { vals[0] = - detail::ConvertToE8M0_CPU(val, rounding::upward, saturation::finite); + detail::ConvertIntToE8M0_CPU(val, rounding::upward, saturation::finite); } template > explicit fp8_e8m0_x(long long val) { vals[0] = - detail::ConvertToE8M0_CPU(val, rounding::upward, saturation::finite); + detail::ConvertIntToE8M0_CPU(val, rounding::upward, saturation::finite); } template > explicit fp8_e8m0_x(unsigned short val) { vals[0] = - detail::ConvertToE8M0_CPU(val, rounding::upward, saturation::finite); + detail::ConvertIntToE8M0_CPU(val, rounding::upward, saturation::finite); } template > explicit fp8_e8m0_x(unsigned int val) { vals[0] = - detail::ConvertToE8M0_CPU(val, rounding::upward, saturation::finite); + detail::ConvertIntToE8M0_CPU(val, rounding::upward, saturation::finite); } template > explicit fp8_e8m0_x(unsigned long val) { vals[0] = - detail::ConvertToE8M0_CPU(val, rounding::upward, saturation::finite); + detail::ConvertIntToE8M0_CPU(val, rounding::upward, saturation::finite); } template > explicit fp8_e8m0_x(unsigned long long val) { vals[0] = - detail::ConvertToE8M0_CPU(val, rounding::upward, saturation::finite); + detail::ConvertIntToE8M0_CPU(val, rounding::upward, saturation::finite); } template > fp8_e8m0_x &operator=(half val) { vals[0] = - detail::ConvertToE8M0_CPU(val, rounding::upward, saturation::finite); + detail::ConvertFloatToE8M0_CPU(val, rounding::upward, saturation::finite); return *this; } template > fp8_e8m0_x &operator=(bfloat16 val) { vals[0] = - detail::ConvertToE8M0_CPU(val, rounding::upward, saturation::finite); + detail::ConvertFloatToE8M0_CPU(val, rounding::upward, + saturation::finite); return *this; } template > fp8_e8m0_x &operator=(float val) { vals[0] = - detail::ConvertToE8M0_CPU(val, rounding::upward, saturation::finite); + detail::ConvertFloatToE8M0_CPU(val, rounding::upward, saturation::finite); return *this; } + template > fp8_e8m0_x &operator=(double val) { vals[0] = - detail::ConvertToE8M0_CPU(val, rounding::upward, saturation::finite); + detail::ConvertFloatToE8M0_CPU(val, rounding::upward, saturation::finite); return *this; } + template > fp8_e8m0_x &operator=(short val) { vals[0] = - detail::ConvertToE8M0_CPU(val, rounding::upward, saturation::finite); + detail::ConvertIntToE8M0_CPU(val, rounding::upward, saturation::finite); return *this; } template > fp8_e8m0_x &operator=(int val) { vals[0] = - detail::ConvertToE8M0_CPU(val, rounding::upward, saturation::finite); + detail::ConvertIntToE8M0_CPU(val, rounding::upward, saturation::finite); return *this; } template > fp8_e8m0_x &operator=(long val) { vals[0] = - detail::ConvertToE8M0_CPU(val, rounding::upward, saturation::finite); + detail::ConvertIntToE8M0_CPU(val, rounding::upward, saturation::finite); return *this; } template > fp8_e8m0_x &operator=(long long val) { vals[0] = - detail::ConvertToE8M0_CPU(val, rounding::upward, saturation::finite); + detail::ConvertIntToE8M0_CPU(val, rounding::upward, saturation::finite); return *this; } template > fp8_e8m0_x &operator=(unsigned short val) { vals[0] = - detail::ConvertToE8M0_CPU(val, rounding::upward, saturation::finite); + detail::ConvertIntToE8M0_CPU(val, rounding::upward, saturation::finite); return *this; } template > fp8_e8m0_x &operator=(unsigned int val) { vals[0] = - detail::ConvertToE8M0_CPU(val, rounding::upward, saturation::finite); + detail::ConvertIntToE8M0_CPU(val, rounding::upward, saturation::finite); return *this; } template > fp8_e8m0_x &operator=(unsigned long val) { vals[0] = - detail::ConvertToE8M0_CPU(val, rounding::upward, saturation::finite); + detail::ConvertIntToE8M0_CPU(val, rounding::upward, saturation::finite); return *this; } template > fp8_e8m0_x &operator=(unsigned long long val) { vals[0] = - detail::ConvertToE8M0_CPU(val, rounding::upward, saturation::finite); + detail::ConvertIntToE8M0_CPU(val, rounding::upward, saturation::finite); return *this; } From 956f22bb4b5d9fb8430f9de3294a6a711fd96501 Mon Sep 17 00:00:00 2001 From: "Klochkov, Denis" Date: Thu, 16 Apr 2026 16:22:43 +0200 Subject: [PATCH 33/89] [SYCL] convert bytes before fp8 --- .../oneapi/experimental/float_8bit/types.hpp | 546 +++++++++++++++--- sycl/unittests/Extensions/fp8/fp8_e5m2.cpp | 32 + 2 files changed, 494 insertions(+), 84 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp index b314132123da8..7bc2a0647789b 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp @@ -14,10 +14,10 @@ #include #include #include +#include #include #include #include -#include #ifdef __SYCL_DEVICE_ONLY__ // FP8 builtins @@ -187,6 +187,20 @@ template static inline ToT MakeDirectNaN() noexcept { } } +template static inline ToT MakeDirectInf(bool negative) noexcept { + if constexpr (std::is_same_v || + std::is_same_v) { + using Traits = DirectBinary16Traits; + const uint16_t sign = negative ? Traits::SignMask : 0u; + return sycl::bit_cast(static_cast(sign | Traits::InfBits)); + } else if constexpr (std::numeric_limits::has_infinity) { + return negative ? -std::numeric_limits::infinity() + : std::numeric_limits::infinity(); + } else { + return ToT{}; + } +} + template static inline ToT ConvertFloatToTarget(bool negative, uint32_t significand, int exp2, int srcFracBits, @@ -343,6 +357,33 @@ template <> struct E8M0SourceTraits { static constexpr int Bias = 1023; }; +template <> struct E8M0SourceTraits { + using UInt = uint8_t; + using UnsignedT = std::make_unsigned_t; + + static constexpr bool IsIntegral = true; + static constexpr bool IsSigned = std::numeric_limits::is_signed; + static constexpr int ValueBits = std::numeric_limits::digits; +}; + +template <> struct E8M0SourceTraits { + using UInt = uint8_t; + using UnsignedT = std::make_unsigned_t; + + static constexpr bool IsIntegral = true; + static constexpr bool IsSigned = true; + static constexpr int ValueBits = std::numeric_limits::digits; +}; + +template <> struct E8M0SourceTraits { + using UInt = uint8_t; + using UnsignedT = std::make_unsigned_t; + + static constexpr bool IsIntegral = true; + static constexpr bool IsSigned = false; + static constexpr int ValueBits = std::numeric_limits::digits; +}; + template <> struct E8M0SourceTraits { using UInt = uint16_t; using UnsignedT = std::make_unsigned_t; @@ -448,6 +489,216 @@ static inline uint8_t ConvertIntToE8M0_CPU(T f, rounding R, return static_cast(127 + lowerExp + (roundUp ? 1 : 0)); } +template > +static inline uint8_t ConvertIntToE4M3_CPU(T f, rounding R, + saturation S) noexcept { + using UnsignedT = typename Traits::UnsignedT; + + constexpr uint8_t MaxFiniteCode = 0x7Eu; + constexpr uint8_t NaNCode = 0x7Fu; + constexpr int TargetBias = 7; + constexpr int TargetEmax = 8; + constexpr int TargetFracBits = 3; + + const uint8_t sign = + (Traits::IsSigned && f < 0) ? static_cast(0x80u) : 0u; + UnsignedT magnitude = 0; + + if constexpr (Traits::IsSigned) { + const UnsignedT bits = static_cast(f); + magnitude = f < 0 ? static_cast(UnsignedT{0} - bits) : bits; + } else { + magnitude = static_cast(f); + } + + if (magnitude == 0) + return sign; + + int unbiasedExp = BitWidth(static_cast(magnitude)) - 1; + if (unbiasedExp > TargetEmax) + return static_cast( + sign | (S == saturation::finite ? MaxFiniteCode : NaNCode)); + + const int shift = unbiasedExp - TargetFracBits; + uint64_t mantissa = 0u; + if (shift <= 0) { + mantissa = static_cast(magnitude) << (-shift); + } else { + const uint64_t truncated = static_cast(magnitude) >> shift; + const uint64_t remainderMask = (uint64_t{1} << shift) - 1u; + const uint64_t remainder = static_cast(magnitude) & remainderMask; + + mantissa = truncated; + if (remainder != 0u) { + if (R == rounding::upward) { + if (sign == 0u) + ++mantissa; + } else if (R == rounding::to_even) { + const uint64_t half = uint64_t{1} << (shift - 1); + if (remainder > half || + (remainder == half && (truncated & uint64_t{1}) != 0u)) { + ++mantissa; + } + } + } + } + + if (mantissa >= 16u) { + mantissa = 8u; + ++unbiasedExp; + } + + if (unbiasedExp > TargetEmax) + return static_cast( + sign | (S == saturation::finite ? MaxFiniteCode : NaNCode)); + + if (unbiasedExp == TargetEmax && mantissa > 14u) + return static_cast( + sign | (S == saturation::finite ? MaxFiniteCode : NaNCode)); + + const uint8_t expField = static_cast(unbiasedExp + TargetBias); + const uint8_t fracField = static_cast(mantissa - 8u); + return static_cast(sign | static_cast(expField << 3) | + fracField); +} + +template > +static inline uint8_t ConvertFloatToE4M3_CPU(T f, rounding R, + saturation S) noexcept { + using UInt = typename Traits::UInt; + + constexpr UInt SignMask = UInt{1} << (Traits::ExpBits + Traits::FracBits); + constexpr UInt FracMask = (UInt{1} << Traits::FracBits) - UInt{1}; + constexpr UInt ExpMask = ((UInt{1} << Traits::ExpBits) - UInt{1}) + << Traits::FracBits; + constexpr UInt ExpAllOnes = (UInt{1} << Traits::ExpBits) - UInt{1}; + constexpr uint8_t MaxFiniteCode = 0x7Eu; + constexpr uint8_t NaNCode = 0x7Fu; + constexpr int TargetBias = 7; + constexpr int TargetEmin = -6; + constexpr int TargetEmax = 8; + constexpr int TargetFracBits = 3; + + UInt bits; + __builtin_memcpy(&bits, &f, sizeof(bits)); + + const uint8_t sign = (bits & SignMask) ? 0x80u : 0x00u; + bits &= ~SignMask; + + const UInt exp = (bits & ExpMask) >> Traits::FracBits; + const UInt frac = bits & FracMask; + + if (exp == ExpAllOnes) { + if (frac != 0u) + return static_cast(sign | NaNCode); + return static_cast( + sign | (S == saturation::finite ? MaxFiniteCode : NaNCode)); + } + + if (exp == 0u && frac == 0u) + return sign; + + uint64_t significand = 0u; + int leadingBit = 0; + int unbiasedExp = 0; + + if (exp != 0u) { + significand = + (uint64_t{1} << Traits::FracBits) | static_cast(frac); + leadingBit = static_cast(Traits::FracBits); + unbiasedExp = static_cast(exp) - Traits::Bias; + } else { + significand = static_cast(frac); + uint64_t tmp = significand; + leadingBit = -1; + while (tmp != 0u) { + ++leadingBit; + tmp >>= 1; + } + unbiasedExp = + 1 - Traits::Bias - static_cast(Traits::FracBits) + leadingBit; + } + + auto roundShiftRight = [&](uint64_t value, int shift) -> uint64_t { + if (shift <= 0) + return value; + + if (shift >= 64) { + if (R == rounding::upward && sign == 0u && value != 0u) + return 1u; + return 0u; + } + + const uint64_t truncated = value >> shift; + const uint64_t remainderMask = (uint64_t{1} << shift) - 1u; + const uint64_t remainder = value & remainderMask; + + if (remainder == 0u) + return truncated; + + if (R == rounding::toward_zero) + return truncated; + + if (R == rounding::upward) + return sign == 0u ? truncated + 1u : truncated; + + const uint64_t half = uint64_t{1} << (shift - 1); + if (remainder > half) + return truncated + 1u; + if (remainder < half) + return truncated; + return (truncated & 1u) != 0u ? truncated + 1u : truncated; + }; + + if (unbiasedExp > TargetEmax) + return static_cast( + sign | (S == saturation::finite ? MaxFiniteCode : NaNCode)); + + if (unbiasedExp == TargetEmax) { + const uint64_t lhs = significand << TargetFracBits; + const uint64_t rhs = 14ull << leadingBit; + if (lhs > rhs) + return static_cast( + sign | (S == saturation::finite ? MaxFiniteCode : NaNCode)); + } + + if (unbiasedExp < TargetEmin) { + const int shift = leadingBit - unbiasedExp - 9; + uint64_t mantissa = shift > 0 ? roundShiftRight(significand, shift) + : (significand << (-shift)); + + if (mantissa == 0u) + return sign; + + if (mantissa >= 8u) + return static_cast(sign | 0x08u); + + return static_cast(sign | static_cast(mantissa)); + } + + const int shift = leadingBit - TargetFracBits; + uint64_t mantissa = shift > 0 ? roundShiftRight(significand, shift) + : (significand << (-shift)); + + if (mantissa >= 16u) { + mantissa = 8u; + ++unbiasedExp; + } + + if (unbiasedExp > TargetEmax) + return static_cast( + sign | (S == saturation::finite ? MaxFiniteCode : NaNCode)); + + if (unbiasedExp == TargetEmax && mantissa > 14u) + return static_cast( + sign | (S == saturation::finite ? MaxFiniteCode : NaNCode)); + + const uint8_t expField = static_cast(unbiasedExp + TargetBias); + const uint8_t fracField = static_cast(mantissa - 8u); + return static_cast(sign | static_cast(expField << 3) | + fracField); +} + template > static inline uint8_t ConvertFloatToE8M0_CPU(T f, rounding R, saturation S) noexcept { @@ -485,14 +736,16 @@ static inline uint8_t ConvertFloatToE8M0_CPU(T f, rounding R, bool isExactPowerOfTwo = false; if (exp != 0u) { - significand = (uint64_t{1} << Traits::FracBits) | static_cast(frac); + significand = + (uint64_t{1} << Traits::FracBits) | static_cast(frac); leadingBit = static_cast(Traits::FracBits); lowerExp = static_cast(exp) - Traits::Bias; isExactPowerOfTwo = frac == 0u; } else { significand = static_cast(frac); leadingBit = BitWidth(significand) - 1; - lowerExp = 1 - Traits::Bias - static_cast(Traits::FracBits) + leadingBit; + lowerExp = + 1 - Traits::Bias - static_cast(Traits::FracBits) + leadingBit; isExactPowerOfTwo = (significand & (significand - 1u)) == 0u; } @@ -533,48 +786,121 @@ struct HasE8M0FloatTraits : std::false_type {}; template struct HasE8M0FloatTraits< - Traits, - std::void_t> : std::true_type {}; + Traits, std::void_t> : std::true_type {}; template struct HasE8M0IntegralTraits : std::false_type {}; template struct HasE8M0IntegralTraits< - Traits, std::void_t> : std::true_type {}; + Traits, + std::void_t> + : std::true_type {}; + +template > +static inline ToT ConvertFromE8M0ToBinaryFloat_CPU(uint8_t code, + rounding R) noexcept { + static_assert((Ebits == 8 && Mbits == 0) || (Ebits == 4 && Mbits == 3) || + (Ebits == 5 && Mbits == 2), + "Unsupported FP8 decode combination"); + + constexpr int SrcBias = (1 << (Ebits - 1)) - 1; + constexpr int SrcEmin = 1 - SrcBias; + constexpr uint8_t SrcExpAllOnes = static_cast((1u << Ebits) - 1u); + constexpr uint8_t SrcFracMask = + (Mbits == 0) ? 0u : static_cast((1u << Mbits) - 1u); + + bool negative = false; + uint32_t significand = 0u; + int exp2 = 0; + bool isNaN = false; + bool isInf = false; + + if constexpr (Ebits == 8 && Mbits == 0) { + if (code == 0xFFu) { + isNaN = true; + } else { + significand = 1u; + exp2 = static_cast(code) - SrcBias; + } + } else { + negative = (code & 0x80u) != 0u; + const uint8_t exp = static_cast((code >> Mbits) & SrcExpAllOnes); + const uint8_t frac = static_cast(code & SrcFracMask); + + if (exp == SrcExpAllOnes) { + if constexpr (Ebits == 5 && Mbits == 2) { + if (frac == 0u) + isInf = true; + else + isNaN = true; + } else if (frac == SrcFracMask) { + isNaN = true; + } else { + significand = static_cast((1u << Mbits) + frac); + exp2 = static_cast(exp) - SrcBias; + } + } else if (exp == 0u) { + if (frac == 0u) + significand = 0u; + else { + significand = frac; + exp2 = SrcEmin; + } + } else { + significand = static_cast((1u << Mbits) + frac); + exp2 = static_cast(exp) - SrcBias; + } + } -template > -static inline ToT ConvertFromE8M0ToBinaryFloat_CPU(uint8_t code) noexcept { if constexpr (HasE8M0FloatTraits::value) { using UInt = typename Traits::UInt; - constexpr UInt ExpAllOnes = - ((UInt{1} << Traits::ExpBits) - UInt{1}) << Traits::FracBits; + constexpr UInt ExpAllOnes = ((UInt{1} << Traits::ExpBits) - UInt{1}) + << Traits::FracBits; + constexpr UInt FracMask = (UInt{1} << Traits::FracBits) - UInt{1}; constexpr UInt QuietNaNBit = UInt{1} << (Traits::FracBits - 1); + constexpr UInt MaxFiniteBits = + (ExpAllOnes - (UInt{1} << Traits::FracBits)) | FracMask; constexpr int MinNormalExp = 1 - Traits::Bias; constexpr int MinSubnormalExp = MinNormalExp - static_cast(Traits::FracBits); constexpr int MaxNormalExp = - static_cast((UInt{1} << Traits::ExpBits) - UInt{2}) - - Traits::Bias; + static_cast((UInt{1} << Traits::ExpBits) - UInt{2}) - Traits::Bias; UInt bits = 0; - if (code == 0xFFu) { + if (isNaN) { bits = ExpAllOnes | QuietNaNBit; + } else if (isInf) { + bits = (negative ? (UInt{1} << (Traits::ExpBits + Traits::FracBits)) : 0u) | + ExpAllOnes; + } else if (significand == 0u) { + bits = negative ? (UInt{1} << (Traits::ExpBits + Traits::FracBits)) : 0u; } else { - const int unbiasedExp = static_cast(code) - 127; + const int sigBits = BitWidth(significand); + const int unbiasedExp = exp2 + sigBits - 1 - Mbits; + const UInt signBit = + negative ? (UInt{1} << (Traits::ExpBits + Traits::FracBits)) : 0u; + if (unbiasedExp > MaxNormalExp) { - bits = ExpAllOnes; + bits = signBit | + ((R == rounding::toward_zero) ? MaxFiniteBits : ExpAllOnes); } else if (unbiasedExp >= MinNormalExp) { - bits = static_cast(unbiasedExp + Traits::Bias) - << Traits::FracBits; + const int shift = static_cast(Traits::FracBits) - (sigBits - 1); + const UInt aligned = static_cast(significand) << shift; + const UInt expField = static_cast(unbiasedExp + Traits::Bias) + << Traits::FracBits; + bits = signBit | expField | (aligned & FracMask); } else if (unbiasedExp >= MinSubnormalExp) { - const int fracBit = - unbiasedExp - MinNormalExp + static_cast(Traits::FracBits); - bits = UInt{1} << fracBit; + const int subShift = + exp2 - Mbits - MinNormalExp + static_cast(Traits::FracBits); + const UInt fracField = static_cast(significand) << subShift; + bits = signBit | fracField; + } else if (R == rounding::upward && !negative) { + bits = UInt{1}; } } @@ -583,18 +909,64 @@ static inline ToT ConvertFromE8M0ToBinaryFloat_CPU(uint8_t code) noexcept { Traits::IsIntegral) { using UnsignedT = typename Traits::UnsignedT; - if (code == 0xFFu) + if (isNaN || isInf) return ToT{}; - const int shift = static_cast(code) - 127; - if (shift < 0 || shift >= Traits::ValueBits) + if (significand == 0u) return ToT{}; - const UnsignedT magnitude = UnsignedT{1} << shift; - return static_cast(magnitude); - } else { - return ToT{}; + const int shift = exp2 - Mbits; + uint64_t magnitude = 0u; + + if (shift >= 0) { + if (shift >= 64) + return ToT{}; + magnitude = static_cast(significand) << shift; + } else { + const int rshift = -shift; + if (rshift >= 64) { + if (R == rounding::upward && !negative) + magnitude = 1u; + } else { + magnitude = static_cast(significand) >> rshift; + const uint64_t remainderMask = (uint64_t{1} << rshift) - 1u; + const uint64_t remainder = + static_cast(significand) & remainderMask; + + if (remainder != 0u) { + if (R == rounding::upward) { + if (!negative) + ++magnitude; + } else if (R == rounding::to_even) { + const uint64_t half = uint64_t{1} << (rshift - 1); + if (remainder > half || + (remainder == half && (magnitude & 1u) != 0u)) { + ++magnitude; + } + } + } + } + } + + if (magnitude == 0u) + return ToT{}; + + if (BitWidth(magnitude) > Traits::ValueBits) + return ToT{}; + + const UnsignedT narrowed = static_cast(magnitude); + if constexpr (Traits::IsSigned) + return static_cast(negative ? -static_cast(narrowed) + : static_cast(narrowed)); + return static_cast(narrowed); } + + if (isNaN) + return MakeDirectNaN(); + if (isInf) + return MakeDirectInf(negative); + + return ToT{}; } /// \brief Converts a given value to fp8 floating point with a rounding @@ -758,16 +1130,10 @@ static inline ToT ConvertFromE8M0_CPU(uint8_t code, rounding R) noexcept { if constexpr (HasE8M0FloatTraits::value || HasE8M0IntegralTraits::value) { - (void)R; - return ConvertFromE8M0ToBinaryFloat_CPU(code); + return ConvertFromE8M0ToBinaryFloat_CPU<8, 0, ToT>(code, R); } - /* constexpr int Bias = 127; - if (code == 0xFF) { - return MakeDirectNaN(); - } - return ConvertFloatToTarget(false, 1u, static_cast(code) - Bias, 0, - R);*/ + return ToT{}; } } // namespace detail @@ -788,7 +1154,15 @@ template class fp8_e4m3_x { #ifdef __SYCL_DEVICE_ONLY__ return __builtin_spirv_ClampConvertFP16ToE4M3INTEL(h); #else - return detail::ConvertToFP8_CPU<4, 3, sycl::half>(h, rounding::to_even); + if constexpr (std::is_same_v, sycl::half> || + std::is_same_v, float> || + std::is_same_v, double>) { + return detail::ConvertFloatToE4M3_CPU(h, rounding::to_even, + saturation::finite); + } else if constexpr (std::is_integral_v>) { + return detail::ConvertIntToE4M3_CPU(h, rounding::to_even, + saturation::finite); + } #endif } @@ -796,16 +1170,18 @@ template class fp8_e4m3_x { #ifdef __SYCL_DEVICE_ONLY__ return __builtin_spirv_ClampConvertBF16ToE4M3INTEL(h); #else - return detail::ConvertToFP8_CPU<4, 3, bfloat16>(h, rounding::to_even); + return detail::ConvertFloatToE4M3_CPU(h, rounding::to_even, + saturation::finite); #endif } - template T ConvertFromFP8(uint8_t v) const { + template + T ConvertFromFP8(uint8_t v, rounding r = rounding::to_even) const { #ifdef __SYCL_DEVICE_ONLY__ sycl::half hi = __builtin_spirv_ConvertE4M3ToFP16EXT(v); return static_cast(hi); #else - return detail::ConvertFromFP8_CPU<4, 3, T>(v); + return detail::ConvertFromE8M0ToBinaryFloat_CPU<4, 3, T>(v, r); #endif } @@ -813,7 +1189,8 @@ template class fp8_e4m3_x { #ifdef __SYCL_DEVICE_ONLY__ return __builtin_spirv_ConvertE4M3ToBF16EXT(v); #else - return detail::ConvertFromFP8_CPU<4, 3, bfloat16>(v); + return detail::ConvertFromE8M0ToBinaryFloat_CPU<4, 3, bfloat16>( + v, rounding::to_even); #endif } @@ -1047,56 +1424,56 @@ template class fp8_e4m3_x { template > explicit operator char() const { - return ConvertFromFP8(vals[0]); + return ConvertFromFP8(vals[0], rounding::toward_zero); } template > explicit operator signed char() const { - return ConvertFromFP8(vals[0]); + return ConvertFromFP8(vals[0], rounding::toward_zero); } template > explicit operator short() const { - return ConvertFromFP8(vals[0]); + return ConvertFromFP8(vals[0], rounding::toward_zero); } template > explicit operator int() const { - return ConvertFromFP8(vals[0]); + return ConvertFromFP8(vals[0], rounding::toward_zero); } template > explicit operator long() const { - return ConvertFromFP8(vals[0]); + return ConvertFromFP8(vals[0], rounding::toward_zero); } template > explicit operator long long() const { - return ConvertFromFP8(vals[0]); + return ConvertFromFP8(vals[0], rounding::toward_zero); } template > explicit operator unsigned char() const { - return ConvertFromFP8(vals[0]); + return ConvertFromFP8(vals[0], rounding::toward_zero); } template > explicit operator unsigned short() const { - return ConvertFromFP8(vals[0]); + return ConvertFromFP8(vals[0], rounding::toward_zero); } template > explicit operator unsigned int() const { - return ConvertFromFP8(vals[0]); + return ConvertFromFP8(vals[0], rounding::toward_zero); } template > explicit operator unsigned long() const { - return ConvertFromFP8(vals[0]); + return ConvertFromFP8(vals[0], rounding::toward_zero); } template > explicit operator unsigned long long() const { - return ConvertFromFP8(vals[0]); + return ConvertFromFP8(vals[0], rounding::toward_zero); } // Convert to bool @@ -1171,12 +1548,13 @@ template class fp8_e5m2_x { #endif } - template T ConvertFromFP8(uint8_t v) const { + template + T ConvertFromFP8(uint8_t v, rounding r = rounding::to_even) const { #ifdef __SYCL_DEVICE_ONLY__ sycl::half hi = __builtin_spirv_ConvertE5M2ToFP16EXT(v); return static_cast(hi); #else - return detail::ConvertFromFP8_CPU<5, 2, T>(v); + return detail::ConvertFromE8M0ToBinaryFloat_CPU<5, 2, T>(v, r); #endif } @@ -1184,7 +1562,8 @@ template class fp8_e5m2_x { #ifdef __SYCL_DEVICE_ONLY__ return __builtin_spirv_ConvertE5M2ToBF16EXT(v); #else - return detail::ConvertFromFP8_CPU<5, 2, bfloat16>(v); + return detail::ConvertFromE8M0ToBinaryFloat_CPU<5, 2, bfloat16>( + v, rounding::to_even); #endif } @@ -1546,57 +1925,57 @@ template class fp8_e5m2_x { template > explicit operator char() const { - return ConvertFromFP8(vals[0]); + return ConvertFromFP8(vals[0], rounding::toward_zero); } template > explicit operator signed char() const { - return ConvertFromFP8(vals[0]); + return ConvertFromFP8(vals[0], rounding::toward_zero); } template > explicit operator short() const { - return ConvertFromFP8(vals[0]); + return ConvertFromFP8(vals[0], rounding::toward_zero); } template > explicit operator int() const { - return ConvertFromFP8(vals[0]); + return ConvertFromFP8(vals[0], rounding::toward_zero); } template > explicit operator long() const { - return ConvertFromFP8(vals[0]); + return ConvertFromFP8(vals[0], rounding::toward_zero); } template > explicit operator long long() const { - return ConvertFromFP8(vals[0]); + return ConvertFromFP8(vals[0], rounding::toward_zero); } template > explicit operator unsigned char() const { - return ConvertFromFP8(vals[0]); + return ConvertFromFP8(vals[0], rounding::toward_zero); } template > explicit operator unsigned short() const { - return ConvertFromFP8(vals[0]); + return ConvertFromFP8(vals[0], rounding::toward_zero); } template > explicit operator unsigned int() const { - return ConvertFromFP8(vals[0]); + return ConvertFromFP8(vals[0], rounding::toward_zero); } template > explicit operator unsigned long() const { - return ConvertFromFP8(vals[0]); + return ConvertFromFP8(vals[0], rounding::toward_zero); } template > explicit operator unsigned long long() const { - return ConvertFromFP8(vals[0]); + return ConvertFromFP8(vals[0], rounding::toward_zero); } // Convert to bool @@ -1726,65 +2105,64 @@ template class fp8_e8m0_x { template > explicit fp8_e8m0_x(int val) { vals[0] = - detail::ConvertIntToE8M0_CPU(val, rounding::upward, saturation::finite); + detail::ConvertIntToE8M0_CPU(val, rounding::upward, saturation::finite); } template > explicit fp8_e8m0_x(long val) { vals[0] = - detail::ConvertIntToE8M0_CPU(val, rounding::upward, saturation::finite); + detail::ConvertIntToE8M0_CPU(val, rounding::upward, saturation::finite); } template > explicit fp8_e8m0_x(long long val) { vals[0] = - detail::ConvertIntToE8M0_CPU(val, rounding::upward, saturation::finite); + detail::ConvertIntToE8M0_CPU(val, rounding::upward, saturation::finite); } template > explicit fp8_e8m0_x(unsigned short val) { vals[0] = - detail::ConvertIntToE8M0_CPU(val, rounding::upward, saturation::finite); + detail::ConvertIntToE8M0_CPU(val, rounding::upward, saturation::finite); } template > explicit fp8_e8m0_x(unsigned int val) { vals[0] = - detail::ConvertIntToE8M0_CPU(val, rounding::upward, saturation::finite); + detail::ConvertIntToE8M0_CPU(val, rounding::upward, saturation::finite); } template > explicit fp8_e8m0_x(unsigned long val) { vals[0] = - detail::ConvertIntToE8M0_CPU(val, rounding::upward, saturation::finite); + detail::ConvertIntToE8M0_CPU(val, rounding::upward, saturation::finite); } template > explicit fp8_e8m0_x(unsigned long long val) { vals[0] = - detail::ConvertIntToE8M0_CPU(val, rounding::upward, saturation::finite); + detail::ConvertIntToE8M0_CPU(val, rounding::upward, saturation::finite); } template > fp8_e8m0_x &operator=(half val) { - vals[0] = - detail::ConvertFloatToE8M0_CPU(val, rounding::upward, saturation::finite); + vals[0] = detail::ConvertFloatToE8M0_CPU(val, rounding::upward, + saturation::finite); return *this; } template > fp8_e8m0_x &operator=(bfloat16 val) { - vals[0] = - detail::ConvertFloatToE8M0_CPU(val, rounding::upward, - saturation::finite); + vals[0] = detail::ConvertFloatToE8M0_CPU(val, rounding::upward, + saturation::finite); return *this; } template > fp8_e8m0_x &operator=(float val) { - vals[0] = - detail::ConvertFloatToE8M0_CPU(val, rounding::upward, saturation::finite); + vals[0] = detail::ConvertFloatToE8M0_CPU(val, rounding::upward, + saturation::finite); return *this; } template > fp8_e8m0_x &operator=(double val) { - vals[0] = - detail::ConvertFloatToE8M0_CPU(val, rounding::upward, saturation::finite); + vals[0] = detail::ConvertFloatToE8M0_CPU(val, rounding::upward, + saturation::finite); return *this; } diff --git a/sycl/unittests/Extensions/fp8/fp8_e5m2.cpp b/sycl/unittests/Extensions/fp8/fp8_e5m2.cpp index 3ca8a3651cc49..05d5ec75874b8 100644 --- a/sycl/unittests/Extensions/fp8/fp8_e5m2.cpp +++ b/sycl/unittests/Extensions/fp8/fp8_e5m2.cpp @@ -85,6 +85,38 @@ TEST(FP8E5M2Test, VariadicNaNEncodingFloat) { EXPECT_EQ(a.vals[1], 0xFF); // -NaN -> 0b1_11111_11 } +TEST(FP8E5M2Test, RawInfinityAndNaNDecoding) { + fp8_e5m2 pos_inf; + fp8_e5m2 neg_inf; + fp8_e5m2 qnan; + + pos_inf.vals[0] = 0x7C; // +inf -> 0b0_11111_00 + neg_inf.vals[0] = 0xFC; // -inf -> 0b1_11111_00 + qnan.vals[0] = 0x7D; // +NaN -> 0b0_11111_01 + + const float pos_inf_f = static_cast(pos_inf); + const float neg_inf_f = static_cast(neg_inf); + const float qnan_f = static_cast(qnan); + + EXPECT_TRUE(std::isinf(pos_inf_f)); + EXPECT_GT(pos_inf_f, 0.0f); + EXPECT_TRUE(std::isinf(neg_inf_f)); + EXPECT_LT(neg_inf_f, 0.0f); + EXPECT_TRUE(std::isnan(qnan_f)); + + const sycl::half pos_inf_h = static_cast(pos_inf); + const sycl::ext::oneapi::bfloat16 neg_inf_bf16 = + static_cast(neg_inf); + const sycl::ext::oneapi::bfloat16 qnan_bf16 = + static_cast(qnan); + + EXPECT_TRUE(std::isinf(static_cast(pos_inf_h))); + EXPECT_GT(static_cast(pos_inf_h), 0.0f); + EXPECT_TRUE(std::isinf(static_cast(neg_inf_bf16))); + EXPECT_LT(static_cast(neg_inf_bf16), 0.0f); + EXPECT_TRUE(std::isnan(static_cast(qnan_bf16))); +} + TEST(FP8E5M2Test, IntegerConstructorToEvenFiniteAndSize) { fp8_e5m2 a0(0); fp8_e5m2 a1(1); From c5e6d91c41d395ec0568ba6d9b1e6c7deeddb5f7 Mon Sep 17 00:00:00 2001 From: "Klochkov, Denis" Date: Thu, 16 Apr 2026 17:22:36 +0200 Subject: [PATCH 34/89] [SYCL] remove unused function --- .../oneapi/experimental/float_8bit/types.hpp | 302 +++++++----------- 1 file changed, 108 insertions(+), 194 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp index 7bc2a0647789b..e462614d2f5b3 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp @@ -201,131 +201,6 @@ template static inline ToT MakeDirectInf(bool negative) noexcept } } -template -static inline ToT ConvertFloatToTarget(bool negative, uint32_t significand, - int exp2, int srcFracBits, - rounding R) noexcept { - if (significand == 0u) - return negative ? -ToT{0} : ToT{0}; - - if constexpr (std::is_same_v || - std::is_same_v) { - using Traits = DirectBinary16Traits; - const uint16_t sign = negative ? Traits::SignMask : 0u; - const int sigBits = BitWidth(significand); - const int unbiasedExp = exp2 + sigBits - 1 - srcFracBits; - - if (unbiasedExp > Traits::Emax) { - return sycl::bit_cast(static_cast( - sign | (R == rounding::toward_zero ? Traits::MaxFiniteBits - : Traits::InfBits))); - } - - if (unbiasedExp >= Traits::Emin) { - const int shift = Traits::FracBits - (sigBits - 1); - const uint32_t aligned = significand << shift; - const uint16_t expField = - static_cast(unbiasedExp + Traits::Bias) << Traits::FracBits; - const uint16_t fracField = - static_cast(aligned & Traits::FracMask); - return sycl::bit_cast( - static_cast(sign | expField | fracField)); - } - - const int subShift = exp2 - srcFracBits - Traits::Emin + Traits::FracBits; - if (subShift < 0) - return sycl::bit_cast(sign); - - const uint32_t fracField = significand << subShift; - if (fracField == 0u || fracField > Traits::FracMask) - return sycl::bit_cast(sign); - - return sycl::bit_cast( - static_cast(sign | static_cast(fracField))); - } else if constexpr (std::is_floating_point_v) { - ToT magnitude = - std::ldexp(static_cast(significand), exp2 - srcFracBits); - return negative ? -magnitude : magnitude; - } else if constexpr (std::is_integral_v) { - const int shift = exp2 - srcFracBits; - uint64_t magnitude = significand; - if (shift >= 0) - magnitude <<= shift; - else if (-shift < 64) - magnitude >>= -shift; - else - magnitude = 0u; - - if constexpr (std::is_signed_v) { - int64_t signedMagnitude = static_cast(magnitude); - return static_cast(negative ? -signedMagnitude : signedMagnitude); - } else - return static_cast(magnitude); - } else - return ToT{}; -} - -template -static inline ToT ConvertFromFP8_CPU(uint8_t b, - rounding R = rounding::to_even) noexcept { - static_assert((Ebits == 4 && Mbits == 3) || (Ebits == 5 && Mbits == 2) || - (Ebits == 8 && Mbits == 0), - "Unsupported FP8 (Ebits,Mbits) combination"); - - constexpr int Bias = (1 << (Ebits - 1)) - 1; - constexpr int Emin = 1 - Bias; - constexpr uint8_t ExpMaskAll = static_cast((1u << Ebits) - 1u); - constexpr uint32_t FracDen = (Mbits == 0) ? 1u : (1u << Mbits); - constexpr uint8_t MaxFrac = static_cast(FracDen - 1u); - - // Extract fields. - uint8_t sign_bit = (b & 0x80u) ? 1u : 0u; - uint8_t frac = (Mbits == 0) ? 0u : static_cast(b & MaxFrac); - - uint8_t exp = static_cast((b >> Mbits) & ExpMaskAll); - if constexpr (Ebits == 8 && Mbits == 0) { - sign_bit = 0u; - exp = b; - } - - auto make_nan = [&]() -> ToT { return MakeDirectNaN(); }; - - // Handle exp = all ones (custom finite-only rules). - if (exp == ExpMaskAll) { - if constexpr (Ebits == 4 && Mbits == 3) { - // E4M3: only frac==111 -> NaN, otherwise normal. - if (frac == MaxFrac) - return make_nan(); - // treat as normal finite - } else if constexpr (Ebits == 5 && Mbits == 2) { - // E5M2: NaN when frac in {01,10,11} i.e. frac != 00 - if (frac != 0) - return make_nan(); - // frac==00 -> normal finite - } else // E8M0: exp all ones -> NaN - return make_nan(); - } - - // exp == 0 : zero or subnormal (if Mbits>0) - if (exp == 0) { - if constexpr (Mbits == 0) { - // E8M0: exp==0 is the smallest normal (no subnormals) - return ConvertFloatToTarget(false, 1u, -Bias, 0, R); - } else { - if (frac == 0) { - return ConvertFloatToTarget(sign_bit != 0u, 0u, 0, 0, R); - } - // Subnormal: value = sign * (frac / 2^Mbits) * 2^(Emin) - return ConvertFloatToTarget(sign_bit != 0u, frac, Emin, Mbits, R); - } - } - - // Normal number. - int E = static_cast(exp) - Bias; - const uint32_t significand = - (Mbits == 0) ? 1u : (static_cast(FracDen) + frac); - return ConvertFloatToTarget(sign_bit != 0u, significand, E, Mbits, R); -} template struct E8M0SourceTraits; @@ -456,6 +331,32 @@ template <> struct E8M0SourceTraits { static constexpr int ValueBits = std::numeric_limits::digits; }; + template struct FP8FiniteFormatTraits { + static_assert((Ebits == 4 && Mbits == 3) || (Ebits == 5 && Mbits == 2), + "Unsupported FP8 finite format"); + + static constexpr uint8_t ExpAllOnes = static_cast((1u << Ebits) - 1u); + static constexpr uint8_t MaxFrac = static_cast((1u << Mbits) - 1u); + static constexpr int Bias = (1 << (Ebits - 1)) - 1; + static constexpr int Emin = 1 - Bias; + static constexpr bool HasInfinity = (Ebits == 5 && Mbits == 2); + static constexpr uint8_t MaxFiniteExpField = + HasInfinity ? static_cast(ExpAllOnes - 1u) : ExpAllOnes; + static constexpr uint8_t MaxFiniteFracField = + (Ebits == 4 && Mbits == 3) ? static_cast(MaxFrac - 1u) + : MaxFrac; + static constexpr uint8_t MaxFiniteCode = + static_cast((MaxFiniteExpField << Mbits) | MaxFiniteFracField); + static constexpr uint8_t NaNCode = + static_cast((ExpAllOnes << Mbits) | MaxFrac); + static constexpr uint8_t InfinityCode = static_cast(ExpAllOnes << Mbits); + static constexpr int MaxFiniteExp = static_cast(MaxFiniteExpField) - Bias; + static constexpr uint64_t MinNormalMantissa = uint64_t{1} << Mbits; + static constexpr uint64_t OverflowMantissa = uint64_t{1} << (Mbits + 1); + static constexpr uint64_t MaxFiniteMantissa = + MinNormalMantissa + MaxFiniteFracField; + }; + template > static inline uint8_t ConvertIntToE8M0_CPU(T f, rounding R, saturation S) noexcept { @@ -489,16 +390,20 @@ static inline uint8_t ConvertIntToE8M0_CPU(T f, rounding R, return static_cast(127 + lowerExp + (roundUp ? 1 : 0)); } -template > +template > static inline uint8_t ConvertIntToE4M3_CPU(T f, rounding R, saturation S) noexcept { using UnsignedT = typename Traits::UnsignedT; - - constexpr uint8_t MaxFiniteCode = 0x7Eu; - constexpr uint8_t NaNCode = 0x7Fu; - constexpr int TargetBias = 7; - constexpr int TargetEmax = 8; - constexpr int TargetFracBits = 3; + using Format = FP8FiniteFormatTraits; + + auto getOverflowCode = [&]() -> uint8_t { + if (S == saturation::finite) + return Format::MaxFiniteCode; + if constexpr (Format::HasInfinity) + return Format::InfinityCode; + return Format::NaNCode; + }; const uint8_t sign = (Traits::IsSigned && f < 0) ? static_cast(0x80u) : 0u; @@ -515,11 +420,10 @@ static inline uint8_t ConvertIntToE4M3_CPU(T f, rounding R, return sign; int unbiasedExp = BitWidth(static_cast(magnitude)) - 1; - if (unbiasedExp > TargetEmax) - return static_cast( - sign | (S == saturation::finite ? MaxFiniteCode : NaNCode)); + if (unbiasedExp > Format::MaxFiniteExp) + return static_cast(sign | getOverflowCode()); - const int shift = unbiasedExp - TargetFracBits; + const int shift = unbiasedExp - Mbits; uint64_t mantissa = 0u; if (shift <= 0) { mantissa = static_cast(magnitude) << (-shift); @@ -543,41 +447,45 @@ static inline uint8_t ConvertIntToE4M3_CPU(T f, rounding R, } } - if (mantissa >= 16u) { - mantissa = 8u; + if (mantissa >= Format::OverflowMantissa) { + mantissa = Format::MinNormalMantissa; ++unbiasedExp; } - if (unbiasedExp > TargetEmax) - return static_cast( - sign | (S == saturation::finite ? MaxFiniteCode : NaNCode)); + if (unbiasedExp > Format::MaxFiniteExp) + return static_cast(sign | getOverflowCode()); - if (unbiasedExp == TargetEmax && mantissa > 14u) - return static_cast( - sign | (S == saturation::finite ? MaxFiniteCode : NaNCode)); + if (unbiasedExp == Format::MaxFiniteExp && + mantissa > Format::MaxFiniteMantissa) + return static_cast(sign | getOverflowCode()); - const uint8_t expField = static_cast(unbiasedExp + TargetBias); - const uint8_t fracField = static_cast(mantissa - 8u); - return static_cast(sign | static_cast(expField << 3) | + const uint8_t expField = static_cast(unbiasedExp + Format::Bias); + const uint8_t fracField = + static_cast(mantissa - Format::MinNormalMantissa); + return static_cast(sign | static_cast(expField << Mbits) | fracField); } -template > +template > static inline uint8_t ConvertFloatToE4M3_CPU(T f, rounding R, saturation S) noexcept { using UInt = typename Traits::UInt; + using Format = FP8FiniteFormatTraits; constexpr UInt SignMask = UInt{1} << (Traits::ExpBits + Traits::FracBits); constexpr UInt FracMask = (UInt{1} << Traits::FracBits) - UInt{1}; constexpr UInt ExpMask = ((UInt{1} << Traits::ExpBits) - UInt{1}) << Traits::FracBits; constexpr UInt ExpAllOnes = (UInt{1} << Traits::ExpBits) - UInt{1}; - constexpr uint8_t MaxFiniteCode = 0x7Eu; - constexpr uint8_t NaNCode = 0x7Fu; - constexpr int TargetBias = 7; - constexpr int TargetEmin = -6; - constexpr int TargetEmax = 8; - constexpr int TargetFracBits = 3; + + auto getOverflowCode = [&](uint8_t sign) -> uint8_t { + if (S == saturation::finite) + return static_cast(sign | Format::MaxFiniteCode); + if constexpr (Format::HasInfinity) + return static_cast(sign | Format::InfinityCode); + return static_cast(sign | Format::NaNCode); + }; UInt bits; __builtin_memcpy(&bits, &f, sizeof(bits)); @@ -590,9 +498,8 @@ static inline uint8_t ConvertFloatToE4M3_CPU(T f, rounding R, if (exp == ExpAllOnes) { if (frac != 0u) - return static_cast(sign | NaNCode); - return static_cast( - sign | (S == saturation::finite ? MaxFiniteCode : NaNCode)); + return static_cast(sign | Format::NaNCode); + return getOverflowCode(sign); } if (exp == 0u && frac == 0u) @@ -650,52 +557,50 @@ static inline uint8_t ConvertFloatToE4M3_CPU(T f, rounding R, return (truncated & 1u) != 0u ? truncated + 1u : truncated; }; - if (unbiasedExp > TargetEmax) - return static_cast( - sign | (S == saturation::finite ? MaxFiniteCode : NaNCode)); + if (unbiasedExp > Format::MaxFiniteExp) + return getOverflowCode(sign); - if (unbiasedExp == TargetEmax) { - const uint64_t lhs = significand << TargetFracBits; - const uint64_t rhs = 14ull << leadingBit; + if (unbiasedExp == Format::MaxFiniteExp) { + const uint64_t lhs = significand << Mbits; + const uint64_t rhs = Format::MaxFiniteMantissa << leadingBit; if (lhs > rhs) - return static_cast( - sign | (S == saturation::finite ? MaxFiniteCode : NaNCode)); + return getOverflowCode(sign); } - if (unbiasedExp < TargetEmin) { - const int shift = leadingBit - unbiasedExp - 9; + if (unbiasedExp < Format::Emin) { + const int shift = leadingBit - unbiasedExp - Format::Bias - Mbits + 1; uint64_t mantissa = shift > 0 ? roundShiftRight(significand, shift) : (significand << (-shift)); if (mantissa == 0u) return sign; - if (mantissa >= 8u) - return static_cast(sign | 0x08u); + if (mantissa >= Format::MinNormalMantissa) + return static_cast(sign | (uint8_t{1} << Mbits)); return static_cast(sign | static_cast(mantissa)); } - const int shift = leadingBit - TargetFracBits; + const int shift = leadingBit - Mbits; uint64_t mantissa = shift > 0 ? roundShiftRight(significand, shift) : (significand << (-shift)); - if (mantissa >= 16u) { - mantissa = 8u; + if (mantissa >= Format::OverflowMantissa) { + mantissa = Format::MinNormalMantissa; ++unbiasedExp; } - if (unbiasedExp > TargetEmax) - return static_cast( - sign | (S == saturation::finite ? MaxFiniteCode : NaNCode)); + if (unbiasedExp > Format::MaxFiniteExp) + return getOverflowCode(sign); - if (unbiasedExp == TargetEmax && mantissa > 14u) - return static_cast( - sign | (S == saturation::finite ? MaxFiniteCode : NaNCode)); + if (unbiasedExp == Format::MaxFiniteExp && + mantissa > Format::MaxFiniteMantissa) + return getOverflowCode(sign); - const uint8_t expField = static_cast(unbiasedExp + TargetBias); - const uint8_t fracField = static_cast(mantissa - 8u); - return static_cast(sign | static_cast(expField << 3) | + const uint8_t expField = static_cast(unbiasedExp + Format::Bias); + const uint8_t fracField = + static_cast(mantissa - Format::MinNormalMantissa); + return static_cast(sign | static_cast(expField << Mbits) | fracField); } @@ -1157,11 +1062,11 @@ template class fp8_e4m3_x { if constexpr (std::is_same_v, sycl::half> || std::is_same_v, float> || std::is_same_v, double>) { - return detail::ConvertFloatToE4M3_CPU(h, rounding::to_even, - saturation::finite); + return detail::ConvertFloatToE4M3_CPU<4, 3, T>(h, rounding::to_even, + saturation::finite); } else if constexpr (std::is_integral_v>) { - return detail::ConvertIntToE4M3_CPU(h, rounding::to_even, - saturation::finite); + return detail::ConvertIntToE4M3_CPU<4, 3, T>(h, rounding::to_even, + saturation::finite); } #endif } @@ -1170,8 +1075,8 @@ template class fp8_e4m3_x { #ifdef __SYCL_DEVICE_ONLY__ return __builtin_spirv_ClampConvertBF16ToE4M3INTEL(h); #else - return detail::ConvertFloatToE4M3_CPU(h, rounding::to_even, - saturation::finite); + return detail::ConvertFloatToE4M3_CPU<4, 3, bfloat16>( + h, rounding::to_even, saturation::finite); #endif } @@ -1521,20 +1426,27 @@ template class fp8_e4m3_x { template class fp8_e5m2_x { static constexpr size_t NExpBits = 5; static constexpr size_t NFracBits = 2; - static constexpr float MaxNormal = 114688.0f; // 1.75 * 2^16 + static constexpr float MaxNormal = 57344.0f; // 1.75 * 2^15 static constexpr float MinSubnormal = 0.0000152587890625f; // 2^-16 - static constexpr uint8_t MaxFiniteCode = 0x7C; // 0.11111.00 + static constexpr uint8_t MaxFiniteCode = 0x7B; // 0.11110.11 static_assert(N == 1 || N == 2, "fp8_e5m2_x: Template argument N must be 1 or 2"); - uint8_t ConvertToFP8(sycl::half h, saturation s) { + template uint8_t ConvertToFP8(T h, saturation s) { #ifdef __SYCL_DEVICE_ONLY__ + const sycl::half halfValue = static_cast(h); return s == saturation::finite - ? __builtin_spirv_ClampConvertFP16ToE5M2INTEL(h) - : __builtin_spirv_ConvertFP16ToE5M2EXT(h); + ? __builtin_spirv_ClampConvertFP16ToE5M2INTEL(halfValue) + : __builtin_spirv_ConvertFP16ToE5M2EXT(halfValue); #else - return detail::ConvertToFP8_CPU<5, 2, sycl::half>(h, rounding::to_even); + if constexpr (std::is_same_v, sycl::half> || + std::is_same_v, float> || + std::is_same_v, double>) { + return detail::ConvertFloatToE4M3_CPU<5, 2, T>(h, rounding::to_even, s); + } else if constexpr (std::is_integral_v>) { + return detail::ConvertIntToE4M3_CPU<5, 2, T>(h, rounding::to_even, s); + } #endif } @@ -1544,7 +1456,9 @@ template class fp8_e5m2_x { ? __builtin_spirv_ClampConvertBF16ToE5M2INTEL(h) : __builtin_spirv_ConvertBF16ToE5M2EXT(h); #else - return detail::ConvertToFP8_CPU<5, 2, bfloat16>(h, rounding::to_even); + return detail::ConvertFloatToE4M3_CPU<5, 2, bfloat16>(h, + rounding::to_even, + s); #endif } From 726294add5bd6af20211995d6e7bf32f1f30bb08 Mon Sep 17 00:00:00 2001 From: "Klochkov, Denis" Date: Thu, 16 Apr 2026 18:05:16 +0200 Subject: [PATCH 35/89] [SYCL] rename functions and traits --- .../oneapi/experimental/float_8bit/types.hpp | 72 +++++++++---------- 1 file changed, 36 insertions(+), 36 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp index e462614d2f5b3..596815a8f0d07 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp @@ -202,37 +202,37 @@ template static inline ToT MakeDirectInf(bool negative) noexcept } -template struct E8M0SourceTraits; +template struct SourceTraits; -template <> struct E8M0SourceTraits { +template <> struct SourceTraits { using UInt = uint32_t; static constexpr size_t ExpBits = 8; static constexpr size_t FracBits = 23; static constexpr int Bias = 127; }; -template <> struct E8M0SourceTraits { +template <> struct SourceTraits { using UInt = uint16_t; static constexpr size_t ExpBits = 5; static constexpr size_t FracBits = 10; static constexpr int Bias = 15; }; -template <> struct E8M0SourceTraits { +template <> struct SourceTraits { using UInt = uint16_t; static constexpr size_t ExpBits = 8; static constexpr size_t FracBits = 7; static constexpr int Bias = 127; }; -template <> struct E8M0SourceTraits { +template <> struct SourceTraits { using UInt = uint64_t; static constexpr size_t ExpBits = 11; static constexpr size_t FracBits = 52; static constexpr int Bias = 1023; }; -template <> struct E8M0SourceTraits { +template <> struct SourceTraits { using UInt = uint8_t; using UnsignedT = std::make_unsigned_t; @@ -241,7 +241,7 @@ template <> struct E8M0SourceTraits { static constexpr int ValueBits = std::numeric_limits::digits; }; -template <> struct E8M0SourceTraits { +template <> struct SourceTraits { using UInt = uint8_t; using UnsignedT = std::make_unsigned_t; @@ -250,7 +250,7 @@ template <> struct E8M0SourceTraits { static constexpr int ValueBits = std::numeric_limits::digits; }; -template <> struct E8M0SourceTraits { +template <> struct SourceTraits { using UInt = uint8_t; using UnsignedT = std::make_unsigned_t; @@ -259,7 +259,7 @@ template <> struct E8M0SourceTraits { static constexpr int ValueBits = std::numeric_limits::digits; }; -template <> struct E8M0SourceTraits { +template <> struct SourceTraits { using UInt = uint16_t; using UnsignedT = std::make_unsigned_t; @@ -268,7 +268,7 @@ template <> struct E8M0SourceTraits { static constexpr int ValueBits = std::numeric_limits::digits; }; -template <> struct E8M0SourceTraits { +template <> struct SourceTraits { using UInt = uint32_t; using UnsignedT = std::make_unsigned_t; @@ -277,7 +277,7 @@ template <> struct E8M0SourceTraits { static constexpr int ValueBits = std::numeric_limits::digits; }; -template <> struct E8M0SourceTraits { +template <> struct SourceTraits { using UInt = std::make_unsigned_t; using UnsignedT = std::make_unsigned_t; @@ -286,7 +286,7 @@ template <> struct E8M0SourceTraits { static constexpr int ValueBits = std::numeric_limits::digits; }; -template <> struct E8M0SourceTraits { +template <> struct SourceTraits { using UInt = uint64_t; using UnsignedT = std::make_unsigned_t; @@ -295,7 +295,7 @@ template <> struct E8M0SourceTraits { static constexpr int ValueBits = std::numeric_limits::digits; }; -template <> struct E8M0SourceTraits { +template <> struct SourceTraits { using UInt = uint16_t; using UnsignedT = std::make_unsigned_t; @@ -304,7 +304,7 @@ template <> struct E8M0SourceTraits { static constexpr int ValueBits = std::numeric_limits::digits; }; -template <> struct E8M0SourceTraits { +template <> struct SourceTraits { using UInt = uint32_t; using UnsignedT = std::make_unsigned_t; @@ -313,7 +313,7 @@ template <> struct E8M0SourceTraits { static constexpr int ValueBits = std::numeric_limits::digits; }; -template <> struct E8M0SourceTraits { +template <> struct SourceTraits { using UInt = std::make_unsigned_t; using UnsignedT = std::make_unsigned_t; @@ -322,7 +322,7 @@ template <> struct E8M0SourceTraits { static constexpr int ValueBits = std::numeric_limits::digits; }; -template <> struct E8M0SourceTraits { +template <> struct SourceTraits { using UInt = uint64_t; using UnsignedT = std::make_unsigned_t; @@ -357,7 +357,7 @@ template <> struct E8M0SourceTraits { MinNormalMantissa + MaxFiniteFracField; }; -template > +template > static inline uint8_t ConvertIntToE8M0_CPU(T f, rounding R, saturation S) noexcept { using UnsignedT = typename Traits::UnsignedT; @@ -391,8 +391,8 @@ static inline uint8_t ConvertIntToE8M0_CPU(T f, rounding R, } template > -static inline uint8_t ConvertIntToE4M3_CPU(T f, rounding R, + typename Traits = SourceTraits> +static inline uint8_t ConvertIntToFP8_CPU(T f, rounding R, saturation S) noexcept { using UnsignedT = typename Traits::UnsignedT; using Format = FP8FiniteFormatTraits; @@ -467,8 +467,8 @@ static inline uint8_t ConvertIntToE4M3_CPU(T f, rounding R, } template > -static inline uint8_t ConvertFloatToE4M3_CPU(T f, rounding R, + typename Traits = SourceTraits> +static inline uint8_t ConvertFloatToFP8_CPU(T f, rounding R, saturation S) noexcept { using UInt = typename Traits::UInt; using Format = FP8FiniteFormatTraits; @@ -604,7 +604,7 @@ static inline uint8_t ConvertFloatToE4M3_CPU(T f, rounding R, fracField); } -template > +template > static inline uint8_t ConvertFloatToE8M0_CPU(T f, rounding R, saturation S) noexcept { using UInt = typename Traits::UInt; @@ -705,8 +705,8 @@ struct HasE8M0IntegralTraits< : std::true_type {}; template > -static inline ToT ConvertFromE8M0ToBinaryFloat_CPU(uint8_t code, + typename Traits = SourceTraits> +static inline ToT ConvertFromFP8ToBinaryFloat_CPU(uint8_t code, rounding R) noexcept { static_assert((Ebits == 8 && Mbits == 0) || (Ebits == 4 && Mbits == 3) || (Ebits == 5 && Mbits == 2), @@ -1031,11 +1031,11 @@ ConvertToFP8_CPU(T h, rounding R = rounding::to_even) noexcept { template static inline ToT ConvertFromE8M0_CPU(uint8_t code, rounding R) noexcept { - using Traits = E8M0SourceTraits; + using Traits = SourceTraits; if constexpr (HasE8M0FloatTraits::value || HasE8M0IntegralTraits::value) { - return ConvertFromE8M0ToBinaryFloat_CPU<8, 0, ToT>(code, R); + return ConvertFromFP8ToBinaryFloat_CPU<8, 0, ToT>(code, R); } return ToT{}; @@ -1062,10 +1062,10 @@ template class fp8_e4m3_x { if constexpr (std::is_same_v, sycl::half> || std::is_same_v, float> || std::is_same_v, double>) { - return detail::ConvertFloatToE4M3_CPU<4, 3, T>(h, rounding::to_even, + return detail::ConvertFloatToFP8_CPU<4, 3, T>(h, rounding::to_even, saturation::finite); } else if constexpr (std::is_integral_v>) { - return detail::ConvertIntToE4M3_CPU<4, 3, T>(h, rounding::to_even, + return detail::ConvertIntToFP8_CPU<4, 3, T>(h, rounding::to_even, saturation::finite); } #endif @@ -1075,7 +1075,7 @@ template class fp8_e4m3_x { #ifdef __SYCL_DEVICE_ONLY__ return __builtin_spirv_ClampConvertBF16ToE4M3INTEL(h); #else - return detail::ConvertFloatToE4M3_CPU<4, 3, bfloat16>( + return detail::ConvertFloatToFP8_CPU<4, 3, bfloat16>( h, rounding::to_even, saturation::finite); #endif } @@ -1086,7 +1086,7 @@ template class fp8_e4m3_x { sycl::half hi = __builtin_spirv_ConvertE4M3ToFP16EXT(v); return static_cast(hi); #else - return detail::ConvertFromE8M0ToBinaryFloat_CPU<4, 3, T>(v, r); + return detail::ConvertFromFP8ToBinaryFloat_CPU<4, 3, T>(v, r); #endif } @@ -1094,7 +1094,7 @@ template class fp8_e4m3_x { #ifdef __SYCL_DEVICE_ONLY__ return __builtin_spirv_ConvertE4M3ToBF16EXT(v); #else - return detail::ConvertFromE8M0ToBinaryFloat_CPU<4, 3, bfloat16>( + return detail::ConvertFromFP8ToBinaryFloat_CPU<4, 3, bfloat16>( v, rounding::to_even); #endif } @@ -1443,9 +1443,9 @@ template class fp8_e5m2_x { if constexpr (std::is_same_v, sycl::half> || std::is_same_v, float> || std::is_same_v, double>) { - return detail::ConvertFloatToE4M3_CPU<5, 2, T>(h, rounding::to_even, s); + return detail::ConvertFloatToFP8_CPU<5, 2, T>(h, rounding::to_even, s); } else if constexpr (std::is_integral_v>) { - return detail::ConvertIntToE4M3_CPU<5, 2, T>(h, rounding::to_even, s); + return detail::ConvertIntToFP8_CPU<5, 2, T>(h, rounding::to_even, s); } #endif } @@ -1456,7 +1456,7 @@ template class fp8_e5m2_x { ? __builtin_spirv_ClampConvertBF16ToE5M2INTEL(h) : __builtin_spirv_ConvertBF16ToE5M2EXT(h); #else - return detail::ConvertFloatToE4M3_CPU<5, 2, bfloat16>(h, + return detail::ConvertFloatToFP8_CPU<5, 2, bfloat16>(h, rounding::to_even, s); #endif @@ -1468,7 +1468,7 @@ template class fp8_e5m2_x { sycl::half hi = __builtin_spirv_ConvertE5M2ToFP16EXT(v); return static_cast(hi); #else - return detail::ConvertFromE8M0ToBinaryFloat_CPU<5, 2, T>(v, r); + return detail::ConvertFromFP8ToBinaryFloat_CPU<5, 2, T>(v, r); #endif } @@ -1476,7 +1476,7 @@ template class fp8_e5m2_x { #ifdef __SYCL_DEVICE_ONLY__ return __builtin_spirv_ConvertE5M2ToBF16EXT(v); #else - return detail::ConvertFromE8M0ToBinaryFloat_CPU<5, 2, bfloat16>( + return detail::ConvertFromFP8ToBinaryFloat_CPU<5, 2, bfloat16>( v, rounding::to_even); #endif } From ddb260d0d8b2780d255e63190657a0f34bce7892 Mon Sep 17 00:00:00 2001 From: "Klochkov, Denis" Date: Fri, 17 Apr 2026 10:42:25 +0200 Subject: [PATCH 36/89] [SYCL] remove unused functions --- .../oneapi/experimental/float_8bit/types.hpp | 341 +++--------------- sycl/unittests/Extensions/fp8/fp8_e8m0.cpp | 7 - 2 files changed, 58 insertions(+), 290 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp index 596815a8f0d07..bfe55241184ad 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp @@ -80,68 +80,8 @@ struct stochastic_seed { }; namespace detail { -static inline uint8_t RneClip(float x, uint8_t max) noexcept { - float f = std::floor(x); - float frac = x - f; - uint8_t i = static_cast(f); - if (frac > 0.5f) - ++i; - else if (frac == 0.5f) - i += (i & 1u); // ties to even - return i > max ? max : i; -} - -static inline uint8_t RoundClip(float x, uint8_t max, rounding R, - uint8_t sign_bit) noexcept { - if (max == 0) { - // No fraction bits (E8M0 path) - if (R == rounding::upward) { - // For sign-preserving formats, roundTowardPositive increments only for - // positive values with a non-zero residual. Negative values stay at the - // lower-magnitude encoding. - if (!std::isnan(x) && sign_bit == 0u && x > 0.0f) - return 1u; - return 0u; - } - if (R == rounding::toward_zero || std::isnan(x)) - return 0u; - if (x > 0.5f) - return 1u; - if (x == 0.5f) - return 0u; // tie -> even (0) - return 0u; - } - - // Formats with fraction bits (E4M3, E5M2) - if (R == rounding::upward) { - if (sign_bit == 0u) { - // Positive: ceil - uint32_t ci = static_cast(std::ceil(x)); - if (ci > max) - ci = max; - return static_cast(ci); - } else { - // Negative: toward +inf => magnitude decreases -> floor - uint32_t fi = static_cast(std::floor(x)); - if (fi > max) - fi = max; - return static_cast(fi); - } - } - // default: round-to-nearest-even - return RneClip(x, max); -} - -static inline int BitWidth(uint32_t x) noexcept { - int width = 0; - while (x != 0u) { - ++width; - x >>= 1; - } - return width; -} -static inline int BitWidth(uint64_t x) noexcept { +template static inline int BitWidth(T x) noexcept { int width = 0; while (x != 0u) { ++width; @@ -154,26 +94,14 @@ template struct DirectBinary16Traits; template <> struct DirectBinary16Traits { static constexpr uint16_t SignMask = 0x8000u; - static constexpr uint16_t FracMask = 0x03FFu; static constexpr uint16_t InfBits = 0x7C00u; - static constexpr uint16_t MaxFiniteBits = 0x7BFFu; static constexpr uint16_t QuietNaNBits = 0x7E00u; - static constexpr int FracBits = 10; - static constexpr int Bias = 15; - static constexpr int Emin = -14; - static constexpr int Emax = 15; }; template <> struct DirectBinary16Traits { static constexpr uint16_t SignMask = 0x8000u; - static constexpr uint16_t FracMask = 0x007Fu; static constexpr uint16_t InfBits = 0x7F80u; - static constexpr uint16_t MaxFiniteBits = 0x7F7Fu; static constexpr uint16_t QuietNaNBits = 0x7FC0u; - static constexpr int FracBits = 7; - static constexpr int Bias = 127; - static constexpr int Emin = -126; - static constexpr int Emax = 127; }; template static inline ToT MakeDirectNaN() noexcept { @@ -187,7 +115,8 @@ template static inline ToT MakeDirectNaN() noexcept { } } -template static inline ToT MakeDirectInf(bool negative) noexcept { +template +static inline ToT MakeDirectInf(bool negative) noexcept { if constexpr (std::is_same_v || std::is_same_v) { using Traits = DirectBinary16Traits; @@ -201,7 +130,6 @@ template static inline ToT MakeDirectInf(bool negative) noexcept } } - template struct SourceTraits; template <> struct SourceTraits { @@ -331,31 +259,33 @@ template <> struct SourceTraits { static constexpr int ValueBits = std::numeric_limits::digits; }; - template struct FP8FiniteFormatTraits { - static_assert((Ebits == 4 && Mbits == 3) || (Ebits == 5 && Mbits == 2), - "Unsupported FP8 finite format"); - - static constexpr uint8_t ExpAllOnes = static_cast((1u << Ebits) - 1u); - static constexpr uint8_t MaxFrac = static_cast((1u << Mbits) - 1u); - static constexpr int Bias = (1 << (Ebits - 1)) - 1; - static constexpr int Emin = 1 - Bias; - static constexpr bool HasInfinity = (Ebits == 5 && Mbits == 2); - static constexpr uint8_t MaxFiniteExpField = +template struct FP8FiniteFormatTraits { + static_assert((Ebits == 4 && Mbits == 3) || (Ebits == 5 && Mbits == 2), + "Unsupported FP8 finite format"); + + static constexpr uint8_t ExpAllOnes = + static_cast((1u << Ebits) - 1u); + static constexpr uint8_t MaxFrac = static_cast((1u << Mbits) - 1u); + static constexpr int Bias = (1 << (Ebits - 1)) - 1; + static constexpr int Emin = 1 - Bias; + static constexpr bool HasInfinity = (Ebits == 5 && Mbits == 2); + static constexpr uint8_t MaxFiniteExpField = HasInfinity ? static_cast(ExpAllOnes - 1u) : ExpAllOnes; - static constexpr uint8_t MaxFiniteFracField = - (Ebits == 4 && Mbits == 3) ? static_cast(MaxFrac - 1u) - : MaxFrac; - static constexpr uint8_t MaxFiniteCode = + static constexpr uint8_t MaxFiniteFracField = + (Ebits == 4 && Mbits == 3) ? static_cast(MaxFrac - 1u) : MaxFrac; + static constexpr uint8_t MaxFiniteCode = static_cast((MaxFiniteExpField << Mbits) | MaxFiniteFracField); - static constexpr uint8_t NaNCode = + static constexpr uint8_t NaNCode = static_cast((ExpAllOnes << Mbits) | MaxFrac); - static constexpr uint8_t InfinityCode = static_cast(ExpAllOnes << Mbits); - static constexpr int MaxFiniteExp = static_cast(MaxFiniteExpField) - Bias; - static constexpr uint64_t MinNormalMantissa = uint64_t{1} << Mbits; - static constexpr uint64_t OverflowMantissa = uint64_t{1} << (Mbits + 1); - static constexpr uint64_t MaxFiniteMantissa = + static constexpr uint8_t InfinityCode = + static_cast(ExpAllOnes << Mbits); + static constexpr int MaxFiniteExp = + static_cast(MaxFiniteExpField) - Bias; + static constexpr uint64_t MinNormalMantissa = uint64_t{1} << Mbits; + static constexpr uint64_t OverflowMantissa = uint64_t{1} << (Mbits + 1); + static constexpr uint64_t MaxFiniteMantissa = MinNormalMantissa + MaxFiniteFracField; - }; +}; template > static inline uint8_t ConvertIntToE8M0_CPU(T f, rounding R, @@ -393,7 +323,7 @@ static inline uint8_t ConvertIntToE8M0_CPU(T f, rounding R, template > static inline uint8_t ConvertIntToFP8_CPU(T f, rounding R, - saturation S) noexcept { + saturation S) noexcept { using UnsignedT = typename Traits::UnsignedT; using Format = FP8FiniteFormatTraits; @@ -469,7 +399,7 @@ static inline uint8_t ConvertIntToFP8_CPU(T f, rounding R, template > static inline uint8_t ConvertFloatToFP8_CPU(T f, rounding R, - saturation S) noexcept { + saturation S) noexcept { using UInt = typename Traits::UInt; using Format = FP8FiniteFormatTraits; @@ -707,7 +637,7 @@ struct HasE8M0IntegralTraits< template > static inline ToT ConvertFromFP8ToBinaryFloat_CPU(uint8_t code, - rounding R) noexcept { + rounding R) noexcept { static_assert((Ebits == 8 && Mbits == 0) || (Ebits == 4 && Mbits == 3) || (Ebits == 5 && Mbits == 2), "Unsupported FP8 decode combination"); @@ -780,8 +710,9 @@ static inline ToT ConvertFromFP8ToBinaryFloat_CPU(uint8_t code, if (isNaN) { bits = ExpAllOnes | QuietNaNBit; } else if (isInf) { - bits = (negative ? (UInt{1} << (Traits::ExpBits + Traits::FracBits)) : 0u) | - ExpAllOnes; + bits = + (negative ? (UInt{1} << (Traits::ExpBits + Traits::FracBits)) : 0u) | + ExpAllOnes; } else if (significand == 0u) { bits = negative ? (UInt{1} << (Traits::ExpBits + Traits::FracBits)) : 0u; } else { @@ -874,161 +805,6 @@ static inline ToT ConvertFromFP8ToBinaryFloat_CPU(uint8_t code, return ToT{}; } -/// \brief Converts a given value to fp8 floating point with a rounding -/// mode to_even by default and saturation finite for host code. -/// \param h The input value to be converted. -/// \param R The rounding mode to be used during conversion. -/// \return uint8_t The converted 8-bit floating point value, MSB is sign bit, -/// Ebits bits exponent, Mbits bits mantissa. -template -static inline uint8_t -ConvertToFP8_CPU(T h, rounding R = rounding::to_even) noexcept { - // Specialized implementation for fp8_e8m0_x (Ebits=8, Mbits=0) - if constexpr (Ebits == 8 && Mbits == 0) { - // Format characteristics (finite-only, no zero, no infinity): - // - Bias: 127 - // - Exponent field range used for normals: 0 .. 254 (E = ecode - 127 -> - // [-127, +127]) - // - Encoding with exp==255 (0xFF) reserved for NaN (single payload 0xFF) - // - Value encoded when exponent field == 0: +/- 2^{-127} - // - Max normal: +/- 2^{127} (~1.7014118e+38) - // - // Rounding mode: the public API restricts this format to rounding::upward. - // Here we honor upward if passed; any other mode falls back to upward - // behavior. - // - // Note: The format cannot represent zero; inputs with |x| < 2^{-127} map - // to the smallest magnitude normal with the input sign preserved - // (consistent with prior sign-preserving underflow behavior). - // - constexpr uint32_t Bias = 127; - constexpr int Emin = -127; - constexpr int Emax = 127; - constexpr uint8_t NaNCode = 0xFF; // 11111111 - constexpr uint8_t MaxExpField = 254; // 255 reserved for NaN - const float min_normal = std::ldexp(1.0f, Emin); // 2^{-127} - const float max_normal = std::ldexp(1.0f, Emax); // 2^{127} - - float x = static_cast(h); - - if (std::isnan(x)) - return NaNCode; - - uint8_t sign = std::signbit(x) ? 0x80 : 0x00; - float ax = std::fabs(x); - - // Handle underflow (|x| < min_normal) and x == 0: encode smallest normal - // with sign. - if (ax == 0.0f || ax < min_normal) - return sign; // exp field = 0 -> E = -127 - - // Handle overflow (|x| >= max_normal * (anything beyond representable)): - if (ax >= max_normal) - return static_cast(sign | (MaxExpField)); // E = +127 - - // Determine exponent E such that 2^E <= ax < 2^{E+1} - int e2 = 0; - std::frexp(ax, &e2); - int E = e2 - 1; - - // Upward rounding semantics: - // - For positive numbers: if not exact power-of-two, round up to next - // power (E+1) if within range. - // - For negative numbers: rounding toward +inf moves value toward zero, so - // keep current E. - - if (R == rounding::upward) { - if (sign == 0x00) { - // Round up (increase exponent) if possible. - if (E < Emax) - ++E; - else - E = Emax; - } - } - - // Clamp exponent just in case. - if (E < Emin) - E = Emin; - if (E > Emax) - E = Emax; - - uint8_t ecode = static_cast(E + Bias); // 0 .. 254 - // ecode must never be 255 here. - return static_cast(sign | ecode); - } - - constexpr int bias = (1 << (Ebits - 1)) - 1; - // allow the top exponent field (ExpAllOnes) as a normal exponent except when - // frac==MaxFrac (NaN) - int emax = 0; - int emin = 0; - if constexpr (Ebits == 8) - emax = 127; - else { - emax = (1 << Ebits) - 1 - bias; // ExpAllOnes - bias - emin = 1 - bias; - } - constexpr uint8_t ExpAllOnes = static_cast((1 << Ebits) - 1); - constexpr uint8_t MaxFrac = static_cast((1 << Mbits) - 1); - constexpr uint8_t MaxFracForMaxNormal = - (Ebits == 4 && Mbits == 3) || (Ebits == 5 && Mbits == 3) - ? static_cast(MaxFrac - 1u) - : MaxFrac; - constexpr uint8_t MaxExpForMaxNormal = - (Ebits == 5 && Mbits == 2) ? static_cast(ExpAllOnes - 1u) - : ExpAllOnes; - constexpr uint8_t MaxFracMask = MaxFrac; - - float x = static_cast(h); - uint8_t sign = std::signbit(x) ? 0x80 : 0x00; - if (std::isnan(x)) - return static_cast( - sign | ((ExpAllOnes << Mbits) | MaxFracMask)); // S.1111.111 -> NaN - uint8_t sign_bit = sign ? 1u : 0u; - float ax = std::fabs(x); - const float max_finite = - (2.0f - std::ldexp(1.0f, 1 - Mbits)) * std::ldexp(1.0f, emax); - const float min_sub = std::ldexp(1.0f, emin - Mbits); - - if (ax > max_finite) { - return static_cast( - sign | ((MaxExpForMaxNormal << Mbits) | MaxFracForMaxNormal)); - } - - if (ax < min_sub) - return sign; // underflow - - int e2 = 0; - float m = std::frexp(ax, &e2); - int E = e2 - 1; - - if (E < emin) { - float scaled = std::ldexp(ax, -emin) * static_cast(1 << Mbits); - uint32_t k = RoundClip(scaled, MaxFrac, R, sign_bit); - if (k == 0) - return sign; - return static_cast(sign | static_cast(k)); - } - - float y = m * 2.0f; - float frac_scaled = (y - 1.0f) * static_cast(1 << Mbits); - uint32_t frac = RoundClip(frac_scaled, MaxFrac, R, sign_bit); - if (frac == (1u << Mbits)) { - frac = 0; - ++E; - } - if (E > emax) { - auto ret = static_cast( - sign | ((MaxExpForMaxNormal << Mbits) | MaxFracForMaxNormal)); - return ret; - } - uint8_t ecode = static_cast(E + bias); - auto ret = static_cast(sign | (ecode << Mbits) | - static_cast(frac)); - return ret; -} - template static inline ToT ConvertFromE8M0_CPU(uint8_t code, rounding R) noexcept { using Traits = SourceTraits; @@ -1046,11 +822,6 @@ static inline ToT ConvertFromE8M0_CPU(uint8_t code, rounding R) noexcept { template class fp8_e4m3_x { static constexpr size_t NExpBits = 4; static constexpr size_t NFracBits = 3; - static constexpr float MaxNormal = 448.0f; - static constexpr float MinSubnormal = 0.00000762939453125f; // 2^-17 - static constexpr uint8_t NaNCode = 0xFF; - static constexpr uint8_t MaxFiniteCode = - 0x7E; // 0.1111.110 (positive max normal) static_assert(N == 1 || N == 2, "fp8_e4m3_x: Template argument N must be 1 or 2"); @@ -1062,11 +833,11 @@ template class fp8_e4m3_x { if constexpr (std::is_same_v, sycl::half> || std::is_same_v, float> || std::is_same_v, double>) { - return detail::ConvertFloatToFP8_CPU<4, 3, T>(h, rounding::to_even, - saturation::finite); + return detail::ConvertFloatToFP8_CPU( + h, rounding::to_even, saturation::finite); } else if constexpr (std::is_integral_v>) { - return detail::ConvertIntToFP8_CPU<4, 3, T>(h, rounding::to_even, - saturation::finite); + return detail::ConvertIntToFP8_CPU( + h, rounding::to_even, saturation::finite); } #endif } @@ -1075,8 +846,8 @@ template class fp8_e4m3_x { #ifdef __SYCL_DEVICE_ONLY__ return __builtin_spirv_ClampConvertBF16ToE4M3INTEL(h); #else - return detail::ConvertFloatToFP8_CPU<4, 3, bfloat16>( - h, rounding::to_even, saturation::finite); + return detail::ConvertFloatToFP8_CPU( + h, rounding::to_even, saturation::finite); #endif } @@ -1086,7 +857,8 @@ template class fp8_e4m3_x { sycl::half hi = __builtin_spirv_ConvertE4M3ToFP16EXT(v); return static_cast(hi); #else - return detail::ConvertFromFP8ToBinaryFloat_CPU<4, 3, T>(v, r); + return detail::ConvertFromFP8ToBinaryFloat_CPU(v, + r); #endif } @@ -1094,8 +866,9 @@ template class fp8_e4m3_x { #ifdef __SYCL_DEVICE_ONLY__ return __builtin_spirv_ConvertE4M3ToBF16EXT(v); #else - return detail::ConvertFromFP8ToBinaryFloat_CPU<4, 3, bfloat16>( - v, rounding::to_even); + return detail::ConvertFromFP8ToBinaryFloat_CPU(v, + rounding::to_even); #endif } @@ -1128,7 +901,8 @@ template class fp8_e4m3_x { vals[i] = ConvertBF16ToFP8(in[i]); return; } - const sycl::half in[N] = {v...}; + using InT = std::common_type_t...>; + const InT in[N] = {v...}; for (size_t i = 0; i < N; ++i) vals[i] = ConvertToFP8(in[i]); } @@ -1426,9 +1200,6 @@ template class fp8_e4m3_x { template class fp8_e5m2_x { static constexpr size_t NExpBits = 5; static constexpr size_t NFracBits = 2; - static constexpr float MaxNormal = 57344.0f; // 1.75 * 2^15 - static constexpr float MinSubnormal = 0.0000152587890625f; // 2^-16 - static constexpr uint8_t MaxFiniteCode = 0x7B; // 0.11110.11 static_assert(N == 1 || N == 2, "fp8_e5m2_x: Template argument N must be 1 or 2"); @@ -1443,9 +1214,11 @@ template class fp8_e5m2_x { if constexpr (std::is_same_v, sycl::half> || std::is_same_v, float> || std::is_same_v, double>) { - return detail::ConvertFloatToFP8_CPU<5, 2, T>(h, rounding::to_even, s); + return detail::ConvertFloatToFP8_CPU( + h, rounding::to_even, s); } else if constexpr (std::is_integral_v>) { - return detail::ConvertIntToFP8_CPU<5, 2, T>(h, rounding::to_even, s); + return detail::ConvertIntToFP8_CPU( + h, rounding::to_even, s); } #endif } @@ -1456,9 +1229,8 @@ template class fp8_e5m2_x { ? __builtin_spirv_ClampConvertBF16ToE5M2INTEL(h) : __builtin_spirv_ConvertBF16ToE5M2EXT(h); #else - return detail::ConvertFloatToFP8_CPU<5, 2, bfloat16>(h, - rounding::to_even, - s); + return detail::ConvertFloatToFP8_CPU( + h, rounding::to_even, s); #endif } @@ -1468,7 +1240,8 @@ template class fp8_e5m2_x { sycl::half hi = __builtin_spirv_ConvertE5M2ToFP16EXT(v); return static_cast(hi); #else - return detail::ConvertFromFP8ToBinaryFloat_CPU<5, 2, T>(v, r); + return detail::ConvertFromFP8ToBinaryFloat_CPU(v, + r); #endif } @@ -1476,8 +1249,9 @@ template class fp8_e5m2_x { #ifdef __SYCL_DEVICE_ONLY__ return __builtin_spirv_ConvertE5M2ToBF16EXT(v); #else - return detail::ConvertFromFP8ToBinaryFloat_CPU<5, 2, bfloat16>( - v, rounding::to_even); + return detail::ConvertFromFP8ToBinaryFloat_CPU(v, + rounding::to_even); #endif } @@ -1511,7 +1285,8 @@ template class fp8_e5m2_x { vals[i] = ConvertBF16ToFP8(in[i], saturation::finite); return; } - const sycl::half in[N] = {v...}; + using InT = std::common_type_t...>; + const InT in[N] = {v...}; for (size_t i = 0; i < N; ++i) vals[i] = ConvertToFP8(in[i], saturation::finite); } diff --git a/sycl/unittests/Extensions/fp8/fp8_e8m0.cpp b/sycl/unittests/Extensions/fp8/fp8_e8m0.cpp index 81a1d8d6c0db3..72876d674474e 100644 --- a/sycl/unittests/Extensions/fp8/fp8_e8m0.cpp +++ b/sycl/unittests/Extensions/fp8/fp8_e8m0.cpp @@ -108,13 +108,6 @@ TEST(FP8E8M0Test, CArrayFloatRoundingModes) { rounding::upward, 0xFF)); } -TEST(FP8E8M0Test, RoundClipZeroFractionNegativeAndTieCases) { - EXPECT_EQ(detail::RoundClip(0.25f, 0, rounding::upward, 0u), 1u); - EXPECT_EQ(detail::RoundClip(0.25f, 0, rounding::upward, 1u), 0u); - EXPECT_EQ(detail::RoundClip(0.5f, 0, rounding::to_even, 0u), 0u); - EXPECT_EQ(detail::RoundClip(0.75f, 0, rounding::to_even, 0u), 1u); -} - TEST(FP8E8M0Test, CArrayHalfHostUpwardFinite) { const sycl::half in[2] = {sycl::half(1.0f), sycl::half(1.1f)}; const sycl::half in1[2] = {sycl::half(3.0f), sycl::half(0.0f)}; From b596110fddeff57fe714bdb82e5c06394ba1b79e Mon Sep 17 00:00:00 2001 From: "Klochkov, Denis" Date: Mon, 20 Apr 2026 16:51:04 +0200 Subject: [PATCH 37/89] [SYCL] rework fp8 types to avoid copy-paste --- .../oneapi/experimental/float_8bit/types.hpp | 1030 +++-------------- 1 file changed, 142 insertions(+), 888 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp index bfe55241184ad..5b3c8ad390138 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp @@ -104,7 +104,7 @@ template <> struct DirectBinary16Traits { static constexpr uint16_t QuietNaNBits = 0x7FC0u; }; -template static inline ToT MakeDirectNaN() noexcept { +template static constexpr inline ToT MakeDirectNaN() noexcept { if constexpr (std::is_same_v || std::is_same_v) { return sycl::bit_cast(DirectBinary16Traits::QuietNaNBits); @@ -116,7 +116,7 @@ template static inline ToT MakeDirectNaN() noexcept { } template -static inline ToT MakeDirectInf(bool negative) noexcept { +static constexpr inline ToT MakeDirectInf(bool negative) noexcept { if constexpr (std::is_same_v || std::is_same_v) { using Traits = DirectBinary16Traits; @@ -817,6 +817,16 @@ static inline ToT ConvertFromE8M0_CPU(uint8_t code, rounding R) noexcept { return ToT{}; } +template +struct IsOneOf : std::disjunction...> {}; + +template +struct IsSyclFpType : IsOneOf, sycl::half, + sycl::ext::oneapi::bfloat16, float, double> {}; + +template +inline constexpr bool IsSyclFpTypeV = IsSyclFpType::value; + } // namespace detail template class fp8_e4m3_x { @@ -828,11 +838,12 @@ template class fp8_e4m3_x { template uint8_t ConvertToFP8(T h) { #ifdef __SYCL_DEVICE_ONLY__ - return __builtin_spirv_ClampConvertFP16ToE4M3INTEL(h); + if constexpr (std::is_same_v, bfloat16>) + return __builtin_spirv_ClampConvertBF16ToE4M3INTEL(h); + else + return __builtin_spirv_ClampConvertFP16ToE4M3INTEL(h); #else - if constexpr (std::is_same_v, sycl::half> || - std::is_same_v, float> || - std::is_same_v, double>) { + if constexpr (detail::IsSyclFpTypeV) { return detail::ConvertFloatToFP8_CPU( h, rounding::to_even, saturation::finite); } else if constexpr (std::is_integral_v>) { @@ -842,36 +853,21 @@ template class fp8_e4m3_x { #endif } - uint8_t ConvertBF16ToFP8(bfloat16 h) { -#ifdef __SYCL_DEVICE_ONLY__ - return __builtin_spirv_ClampConvertBF16ToE4M3INTEL(h); -#else - return detail::ConvertFloatToFP8_CPU( - h, rounding::to_even, saturation::finite); -#endif - } - template T ConvertFromFP8(uint8_t v, rounding r = rounding::to_even) const { #ifdef __SYCL_DEVICE_ONLY__ - sycl::half hi = __builtin_spirv_ConvertE4M3ToFP16EXT(v); - return static_cast(hi); + if constexpr (std::is_same_v, bfloat16>) + return __builtin_spirv_ConvertE4M3ToBF16EXT(v); + else { + sycl::half hi = __builtin_spirv_ConvertE4M3ToFP16EXT(v); + return static_cast(hi); + } #else return detail::ConvertFromFP8ToBinaryFloat_CPU(v, r); #endif } - bfloat16 ConvertBF16FromFP8(uint8_t v) const { -#ifdef __SYCL_DEVICE_ONLY__ - return __builtin_spirv_ConvertE4M3ToBF16EXT(v); -#else - return detail::ConvertFromFP8ToBinaryFloat_CPU(v, - rounding::to_even); -#endif - } - void CheckConstraints(rounding r) const { assert(r == rounding::to_even && "fp8_e4m3_x: only rounding::to_even is supported"); @@ -895,12 +891,6 @@ template class fp8_e4m3_x { ((std::is_same_v, float>) && ...) || ((std::is_same_v, double>) && ...))>> explicit fp8_e4m3_x(Types... v) { - if constexpr (((std::is_same_v, bfloat16>) && ...)) { - const bfloat16 in[N] = {static_cast(v)...}; - for (size_t i = 0; i < N; ++i) - vals[i] = ConvertBF16ToFP8(in[i]); - return; - } using InT = std::common_type_t...>; const InT in[N] = {v...}; for (size_t i = 0; i < N; ++i) @@ -908,253 +898,52 @@ template class fp8_e4m3_x { } // Construct from an array of half, bfloat16, float, double. - explicit fp8_e4m3_x(sycl::half const (&v)[N], - rounding r = rounding::to_even) { - CheckConstraints(r); - for (size_t i = 0; i < N; ++i) - vals[i] = ConvertToFP8(v[i]); - } - - explicit fp8_e4m3_x(bfloat16 const (&v)[N], rounding r = rounding::to_even) { - CheckConstraints(r); - for (size_t i = 0; i < N; ++i) - vals[i] = ConvertBF16ToFP8(v[i]); - } - - explicit fp8_e4m3_x(float const (&v)[N], rounding r = rounding::to_even) { + template >> + explicit fp8_e4m3_x(T const (&v)[N], rounding r = rounding::to_even) { CheckConstraints(r); for (size_t i = 0; i < N; ++i) vals[i] = ConvertToFP8(v[i]); } - explicit fp8_e4m3_x(double const (&v)[N]) { - for (size_t i = 0; i < N; ++i) - vals[i] = ConvertToFP8(v[i]); - } - // Construct from an marray of half, bfloat16, float, double. - explicit fp8_e4m3_x(const sycl::marray &v, - rounding r = rounding::to_even) { - CheckConstraints(r); - for (size_t i = 0; i < N; ++i) - vals[i] = ConvertToFP8(v[i]); - } - - explicit fp8_e4m3_x(const sycl::marray &v, - rounding r = rounding::to_even) { - CheckConstraints(r); - for (size_t i = 0; i < N; ++i) - vals[i] = ConvertBF16ToFP8(v[i]); - } - - explicit fp8_e4m3_x(const sycl::marray &v, + template >> + explicit fp8_e4m3_x(const sycl::marray &v, rounding r = rounding::to_even) { CheckConstraints(r); for (size_t i = 0; i < N; ++i) vals[i] = ConvertToFP8(v[i]); } - explicit fp8_e4m3_x(const sycl::marray &v) { - for (size_t i = 0; i < N; ++i) - vals[i] = ConvertToFP8(v[i]); - } - // Construct from integer types. // Available only when N==1. - template > - explicit fp8_e4m3_x(short val) { - vals[0] = ConvertToFP8(val); - } - - template > - explicit fp8_e4m3_x(int val) { - vals[0] = ConvertToFP8(val); - } - - template > - explicit fp8_e4m3_x(long val) { - vals[0] = ConvertToFP8(val); - } - - template > - explicit fp8_e4m3_x(long long val) { - vals[0] = ConvertToFP8(val); - } - - template > - explicit fp8_e4m3_x(unsigned short val) { - vals[0] = ConvertToFP8(val); - } - - template > - explicit fp8_e4m3_x(unsigned int val) { - vals[0] = ConvertToFP8(val); - } - - template > - explicit fp8_e4m3_x(unsigned long val) { - vals[0] = ConvertToFP8(val); - } - - template > - explicit fp8_e4m3_x(unsigned long long val) { + template >> + explicit fp8_e4m3_x(T val) { vals[0] = ConvertToFP8(val); } - // Assign (operator) from half, bfloat16, float, double, and integer types. // Available only when N==1. - template > - fp8_e4m3_x &operator=(sycl::half val) { - vals[0] = ConvertToFP8(val); - return *this; - } - - template > - fp8_e4m3_x &operator=(bfloat16 val) { - vals[0] = ConvertBF16ToFP8(val); - return *this; - } - - template > - fp8_e4m3_x &operator=(float val) { - vals[0] = ConvertToFP8(val); - return *this; - } - - template > - fp8_e4m3_x &operator=(double val) { - vals[0] = ConvertToFP8(val); - return *this; - } - - template > - fp8_e4m3_x &operator=(short val) { - vals[0] = ConvertToFP8(val); - return *this; - } - - template > - fp8_e4m3_x &operator=(int val) { - vals[0] = ConvertToFP8(val); - return *this; - } - - template > - fp8_e4m3_x &operator=(long val) { - vals[0] = ConvertToFP8(val); - return *this; - } - - template > - fp8_e4m3_x &operator=(long long val) { - vals[0] = ConvertToFP8(val); - return *this; - } - - template > - fp8_e4m3_x &operator=(unsigned short val) { - vals[0] = ConvertToFP8(val); - return *this; - } - - template > - fp8_e4m3_x &operator=(unsigned int val) { - vals[0] = ConvertToFP8(val); - return *this; - } - - template > - fp8_e4m3_x &operator=(unsigned long val) { - vals[0] = ConvertToFP8(val); - return *this; - } - - template > - fp8_e4m3_x &operator=(unsigned long long val) { + template || + std::is_integral_v)>> + fp8_e4m3_x &operator=(T val) { vals[0] = ConvertToFP8(val); return *this; } - - // Convert to half, bfloat16, float, double. - // Available only when N==1. - - template > - explicit operator half() const { - return ConvertFromFP8(vals[0]); - } - - template > - explicit operator bfloat16() const { - return ConvertBF16FromFP8(vals[0]); - } - template > - explicit operator float() const { - return ConvertFromFP8(vals[0]); - } - template > - explicit operator double() const { - return ConvertFromFP8(vals[0]); - } - - // Convert to integer types. + // Convert to half, bfloat16, float, double and integer types // Available only when N==1. - template > - explicit operator char() const { - return ConvertFromFP8(vals[0], rounding::toward_zero); - } - - template > - explicit operator signed char() const { - return ConvertFromFP8(vals[0], rounding::toward_zero); - } - - template > - explicit operator short() const { - return ConvertFromFP8(vals[0], rounding::toward_zero); - } - - template > - explicit operator int() const { - return ConvertFromFP8(vals[0], rounding::toward_zero); - } - - template > - explicit operator long() const { - return ConvertFromFP8(vals[0], rounding::toward_zero); - } - - template > - explicit operator long long() const { - return ConvertFromFP8(vals[0], rounding::toward_zero); - } - template > - explicit operator unsigned char() const { - return ConvertFromFP8(vals[0], rounding::toward_zero); - } - - template > - explicit operator unsigned short() const { - return ConvertFromFP8(vals[0], rounding::toward_zero); - } - - template > - explicit operator unsigned int() const { - return ConvertFromFP8(vals[0], rounding::toward_zero); - } - - template > - explicit operator unsigned long() const { - return ConvertFromFP8(vals[0], rounding::toward_zero); - } - - template > - explicit operator unsigned long long() const { - return ConvertFromFP8(vals[0], rounding::toward_zero); + template || + std::is_integral_v)>> + explicit operator T() const { + if constexpr (std::is_integral_v) + return ConvertFromFP8(vals[0], rounding::toward_zero); + else + return ConvertFromFP8(vals[0]); } - // Convert to bool // Available only when N==1. @@ -1171,28 +960,15 @@ template class fp8_e4m3_x { } // Convert to marray of half, bfloat16, float - - explicit operator sycl::marray() const { - sycl::marray ret; - for (size_t i = 0; i < N; ++i) - ret[i] = ConvertFromFP8(vals[i]); - return ret; - } - - explicit operator sycl::marray() const { - sycl::marray ret; - for (size_t i = 0; i < N; ++i) - ret[i] = ConvertBF16FromFP8(vals[i]); - return ret; - } - - explicit operator sycl::marray() const { - sycl::marray ret; + template , sycl::half, + sycl::ext::oneapi::bfloat16, float>::value>> + explicit operator sycl::marray() const { + sycl::marray ret; for (size_t i = 0; i < N; ++i) - ret[i] = ConvertFromFP8(vals[i]); + ret[i] = ConvertFromFP8(vals[i]); return ret; } - // Intentionally public to allow access to the raw values. uint8_t vals[N]; }; @@ -1206,14 +982,16 @@ template class fp8_e5m2_x { template uint8_t ConvertToFP8(T h, saturation s) { #ifdef __SYCL_DEVICE_ONLY__ + if constexpr (std::is_same_v, bfloat16>) + return s == saturation::finite + ? __builtin_spirv_ClampConvertBF16ToE5M2INTEL(h) + : __builtin_spirv_ConvertBF16ToE5M2EXT(h); const sycl::half halfValue = static_cast(h); return s == saturation::finite ? __builtin_spirv_ClampConvertFP16ToE5M2INTEL(halfValue) : __builtin_spirv_ConvertFP16ToE5M2EXT(halfValue); #else - if constexpr (std::is_same_v, sycl::half> || - std::is_same_v, float> || - std::is_same_v, double>) { + if constexpr (detail::IsSyclFpTypeV) { return detail::ConvertFloatToFP8_CPU( h, rounding::to_even, s); } else if constexpr (std::is_integral_v>) { @@ -1223,20 +1001,35 @@ template class fp8_e5m2_x { #endif } - uint8_t ConvertBF16ToFP8(bfloat16 h, saturation s) { + template + void StochasticConvertToFP8(T h, uint32_t current_seed, uint32_t *pseed, + saturation s, uint8_t i) { #ifdef __SYCL_DEVICE_ONLY__ - return s == saturation::finite - ? __builtin_spirv_ClampConvertBF16ToE5M2INTEL(h) - : __builtin_spirv_ConvertBF16ToE5M2EXT(h); -#else - return detail::ConvertFloatToFP8_CPU( - h, rounding::to_even, s); + if constexpr (std::is_same_v) { + if (s == saturation::finite) { + vals[i] = __builtin_spirv_ClampStochasticRoundBF16ToE5M2INTEL( + h, current_seed, pseed); + } else { + vals[i] = __builtin_spirv_StochasticRoundBF16ToE5M2INTEL( + h, current_seed, pseed); + } + } else { + if (s == saturation::finite) { + vals[i] = __builtin_spirv_ClampStochasticRoundFP16ToE5M2INTEL( + h, current_seed, pseed); + } else { + vals[i] = __builtin_spirv_StochasticRoundFP16ToE5M2INTEL( + h, current_seed, pseed); + } + } #endif } template T ConvertFromFP8(uint8_t v, rounding r = rounding::to_even) const { #ifdef __SYCL_DEVICE_ONLY__ + if constexpr (std::is_same_v, bfloat16>) + return __builtin_spirv_ConvertE5M2ToBF16EXT(v); sycl::half hi = __builtin_spirv_ConvertE5M2ToFP16EXT(v); return static_cast(hi); #else @@ -1245,16 +1038,6 @@ template class fp8_e5m2_x { #endif } - bfloat16 ConvertBF16FromFP8(uint8_t v) const { -#ifdef __SYCL_DEVICE_ONLY__ - return __builtin_spirv_ConvertE5M2ToBF16EXT(v); -#else - return detail::ConvertFromFP8ToBinaryFloat_CPU(v, - rounding::to_even); -#endif - } - void CheckConstraints(rounding r) const { assert(r == rounding::to_even && "fp8_e5m2_x: only rounding::to_even is supported"); @@ -1279,67 +1062,23 @@ template class fp8_e5m2_x { ((std::is_same_v, float>) && ...) || ((std::is_same_v, double>) && ...))>> explicit fp8_e5m2_x(Types... v) { - if constexpr (((std::is_same_v, bfloat16>) && ...)) { - const bfloat16 in[N] = {static_cast(v)...}; - for (size_t i = 0; i < N; ++i) - vals[i] = ConvertBF16ToFP8(in[i], saturation::finite); - return; - } using InT = std::common_type_t...>; const InT in[N] = {v...}; for (size_t i = 0; i < N; ++i) vals[i] = ConvertToFP8(in[i], saturation::finite); } - // Construct from an array of half, bfloat16, float, double. - - explicit fp8_e5m2_x(half const (&v)[N], rounding r = rounding::to_even, - saturation s = saturation::finite) { - CheckConstraints(r); - // TODO: optimize with vectorized builtin calls - for (size_t i = 0; i < N; ++i) - vals[i] = ConvertToFP8(v[i], s); - } - - explicit fp8_e5m2_x(bfloat16 const (&v)[N], rounding r = rounding::to_even, + template >> + explicit fp8_e5m2_x(T const (&v)[N], rounding r = rounding::to_even, saturation s = saturation::finite) { CheckConstraints(r); // TODO: optimize with vectorized builtin calls - for (size_t i = 0; i < N; ++i) - vals[i] = ConvertBF16ToFP8(v[i], s); - } - - explicit fp8_e5m2_x(float const (&v)[N], rounding r = rounding::to_even, - saturation s = saturation::finite) { - CheckConstraints(r); - for (size_t i = 0; i < N; ++i) - vals[i] = ConvertToFP8(v[i], s); - } - - explicit fp8_e5m2_x(double const (&v)[N]) { - for (size_t i = 0; i < N; ++i) - vals[i] = ConvertToFP8(v[i], saturation::finite); - } - - // Construct from an marray of half, bfloat16, float, double. - - explicit fp8_e5m2_x(const sycl::marray &v, - rounding r = rounding::to_even, - saturation s = saturation::finite) { - CheckConstraints(r); for (size_t i = 0; i < N; ++i) vals[i] = ConvertToFP8(v[i], s); } - explicit fp8_e5m2_x(const sycl::marray &v, - rounding r = rounding::to_even, - saturation s = saturation::finite) { - CheckConstraints(r); - for (size_t i = 0; i < N; ++i) - vals[i] = ConvertBF16ToFP8(v[i], s); - } - - explicit fp8_e5m2_x(const sycl::marray &v, + template >> + explicit fp8_e5m2_x(const sycl::marray &v, rounding r = rounding::to_even, saturation s = saturation::finite) { CheckConstraints(r); @@ -1347,324 +1086,59 @@ template class fp8_e5m2_x { vals[i] = ConvertToFP8(v[i], s); } - explicit fp8_e5m2_x(const sycl::marray &v) { - for (size_t i = 0; i < N; ++i) - vals[i] = ConvertToFP8(v[i], saturation::finite); - } - - // Construct with stochastic rounding with user provided seed from an array of - // half, bfloat16, float. - - explicit fp8_e5m2_x([[maybe_unused]] half const (&in)[N], + template , sycl::half, bfloat16, float>::value>> + explicit fp8_e5m2_x([[maybe_unused]] T const (&in)[N], [[maybe_unused]] const stochastic_seed &seed, [[maybe_unused]] saturation s = saturation::finite) { #ifdef __SYCL_DEVICE_ONLY__ uint32_t current_seed = *seed.pseed; for (size_t i = 0; i < N; ++i) { - if (s == saturation::finite) { - vals[i] = __builtin_spirv_ClampStochasticRoundFP16ToE5M2INTEL( - in[i], current_seed, seed.pseed); - } else { - vals[i] = __builtin_spirv_StochasticRoundFP16ToE5M2INTEL( - in[i], current_seed, seed.pseed); - } + StochasticConvertToFP8(in[i], current_seed, seed.pseed, s, i); current_seed = *seed.pseed; } #endif } - explicit fp8_e5m2_x([[maybe_unused]] bfloat16 const (&in)[N], + template , sycl::half, bfloat16, float>::value>> + explicit fp8_e5m2_x([[maybe_unused]] const sycl::marray &in, [[maybe_unused]] const stochastic_seed &seed, [[maybe_unused]] saturation s = saturation::finite) { + #ifdef __SYCL_DEVICE_ONLY__ uint32_t current_seed = *seed.pseed; for (size_t i = 0; i < N; ++i) { - if (s == saturation::finite) { - vals[i] = __builtin_spirv_ClampStochasticRoundBF16ToE5M2INTEL( - in[i], current_seed, seed.pseed); - } else { - vals[i] = __builtin_spirv_StochasticRoundBF16ToE5M2INTEL( - in[i], current_seed, seed.pseed); - } + StochasticConvertToFP8(in[i], current_seed, seed.pseed, s, i); current_seed = *seed.pseed; } #endif } - explicit fp8_e5m2_x([[maybe_unused]] float const (&in)[N], - [[maybe_unused]] const stochastic_seed &seed, - [[maybe_unused]] saturation s = saturation::finite) { -#ifdef __SYCL_DEVICE_ONLY__ - uint32_t current_seed = *seed.pseed; - for (size_t i = 0; i < N; ++i) { - sycl::half h = static_cast(in[i]); - if (s == saturation::finite) { - vals[i] = __builtin_spirv_ClampStochasticRoundFP16ToE5M2INTEL( - h, current_seed, seed.pseed); - } else { - vals[i] = __builtin_spirv_StochasticRoundFP16ToE5M2INTEL( - h, current_seed, seed.pseed); - } - current_seed = *seed.pseed; - } -#endif - } - - // Construct with stochastic rounding with user provided seed from an marray - // of half, bfloat16, float. - - explicit fp8_e5m2_x([[maybe_unused]] const sycl::marray &in, - [[maybe_unused]] const stochastic_seed &seed, - [[maybe_unused]] saturation s = saturation::finite) { -#ifdef __SYCL_DEVICE_ONLY__ - uint32_t current_seed = *seed.pseed; - for (size_t i = 0; i < N; ++i) { - if (s == saturation::finite) { - vals[i] = __builtin_spirv_ClampStochasticRoundFP16ToE5M2INTEL( - in[i], current_seed, seed.pseed); - } else { - vals[i] = __builtin_spirv_StochasticRoundFP16ToE5M2INTEL( - in[i], current_seed, seed.pseed); - } - current_seed = *seed.pseed; - } -#endif - } - - explicit fp8_e5m2_x([[maybe_unused]] const sycl::marray &in, - [[maybe_unused]] const stochastic_seed &seed, - [[maybe_unused]] saturation s = saturation::finite) { -#ifdef __SYCL_DEVICE_ONLY__ - uint32_t current_seed = *seed.pseed; - for (size_t i = 0; i < N; ++i) { - if (s == saturation::finite) { - vals[i] = __builtin_spirv_ClampStochasticRoundBF16ToE5M2INTEL( - in[i], current_seed, seed.pseed); - } else { - vals[i] = __builtin_spirv_StochasticRoundBF16ToE5M2INTEL( - in[i], current_seed, seed.pseed); - } - current_seed = *seed.pseed; - } -#endif - } - - explicit fp8_e5m2_x([[maybe_unused]] const sycl::marray &in, - [[maybe_unused]] const stochastic_seed &seed, - [[maybe_unused]] saturation s = saturation::finite) { -#ifdef __SYCL_DEVICE_ONLY__ - uint32_t current_seed = *seed.pseed; - for (size_t i = 0; i < N; ++i) { - sycl::half h = static_cast(in[i]); - if (s == saturation::finite) { - vals[i] = __builtin_spirv_ClampStochasticRoundFP16ToE5M2INTEL( - h, current_seed, seed.pseed); - } else { - vals[i] = __builtin_spirv_StochasticRoundFP16ToE5M2INTEL( - h, current_seed, seed.pseed); - } - current_seed = *seed.pseed; - } -#endif - } - - // Construct from integer types. - // Available only when N==1. - - template > - explicit fp8_e5m2_x(short val) { - vals[0] = ConvertToFP8(val, saturation::finite); - } - - template > - explicit fp8_e5m2_x(int val) { - vals[0] = ConvertToFP8(val, saturation::finite); - } - - template > - explicit fp8_e5m2_x(long val) { - vals[0] = ConvertToFP8(val, saturation::finite); - } - - template > - explicit fp8_e5m2_x(long long val) { - vals[0] = ConvertToFP8(val, saturation::finite); - } - - template > - explicit fp8_e5m2_x(unsigned short val) { - vals[0] = ConvertToFP8(val, saturation::finite); - } - - template > - explicit fp8_e5m2_x(unsigned int val) { - vals[0] = ConvertToFP8(val, saturation::finite); - } - - template > - explicit fp8_e5m2_x(unsigned long val) { - vals[0] = ConvertToFP8(val, saturation::finite); - } - - template > - explicit fp8_e5m2_x(unsigned long long val) { - vals[0] = ConvertToFP8(val, saturation::finite); - } - - // Assign (operator) from half, bfloat16, float, double, and integer types. - // Available only when N==1. - - template > - fp8_e5m2_x &operator=(sycl::half val) { - vals[0] = ConvertToFP8(val, saturation::finite); - return *this; - } - - template > - fp8_e5m2_x &operator=(bfloat16 val) { - vals[0] = ConvertBF16ToFP8(val, saturation::finite); - return *this; - } - - template > - fp8_e5m2_x &operator=(float val) { - vals[0] = ConvertToFP8(val, saturation::finite); - return *this; - } - - template > - fp8_e5m2_x &operator=(double val) { - vals[0] = ConvertToFP8(val, saturation::finite); - return *this; - } - - template > - fp8_e5m2_x &operator=(short val) { - vals[0] = ConvertToFP8(val, saturation::finite); - return *this; - } - - template > - fp8_e5m2_x &operator=(int val) { - vals[0] = ConvertToFP8(val, saturation::finite); - return *this; - } - - template > - fp8_e5m2_x &operator=(long val) { - vals[0] = ConvertToFP8(val, saturation::finite); - return *this; - } - - template > - fp8_e5m2_x &operator=(long long val) { - vals[0] = ConvertToFP8(val, saturation::finite); - return *this; - } - - template > - fp8_e5m2_x &operator=(unsigned short val) { + template >> + explicit fp8_e5m2_x(T val) { vals[0] = ConvertToFP8(val, saturation::finite); - return *this; - } - - template > - fp8_e5m2_x &operator=(unsigned int val) { - vals[0] = ConvertToFP8(val, saturation::finite); - return *this; } - template > - fp8_e5m2_x &operator=(unsigned long val) { - vals[0] = ConvertToFP8(val, saturation::finite); - return *this; - } - - template > - fp8_e5m2_x &operator=(unsigned long long val) { + template || + std::is_integral_v)>> + fp8_e5m2_x &operator=(T val) { vals[0] = ConvertToFP8(val, saturation::finite); return *this; } - // Convert to half, bfloat16, float, double. - // Available only when N==1. - - template > - explicit operator half() const { - return ConvertFromFP8(vals[0]); - } - - template > - explicit operator bfloat16() const { - return ConvertBF16FromFP8(vals[0]); - } - - template > - explicit operator float() const { - return ConvertFromFP8(vals[0]); - } - - template > - explicit operator double() const { - return ConvertFromFP8(vals[0]); - } - - // Convert to integer types. - // Available only when N==1. - - template > - explicit operator char() const { - return ConvertFromFP8(vals[0], rounding::toward_zero); - } - - template > - explicit operator signed char() const { - return ConvertFromFP8(vals[0], rounding::toward_zero); - } - - template > - explicit operator short() const { - return ConvertFromFP8(vals[0], rounding::toward_zero); - } - - template > - explicit operator int() const { - return ConvertFromFP8(vals[0], rounding::toward_zero); - } - - template > - explicit operator long() const { - return ConvertFromFP8(vals[0], rounding::toward_zero); - } - - template > - explicit operator long long() const { - return ConvertFromFP8(vals[0], rounding::toward_zero); - } - - template > - explicit operator unsigned char() const { - return ConvertFromFP8(vals[0], rounding::toward_zero); - } - - template > - explicit operator unsigned short() const { - return ConvertFromFP8(vals[0], rounding::toward_zero); - } - - template > - explicit operator unsigned int() const { - return ConvertFromFP8(vals[0], rounding::toward_zero); - } - - template > - explicit operator unsigned long() const { - return ConvertFromFP8(vals[0], rounding::toward_zero); - } - - template > - explicit operator unsigned long long() const { - return ConvertFromFP8(vals[0], rounding::toward_zero); + template || + std::is_integral_v)>> + explicit operator T() const { + if constexpr (std::is_integral_v) + return ConvertFromFP8(vals[0], rounding::toward_zero); + else + return ConvertFromFP8(vals[0]); } // Convert to bool @@ -1676,22 +1150,13 @@ template class fp8_e5m2_x { return vals[0] != 0x00 && vals[0] != 0x80; } - explicit operator sycl::marray() const { - sycl::marray out; - for (size_t i = 0; i < N; ++i) - out[i] = ConvertFromFP8(vals[i]); - return out; - } - explicit operator sycl::marray() const { - sycl::marray out; + template , sycl::half, + sycl::ext::oneapi::bfloat16, float>::value>> + explicit operator sycl::marray() const { + sycl::marray out; for (size_t i = 0; i < N; ++i) - out[i] = ConvertBF16FromFP8(vals[i]); - return out; - } - explicit operator sycl::marray() const { - sycl::marray out; - for (size_t i = 0; i < N; ++i) - out[i] = ConvertFromFP8(vals[i]); + out[i] = ConvertFromFP8(vals[i]); return out; } @@ -1731,248 +1196,48 @@ template class fp8_e8m0_x { saturation::finite); } - explicit fp8_e8m0_x(half const (&in)[N], rounding r = rounding::upward) { - CheckConstraints(r); - for (size_t i = 0; i < N; ++i) - vals[i] = detail::ConvertFloatToE8M0_CPU(in[i], r, saturation::finite); - } - - explicit fp8_e8m0_x(bfloat16 const (&in)[N], rounding r = rounding::upward) { - CheckConstraints(r); - for (size_t i = 0; i < N; ++i) - vals[i] = detail::ConvertFloatToE8M0_CPU(in[i], r, saturation::finite); - } - - explicit fp8_e8m0_x(float const (&in)[N], rounding r = rounding::upward) { - CheckConstraints(r); - for (size_t i = 0; i < N; ++i) - vals[i] = detail::ConvertFloatToE8M0_CPU(in[i], r, saturation::finite); - } - - explicit fp8_e8m0_x(double const (&in)[N]) { - for (size_t i = 0; i < N; ++i) - vals[i] = detail::ConvertFloatToE8M0_CPU(in[i], rounding::upward, - saturation::finite); - } - - explicit fp8_e8m0_x(const marray &in, - rounding r = rounding::upward) { - CheckConstraints(r); - for (size_t i = 0; i < N; ++i) - vals[i] = detail::ConvertFloatToE8M0_CPU(in[i], r, saturation::finite); - } - - explicit fp8_e8m0_x(const marray &in, - rounding r = rounding::upward) { + template >> + explicit fp8_e8m0_x(T const (&in)[N], rounding r = rounding::upward) { CheckConstraints(r); for (size_t i = 0; i < N; ++i) vals[i] = detail::ConvertFloatToE8M0_CPU(in[i], r, saturation::finite); } - explicit fp8_e8m0_x(const marray &in, - rounding r = rounding::upward) { + template >> + explicit fp8_e8m0_x(const marray &in, rounding r = rounding::upward) { CheckConstraints(r); for (size_t i = 0; i < N; ++i) vals[i] = detail::ConvertFloatToE8M0_CPU(in[i], r, saturation::finite); } - explicit fp8_e8m0_x(const marray &in) { - for (size_t i = 0; i < N; ++i) - vals[i] = detail::ConvertFloatToE8M0_CPU(in[i], rounding::upward, - saturation::finite); - } - - // Construct from integer types. - // Available only when N==1. - - template > - explicit fp8_e8m0_x(short val) { + template >> + explicit fp8_e8m0_x(T val) { vals[0] = detail::ConvertIntToE8M0_CPU(val, rounding::upward, saturation::finite); } - template > - explicit fp8_e8m0_x(int val) { - vals[0] = - detail::ConvertIntToE8M0_CPU(val, rounding::upward, saturation::finite); - } - - template > - explicit fp8_e8m0_x(long val) { - vals[0] = - detail::ConvertIntToE8M0_CPU(val, rounding::upward, saturation::finite); - } - - template > - explicit fp8_e8m0_x(long long val) { - vals[0] = - detail::ConvertIntToE8M0_CPU(val, rounding::upward, saturation::finite); - } - template > - explicit fp8_e8m0_x(unsigned short val) { - vals[0] = - detail::ConvertIntToE8M0_CPU(val, rounding::upward, saturation::finite); - } - template > - explicit fp8_e8m0_x(unsigned int val) { - vals[0] = - detail::ConvertIntToE8M0_CPU(val, rounding::upward, saturation::finite); - } - template > - explicit fp8_e8m0_x(unsigned long val) { - vals[0] = - detail::ConvertIntToE8M0_CPU(val, rounding::upward, saturation::finite); - } - template > - explicit fp8_e8m0_x(unsigned long long val) { - vals[0] = - detail::ConvertIntToE8M0_CPU(val, rounding::upward, saturation::finite); - } - - template > - fp8_e8m0_x &operator=(half val) { - vals[0] = detail::ConvertFloatToE8M0_CPU(val, rounding::upward, - saturation::finite); - return *this; - } - template > - fp8_e8m0_x &operator=(bfloat16 val) { - vals[0] = detail::ConvertFloatToE8M0_CPU(val, rounding::upward, - saturation::finite); - return *this; - } - template > - fp8_e8m0_x &operator=(float val) { - vals[0] = detail::ConvertFloatToE8M0_CPU(val, rounding::upward, - saturation::finite); - return *this; - } - - template > - fp8_e8m0_x &operator=(double val) { - vals[0] = detail::ConvertFloatToE8M0_CPU(val, rounding::upward, + template || + std::is_integral_v)>> + fp8_e8m0_x &operator=(T val) { + if constexpr (std::is_integral_v) + vals[0] = detail::ConvertIntToE8M0_CPU(val, rounding::upward, saturation::finite); + else + vals[0] = detail::ConvertFloatToE8M0_CPU(val, rounding::upward, + saturation::finite); return *this; } - template > - fp8_e8m0_x &operator=(short val) { - vals[0] = - detail::ConvertIntToE8M0_CPU(val, rounding::upward, saturation::finite); - return *this; - } - template > - fp8_e8m0_x &operator=(int val) { - vals[0] = - detail::ConvertIntToE8M0_CPU(val, rounding::upward, saturation::finite); - return *this; - } - template > - fp8_e8m0_x &operator=(long val) { - vals[0] = - detail::ConvertIntToE8M0_CPU(val, rounding::upward, saturation::finite); - return *this; - } - template > - fp8_e8m0_x &operator=(long long val) { - vals[0] = - detail::ConvertIntToE8M0_CPU(val, rounding::upward, saturation::finite); - return *this; - } - template > - fp8_e8m0_x &operator=(unsigned short val) { - vals[0] = - detail::ConvertIntToE8M0_CPU(val, rounding::upward, saturation::finite); - return *this; - } - template > - fp8_e8m0_x &operator=(unsigned int val) { - vals[0] = - detail::ConvertIntToE8M0_CPU(val, rounding::upward, saturation::finite); - return *this; - } - template > - fp8_e8m0_x &operator=(unsigned long val) { - vals[0] = - detail::ConvertIntToE8M0_CPU(val, rounding::upward, saturation::finite); - return *this; - } - template > - fp8_e8m0_x &operator=(unsigned long long val) { - vals[0] = - detail::ConvertIntToE8M0_CPU(val, rounding::upward, saturation::finite); - return *this; - } - - template > - explicit operator half() const { - return detail::ConvertFromE8M0_CPU(vals[0], rounding::to_even); - } - template > - explicit operator bfloat16() const { - return detail::ConvertFromE8M0_CPU(vals[0], rounding::to_even); - } - template > - explicit operator float() const { - return detail::ConvertFromE8M0_CPU(vals[0], rounding::to_even); - } - template > - explicit operator double() const { - return detail::ConvertFromE8M0_CPU(vals[0], rounding::to_even); - } - - template > - explicit operator char() const { - return detail::ConvertFromE8M0_CPU(vals[0], rounding::toward_zero); - } - template > - explicit operator signed char() const { - return detail::ConvertFromE8M0_CPU(vals[0], - rounding::toward_zero); - } - template > - explicit operator short() const { - return detail::ConvertFromE8M0_CPU(vals[0], rounding::toward_zero); - } - template > - explicit operator int() const { - return detail::ConvertFromE8M0_CPU(vals[0], rounding::toward_zero); - } - template > - explicit operator long() const { - return detail::ConvertFromE8M0_CPU(vals[0], rounding::toward_zero); - } - template > - explicit operator long long() const { - return detail::ConvertFromE8M0_CPU(vals[0], - rounding::toward_zero); - } - template > - explicit operator unsigned char() const { - return detail::ConvertFromE8M0_CPU(vals[0], - rounding::toward_zero); - } - template > - explicit operator unsigned short() const { - return detail::ConvertFromE8M0_CPU(vals[0], - rounding::toward_zero); - } - template > - explicit operator unsigned int() const { - return detail::ConvertFromE8M0_CPU(vals[0], - rounding::toward_zero); - } - - template > - explicit operator unsigned long() const { - return detail::ConvertFromE8M0_CPU(vals[0], - rounding::toward_zero); - } - - template > - explicit operator unsigned long long() const { - return detail::ConvertFromE8M0_CPU( - vals[0], rounding::toward_zero); + template || + std::is_integral_v)>> + explicit operator T() const { + if constexpr (std::is_integral_v) + return detail::ConvertFromE8M0_CPU(vals[0], rounding::toward_zero); + else + return detail::ConvertFromE8M0_CPU(vals[0], rounding::to_even); } template > @@ -1980,26 +1245,15 @@ template class fp8_e8m0_x { return true; } - explicit operator sycl::marray() const { - sycl::marray out; + template , sycl::half, + sycl::ext::oneapi::bfloat16, float>::value>> + explicit operator sycl::marray() const { + sycl::marray out; for (size_t i = 0; i < N; ++i) - out[i] = detail::ConvertFromE8M0_CPU(vals[i], rounding::to_even); + out[i] = detail::ConvertFromE8M0_CPU(vals[i], rounding::to_even); return out; } - explicit operator sycl::marray() const { - sycl::marray out; - for (size_t i = 0; i < N; ++i) - out[i] = - detail::ConvertFromE8M0_CPU(vals[i], rounding::to_even); - return out; - } - explicit operator sycl::marray() const { - sycl::marray out; - for (size_t i = 0; i < N; ++i) - out[i] = detail::ConvertFromE8M0_CPU(vals[i], rounding::to_even); - return out; - } - // Intentionally public to allow access to the raw values. uint8_t vals[N]; From edbb6fddcd5bf2373aafa801dee02250a00c665e Mon Sep 17 00:00:00 2001 From: "Klochkov, Denis" Date: Mon, 20 Apr 2026 17:09:26 +0200 Subject: [PATCH 38/89] [SYCL] add separate trait for variadic constructors --- .../oneapi/experimental/float_8bit/types.hpp | 33 +++++++++---------- 1 file changed, 15 insertions(+), 18 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp index 5b3c8ad390138..26fccf387ca15 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp @@ -827,6 +827,18 @@ struct IsSyclFpType : IsOneOf, sycl::half, template inline constexpr bool IsSyclFpTypeV = IsSyclFpType::value; +template +struct SyclfpVariadic + : std::bool_constant< + (sizeof...(Types) == N) && + (((std::is_same_v, half>) && ...) || + ((std::is_same_v, bfloat16>) && ...) || + ((std::is_same_v, float>) && ...) || + ((std::is_same_v, double>) && ...))> {}; + +template +inline constexpr bool SyclfpVariadicV = SyclfpVariadic::value; + } // namespace detail template class fp8_e4m3_x { @@ -884,12 +896,7 @@ template class fp8_e4m3_x { // Available only when the size of the pack is equal to N. template , half>) && ...) || - ((std::is_same_v, bfloat16>) && ...) || - ((std::is_same_v, float>) && ...) || - ((std::is_same_v, double>) && ...))>> + typename = std::enable_if_t>> explicit fp8_e4m3_x(Types... v) { using InT = std::common_type_t...>; const InT in[N] = {v...}; @@ -1055,12 +1062,7 @@ template class fp8_e5m2_x { // Available only when each type in the pack is half. template , half>) && ...) || - ((std::is_same_v, bfloat16>) && ...) || - ((std::is_same_v, float>) && ...) || - ((std::is_same_v, double>) && ...))>> + typename = std::enable_if_t>> explicit fp8_e5m2_x(Types... v) { using InT = std::common_type_t...>; const InT in[N] = {v...}; @@ -1182,12 +1184,7 @@ template class fp8_e8m0_x { fp8_e8m0_x &operator=(const fp8_e8m0_x &) = default; template , half>) && ...) || - ((std::is_same_v, bfloat16>) && ...) || - ((std::is_same_v, float>) && ...) || - ((std::is_same_v, double>) && ...))>> + typename = std::enable_if_t>> explicit fp8_e8m0_x(Types... v) { using InT = std::common_type_t...>; const InT in[N] = {v...}; From cfd3a3a588792b3d95a97cab90b0b93f65858ad7 Mon Sep 17 00:00:00 2001 From: "Klochkov, Denis" Date: Mon, 20 Apr 2026 17:28:06 +0200 Subject: [PATCH 39/89] [SYCL] pass references in test --- .../ext/oneapi/experimental/float_8bit/types.hpp | 12 ++++++------ sycl/unittests/Extensions/fp8/fp8_e4m3.cpp | 13 +++++++++++++ sycl/unittests/Extensions/fp8/fp8_e5m2.cpp | 13 +++++++++++++ sycl/unittests/Extensions/fp8/fp8_e8m0.cpp | 13 +++++++++++++ 4 files changed, 45 insertions(+), 6 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp index 26fccf387ca15..a7275c10fc1db 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp @@ -829,12 +829,12 @@ inline constexpr bool IsSyclFpTypeV = IsSyclFpType::value; template struct SyclfpVariadic - : std::bool_constant< - (sizeof...(Types) == N) && - (((std::is_same_v, half>) && ...) || - ((std::is_same_v, bfloat16>) && ...) || - ((std::is_same_v, float>) && ...) || - ((std::is_same_v, double>) && ...))> {}; + : std::bool_constant< + (sizeof...(Types) == N) && + (((std::is_same_v, half>) && ...) || + ((std::is_same_v, bfloat16>) && ...) || + ((std::is_same_v, float>) && ...) || + ((std::is_same_v, double>) && ...))> {}; template inline constexpr bool SyclfpVariadicV = SyclfpVariadic::value; diff --git a/sycl/unittests/Extensions/fp8/fp8_e4m3.cpp b/sycl/unittests/Extensions/fp8/fp8_e4m3.cpp index d231cf93f7a92..961f5adc79cb1 100644 --- a/sycl/unittests/Extensions/fp8/fp8_e4m3.cpp +++ b/sycl/unittests/Extensions/fp8/fp8_e4m3.cpp @@ -50,6 +50,19 @@ TEST(FP8E4M3Test, VariadicFloat) { EXPECT_EQ(b.vals[0], 0x39); } +TEST(FP8E4M3Test, VariadicFloatReferences) { + float x = 1.0f; + float y = 2.0f; + float &xf = x; + float &yf = y; + + fp8_e4m3_x2 a(xf, yf); + + EXPECT_EQ(sizeof(a.vals), 2u); + EXPECT_EQ(a.vals[0], 0x38); + EXPECT_EQ(a.vals[1], 0x40); +} + TEST(FP8E4M3Test, VariadicBoundaryEncodingsFloat) { // CPU host path: variadic constructors use rounding::to_even and // saturation::finite. diff --git a/sycl/unittests/Extensions/fp8/fp8_e5m2.cpp b/sycl/unittests/Extensions/fp8/fp8_e5m2.cpp index 05d5ec75874b8..3459bbdbf7e2b 100644 --- a/sycl/unittests/Extensions/fp8/fp8_e5m2.cpp +++ b/sycl/unittests/Extensions/fp8/fp8_e5m2.cpp @@ -117,6 +117,19 @@ TEST(FP8E5M2Test, RawInfinityAndNaNDecoding) { EXPECT_TRUE(std::isnan(static_cast(qnan_bf16))); } +TEST(FP8E5M2Test, VariadicFloatReferences) { + float x = 1.0f; + float y = 2.0f; + float &xf = x; + float &yf = y; + + fp8_e4m3_x2 a(xf, yf); + + EXPECT_EQ(sizeof(a.vals), 2u); + EXPECT_EQ(a.vals[0], 0x38); + EXPECT_EQ(a.vals[1], 0x40); +} + TEST(FP8E5M2Test, IntegerConstructorToEvenFiniteAndSize) { fp8_e5m2 a0(0); fp8_e5m2 a1(1); diff --git a/sycl/unittests/Extensions/fp8/fp8_e8m0.cpp b/sycl/unittests/Extensions/fp8/fp8_e8m0.cpp index 72876d674474e..a799c953aa044 100644 --- a/sycl/unittests/Extensions/fp8/fp8_e8m0.cpp +++ b/sycl/unittests/Extensions/fp8/fp8_e8m0.cpp @@ -123,6 +123,19 @@ TEST(FP8E8M0Test, CArrayHalfHostUpwardFinite) { EXPECT_EQ(a1.vals[1], 0x00); } +TEST(FP8E8M0Test, VariadicFloatReferences) { + float x = 1.0f; + float y = 2.0f; + float &xf = x; + float &yf = y; + + fp8_e4m3_x2 a(xf, yf); + + EXPECT_EQ(sizeof(a.vals), 2u); + EXPECT_EQ(a.vals[0], 0x38); + EXPECT_EQ(a.vals[1], 0x40); +} + TEST(FP8E8M0Test, CArrayBFloat16HostUpwardFinite) { const sycl::ext::oneapi::bfloat16 in[2] = {sycl::ext::oneapi::bfloat16(1.0f), sycl::ext::oneapi::bfloat16(2.0f)}; From ad9b49f5a3176f592af03d849352cd31529148f0 Mon Sep 17 00:00:00 2001 From: "Klochkov, Denis" Date: Tue, 21 Apr 2026 09:09:49 +0200 Subject: [PATCH 40/89] Revert "[SYCL] pass references in test" This reverts commit cfd3a3a588792b3d95a97cab90b0b93f65858ad7. --- .../ext/oneapi/experimental/float_8bit/types.hpp | 12 ++++++------ sycl/unittests/Extensions/fp8/fp8_e4m3.cpp | 13 ------------- sycl/unittests/Extensions/fp8/fp8_e5m2.cpp | 13 ------------- sycl/unittests/Extensions/fp8/fp8_e8m0.cpp | 13 ------------- 4 files changed, 6 insertions(+), 45 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp index a7275c10fc1db..26fccf387ca15 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp @@ -829,12 +829,12 @@ inline constexpr bool IsSyclFpTypeV = IsSyclFpType::value; template struct SyclfpVariadic - : std::bool_constant< - (sizeof...(Types) == N) && - (((std::is_same_v, half>) && ...) || - ((std::is_same_v, bfloat16>) && ...) || - ((std::is_same_v, float>) && ...) || - ((std::is_same_v, double>) && ...))> {}; + : std::bool_constant< + (sizeof...(Types) == N) && + (((std::is_same_v, half>) && ...) || + ((std::is_same_v, bfloat16>) && ...) || + ((std::is_same_v, float>) && ...) || + ((std::is_same_v, double>) && ...))> {}; template inline constexpr bool SyclfpVariadicV = SyclfpVariadic::value; diff --git a/sycl/unittests/Extensions/fp8/fp8_e4m3.cpp b/sycl/unittests/Extensions/fp8/fp8_e4m3.cpp index 961f5adc79cb1..d231cf93f7a92 100644 --- a/sycl/unittests/Extensions/fp8/fp8_e4m3.cpp +++ b/sycl/unittests/Extensions/fp8/fp8_e4m3.cpp @@ -50,19 +50,6 @@ TEST(FP8E4M3Test, VariadicFloat) { EXPECT_EQ(b.vals[0], 0x39); } -TEST(FP8E4M3Test, VariadicFloatReferences) { - float x = 1.0f; - float y = 2.0f; - float &xf = x; - float &yf = y; - - fp8_e4m3_x2 a(xf, yf); - - EXPECT_EQ(sizeof(a.vals), 2u); - EXPECT_EQ(a.vals[0], 0x38); - EXPECT_EQ(a.vals[1], 0x40); -} - TEST(FP8E4M3Test, VariadicBoundaryEncodingsFloat) { // CPU host path: variadic constructors use rounding::to_even and // saturation::finite. diff --git a/sycl/unittests/Extensions/fp8/fp8_e5m2.cpp b/sycl/unittests/Extensions/fp8/fp8_e5m2.cpp index 3459bbdbf7e2b..05d5ec75874b8 100644 --- a/sycl/unittests/Extensions/fp8/fp8_e5m2.cpp +++ b/sycl/unittests/Extensions/fp8/fp8_e5m2.cpp @@ -117,19 +117,6 @@ TEST(FP8E5M2Test, RawInfinityAndNaNDecoding) { EXPECT_TRUE(std::isnan(static_cast(qnan_bf16))); } -TEST(FP8E5M2Test, VariadicFloatReferences) { - float x = 1.0f; - float y = 2.0f; - float &xf = x; - float &yf = y; - - fp8_e4m3_x2 a(xf, yf); - - EXPECT_EQ(sizeof(a.vals), 2u); - EXPECT_EQ(a.vals[0], 0x38); - EXPECT_EQ(a.vals[1], 0x40); -} - TEST(FP8E5M2Test, IntegerConstructorToEvenFiniteAndSize) { fp8_e5m2 a0(0); fp8_e5m2 a1(1); diff --git a/sycl/unittests/Extensions/fp8/fp8_e8m0.cpp b/sycl/unittests/Extensions/fp8/fp8_e8m0.cpp index a799c953aa044..72876d674474e 100644 --- a/sycl/unittests/Extensions/fp8/fp8_e8m0.cpp +++ b/sycl/unittests/Extensions/fp8/fp8_e8m0.cpp @@ -123,19 +123,6 @@ TEST(FP8E8M0Test, CArrayHalfHostUpwardFinite) { EXPECT_EQ(a1.vals[1], 0x00); } -TEST(FP8E8M0Test, VariadicFloatReferences) { - float x = 1.0f; - float y = 2.0f; - float &xf = x; - float &yf = y; - - fp8_e4m3_x2 a(xf, yf); - - EXPECT_EQ(sizeof(a.vals), 2u); - EXPECT_EQ(a.vals[0], 0x38); - EXPECT_EQ(a.vals[1], 0x40); -} - TEST(FP8E8M0Test, CArrayBFloat16HostUpwardFinite) { const sycl::ext::oneapi::bfloat16 in[2] = {sycl::ext::oneapi::bfloat16(1.0f), sycl::ext::oneapi::bfloat16(2.0f)}; From e1ec70d3829738749af062788a103a4b5342032f Mon Sep 17 00:00:00 2001 From: "Klochkov, Denis" Date: Tue, 21 Apr 2026 09:10:07 +0200 Subject: [PATCH 41/89] Revert "[SYCL] add separate trait for variadic constructors" This reverts commit edbb6fddcd5bf2373aafa801dee02250a00c665e. --- .../oneapi/experimental/float_8bit/types.hpp | 33 ++++++++++--------- 1 file changed, 18 insertions(+), 15 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp index 26fccf387ca15..5b3c8ad390138 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp @@ -827,18 +827,6 @@ struct IsSyclFpType : IsOneOf, sycl::half, template inline constexpr bool IsSyclFpTypeV = IsSyclFpType::value; -template -struct SyclfpVariadic - : std::bool_constant< - (sizeof...(Types) == N) && - (((std::is_same_v, half>) && ...) || - ((std::is_same_v, bfloat16>) && ...) || - ((std::is_same_v, float>) && ...) || - ((std::is_same_v, double>) && ...))> {}; - -template -inline constexpr bool SyclfpVariadicV = SyclfpVariadic::value; - } // namespace detail template class fp8_e4m3_x { @@ -896,7 +884,12 @@ template class fp8_e4m3_x { // Available only when the size of the pack is equal to N. template >> + typename = std::enable_if_t< + (sizeof...(Types) == N) && + (((std::is_same_v, half>) && ...) || + ((std::is_same_v, bfloat16>) && ...) || + ((std::is_same_v, float>) && ...) || + ((std::is_same_v, double>) && ...))>> explicit fp8_e4m3_x(Types... v) { using InT = std::common_type_t...>; const InT in[N] = {v...}; @@ -1062,7 +1055,12 @@ template class fp8_e5m2_x { // Available only when each type in the pack is half. template >> + typename = std::enable_if_t< + (sizeof...(Types) == N) && + (((std::is_same_v, half>) && ...) || + ((std::is_same_v, bfloat16>) && ...) || + ((std::is_same_v, float>) && ...) || + ((std::is_same_v, double>) && ...))>> explicit fp8_e5m2_x(Types... v) { using InT = std::common_type_t...>; const InT in[N] = {v...}; @@ -1184,7 +1182,12 @@ template class fp8_e8m0_x { fp8_e8m0_x &operator=(const fp8_e8m0_x &) = default; template >> + typename = std::enable_if_t< + (sizeof...(Types) == N) && + (((std::is_same_v, half>) && ...) || + ((std::is_same_v, bfloat16>) && ...) || + ((std::is_same_v, float>) && ...) || + ((std::is_same_v, double>) && ...))>> explicit fp8_e8m0_x(Types... v) { using InT = std::common_type_t...>; const InT in[N] = {v...}; From 0843e6bda98fe41d67b3408d883c3e37fd5e67e7 Mon Sep 17 00:00:00 2001 From: "Klochkov, Denis" Date: Tue, 21 Apr 2026 09:10:32 +0200 Subject: [PATCH 42/89] Revert "[SYCL] rework fp8 types to avoid copy-paste" This reverts commit b596110fddeff57fe714bdb82e5c06394ba1b79e. --- .../oneapi/experimental/float_8bit/types.hpp | 1038 ++++++++++++++--- 1 file changed, 892 insertions(+), 146 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp index 5b3c8ad390138..bfe55241184ad 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp @@ -104,7 +104,7 @@ template <> struct DirectBinary16Traits { static constexpr uint16_t QuietNaNBits = 0x7FC0u; }; -template static constexpr inline ToT MakeDirectNaN() noexcept { +template static inline ToT MakeDirectNaN() noexcept { if constexpr (std::is_same_v || std::is_same_v) { return sycl::bit_cast(DirectBinary16Traits::QuietNaNBits); @@ -116,7 +116,7 @@ template static constexpr inline ToT MakeDirectNaN() noexcept { } template -static constexpr inline ToT MakeDirectInf(bool negative) noexcept { +static inline ToT MakeDirectInf(bool negative) noexcept { if constexpr (std::is_same_v || std::is_same_v) { using Traits = DirectBinary16Traits; @@ -817,16 +817,6 @@ static inline ToT ConvertFromE8M0_CPU(uint8_t code, rounding R) noexcept { return ToT{}; } -template -struct IsOneOf : std::disjunction...> {}; - -template -struct IsSyclFpType : IsOneOf, sycl::half, - sycl::ext::oneapi::bfloat16, float, double> {}; - -template -inline constexpr bool IsSyclFpTypeV = IsSyclFpType::value; - } // namespace detail template class fp8_e4m3_x { @@ -838,12 +828,11 @@ template class fp8_e4m3_x { template uint8_t ConvertToFP8(T h) { #ifdef __SYCL_DEVICE_ONLY__ - if constexpr (std::is_same_v, bfloat16>) - return __builtin_spirv_ClampConvertBF16ToE4M3INTEL(h); - else - return __builtin_spirv_ClampConvertFP16ToE4M3INTEL(h); + return __builtin_spirv_ClampConvertFP16ToE4M3INTEL(h); #else - if constexpr (detail::IsSyclFpTypeV) { + if constexpr (std::is_same_v, sycl::half> || + std::is_same_v, float> || + std::is_same_v, double>) { return detail::ConvertFloatToFP8_CPU( h, rounding::to_even, saturation::finite); } else if constexpr (std::is_integral_v>) { @@ -853,21 +842,36 @@ template class fp8_e4m3_x { #endif } + uint8_t ConvertBF16ToFP8(bfloat16 h) { +#ifdef __SYCL_DEVICE_ONLY__ + return __builtin_spirv_ClampConvertBF16ToE4M3INTEL(h); +#else + return detail::ConvertFloatToFP8_CPU( + h, rounding::to_even, saturation::finite); +#endif + } + template T ConvertFromFP8(uint8_t v, rounding r = rounding::to_even) const { #ifdef __SYCL_DEVICE_ONLY__ - if constexpr (std::is_same_v, bfloat16>) - return __builtin_spirv_ConvertE4M3ToBF16EXT(v); - else { - sycl::half hi = __builtin_spirv_ConvertE4M3ToFP16EXT(v); - return static_cast(hi); - } + sycl::half hi = __builtin_spirv_ConvertE4M3ToFP16EXT(v); + return static_cast(hi); #else return detail::ConvertFromFP8ToBinaryFloat_CPU(v, r); #endif } + bfloat16 ConvertBF16FromFP8(uint8_t v) const { +#ifdef __SYCL_DEVICE_ONLY__ + return __builtin_spirv_ConvertE4M3ToBF16EXT(v); +#else + return detail::ConvertFromFP8ToBinaryFloat_CPU(v, + rounding::to_even); +#endif + } + void CheckConstraints(rounding r) const { assert(r == rounding::to_even && "fp8_e4m3_x: only rounding::to_even is supported"); @@ -891,6 +895,12 @@ template class fp8_e4m3_x { ((std::is_same_v, float>) && ...) || ((std::is_same_v, double>) && ...))>> explicit fp8_e4m3_x(Types... v) { + if constexpr (((std::is_same_v, bfloat16>) && ...)) { + const bfloat16 in[N] = {static_cast(v)...}; + for (size_t i = 0; i < N; ++i) + vals[i] = ConvertBF16ToFP8(in[i]); + return; + } using InT = std::common_type_t...>; const InT in[N] = {v...}; for (size_t i = 0; i < N; ++i) @@ -898,52 +908,253 @@ template class fp8_e4m3_x { } // Construct from an array of half, bfloat16, float, double. - template >> - explicit fp8_e4m3_x(T const (&v)[N], rounding r = rounding::to_even) { + explicit fp8_e4m3_x(sycl::half const (&v)[N], + rounding r = rounding::to_even) { + CheckConstraints(r); + for (size_t i = 0; i < N; ++i) + vals[i] = ConvertToFP8(v[i]); + } + + explicit fp8_e4m3_x(bfloat16 const (&v)[N], rounding r = rounding::to_even) { + CheckConstraints(r); + for (size_t i = 0; i < N; ++i) + vals[i] = ConvertBF16ToFP8(v[i]); + } + + explicit fp8_e4m3_x(float const (&v)[N], rounding r = rounding::to_even) { CheckConstraints(r); for (size_t i = 0; i < N; ++i) vals[i] = ConvertToFP8(v[i]); } + explicit fp8_e4m3_x(double const (&v)[N]) { + for (size_t i = 0; i < N; ++i) + vals[i] = ConvertToFP8(v[i]); + } + // Construct from an marray of half, bfloat16, float, double. - template >> - explicit fp8_e4m3_x(const sycl::marray &v, + explicit fp8_e4m3_x(const sycl::marray &v, + rounding r = rounding::to_even) { + CheckConstraints(r); + for (size_t i = 0; i < N; ++i) + vals[i] = ConvertToFP8(v[i]); + } + + explicit fp8_e4m3_x(const sycl::marray &v, + rounding r = rounding::to_even) { + CheckConstraints(r); + for (size_t i = 0; i < N; ++i) + vals[i] = ConvertBF16ToFP8(v[i]); + } + + explicit fp8_e4m3_x(const sycl::marray &v, rounding r = rounding::to_even) { CheckConstraints(r); for (size_t i = 0; i < N; ++i) vals[i] = ConvertToFP8(v[i]); } + explicit fp8_e4m3_x(const sycl::marray &v) { + for (size_t i = 0; i < N; ++i) + vals[i] = ConvertToFP8(v[i]); + } + // Construct from integer types. // Available only when N==1. - template >> - explicit fp8_e4m3_x(T val) { + template > + explicit fp8_e4m3_x(short val) { + vals[0] = ConvertToFP8(val); + } + + template > + explicit fp8_e4m3_x(int val) { + vals[0] = ConvertToFP8(val); + } + + template > + explicit fp8_e4m3_x(long val) { + vals[0] = ConvertToFP8(val); + } + + template > + explicit fp8_e4m3_x(long long val) { + vals[0] = ConvertToFP8(val); + } + + template > + explicit fp8_e4m3_x(unsigned short val) { + vals[0] = ConvertToFP8(val); + } + + template > + explicit fp8_e4m3_x(unsigned int val) { + vals[0] = ConvertToFP8(val); + } + + template > + explicit fp8_e4m3_x(unsigned long val) { + vals[0] = ConvertToFP8(val); + } + + template > + explicit fp8_e4m3_x(unsigned long long val) { vals[0] = ConvertToFP8(val); } + // Assign (operator) from half, bfloat16, float, double, and integer types. // Available only when N==1. - template || - std::is_integral_v)>> - fp8_e4m3_x &operator=(T val) { + template > + fp8_e4m3_x &operator=(sycl::half val) { + vals[0] = ConvertToFP8(val); + return *this; + } + + template > + fp8_e4m3_x &operator=(bfloat16 val) { + vals[0] = ConvertBF16ToFP8(val); + return *this; + } + + template > + fp8_e4m3_x &operator=(float val) { + vals[0] = ConvertToFP8(val); + return *this; + } + + template > + fp8_e4m3_x &operator=(double val) { + vals[0] = ConvertToFP8(val); + return *this; + } + + template > + fp8_e4m3_x &operator=(short val) { + vals[0] = ConvertToFP8(val); + return *this; + } + + template > + fp8_e4m3_x &operator=(int val) { + vals[0] = ConvertToFP8(val); + return *this; + } + + template > + fp8_e4m3_x &operator=(long val) { + vals[0] = ConvertToFP8(val); + return *this; + } + + template > + fp8_e4m3_x &operator=(long long val) { + vals[0] = ConvertToFP8(val); + return *this; + } + + template > + fp8_e4m3_x &operator=(unsigned short val) { + vals[0] = ConvertToFP8(val); + return *this; + } + + template > + fp8_e4m3_x &operator=(unsigned int val) { + vals[0] = ConvertToFP8(val); + return *this; + } + + template > + fp8_e4m3_x &operator=(unsigned long val) { + vals[0] = ConvertToFP8(val); + return *this; + } + + template > + fp8_e4m3_x &operator=(unsigned long long val) { vals[0] = ConvertToFP8(val); return *this; } - // Convert to half, bfloat16, float, double and integer types + + // Convert to half, bfloat16, float, double. + // Available only when N==1. + + template > + explicit operator half() const { + return ConvertFromFP8(vals[0]); + } + + template > + explicit operator bfloat16() const { + return ConvertBF16FromFP8(vals[0]); + } + template > + explicit operator float() const { + return ConvertFromFP8(vals[0]); + } + template > + explicit operator double() const { + return ConvertFromFP8(vals[0]); + } + + // Convert to integer types. // Available only when N==1. - template || - std::is_integral_v)>> - explicit operator T() const { - if constexpr (std::is_integral_v) - return ConvertFromFP8(vals[0], rounding::toward_zero); - else - return ConvertFromFP8(vals[0]); + template > + explicit operator char() const { + return ConvertFromFP8(vals[0], rounding::toward_zero); + } + + template > + explicit operator signed char() const { + return ConvertFromFP8(vals[0], rounding::toward_zero); + } + + template > + explicit operator short() const { + return ConvertFromFP8(vals[0], rounding::toward_zero); + } + + template > + explicit operator int() const { + return ConvertFromFP8(vals[0], rounding::toward_zero); + } + + template > + explicit operator long() const { + return ConvertFromFP8(vals[0], rounding::toward_zero); + } + + template > + explicit operator long long() const { + return ConvertFromFP8(vals[0], rounding::toward_zero); + } + template > + explicit operator unsigned char() const { + return ConvertFromFP8(vals[0], rounding::toward_zero); + } + + template > + explicit operator unsigned short() const { + return ConvertFromFP8(vals[0], rounding::toward_zero); + } + + template > + explicit operator unsigned int() const { + return ConvertFromFP8(vals[0], rounding::toward_zero); + } + + template > + explicit operator unsigned long() const { + return ConvertFromFP8(vals[0], rounding::toward_zero); + } + + template > + explicit operator unsigned long long() const { + return ConvertFromFP8(vals[0], rounding::toward_zero); } + // Convert to bool // Available only when N==1. @@ -960,15 +1171,28 @@ template class fp8_e4m3_x { } // Convert to marray of half, bfloat16, float - template , sycl::half, - sycl::ext::oneapi::bfloat16, float>::value>> - explicit operator sycl::marray() const { - sycl::marray ret; + + explicit operator sycl::marray() const { + sycl::marray ret; + for (size_t i = 0; i < N; ++i) + ret[i] = ConvertFromFP8(vals[i]); + return ret; + } + + explicit operator sycl::marray() const { + sycl::marray ret; + for (size_t i = 0; i < N; ++i) + ret[i] = ConvertBF16FromFP8(vals[i]); + return ret; + } + + explicit operator sycl::marray() const { + sycl::marray ret; for (size_t i = 0; i < N; ++i) - ret[i] = ConvertFromFP8(vals[i]); + ret[i] = ConvertFromFP8(vals[i]); return ret; } + // Intentionally public to allow access to the raw values. uint8_t vals[N]; }; @@ -982,16 +1206,14 @@ template class fp8_e5m2_x { template uint8_t ConvertToFP8(T h, saturation s) { #ifdef __SYCL_DEVICE_ONLY__ - if constexpr (std::is_same_v, bfloat16>) - return s == saturation::finite - ? __builtin_spirv_ClampConvertBF16ToE5M2INTEL(h) - : __builtin_spirv_ConvertBF16ToE5M2EXT(h); const sycl::half halfValue = static_cast(h); return s == saturation::finite ? __builtin_spirv_ClampConvertFP16ToE5M2INTEL(halfValue) : __builtin_spirv_ConvertFP16ToE5M2EXT(halfValue); #else - if constexpr (detail::IsSyclFpTypeV) { + if constexpr (std::is_same_v, sycl::half> || + std::is_same_v, float> || + std::is_same_v, double>) { return detail::ConvertFloatToFP8_CPU( h, rounding::to_even, s); } else if constexpr (std::is_integral_v>) { @@ -1001,35 +1223,20 @@ template class fp8_e5m2_x { #endif } - template - void StochasticConvertToFP8(T h, uint32_t current_seed, uint32_t *pseed, - saturation s, uint8_t i) { + uint8_t ConvertBF16ToFP8(bfloat16 h, saturation s) { #ifdef __SYCL_DEVICE_ONLY__ - if constexpr (std::is_same_v) { - if (s == saturation::finite) { - vals[i] = __builtin_spirv_ClampStochasticRoundBF16ToE5M2INTEL( - h, current_seed, pseed); - } else { - vals[i] = __builtin_spirv_StochasticRoundBF16ToE5M2INTEL( - h, current_seed, pseed); - } - } else { - if (s == saturation::finite) { - vals[i] = __builtin_spirv_ClampStochasticRoundFP16ToE5M2INTEL( - h, current_seed, pseed); - } else { - vals[i] = __builtin_spirv_StochasticRoundFP16ToE5M2INTEL( - h, current_seed, pseed); - } - } + return s == saturation::finite + ? __builtin_spirv_ClampConvertBF16ToE5M2INTEL(h) + : __builtin_spirv_ConvertBF16ToE5M2EXT(h); +#else + return detail::ConvertFloatToFP8_CPU( + h, rounding::to_even, s); #endif } template T ConvertFromFP8(uint8_t v, rounding r = rounding::to_even) const { #ifdef __SYCL_DEVICE_ONLY__ - if constexpr (std::is_same_v, bfloat16>) - return __builtin_spirv_ConvertE5M2ToBF16EXT(v); sycl::half hi = __builtin_spirv_ConvertE5M2ToFP16EXT(v); return static_cast(hi); #else @@ -1038,6 +1245,16 @@ template class fp8_e5m2_x { #endif } + bfloat16 ConvertBF16FromFP8(uint8_t v) const { +#ifdef __SYCL_DEVICE_ONLY__ + return __builtin_spirv_ConvertE5M2ToBF16EXT(v); +#else + return detail::ConvertFromFP8ToBinaryFloat_CPU(v, + rounding::to_even); +#endif + } + void CheckConstraints(rounding r) const { assert(r == rounding::to_even && "fp8_e5m2_x: only rounding::to_even is supported"); @@ -1062,23 +1279,67 @@ template class fp8_e5m2_x { ((std::is_same_v, float>) && ...) || ((std::is_same_v, double>) && ...))>> explicit fp8_e5m2_x(Types... v) { + if constexpr (((std::is_same_v, bfloat16>) && ...)) { + const bfloat16 in[N] = {static_cast(v)...}; + for (size_t i = 0; i < N; ++i) + vals[i] = ConvertBF16ToFP8(in[i], saturation::finite); + return; + } using InT = std::common_type_t...>; const InT in[N] = {v...}; for (size_t i = 0; i < N; ++i) vals[i] = ConvertToFP8(in[i], saturation::finite); } - template >> - explicit fp8_e5m2_x(T const (&v)[N], rounding r = rounding::to_even, + // Construct from an array of half, bfloat16, float, double. + + explicit fp8_e5m2_x(half const (&v)[N], rounding r = rounding::to_even, + saturation s = saturation::finite) { + CheckConstraints(r); + // TODO: optimize with vectorized builtin calls + for (size_t i = 0; i < N; ++i) + vals[i] = ConvertToFP8(v[i], s); + } + + explicit fp8_e5m2_x(bfloat16 const (&v)[N], rounding r = rounding::to_even, saturation s = saturation::finite) { CheckConstraints(r); // TODO: optimize with vectorized builtin calls + for (size_t i = 0; i < N; ++i) + vals[i] = ConvertBF16ToFP8(v[i], s); + } + + explicit fp8_e5m2_x(float const (&v)[N], rounding r = rounding::to_even, + saturation s = saturation::finite) { + CheckConstraints(r); + for (size_t i = 0; i < N; ++i) + vals[i] = ConvertToFP8(v[i], s); + } + + explicit fp8_e5m2_x(double const (&v)[N]) { + for (size_t i = 0; i < N; ++i) + vals[i] = ConvertToFP8(v[i], saturation::finite); + } + + // Construct from an marray of half, bfloat16, float, double. + + explicit fp8_e5m2_x(const sycl::marray &v, + rounding r = rounding::to_even, + saturation s = saturation::finite) { + CheckConstraints(r); for (size_t i = 0; i < N; ++i) vals[i] = ConvertToFP8(v[i], s); } - template >> - explicit fp8_e5m2_x(const sycl::marray &v, + explicit fp8_e5m2_x(const sycl::marray &v, + rounding r = rounding::to_even, + saturation s = saturation::finite) { + CheckConstraints(r); + for (size_t i = 0; i < N; ++i) + vals[i] = ConvertBF16ToFP8(v[i], s); + } + + explicit fp8_e5m2_x(const sycl::marray &v, rounding r = rounding::to_even, saturation s = saturation::finite) { CheckConstraints(r); @@ -1086,59 +1347,324 @@ template class fp8_e5m2_x { vals[i] = ConvertToFP8(v[i], s); } - template , sycl::half, bfloat16, float>::value>> - explicit fp8_e5m2_x([[maybe_unused]] T const (&in)[N], + explicit fp8_e5m2_x(const sycl::marray &v) { + for (size_t i = 0; i < N; ++i) + vals[i] = ConvertToFP8(v[i], saturation::finite); + } + + // Construct with stochastic rounding with user provided seed from an array of + // half, bfloat16, float. + + explicit fp8_e5m2_x([[maybe_unused]] half const (&in)[N], [[maybe_unused]] const stochastic_seed &seed, [[maybe_unused]] saturation s = saturation::finite) { #ifdef __SYCL_DEVICE_ONLY__ uint32_t current_seed = *seed.pseed; for (size_t i = 0; i < N; ++i) { - StochasticConvertToFP8(in[i], current_seed, seed.pseed, s, i); + if (s == saturation::finite) { + vals[i] = __builtin_spirv_ClampStochasticRoundFP16ToE5M2INTEL( + in[i], current_seed, seed.pseed); + } else { + vals[i] = __builtin_spirv_StochasticRoundFP16ToE5M2INTEL( + in[i], current_seed, seed.pseed); + } current_seed = *seed.pseed; } #endif } - template , sycl::half, bfloat16, float>::value>> - explicit fp8_e5m2_x([[maybe_unused]] const sycl::marray &in, + explicit fp8_e5m2_x([[maybe_unused]] bfloat16 const (&in)[N], [[maybe_unused]] const stochastic_seed &seed, [[maybe_unused]] saturation s = saturation::finite) { - #ifdef __SYCL_DEVICE_ONLY__ uint32_t current_seed = *seed.pseed; for (size_t i = 0; i < N; ++i) { - StochasticConvertToFP8(in[i], current_seed, seed.pseed, s, i); + if (s == saturation::finite) { + vals[i] = __builtin_spirv_ClampStochasticRoundBF16ToE5M2INTEL( + in[i], current_seed, seed.pseed); + } else { + vals[i] = __builtin_spirv_StochasticRoundBF16ToE5M2INTEL( + in[i], current_seed, seed.pseed); + } current_seed = *seed.pseed; } #endif } - template >> - explicit fp8_e5m2_x(T val) { - vals[0] = ConvertToFP8(val, saturation::finite); - } - - template || - std::is_integral_v)>> - fp8_e5m2_x &operator=(T val) { - vals[0] = ConvertToFP8(val, saturation::finite); + explicit fp8_e5m2_x([[maybe_unused]] float const (&in)[N], + [[maybe_unused]] const stochastic_seed &seed, + [[maybe_unused]] saturation s = saturation::finite) { +#ifdef __SYCL_DEVICE_ONLY__ + uint32_t current_seed = *seed.pseed; + for (size_t i = 0; i < N; ++i) { + sycl::half h = static_cast(in[i]); + if (s == saturation::finite) { + vals[i] = __builtin_spirv_ClampStochasticRoundFP16ToE5M2INTEL( + h, current_seed, seed.pseed); + } else { + vals[i] = __builtin_spirv_StochasticRoundFP16ToE5M2INTEL( + h, current_seed, seed.pseed); + } + current_seed = *seed.pseed; + } +#endif + } + + // Construct with stochastic rounding with user provided seed from an marray + // of half, bfloat16, float. + + explicit fp8_e5m2_x([[maybe_unused]] const sycl::marray &in, + [[maybe_unused]] const stochastic_seed &seed, + [[maybe_unused]] saturation s = saturation::finite) { +#ifdef __SYCL_DEVICE_ONLY__ + uint32_t current_seed = *seed.pseed; + for (size_t i = 0; i < N; ++i) { + if (s == saturation::finite) { + vals[i] = __builtin_spirv_ClampStochasticRoundFP16ToE5M2INTEL( + in[i], current_seed, seed.pseed); + } else { + vals[i] = __builtin_spirv_StochasticRoundFP16ToE5M2INTEL( + in[i], current_seed, seed.pseed); + } + current_seed = *seed.pseed; + } +#endif + } + + explicit fp8_e5m2_x([[maybe_unused]] const sycl::marray &in, + [[maybe_unused]] const stochastic_seed &seed, + [[maybe_unused]] saturation s = saturation::finite) { +#ifdef __SYCL_DEVICE_ONLY__ + uint32_t current_seed = *seed.pseed; + for (size_t i = 0; i < N; ++i) { + if (s == saturation::finite) { + vals[i] = __builtin_spirv_ClampStochasticRoundBF16ToE5M2INTEL( + in[i], current_seed, seed.pseed); + } else { + vals[i] = __builtin_spirv_StochasticRoundBF16ToE5M2INTEL( + in[i], current_seed, seed.pseed); + } + current_seed = *seed.pseed; + } +#endif + } + + explicit fp8_e5m2_x([[maybe_unused]] const sycl::marray &in, + [[maybe_unused]] const stochastic_seed &seed, + [[maybe_unused]] saturation s = saturation::finite) { +#ifdef __SYCL_DEVICE_ONLY__ + uint32_t current_seed = *seed.pseed; + for (size_t i = 0; i < N; ++i) { + sycl::half h = static_cast(in[i]); + if (s == saturation::finite) { + vals[i] = __builtin_spirv_ClampStochasticRoundFP16ToE5M2INTEL( + h, current_seed, seed.pseed); + } else { + vals[i] = __builtin_spirv_StochasticRoundFP16ToE5M2INTEL( + h, current_seed, seed.pseed); + } + current_seed = *seed.pseed; + } +#endif + } + + // Construct from integer types. + // Available only when N==1. + + template > + explicit fp8_e5m2_x(short val) { + vals[0] = ConvertToFP8(val, saturation::finite); + } + + template > + explicit fp8_e5m2_x(int val) { + vals[0] = ConvertToFP8(val, saturation::finite); + } + + template > + explicit fp8_e5m2_x(long val) { + vals[0] = ConvertToFP8(val, saturation::finite); + } + + template > + explicit fp8_e5m2_x(long long val) { + vals[0] = ConvertToFP8(val, saturation::finite); + } + + template > + explicit fp8_e5m2_x(unsigned short val) { + vals[0] = ConvertToFP8(val, saturation::finite); + } + + template > + explicit fp8_e5m2_x(unsigned int val) { + vals[0] = ConvertToFP8(val, saturation::finite); + } + + template > + explicit fp8_e5m2_x(unsigned long val) { + vals[0] = ConvertToFP8(val, saturation::finite); + } + + template > + explicit fp8_e5m2_x(unsigned long long val) { + vals[0] = ConvertToFP8(val, saturation::finite); + } + + // Assign (operator) from half, bfloat16, float, double, and integer types. + // Available only when N==1. + + template > + fp8_e5m2_x &operator=(sycl::half val) { + vals[0] = ConvertToFP8(val, saturation::finite); + return *this; + } + + template > + fp8_e5m2_x &operator=(bfloat16 val) { + vals[0] = ConvertBF16ToFP8(val, saturation::finite); + return *this; + } + + template > + fp8_e5m2_x &operator=(float val) { + vals[0] = ConvertToFP8(val, saturation::finite); + return *this; + } + + template > + fp8_e5m2_x &operator=(double val) { + vals[0] = ConvertToFP8(val, saturation::finite); + return *this; + } + + template > + fp8_e5m2_x &operator=(short val) { + vals[0] = ConvertToFP8(val, saturation::finite); + return *this; + } + + template > + fp8_e5m2_x &operator=(int val) { + vals[0] = ConvertToFP8(val, saturation::finite); + return *this; + } + + template > + fp8_e5m2_x &operator=(long val) { + vals[0] = ConvertToFP8(val, saturation::finite); + return *this; + } + + template > + fp8_e5m2_x &operator=(long long val) { + vals[0] = ConvertToFP8(val, saturation::finite); + return *this; + } + + template > + fp8_e5m2_x &operator=(unsigned short val) { + vals[0] = ConvertToFP8(val, saturation::finite); + return *this; + } + + template > + fp8_e5m2_x &operator=(unsigned int val) { + vals[0] = ConvertToFP8(val, saturation::finite); + return *this; + } + + template > + fp8_e5m2_x &operator=(unsigned long val) { + vals[0] = ConvertToFP8(val, saturation::finite); return *this; } - template || - std::is_integral_v)>> - explicit operator T() const { - if constexpr (std::is_integral_v) - return ConvertFromFP8(vals[0], rounding::toward_zero); - else - return ConvertFromFP8(vals[0]); + template > + fp8_e5m2_x &operator=(unsigned long long val) { + vals[0] = ConvertToFP8(val, saturation::finite); + return *this; + } + + // Convert to half, bfloat16, float, double. + // Available only when N==1. + + template > + explicit operator half() const { + return ConvertFromFP8(vals[0]); + } + + template > + explicit operator bfloat16() const { + return ConvertBF16FromFP8(vals[0]); + } + + template > + explicit operator float() const { + return ConvertFromFP8(vals[0]); + } + + template > + explicit operator double() const { + return ConvertFromFP8(vals[0]); + } + + // Convert to integer types. + // Available only when N==1. + + template > + explicit operator char() const { + return ConvertFromFP8(vals[0], rounding::toward_zero); + } + + template > + explicit operator signed char() const { + return ConvertFromFP8(vals[0], rounding::toward_zero); + } + + template > + explicit operator short() const { + return ConvertFromFP8(vals[0], rounding::toward_zero); + } + + template > + explicit operator int() const { + return ConvertFromFP8(vals[0], rounding::toward_zero); + } + + template > + explicit operator long() const { + return ConvertFromFP8(vals[0], rounding::toward_zero); + } + + template > + explicit operator long long() const { + return ConvertFromFP8(vals[0], rounding::toward_zero); + } + + template > + explicit operator unsigned char() const { + return ConvertFromFP8(vals[0], rounding::toward_zero); + } + + template > + explicit operator unsigned short() const { + return ConvertFromFP8(vals[0], rounding::toward_zero); + } + + template > + explicit operator unsigned int() const { + return ConvertFromFP8(vals[0], rounding::toward_zero); + } + + template > + explicit operator unsigned long() const { + return ConvertFromFP8(vals[0], rounding::toward_zero); + } + + template > + explicit operator unsigned long long() const { + return ConvertFromFP8(vals[0], rounding::toward_zero); } // Convert to bool @@ -1150,13 +1676,22 @@ template class fp8_e5m2_x { return vals[0] != 0x00 && vals[0] != 0x80; } - template , sycl::half, - sycl::ext::oneapi::bfloat16, float>::value>> - explicit operator sycl::marray() const { - sycl::marray out; + explicit operator sycl::marray() const { + sycl::marray out; + for (size_t i = 0; i < N; ++i) + out[i] = ConvertFromFP8(vals[i]); + return out; + } + explicit operator sycl::marray() const { + sycl::marray out; + for (size_t i = 0; i < N; ++i) + out[i] = ConvertBF16FromFP8(vals[i]); + return out; + } + explicit operator sycl::marray() const { + sycl::marray out; for (size_t i = 0; i < N; ++i) - out[i] = ConvertFromFP8(vals[i]); + out[i] = ConvertFromFP8(vals[i]); return out; } @@ -1196,48 +1731,248 @@ template class fp8_e8m0_x { saturation::finite); } - template >> - explicit fp8_e8m0_x(T const (&in)[N], rounding r = rounding::upward) { + explicit fp8_e8m0_x(half const (&in)[N], rounding r = rounding::upward) { + CheckConstraints(r); + for (size_t i = 0; i < N; ++i) + vals[i] = detail::ConvertFloatToE8M0_CPU(in[i], r, saturation::finite); + } + + explicit fp8_e8m0_x(bfloat16 const (&in)[N], rounding r = rounding::upward) { + CheckConstraints(r); + for (size_t i = 0; i < N; ++i) + vals[i] = detail::ConvertFloatToE8M0_CPU(in[i], r, saturation::finite); + } + + explicit fp8_e8m0_x(float const (&in)[N], rounding r = rounding::upward) { + CheckConstraints(r); + for (size_t i = 0; i < N; ++i) + vals[i] = detail::ConvertFloatToE8M0_CPU(in[i], r, saturation::finite); + } + + explicit fp8_e8m0_x(double const (&in)[N]) { + for (size_t i = 0; i < N; ++i) + vals[i] = detail::ConvertFloatToE8M0_CPU(in[i], rounding::upward, + saturation::finite); + } + + explicit fp8_e8m0_x(const marray &in, + rounding r = rounding::upward) { + CheckConstraints(r); + for (size_t i = 0; i < N; ++i) + vals[i] = detail::ConvertFloatToE8M0_CPU(in[i], r, saturation::finite); + } + + explicit fp8_e8m0_x(const marray &in, + rounding r = rounding::upward) { CheckConstraints(r); for (size_t i = 0; i < N; ++i) vals[i] = detail::ConvertFloatToE8M0_CPU(in[i], r, saturation::finite); } - template >> - explicit fp8_e8m0_x(const marray &in, rounding r = rounding::upward) { + explicit fp8_e8m0_x(const marray &in, + rounding r = rounding::upward) { CheckConstraints(r); for (size_t i = 0; i < N; ++i) vals[i] = detail::ConvertFloatToE8M0_CPU(in[i], r, saturation::finite); } - template >> - explicit fp8_e8m0_x(T val) { + explicit fp8_e8m0_x(const marray &in) { + for (size_t i = 0; i < N; ++i) + vals[i] = detail::ConvertFloatToE8M0_CPU(in[i], rounding::upward, + saturation::finite); + } + + // Construct from integer types. + // Available only when N==1. + + template > + explicit fp8_e8m0_x(short val) { + vals[0] = + detail::ConvertIntToE8M0_CPU(val, rounding::upward, saturation::finite); + } + + template > + explicit fp8_e8m0_x(int val) { + vals[0] = + detail::ConvertIntToE8M0_CPU(val, rounding::upward, saturation::finite); + } + + template > + explicit fp8_e8m0_x(long val) { + vals[0] = + detail::ConvertIntToE8M0_CPU(val, rounding::upward, saturation::finite); + } + + template > + explicit fp8_e8m0_x(long long val) { + vals[0] = + detail::ConvertIntToE8M0_CPU(val, rounding::upward, saturation::finite); + } + template > + explicit fp8_e8m0_x(unsigned short val) { + vals[0] = + detail::ConvertIntToE8M0_CPU(val, rounding::upward, saturation::finite); + } + template > + explicit fp8_e8m0_x(unsigned int val) { + vals[0] = + detail::ConvertIntToE8M0_CPU(val, rounding::upward, saturation::finite); + } + template > + explicit fp8_e8m0_x(unsigned long val) { + vals[0] = + detail::ConvertIntToE8M0_CPU(val, rounding::upward, saturation::finite); + } + template > + explicit fp8_e8m0_x(unsigned long long val) { vals[0] = detail::ConvertIntToE8M0_CPU(val, rounding::upward, saturation::finite); } - template || - std::is_integral_v)>> - fp8_e8m0_x &operator=(T val) { - if constexpr (std::is_integral_v) - vals[0] = detail::ConvertIntToE8M0_CPU(val, rounding::upward, + template > + fp8_e8m0_x &operator=(half val) { + vals[0] = detail::ConvertFloatToE8M0_CPU(val, rounding::upward, saturation::finite); - else - vals[0] = detail::ConvertFloatToE8M0_CPU(val, rounding::upward, - saturation::finite); return *this; } + template > + fp8_e8m0_x &operator=(bfloat16 val) { + vals[0] = detail::ConvertFloatToE8M0_CPU(val, rounding::upward, + saturation::finite); + return *this; + } + template > + fp8_e8m0_x &operator=(float val) { + vals[0] = detail::ConvertFloatToE8M0_CPU(val, rounding::upward, + saturation::finite); + return *this; + } + + template > + fp8_e8m0_x &operator=(double val) { + vals[0] = detail::ConvertFloatToE8M0_CPU(val, rounding::upward, + saturation::finite); + return *this; + } + + template > + fp8_e8m0_x &operator=(short val) { + vals[0] = + detail::ConvertIntToE8M0_CPU(val, rounding::upward, saturation::finite); + return *this; + } + template > + fp8_e8m0_x &operator=(int val) { + vals[0] = + detail::ConvertIntToE8M0_CPU(val, rounding::upward, saturation::finite); + return *this; + } + template > + fp8_e8m0_x &operator=(long val) { + vals[0] = + detail::ConvertIntToE8M0_CPU(val, rounding::upward, saturation::finite); + return *this; + } + template > + fp8_e8m0_x &operator=(long long val) { + vals[0] = + detail::ConvertIntToE8M0_CPU(val, rounding::upward, saturation::finite); + return *this; + } + template > + fp8_e8m0_x &operator=(unsigned short val) { + vals[0] = + detail::ConvertIntToE8M0_CPU(val, rounding::upward, saturation::finite); + return *this; + } + template > + fp8_e8m0_x &operator=(unsigned int val) { + vals[0] = + detail::ConvertIntToE8M0_CPU(val, rounding::upward, saturation::finite); + return *this; + } + template > + fp8_e8m0_x &operator=(unsigned long val) { + vals[0] = + detail::ConvertIntToE8M0_CPU(val, rounding::upward, saturation::finite); + return *this; + } + template > + fp8_e8m0_x &operator=(unsigned long long val) { + vals[0] = + detail::ConvertIntToE8M0_CPU(val, rounding::upward, saturation::finite); + return *this; + } + + template > + explicit operator half() const { + return detail::ConvertFromE8M0_CPU(vals[0], rounding::to_even); + } + template > + explicit operator bfloat16() const { + return detail::ConvertFromE8M0_CPU(vals[0], rounding::to_even); + } + template > + explicit operator float() const { + return detail::ConvertFromE8M0_CPU(vals[0], rounding::to_even); + } + template > + explicit operator double() const { + return detail::ConvertFromE8M0_CPU(vals[0], rounding::to_even); + } + + template > + explicit operator char() const { + return detail::ConvertFromE8M0_CPU(vals[0], rounding::toward_zero); + } + template > + explicit operator signed char() const { + return detail::ConvertFromE8M0_CPU(vals[0], + rounding::toward_zero); + } + template > + explicit operator short() const { + return detail::ConvertFromE8M0_CPU(vals[0], rounding::toward_zero); + } + template > + explicit operator int() const { + return detail::ConvertFromE8M0_CPU(vals[0], rounding::toward_zero); + } + template > + explicit operator long() const { + return detail::ConvertFromE8M0_CPU(vals[0], rounding::toward_zero); + } + template > + explicit operator long long() const { + return detail::ConvertFromE8M0_CPU(vals[0], + rounding::toward_zero); + } + template > + explicit operator unsigned char() const { + return detail::ConvertFromE8M0_CPU(vals[0], + rounding::toward_zero); + } + template > + explicit operator unsigned short() const { + return detail::ConvertFromE8M0_CPU(vals[0], + rounding::toward_zero); + } + template > + explicit operator unsigned int() const { + return detail::ConvertFromE8M0_CPU(vals[0], + rounding::toward_zero); + } - template || - std::is_integral_v)>> - explicit operator T() const { - if constexpr (std::is_integral_v) - return detail::ConvertFromE8M0_CPU(vals[0], rounding::toward_zero); - else - return detail::ConvertFromE8M0_CPU(vals[0], rounding::to_even); + template > + explicit operator unsigned long() const { + return detail::ConvertFromE8M0_CPU(vals[0], + rounding::toward_zero); + } + + template > + explicit operator unsigned long long() const { + return detail::ConvertFromE8M0_CPU( + vals[0], rounding::toward_zero); } template > @@ -1245,15 +1980,26 @@ template class fp8_e8m0_x { return true; } - template , sycl::half, - sycl::ext::oneapi::bfloat16, float>::value>> - explicit operator sycl::marray() const { - sycl::marray out; + explicit operator sycl::marray() const { + sycl::marray out; for (size_t i = 0; i < N; ++i) - out[i] = detail::ConvertFromE8M0_CPU(vals[i], rounding::to_even); + out[i] = detail::ConvertFromE8M0_CPU(vals[i], rounding::to_even); return out; } + explicit operator sycl::marray() const { + sycl::marray out; + for (size_t i = 0; i < N; ++i) + out[i] = + detail::ConvertFromE8M0_CPU(vals[i], rounding::to_even); + return out; + } + explicit operator sycl::marray() const { + sycl::marray out; + for (size_t i = 0; i < N; ++i) + out[i] = detail::ConvertFromE8M0_CPU(vals[i], rounding::to_even); + return out; + } + // Intentionally public to allow access to the raw values. uint8_t vals[N]; From bce68947861d4f8330e4dd5cc58ea9302b64e39d Mon Sep 17 00:00:00 2001 From: "Klochkov, Denis" Date: Tue, 21 Apr 2026 09:35:04 +0200 Subject: [PATCH 43/89] [SYCL] remove stochastic float constructors --- .../oneapi/experimental/float_8bit/types.hpp | 42 +------------------ .../Extensions/fp8/builtin_call_tests.cpp | 26 ------------ 2 files changed, 2 insertions(+), 66 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp index bfe55241184ad..c3502554af684 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp @@ -1353,7 +1353,7 @@ template class fp8_e5m2_x { } // Construct with stochastic rounding with user provided seed from an array of - // half, bfloat16, float. + // half, bfloat16. explicit fp8_e5m2_x([[maybe_unused]] half const (&in)[N], [[maybe_unused]] const stochastic_seed &seed, @@ -1391,27 +1391,8 @@ template class fp8_e5m2_x { #endif } - explicit fp8_e5m2_x([[maybe_unused]] float const (&in)[N], - [[maybe_unused]] const stochastic_seed &seed, - [[maybe_unused]] saturation s = saturation::finite) { -#ifdef __SYCL_DEVICE_ONLY__ - uint32_t current_seed = *seed.pseed; - for (size_t i = 0; i < N; ++i) { - sycl::half h = static_cast(in[i]); - if (s == saturation::finite) { - vals[i] = __builtin_spirv_ClampStochasticRoundFP16ToE5M2INTEL( - h, current_seed, seed.pseed); - } else { - vals[i] = __builtin_spirv_StochasticRoundFP16ToE5M2INTEL( - h, current_seed, seed.pseed); - } - current_seed = *seed.pseed; - } -#endif - } - // Construct with stochastic rounding with user provided seed from an marray - // of half, bfloat16, float. + // of half, bfloat16. explicit fp8_e5m2_x([[maybe_unused]] const sycl::marray &in, [[maybe_unused]] const stochastic_seed &seed, @@ -1449,25 +1430,6 @@ template class fp8_e5m2_x { #endif } - explicit fp8_e5m2_x([[maybe_unused]] const sycl::marray &in, - [[maybe_unused]] const stochastic_seed &seed, - [[maybe_unused]] saturation s = saturation::finite) { -#ifdef __SYCL_DEVICE_ONLY__ - uint32_t current_seed = *seed.pseed; - for (size_t i = 0; i < N; ++i) { - sycl::half h = static_cast(in[i]); - if (s == saturation::finite) { - vals[i] = __builtin_spirv_ClampStochasticRoundFP16ToE5M2INTEL( - h, current_seed, seed.pseed); - } else { - vals[i] = __builtin_spirv_StochasticRoundFP16ToE5M2INTEL( - h, current_seed, seed.pseed); - } - current_seed = *seed.pseed; - } -#endif - } - // Construct from integer types. // Available only when N==1. diff --git a/sycl/unittests/Extensions/fp8/builtin_call_tests.cpp b/sycl/unittests/Extensions/fp8/builtin_call_tests.cpp index 461b0dfd0b644..be8c031ffad14 100644 --- a/sycl/unittests/Extensions/fp8/builtin_call_tests.cpp +++ b/sycl/unittests/Extensions/fp8/builtin_call_tests.cpp @@ -244,30 +244,4 @@ TEST_F(Fp8BuiltinCallTest, E5M2StochasticBf16NoneCallsNonClampStochastic) { EXPECT_EQ(fp8_builtin_mock::getCounters().StochasticRoundBF16ToE5M2INTEL, 1); } -TEST_F(Fp8BuiltinCallTest, E5M2StochasticFloatFiniteCallsClampStochastic) { - float Input[2] = {3.0f, 4.0f}; - uint32_t SeedValue = 50; - stochastic_seed Seed(&SeedValue); - - fp8_e5m2_x2 Value(Input, Seed, saturation::finite); - (void)Value; - - EXPECT_EQ(fp8_builtin_mock::getCounters().ClampStochasticRoundFP16ToE5M2INTEL, - 2); - EXPECT_EQ(SeedValue, 52u); -} - -TEST_F(Fp8BuiltinCallTest, - E5M2StochasticMarrayFloatNoneCallsNonClampStochastic) { - sycl::marray Input = {3.0f, 4.0f}; - uint32_t SeedValue = 60; - stochastic_seed Seed(&SeedValue); - - fp8_e5m2_x2 Value(Input, Seed, saturation::none); - (void)Value; - - EXPECT_EQ(fp8_builtin_mock::getCounters().StochasticRoundFP16ToE5M2INTEL, 2); - EXPECT_EQ(SeedValue, 62u); -} - } // namespace From dec9931edca1ccee5798a4052901568a0a30cb2c Mon Sep 17 00:00:00 2001 From: "Klochkov, Denis" Date: Tue, 21 Apr 2026 09:55:39 +0200 Subject: [PATCH 44/89] [SYCL] add tests to check decay --- sycl/unittests/Extensions/fp8/fp8_e4m3.cpp | 13 +++++++++++++ sycl/unittests/Extensions/fp8/fp8_e5m2.cpp | 13 +++++++++++++ sycl/unittests/Extensions/fp8/fp8_e8m0.cpp | 13 +++++++++++++ 3 files changed, 39 insertions(+) diff --git a/sycl/unittests/Extensions/fp8/fp8_e4m3.cpp b/sycl/unittests/Extensions/fp8/fp8_e4m3.cpp index d231cf93f7a92..cea92e9b9076d 100644 --- a/sycl/unittests/Extensions/fp8/fp8_e4m3.cpp +++ b/sycl/unittests/Extensions/fp8/fp8_e4m3.cpp @@ -653,3 +653,16 @@ TEST(FP8E4M3Test, MarrayFloatRejectsTowardZeroRounding) { }, "fp8_e4m3_x: only rounding::to_even is supported"); } + +TEST(FP8E4M3Test, VariadicFloatReferences) { + float x = 1.0f; + float y = 2.0f; + float &xf = x; + float &yf = y; + + fp8_e4m3_x2 a(xf, yf); + + EXPECT_EQ(sizeof(a.vals), 2u); + EXPECT_EQ(a.vals[0], 0x38); + EXPECT_EQ(a.vals[1], 0x40); +} \ No newline at end of file diff --git a/sycl/unittests/Extensions/fp8/fp8_e5m2.cpp b/sycl/unittests/Extensions/fp8/fp8_e5m2.cpp index 05d5ec75874b8..6f67cae20189c 100644 --- a/sycl/unittests/Extensions/fp8/fp8_e5m2.cpp +++ b/sycl/unittests/Extensions/fp8/fp8_e5m2.cpp @@ -727,3 +727,16 @@ TEST(FP8E5M2Test, MarrayFloatTowardZeroRounding) { }, "fp8_e5m2_x: only rounding::to_even is supported"); } + +TEST(FP8E5M2Test, VariadicFloatReferences) { + float x = 1.0f; + float y = 2.0f; + float &xf = x; + float &yf = y; + + fp8_e5m2_x2 a(xf, yf); + + EXPECT_EQ(sizeof(a.vals), 2u); + EXPECT_EQ(a.vals[0], 0x3C); + EXPECT_EQ(a.vals[1], 0x40); +} \ No newline at end of file diff --git a/sycl/unittests/Extensions/fp8/fp8_e8m0.cpp b/sycl/unittests/Extensions/fp8/fp8_e8m0.cpp index 72876d674474e..fad814697b6e2 100644 --- a/sycl/unittests/Extensions/fp8/fp8_e8m0.cpp +++ b/sycl/unittests/Extensions/fp8/fp8_e8m0.cpp @@ -495,3 +495,16 @@ TEST(FP8E8M0Test, MarrayFloatToEvenRounding) { }, UnsupportedRoundingAssertRegex); } + +TEST(FP8E8M0Test, VariadicFloatReferences) { + float x = 1.0f; + float y = 2.0f; + float &xf = x; + float &yf = y; + + fp8_e8m0_x2 a(xf, yf); + + EXPECT_EQ(sizeof(a.vals), 2u); + EXPECT_EQ(a.vals[0], 0x7F); + EXPECT_EQ(a.vals[1], 0x80); +} \ No newline at end of file From 7d12a235da3584adf090fe6af259b63b6237c818 Mon Sep 17 00:00:00 2001 From: "Klochkov, Denis" Date: Wed, 29 Apr 2026 17:06:19 +0200 Subject: [PATCH 45/89] [SYCL][E2E] add simple e2e test of fp_e4m3 --- .../oneapi/experimental/float_8bit/types.hpp | 122 ++++++++++-------- sycl/test-e2e/Experimental/fp8/e4m3_cri.cpp | 58 +++++++++ .../Experimental/fp8/lit.local.cfg.py | 10 ++ .../Extensions/fp8/builtin_mocks.hpp | 63 ++++----- 4 files changed, 162 insertions(+), 91 deletions(-) create mode 100644 sycl/test-e2e/Experimental/fp8/e4m3_cri.cpp create mode 100644 sycl/test-e2e/Experimental/fp8/lit.local.cfg.py diff --git a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp index c3502554af684..c5f89a14d0623 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp @@ -21,45 +21,41 @@ #ifdef __SYCL_DEVICE_ONLY__ // FP8 builtins -extern __DPCPP_SYCL_EXTERNAL - uint8_t __builtin_spirv_ClampConvertFP16ToE4M3INTEL(sycl::half) noexcept; -extern __DPCPP_SYCL_EXTERNAL - uint8_t __builtin_spirv_ConvertFP16ToE4M3EXT(sycl::half) noexcept; -extern __DPCPP_SYCL_EXTERNAL sycl::half -__builtin_spirv_ConvertE4M3ToFP16EXT(uint8_t) noexcept; -extern __DPCPP_SYCL_EXTERNAL - uint8_t __builtin_spirv_ClampConvertBF16ToE4M3INTEL( - sycl::ext::oneapi::bfloat16) noexcept; + extern __DPCPP_SYCL_EXTERNAL uint8_t - __builtin_spirv_ConvertBF16ToE4M3EXT(sycl::ext::oneapi::bfloat16) noexcept; -extern __DPCPP_SYCL_EXTERNAL sycl::ext::oneapi::bfloat16 +__builtin_spirv_ClampConvertFP16ToE4M3INTEL(_Float16) noexcept; + +extern __DPCPP_SYCL_EXTERNAL _Float16 +__builtin_spirv_ConvertE4M3ToFP16EXT(char) noexcept; + +extern __DPCPP_SYCL_EXTERNAL uint8_t +__builtin_spirv_ClampConvertBF16ToE4M3INTEL(__bf16) noexcept; +extern __DPCPP_SYCL_EXTERNAL __bf16 __builtin_spirv_ConvertE4M3ToBF16EXT(uint8_t) noexcept; -extern __DPCPP_SYCL_EXTERNAL - uint8_t __builtin_spirv_ClampConvertFP16ToE5M2INTEL(sycl::half) noexcept; -extern __DPCPP_SYCL_EXTERNAL - uint8_t __builtin_spirv_ConvertFP16ToE5M2EXT(sycl::half) noexcept; -extern __DPCPP_SYCL_EXTERNAL sycl::half +extern __DPCPP_SYCL_EXTERNAL uint8_t +__builtin_spirv_ClampConvertFP16ToE5M2INTEL(_Float16) noexcept; +extern __DPCPP_SYCL_EXTERNAL uint8_t +__builtin_spirv_ConvertFP16ToE5M2EXT(_Float16) noexcept; +extern __DPCPP_SYCL_EXTERNAL _Float16 __builtin_spirv_ConvertE5M2ToFP16EXT(uint8_t) noexcept; -extern __DPCPP_SYCL_EXTERNAL - uint8_t __builtin_spirv_ClampConvertBF16ToE5M2INTEL( - sycl::ext::oneapi::bfloat16) noexcept; extern __DPCPP_SYCL_EXTERNAL uint8_t - __builtin_spirv_ConvertBF16ToE5M2EXT(sycl::ext::oneapi::bfloat16) noexcept; -extern __DPCPP_SYCL_EXTERNAL sycl::ext::oneapi::bfloat16 +__builtin_spirv_ClampConvertBF16ToE5M2INTEL(__bf16) noexcept; +extern __DPCPP_SYCL_EXTERNAL uint8_t +__builtin_spirv_ConvertBF16ToE5M2EXT(__bf16) noexcept; +extern __DPCPP_SYCL_EXTERNAL __bf16 __builtin_spirv_ConvertE5M2ToBF16EXT(uint8_t) noexcept; extern __DPCPP_SYCL_EXTERNAL uint8_t -__builtin_spirv_ClampStochasticRoundFP16ToE5M2INTEL(sycl::half, uint32_t, +__builtin_spirv_ClampStochasticRoundFP16ToE5M2INTEL(_Float16, uint32_t, uint32_t *) noexcept; extern __DPCPP_SYCL_EXTERNAL uint8_t -__builtin_spirv_StochasticRoundFP16ToE5M2INTEL(sycl::half, uint32_t, +__builtin_spirv_StochasticRoundFP16ToE5M2INTEL(_Float16, uint32_t, uint32_t *) noexcept; extern __DPCPP_SYCL_EXTERNAL uint8_t -__builtin_spirv_ClampStochasticRoundBF16ToE5M2INTEL(sycl::ext::oneapi::bfloat16, - uint32_t, +__builtin_spirv_ClampStochasticRoundBF16ToE5M2INTEL(__bf16, uint32_t, uint32_t *) noexcept; extern __DPCPP_SYCL_EXTERNAL uint8_t -__builtin_spirv_StochasticRoundBF16ToE5M2INTEL(sycl::ext::oneapi::bfloat16, - uint32_t, uint32_t *) noexcept; +__builtin_spirv_StochasticRoundBF16ToE5M2INTEL(__bf16, uint32_t, + uint32_t *) noexcept; #endif // __SYCL_DEVICE_ONLY__ namespace sycl { @@ -828,7 +824,12 @@ template class fp8_e4m3_x { template uint8_t ConvertToFP8(T h) { #ifdef __SYCL_DEVICE_ONLY__ - return __builtin_spirv_ClampConvertFP16ToE4M3INTEL(h); + _Float16 v{0}; + if constexpr (std::is_same_v, sycl::half>) + v = static_cast<_Float16>(static_cast(h)); + else + v = static_cast<_Float16>(h); + return __builtin_spirv_ClampConvertFP16ToE4M3INTEL(v); #else if constexpr (std::is_same_v, sycl::half> || std::is_same_v, float> || @@ -844,7 +845,8 @@ template class fp8_e4m3_x { uint8_t ConvertBF16ToFP8(bfloat16 h) { #ifdef __SYCL_DEVICE_ONLY__ - return __builtin_spirv_ClampConvertBF16ToE4M3INTEL(h); + return __builtin_spirv_ClampConvertBF16ToE4M3INTEL( + sycl::bit_cast<__bf16>(h)); #else return detail::ConvertFloatToFP8_CPU( h, rounding::to_even, saturation::finite); @@ -864,7 +866,7 @@ template class fp8_e4m3_x { bfloat16 ConvertBF16FromFP8(uint8_t v) const { #ifdef __SYCL_DEVICE_ONLY__ - return __builtin_spirv_ConvertE4M3ToBF16EXT(v); + return sycl::bit_cast(__builtin_spirv_ConvertE4M3ToBF16EXT(v)); #else return detail::ConvertFromFP8ToBinaryFloat_CPU(v, @@ -899,12 +901,12 @@ template class fp8_e4m3_x { const bfloat16 in[N] = {static_cast(v)...}; for (size_t i = 0; i < N; ++i) vals[i] = ConvertBF16ToFP8(in[i]); - return; + } else { + using InT = std::common_type_t...>; + const InT in[N] = {v...}; + for (size_t i = 0; i < N; ++i) + vals[i] = ConvertToFP8(in[i]); } - using InT = std::common_type_t...>; - const InT in[N] = {v...}; - for (size_t i = 0; i < N; ++i) - vals[i] = ConvertToFP8(in[i]); } // Construct from an array of half, bfloat16, float, double. @@ -1162,7 +1164,8 @@ template class fp8_e4m3_x { explicit operator bool() const { #ifdef __SYCL_DEVICE_ONLY__ // detect +0 / -0 - sycl::half h = __builtin_spirv_ConvertE4M3ToFP16EXT(vals[0]); + sycl::half h = + __builtin_spirv_ConvertE4M3ToFP16EXT(sycl::bit_cast(vals[0])); return h != 0; #else // no need to convert, just check sign bit and 0s @@ -1206,10 +1209,14 @@ template class fp8_e5m2_x { template uint8_t ConvertToFP8(T h, saturation s) { #ifdef __SYCL_DEVICE_ONLY__ - const sycl::half halfValue = static_cast(h); + _Float16 v{0}; + if constexpr (std::is_same_v, sycl::half>) + v = static_cast<_Float16>(static_cast(h)); + else + v = static_cast<_Float16>(h); return s == saturation::finite - ? __builtin_spirv_ClampConvertFP16ToE5M2INTEL(halfValue) - : __builtin_spirv_ConvertFP16ToE5M2EXT(halfValue); + ? __builtin_spirv_ClampConvertFP16ToE5M2INTEL(v) + : __builtin_spirv_ConvertFP16ToE5M2EXT(v); #else if constexpr (std::is_same_v, sycl::half> || std::is_same_v, float> || @@ -1226,8 +1233,10 @@ template class fp8_e5m2_x { uint8_t ConvertBF16ToFP8(bfloat16 h, saturation s) { #ifdef __SYCL_DEVICE_ONLY__ return s == saturation::finite - ? __builtin_spirv_ClampConvertBF16ToE5M2INTEL(h) - : __builtin_spirv_ConvertBF16ToE5M2EXT(h); + ? __builtin_spirv_ClampConvertBF16ToE5M2INTEL( + sycl::bit_cast<__bf16>(h)) + : __builtin_spirv_ConvertBF16ToE5M2EXT( + sycl::bit_cast<__bf16>(h)); #else return detail::ConvertFloatToFP8_CPU( h, rounding::to_even, s); @@ -1247,7 +1256,7 @@ template class fp8_e5m2_x { bfloat16 ConvertBF16FromFP8(uint8_t v) const { #ifdef __SYCL_DEVICE_ONLY__ - return __builtin_spirv_ConvertE5M2ToBF16EXT(v); + return sycl::bit_cast(__builtin_spirv_ConvertE5M2ToBF16EXT(v)); #else return detail::ConvertFromFP8ToBinaryFloat_CPU(v, @@ -1283,12 +1292,12 @@ template class fp8_e5m2_x { const bfloat16 in[N] = {static_cast(v)...}; for (size_t i = 0; i < N; ++i) vals[i] = ConvertBF16ToFP8(in[i], saturation::finite); - return; + } else { + using InT = std::common_type_t...>; + const InT in[N] = {v...}; + for (size_t i = 0; i < N; ++i) + vals[i] = ConvertToFP8(in[i], saturation::finite); } - using InT = std::common_type_t...>; - const InT in[N] = {v...}; - for (size_t i = 0; i < N; ++i) - vals[i] = ConvertToFP8(in[i], saturation::finite); } // Construct from an array of half, bfloat16, float, double. @@ -1361,12 +1370,13 @@ template class fp8_e5m2_x { #ifdef __SYCL_DEVICE_ONLY__ uint32_t current_seed = *seed.pseed; for (size_t i = 0; i < N; ++i) { + const _Float16 v = static_cast<_Float16>(static_cast(in[i])); if (s == saturation::finite) { vals[i] = __builtin_spirv_ClampStochasticRoundFP16ToE5M2INTEL( - in[i], current_seed, seed.pseed); + v, current_seed, seed.pseed); } else { vals[i] = __builtin_spirv_StochasticRoundFP16ToE5M2INTEL( - in[i], current_seed, seed.pseed); + v, current_seed, seed.pseed); } current_seed = *seed.pseed; } @@ -1381,10 +1391,10 @@ template class fp8_e5m2_x { for (size_t i = 0; i < N; ++i) { if (s == saturation::finite) { vals[i] = __builtin_spirv_ClampStochasticRoundBF16ToE5M2INTEL( - in[i], current_seed, seed.pseed); + sycl::bit_cast<__bf16>(in[i]), current_seed, seed.pseed); } else { vals[i] = __builtin_spirv_StochasticRoundBF16ToE5M2INTEL( - in[i], current_seed, seed.pseed); + sycl::bit_cast<__bf16>(in[i]), current_seed, seed.pseed); } current_seed = *seed.pseed; } @@ -1400,12 +1410,14 @@ template class fp8_e5m2_x { #ifdef __SYCL_DEVICE_ONLY__ uint32_t current_seed = *seed.pseed; for (size_t i = 0; i < N; ++i) { + + _Float16 v = static_cast<_Float16>(static_cast(in[i])); if (s == saturation::finite) { vals[i] = __builtin_spirv_ClampStochasticRoundFP16ToE5M2INTEL( - in[i], current_seed, seed.pseed); + v, current_seed, seed.pseed); } else { vals[i] = __builtin_spirv_StochasticRoundFP16ToE5M2INTEL( - in[i], current_seed, seed.pseed); + v, current_seed, seed.pseed); } current_seed = *seed.pseed; } @@ -1420,10 +1432,10 @@ template class fp8_e5m2_x { for (size_t i = 0; i < N; ++i) { if (s == saturation::finite) { vals[i] = __builtin_spirv_ClampStochasticRoundBF16ToE5M2INTEL( - in[i], current_seed, seed.pseed); + sycl::bit_cast<__bf16>(in[i]), current_seed, seed.pseed); } else { vals[i] = __builtin_spirv_StochasticRoundBF16ToE5M2INTEL( - in[i], current_seed, seed.pseed); + sycl::bit_cast<__bf16>(in[i]), current_seed, seed.pseed); } current_seed = *seed.pseed; } diff --git a/sycl/test-e2e/Experimental/fp8/e4m3_cri.cpp b/sycl/test-e2e/Experimental/fp8/e4m3_cri.cpp new file mode 100644 index 0000000000000..a7c750aa5019e --- /dev/null +++ b/sycl/test-e2e/Experimental/fp8/e4m3_cri.cpp @@ -0,0 +1,58 @@ +// RUN: %{build} -Xclang -freg-struct-return -Xspirv-translator --spirv-ext=+SPV_INTEL_fp_conversions,+SPV_EXT_float8,+SPV_KHR_bfloat16 -o %t.out +// RUN: %{run} SYCL_UR_TRACE=1 %t.out + +// Warning! This test requires CRI simulator run to communicate via TCP socket +// with port 60999, or any other from config + +#include +#include +#include +#include +#include + +using namespace sycl::ext::oneapi::experimental; + +template int run_basic_fp8_test(sycl::queue &queue) { + auto *data = sycl::malloc_shared(1, queue); + data[0] = fp8_e4m3(static_cast(1.5)); + + queue.single_task([=]() { + fp8_e4m3 value = data[0]; + T f = static_cast(value); + f += 1.0f; + data[0] = fp8_e4m3(f); + }); + queue.wait_and_throw(); + + fp8_e4m3 expected(2.5f); + T out = static_cast(data[0]); + T expected_out = static_cast(expected); + + sycl::free(data, queue); + if (std::fabs(out - expected_out) > 0.0f) + return 1; + + return 0; +} + +int main() { + auto async_handler = [](sycl::exception_list exceptions) { + for (const std::exception_ptr &e : exceptions) { + try { + std::rethrow_exception(e); + } catch (const sycl::exception &ex) { + std::cerr << "Async SYCL exception: " << ex.what() << '\n'; + std::terminate(); + } + } + }; + + sycl::queue queue{async_handler}; + + int ret = run_basic_fp8_test(queue); + ret |= run_basic_fp8_test(queue); + ret |= run_basic_fp8_test(queue); + ret |= run_basic_fp8_test(queue); + + return ret; +} diff --git a/sycl/test-e2e/Experimental/fp8/lit.local.cfg.py b/sycl/test-e2e/Experimental/fp8/lit.local.cfg.py new file mode 100644 index 0000000000000..605551f377933 --- /dev/null +++ b/sycl/test-e2e/Experimental/fp8/lit.local.cfg.py @@ -0,0 +1,10 @@ +config.environment["NEOReadDebugKeys"] = "1" +config.environment["ProductFamilyOverride"] = "cri" +config.environment["HardwareInfoOverride"] = "1x8x8" +config.environment["SetCommandStreamReceiver"] = "2" +config.environment["TbxPort"] = "60999" +config.environment["RebuildPrecompiledKernels"] = "1" +config.environment["EnableDirectSubmission"] = "0" +config.environment["EnableBlitterOperationsSupport"] = "1" +config.environment["BlitterEnableMaskOverride"] = "6" +config.environment["Enable64BitAddressing"] = "1" \ No newline at end of file diff --git a/sycl/unittests/Extensions/fp8/builtin_mocks.hpp b/sycl/unittests/Extensions/fp8/builtin_mocks.hpp index ac89f27cfe614..0d5f9cee1f7c7 100644 --- a/sycl/unittests/Extensions/fp8/builtin_mocks.hpp +++ b/sycl/unittests/Extensions/fp8/builtin_mocks.hpp @@ -43,76 +43,67 @@ inline void resetCounters() { getCounters() = Counters{}; } } // namespace fp8_builtin_mock // Builtin mocks (do not replace helpers.hpp; provide symbols here). -inline sycl::half __builtin_spirv_ConvertE4M3ToFP16EXT(uint8_t) noexcept { +inline _Float16 __builtin_spirv_ConvertE4M3ToFP16EXT(char) noexcept { ++fp8_builtin_mock::getCounters().ConvertE4M3ToFP16EXT; - return static_cast(2.0f); + return static_cast<_Float16>(2.0f); } -inline sycl::half __builtin_spirv_ConvertE5M2ToFP16EXT(uint8_t) noexcept { +inline _Float16 __builtin_spirv_ConvertE5M2ToFP16EXT(uint8_t) noexcept { ++fp8_builtin_mock::getCounters().ConvertE5M2ToFP16EXT; - return static_cast(3.0f); + return static_cast<_Float16>(3.0f); } -inline sycl::ext::oneapi::bfloat16 -__builtin_spirv_ConvertE4M3ToBF16EXT(uint8_t) noexcept { +inline __bf16 __builtin_spirv_ConvertE4M3ToBF16EXT(uint8_t) noexcept { ++fp8_builtin_mock::getCounters().ConvertE4M3ToBF16EXT; - return static_cast(4.0f); + return static_cast<__bf16>(4.0f); } -inline sycl::ext::oneapi::bfloat16 -__builtin_spirv_ConvertE5M2ToBF16EXT(uint8_t) noexcept { +inline __bf16 __builtin_spirv_ConvertE5M2ToBF16EXT(uint8_t) noexcept { ++fp8_builtin_mock::getCounters().ConvertE5M2ToBF16EXT; - return static_cast(5.0f); + return static_cast<__bf16>(5.0f); } -inline uint8_t __builtin_spirv_ConvertFP16ToE4M3EXT(sycl::half) noexcept { +inline uint8_t __builtin_spirv_ConvertFP16ToE4M3EXT(_Float16) noexcept { ++fp8_builtin_mock::getCounters().ConvertFP16ToE4M3EXT; return 0x01; } -inline uint8_t -__builtin_spirv_ConvertBF16ToE4M3EXT(sycl::ext::oneapi::bfloat16) noexcept { +inline uint8_t __builtin_spirv_ConvertBF16ToE4M3EXT(__bf16) noexcept { ++fp8_builtin_mock::getCounters().ConvertBF16ToE4M3EXT; return 0x02; } - -inline uint8_t -__builtin_spirv_ClampConvertFP16ToE4M3INTEL(sycl::half) noexcept { +inline uint8_t __builtin_spirv_ClampConvertFP16ToE4M3INTEL(_Float16) noexcept { ++fp8_builtin_mock::getCounters().ClampConvertFP16ToE4M3INTEL; return 0x11; } -inline uint8_t __builtin_spirv_ClampConvertBF16ToE4M3INTEL( - sycl::ext::oneapi::bfloat16) noexcept { +inline uint8_t __builtin_spirv_ClampConvertBF16ToE4M3INTEL(__bf16) noexcept { ++fp8_builtin_mock::getCounters().ClampConvertBF16ToE4M3INTEL; return 0x12; } -inline uint8_t __builtin_spirv_ConvertFP16ToE5M2EXT(sycl::half) noexcept { +inline uint8_t __builtin_spirv_ConvertFP16ToE5M2EXT(_Float16) noexcept { ++fp8_builtin_mock::getCounters().ConvertFP16ToE5M2EXT; return 0x03; } -inline uint8_t -__builtin_spirv_ClampConvertFP16ToE5M2INTEL(sycl::half) noexcept { +inline uint8_t __builtin_spirv_ClampConvertFP16ToE5M2INTEL(_Float16) noexcept { ++fp8_builtin_mock::getCounters().ClampConvertFP16ToE5M2INTEL; return 0x21; } -inline uint8_t -__builtin_spirv_ConvertBF16ToE5M2EXT(sycl::ext::oneapi::bfloat16) noexcept { +inline uint8_t __builtin_spirv_ConvertBF16ToE5M2EXT(__bf16) noexcept { ++fp8_builtin_mock::getCounters().ConvertBF16ToE5M2EXT; return 0x04; } -inline uint8_t __builtin_spirv_ClampConvertBF16ToE5M2INTEL( - sycl::ext::oneapi::bfloat16) noexcept { +inline uint8_t __builtin_spirv_ClampConvertBF16ToE5M2INTEL(__bf16) noexcept { ++fp8_builtin_mock::getCounters().ClampConvertBF16ToE5M2INTEL; return 0x22; } inline uint8_t -__builtin_spirv_StochasticRoundFP16ToE5M2INTEL(sycl::half, uint32_t Seed, +__builtin_spirv_StochasticRoundFP16ToE5M2INTEL(_Float16, uint32_t Seed, uint32_t *NextSeed) noexcept { ++fp8_builtin_mock::getCounters().StochasticRoundFP16ToE5M2INTEL; if (NextSeed) @@ -121,25 +112,25 @@ __builtin_spirv_StochasticRoundFP16ToE5M2INTEL(sycl::half, uint32_t Seed, } inline uint8_t -__builtin_spirv_StochasticRoundFP16ToE4M3INTEL(sycl::half) noexcept { +__builtin_spirv_StochasticRoundFP16ToE4M3INTEL(_Float16) noexcept { return 0x00; } -inline uint8_t __builtin_spirv_StochasticRoundBF16ToE5M2INTEL( - sycl::ext::oneapi::bfloat16, uint32_t Seed, uint32_t *NextSeed) noexcept { +inline uint8_t +__builtin_spirv_StochasticRoundBF16ToE5M2INTEL(__bf16, uint32_t Seed, + uint32_t *NextSeed) noexcept { ++fp8_builtin_mock::getCounters().StochasticRoundBF16ToE5M2INTEL; if (NextSeed) *NextSeed = Seed + 1; return 0x32; } -inline uint8_t __builtin_spirv_StochasticRoundBF16ToE4M3INTEL( - sycl::ext::oneapi::bfloat16) noexcept { +inline uint8_t __builtin_spirv_StochasticRoundBF16ToE4M3INTEL(__bf16) noexcept { return 0x00; } inline uint8_t __builtin_spirv_ClampStochasticRoundFP16ToE5M2INTEL( - sycl::half, uint32_t Seed, uint32_t *NextSeed) noexcept { + _Float16, uint32_t Seed, uint32_t *NextSeed) noexcept { ++fp8_builtin_mock::getCounters().ClampStochasticRoundFP16ToE5M2INTEL; if (NextSeed) *NextSeed = Seed + 1; @@ -147,19 +138,19 @@ inline uint8_t __builtin_spirv_ClampStochasticRoundFP16ToE5M2INTEL( } inline uint8_t -__builtin_spirv_ClampStochasticRoundFP16ToE4M3INTEL(sycl::half) noexcept { +__builtin_spirv_ClampStochasticRoundFP16ToE4M3INTEL(_Float16) noexcept { return 0x00; } inline uint8_t __builtin_spirv_ClampStochasticRoundBF16ToE5M2INTEL( - sycl::ext::oneapi::bfloat16, uint32_t Seed, uint32_t *NextSeed) noexcept { + __bf16, uint32_t Seed, uint32_t *NextSeed) noexcept { ++fp8_builtin_mock::getCounters().ClampStochasticRoundBF16ToE5M2INTEL; if (NextSeed) *NextSeed = Seed + 1; return 0x42; } -inline uint8_t __builtin_spirv_ClampStochasticRoundBF16ToE4M3INTEL( - sycl::ext::oneapi::bfloat16) noexcept { +inline uint8_t +__builtin_spirv_ClampStochasticRoundBF16ToE4M3INTEL(__bf16) noexcept { return 0x00; } From 2b7918eee716549c905b348b1b365834a4bd1f29 Mon Sep 17 00:00:00 2001 From: "Klochkov, Denis" Date: Tue, 12 May 2026 18:13:13 +0200 Subject: [PATCH 46/89] [SYCL] fix warnings of LLVm translator --- llvm/lib/SYCLPostLink/ModuleSplitter.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/llvm/lib/SYCLPostLink/ModuleSplitter.cpp b/llvm/lib/SYCLPostLink/ModuleSplitter.cpp index c95ff7f8235ff..26c3c0e9eb6be 100644 --- a/llvm/lib/SYCLPostLink/ModuleSplitter.cpp +++ b/llvm/lib/SYCLPostLink/ModuleSplitter.cpp @@ -93,7 +93,8 @@ bool isSpirvSyclBuiltin(StringRef FName) { // now skip the digits FName = FName.drop_while([](char C) { return std::isdigit(C); }); - return FName.starts_with("__spirv_") || FName.starts_with("__sycl_"); + return FName.starts_with("__spirv_") || FName.starts_with("__sycl_") || + FName.starts_with("__builtin_spirv_"); } // Return true if the function is a ESIMD builtin From 9caabe8ce5d60dba35dc7667b3ce844c69af9252 Mon Sep 17 00:00:00 2001 From: "Klochkov, Denis" Date: Tue, 12 May 2026 18:14:23 +0200 Subject: [PATCH 47/89] [SYCL] remove doubles and add more tests Co-authored-by: Copilot --- .../oneapi/experimental/float_8bit/types.hpp | 112 +----- sycl/test-e2e/Experimental/fp8/e4m3_cri.cpp | 58 ---- .../Experimental/fp8/e4m3_cri_conversion.cpp | 148 ++++++++ .../fp8/e4m3_x2_cri_conversion.cpp | 320 ++++++++++++++++++ sycl/unittests/Extensions/fp8/fp8_e4m3.cpp | 146 ++++++-- sycl/unittests/Extensions/fp8/fp8_e5m2.cpp | 52 --- sycl/unittests/Extensions/fp8/fp8_e8m0.cpp | 39 +-- 7 files changed, 597 insertions(+), 278 deletions(-) delete mode 100644 sycl/test-e2e/Experimental/fp8/e4m3_cri.cpp create mode 100644 sycl/test-e2e/Experimental/fp8/e4m3_cri_conversion.cpp create mode 100644 sycl/test-e2e/Experimental/fp8/e4m3_x2_cri_conversion.cpp diff --git a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp index c5f89a14d0623..7cd35a62a6ec6 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp @@ -149,13 +149,6 @@ template <> struct SourceTraits { static constexpr int Bias = 127; }; -template <> struct SourceTraits { - using UInt = uint64_t; - static constexpr size_t ExpBits = 11; - static constexpr size_t FracBits = 52; - static constexpr int Bias = 1023; -}; - template <> struct SourceTraits { using UInt = uint8_t; using UnsignedT = std::make_unsigned_t; @@ -832,8 +825,7 @@ template class fp8_e4m3_x { return __builtin_spirv_ClampConvertFP16ToE4M3INTEL(v); #else if constexpr (std::is_same_v, sycl::half> || - std::is_same_v, float> || - std::is_same_v, double>) { + std::is_same_v, float>) { return detail::ConvertFloatToFP8_CPU( h, rounding::to_even, saturation::finite); } else if constexpr (std::is_integral_v>) { @@ -886,7 +878,7 @@ template class fp8_e4m3_x { ~fp8_e4m3_x() = default; fp8_e4m3_x &operator=(const fp8_e4m3_x &) = default; - // Construct from pack of half, float, double. + // Construct from pack of half, float. // Available only when the size of the pack is equal to N. template class fp8_e4m3_x { (sizeof...(Types) == N) && (((std::is_same_v, half>) && ...) || ((std::is_same_v, bfloat16>) && ...) || - ((std::is_same_v, float>) && ...) || - ((std::is_same_v, double>) && ...))>> + ((std::is_same_v, float>) && ...))>> explicit fp8_e4m3_x(Types... v) { if constexpr (((std::is_same_v, bfloat16>) && ...)) { const bfloat16 in[N] = {static_cast(v)...}; @@ -909,7 +900,7 @@ template class fp8_e4m3_x { } } - // Construct from an array of half, bfloat16, float, double. + // Construct from an array of half, bfloat16, float. explicit fp8_e4m3_x(sycl::half const (&v)[N], rounding r = rounding::to_even) { CheckConstraints(r); @@ -929,12 +920,7 @@ template class fp8_e4m3_x { vals[i] = ConvertToFP8(v[i]); } - explicit fp8_e4m3_x(double const (&v)[N]) { - for (size_t i = 0; i < N; ++i) - vals[i] = ConvertToFP8(v[i]); - } - - // Construct from an marray of half, bfloat16, float, double. + // Construct from an marray of half, bfloat16, float. explicit fp8_e4m3_x(const sycl::marray &v, rounding r = rounding::to_even) { CheckConstraints(r); @@ -956,11 +942,6 @@ template class fp8_e4m3_x { vals[i] = ConvertToFP8(v[i]); } - explicit fp8_e4m3_x(const sycl::marray &v) { - for (size_t i = 0; i < N; ++i) - vals[i] = ConvertToFP8(v[i]); - } - // Construct from integer types. // Available only when N==1. @@ -1004,7 +985,7 @@ template class fp8_e4m3_x { vals[0] = ConvertToFP8(val); } - // Assign (operator) from half, bfloat16, float, double, and integer types. + // Assign (operator) from half, bfloat16, float, and integer types. // Available only when N==1. template > @@ -1025,12 +1006,6 @@ template class fp8_e4m3_x { return *this; } - template > - fp8_e4m3_x &operator=(double val) { - vals[0] = ConvertToFP8(val); - return *this; - } - template > fp8_e4m3_x &operator=(short val) { vals[0] = ConvertToFP8(val); @@ -1079,7 +1054,7 @@ template class fp8_e4m3_x { return *this; } - // Convert to half, bfloat16, float, double. + // Convert to half, bfloat16, float. // Available only when N==1. template > @@ -1095,10 +1070,6 @@ template class fp8_e4m3_x { explicit operator float() const { return ConvertFromFP8(vals[0]); } - template > - explicit operator double() const { - return ConvertFromFP8(vals[0]); - } // Convert to integer types. // Available only when N==1. @@ -1162,15 +1133,8 @@ template class fp8_e4m3_x { template > explicit operator bool() const { -#ifdef __SYCL_DEVICE_ONLY__ - // detect +0 / -0 - sycl::half h = - __builtin_spirv_ConvertE4M3ToFP16EXT(sycl::bit_cast(vals[0])); - return h != 0; -#else // no need to convert, just check sign bit and 0s return vals[0] != 0 && vals[0] != 0x80; -#endif } // Convert to marray of half, bfloat16, float @@ -1219,8 +1183,7 @@ template class fp8_e5m2_x { : __builtin_spirv_ConvertFP16ToE5M2EXT(v); #else if constexpr (std::is_same_v, sycl::half> || - std::is_same_v, float> || - std::is_same_v, double>) { + std::is_same_v, float>) { return detail::ConvertFloatToFP8_CPU( h, rounding::to_even, s); } else if constexpr (std::is_integral_v>) { @@ -1275,7 +1238,7 @@ template class fp8_e5m2_x { ~fp8_e5m2_x() = default; fp8_e5m2_x &operator=(const fp8_e5m2_x &) = default; - // Construct from pack of half, bfloat16, float, double. + // Construct from pack of half, bfloat16, float. // Available only when the size of the pack is equal to N. // Available only when each type in the pack is half. @@ -1285,8 +1248,7 @@ template class fp8_e5m2_x { (sizeof...(Types) == N) && (((std::is_same_v, half>) && ...) || ((std::is_same_v, bfloat16>) && ...) || - ((std::is_same_v, float>) && ...) || - ((std::is_same_v, double>) && ...))>> + ((std::is_same_v, float>) && ...))>> explicit fp8_e5m2_x(Types... v) { if constexpr (((std::is_same_v, bfloat16>) && ...)) { const bfloat16 in[N] = {static_cast(v)...}; @@ -1300,7 +1262,7 @@ template class fp8_e5m2_x { } } - // Construct from an array of half, bfloat16, float, double. + // Construct from an array of half, bfloat16, float. explicit fp8_e5m2_x(half const (&v)[N], rounding r = rounding::to_even, saturation s = saturation::finite) { @@ -1325,12 +1287,7 @@ template class fp8_e5m2_x { vals[i] = ConvertToFP8(v[i], s); } - explicit fp8_e5m2_x(double const (&v)[N]) { - for (size_t i = 0; i < N; ++i) - vals[i] = ConvertToFP8(v[i], saturation::finite); - } - - // Construct from an marray of half, bfloat16, float, double. + // Construct from an marray of half, bfloat16, float. explicit fp8_e5m2_x(const sycl::marray &v, rounding r = rounding::to_even, @@ -1356,11 +1313,6 @@ template class fp8_e5m2_x { vals[i] = ConvertToFP8(v[i], s); } - explicit fp8_e5m2_x(const sycl::marray &v) { - for (size_t i = 0; i < N; ++i) - vals[i] = ConvertToFP8(v[i], saturation::finite); - } - // Construct with stochastic rounding with user provided seed from an array of // half, bfloat16. @@ -1485,7 +1437,7 @@ template class fp8_e5m2_x { vals[0] = ConvertToFP8(val, saturation::finite); } - // Assign (operator) from half, bfloat16, float, double, and integer types. + // Assign (operator) from half, bfloat16, float, and integer types. // Available only when N==1. template > @@ -1506,11 +1458,6 @@ template class fp8_e5m2_x { return *this; } - template > - fp8_e5m2_x &operator=(double val) { - vals[0] = ConvertToFP8(val, saturation::finite); - return *this; - } template > fp8_e5m2_x &operator=(short val) { @@ -1560,7 +1507,7 @@ template class fp8_e5m2_x { return *this; } - // Convert to half, bfloat16, float, double. + // Convert to half, bfloat16, float. // Available only when N==1. template > @@ -1578,10 +1525,6 @@ template class fp8_e5m2_x { return ConvertFromFP8(vals[0]); } - template > - explicit operator double() const { - return ConvertFromFP8(vals[0]); - } // Convert to integer types. // Available only when N==1. @@ -1695,8 +1638,7 @@ template class fp8_e8m0_x { (sizeof...(Types) == N) && (((std::is_same_v, half>) && ...) || ((std::is_same_v, bfloat16>) && ...) || - ((std::is_same_v, float>) && ...) || - ((std::is_same_v, double>) && ...))>> + ((std::is_same_v, float>) && ...))>> explicit fp8_e8m0_x(Types... v) { using InT = std::common_type_t...>; const InT in[N] = {v...}; @@ -1723,12 +1665,6 @@ template class fp8_e8m0_x { vals[i] = detail::ConvertFloatToE8M0_CPU(in[i], r, saturation::finite); } - explicit fp8_e8m0_x(double const (&in)[N]) { - for (size_t i = 0; i < N; ++i) - vals[i] = detail::ConvertFloatToE8M0_CPU(in[i], rounding::upward, - saturation::finite); - } - explicit fp8_e8m0_x(const marray &in, rounding r = rounding::upward) { CheckConstraints(r); @@ -1750,12 +1686,6 @@ template class fp8_e8m0_x { vals[i] = detail::ConvertFloatToE8M0_CPU(in[i], r, saturation::finite); } - explicit fp8_e8m0_x(const marray &in) { - for (size_t i = 0; i < N; ++i) - vals[i] = detail::ConvertFloatToE8M0_CPU(in[i], rounding::upward, - saturation::finite); - } - // Construct from integer types. // Available only when N==1. @@ -1822,13 +1752,6 @@ template class fp8_e8m0_x { return *this; } - template > - fp8_e8m0_x &operator=(double val) { - vals[0] = detail::ConvertFloatToE8M0_CPU(val, rounding::upward, - saturation::finite); - return *this; - } - template > fp8_e8m0_x &operator=(short val) { vals[0] = @@ -1886,14 +1809,11 @@ template class fp8_e8m0_x { explicit operator bfloat16() const { return detail::ConvertFromE8M0_CPU(vals[0], rounding::to_even); } + template > explicit operator float() const { return detail::ConvertFromE8M0_CPU(vals[0], rounding::to_even); } - template > - explicit operator double() const { - return detail::ConvertFromE8M0_CPU(vals[0], rounding::to_even); - } template > explicit operator char() const { diff --git a/sycl/test-e2e/Experimental/fp8/e4m3_cri.cpp b/sycl/test-e2e/Experimental/fp8/e4m3_cri.cpp deleted file mode 100644 index a7c750aa5019e..0000000000000 --- a/sycl/test-e2e/Experimental/fp8/e4m3_cri.cpp +++ /dev/null @@ -1,58 +0,0 @@ -// RUN: %{build} -Xclang -freg-struct-return -Xspirv-translator --spirv-ext=+SPV_INTEL_fp_conversions,+SPV_EXT_float8,+SPV_KHR_bfloat16 -o %t.out -// RUN: %{run} SYCL_UR_TRACE=1 %t.out - -// Warning! This test requires CRI simulator run to communicate via TCP socket -// with port 60999, or any other from config - -#include -#include -#include -#include -#include - -using namespace sycl::ext::oneapi::experimental; - -template int run_basic_fp8_test(sycl::queue &queue) { - auto *data = sycl::malloc_shared(1, queue); - data[0] = fp8_e4m3(static_cast(1.5)); - - queue.single_task([=]() { - fp8_e4m3 value = data[0]; - T f = static_cast(value); - f += 1.0f; - data[0] = fp8_e4m3(f); - }); - queue.wait_and_throw(); - - fp8_e4m3 expected(2.5f); - T out = static_cast(data[0]); - T expected_out = static_cast(expected); - - sycl::free(data, queue); - if (std::fabs(out - expected_out) > 0.0f) - return 1; - - return 0; -} - -int main() { - auto async_handler = [](sycl::exception_list exceptions) { - for (const std::exception_ptr &e : exceptions) { - try { - std::rethrow_exception(e); - } catch (const sycl::exception &ex) { - std::cerr << "Async SYCL exception: " << ex.what() << '\n'; - std::terminate(); - } - } - }; - - sycl::queue queue{async_handler}; - - int ret = run_basic_fp8_test(queue); - ret |= run_basic_fp8_test(queue); - ret |= run_basic_fp8_test(queue); - ret |= run_basic_fp8_test(queue); - - return ret; -} diff --git a/sycl/test-e2e/Experimental/fp8/e4m3_cri_conversion.cpp b/sycl/test-e2e/Experimental/fp8/e4m3_cri_conversion.cpp new file mode 100644 index 0000000000000..ac7594d42ab8d --- /dev/null +++ b/sycl/test-e2e/Experimental/fp8/e4m3_cri_conversion.cpp @@ -0,0 +1,148 @@ + +// RUN: %{build} -Xclang -freg-struct-return -Xspirv-translator --spirv-ext=+SPV_INTEL_fp_conversions,+SPV_EXT_float8,+SPV_KHR_bfloat16 -o %t.out +// RUN: %{run} SYCL_UR_TRACE=1 %t.out + +// Warning! This test requires CRI device or its simulator run to communicate +// via TCP socket with port 60999, or any other from config + +// TODO need to set requirement of intel_feature_gpu_cri + +#include +#include +#include +#include + +using namespace sycl::ext::oneapi::experimental; + +template int test_fp8_simple_type_conversion(sycl::queue &queue) { + auto *data = sycl::malloc_shared(1, queue); + data[0] = fp8_e4m3(static_cast(1.5)); + + queue.single_task([=]() { + fp8_e4m3 value = data[0]; + T f = static_cast(value); + f += static_cast(1.0f); + data[0] = fp8_e4m3(f); + }); + queue.wait_and_throw(); + + fp8_e4m3 expected(2.5f); + T out = static_cast(data[0]); + T expected_out = static_cast(expected); + + sycl::free(data, queue); + if (std::fabs(out - expected_out) > 0.0f) + return 1; + + return 0; +} + +int test_boolean_conversion(sycl::queue &queue, float test_value, + bool expected) { + auto *data = sycl::malloc_shared(1, queue); + auto *res = sycl::malloc_shared(1, queue); + data[0] = fp8_e4m3(test_value); // we do not care if float or any other type; + queue.single_task([=]() { + fp8_e4m3 value = data[0]; + res[0] = static_cast(value); + }); + queue.wait_and_throw(); + int ret = res[0] == expected ? 0 : 1; + sycl::free(data, queue); + sycl::free(res, queue); + return ret; +} + +template int test_marray_conversion(sycl::queue &queue) { + sycl::marray input(static_cast(1.25f)); + auto *data = sycl::malloc_shared(1, queue); + data[0] = fp8_e4m3(input); + + queue.single_task([=]() { + fp8_e4m3 value = data[0]; + sycl::marray f = static_cast>(value); + f[0] += static_cast(1.0f); + data[0] = fp8_e4m3(f); + }); + queue.wait_and_throw(); + /* + sycl::marray expected_input(static_cast(2.25f)); + fp8_e4m3 expected(expected_input); + sycl::marray out = static_cast>(data[0]); + sycl::marray expected_out = static_cast>(expected); + + sycl::free(data, queue); + if (std::fabs(out[0] - expected_out[0]) > 0.0f) + return 1; + */ + return 0; +} + +template int test_carray_conversion(sycl::queue &queue) { + auto *data = sycl::malloc_shared(1, queue); + data[0] = fp8_e4m3(static_cast(1.25f)); + + queue.single_task([=]() { + fp8_e4m3 value = data[0]; + T f = {static_cast(value)}; + f += static_cast(1.0f); + data[0] = fp8_e4m3(f); + }); + queue.wait_and_throw(); + + fp8_e4m3 expected(static_cast(2.25f)); + T out = {static_cast(data[0])}; + T expected_out = {static_cast(expected)}; + + sycl::free(data, queue); + if (std::fabs(out - expected_out) > 0.0f) + return 1; + return 0; +} + +int main() { + auto async_handler = [](sycl::exception_list exceptions) { + for (const std::exception_ptr &e : exceptions) { + try { + std::rethrow_exception(e); + } catch (const sycl::exception &ex) { + std::cerr << "Async SYCL exception: " << ex.what() << '\n'; + std::terminate(); + } + } + }; + + sycl::queue queue{async_handler}; + + int ret = test_fp8_simple_type_conversion(queue); + ret |= test_fp8_simple_type_conversion(queue); + ret |= test_fp8_simple_type_conversion(queue); + ret |= test_fp8_simple_type_conversion(queue); + ret |= test_fp8_simple_type_conversion(queue); + ret |= test_fp8_simple_type_conversion(queue); + ret |= test_fp8_simple_type_conversion(queue); + ret |= test_fp8_simple_type_conversion(queue); + ret |= test_fp8_simple_type_conversion(queue); + ret |= test_fp8_simple_type_conversion(queue); + ret |= test_fp8_simple_type_conversion(queue); + ret |= test_fp8_simple_type_conversion(queue); + ret |= test_fp8_simple_type_conversion(queue); + ret |= test_fp8_simple_type_conversion(queue); + // check special requirement for boolean conversion - only +0.0 and -0.0 + // should be converted to false, all other values should be converted to true + ret |= test_boolean_conversion(queue, 0.0f, false); + ret |= test_boolean_conversion(queue, -0.0f, false); + ret |= test_boolean_conversion(queue, 1.0f, true); + ret |= test_boolean_conversion(queue, -1.0f, true); + + ret |= test_marray_conversion(queue); + ret |= test_marray_conversion(queue); + // TODO: uncomment when bfloat16 conversion is fixed + //ret |= test_marray_conversion(queue); + + ret |= test_carray_conversion(queue); + ret |= test_carray_conversion(queue); + // TODO: uncomment when bfloat16 conversion is fixed + //ret |= test_carray_conversion(queue); + return ret; +} diff --git a/sycl/test-e2e/Experimental/fp8/e4m3_x2_cri_conversion.cpp b/sycl/test-e2e/Experimental/fp8/e4m3_x2_cri_conversion.cpp new file mode 100644 index 0000000000000..36f991296986a --- /dev/null +++ b/sycl/test-e2e/Experimental/fp8/e4m3_x2_cri_conversion.cpp @@ -0,0 +1,320 @@ + +// RUN: %{build} -Xclang -freg-struct-return -Xspirv-translator --spirv-ext=+SPV_INTEL_fp_conversions,+SPV_EXT_float8,+SPV_KHR_bfloat16 -o %t.out +// RUN: %{run} SYCL_UR_TRACE=1 %t.out + +// Warning! This test requires CRI device or its simulator run to communicate +// via TCP socket with port 60999, or any other from config + +// TODO need to set requirement of intel_feature_gpu_cri + +#include +#include +#include +#include +#include + +using namespace sycl::ext::oneapi::experimental; + +namespace { + +bool equal_or_both_nan(float actual, float expected) { + if (std::isnan(expected)) + return std::isnan(actual); + return actual == expected; +} + +bool equal_with_zero_sign(float actual, float expected) { + if (!equal_or_both_nan(actual, expected)) + return false; + if (expected == 0.0f) + return std::signbit(actual) == std::signbit(expected); + return true; +} + +template +int test_explicit_to_even_carray_constructor(sycl::queue &queue) { + T input[2] = {static_cast(0.01171875f), static_cast(-5.5f)}; + auto *data = sycl::malloc_shared(1, queue); + auto *out = sycl::malloc_shared(2, queue); + data[0] = fp8_e4m3_x2(input, rounding::to_even); + + queue.single_task([=]() { + sycl::marray unpacked = static_cast>(data[0]); + out[0] = unpacked[0]; + out[1] = unpacked[1]; + }); + queue.wait_and_throw(); + + int ret = 0; + if (static_cast(out[0]) != 0.01171875f) + ret = 1; + if (static_cast(out[1]) != -5.5f) + ret = 1; + + sycl::free(data, queue); + sycl::free(out, queue); + return ret; +} + +template +int test_explicit_to_even_marray_constructor(sycl::queue &queue) { + sycl::marray input(static_cast(3.25f), + static_cast(-0.009765625f)); + auto *data = sycl::malloc_shared(1, queue); + auto *out = sycl::malloc_shared(2, queue); + data[0] = fp8_e4m3_x2(input, rounding::to_even); + + queue.single_task([=]() { + sycl::marray unpacked = static_cast>(data[0]); + out[0] = unpacked[0]; + out[1] = unpacked[1]; + }); + queue.wait_and_throw(); + + int ret = 0; + if (static_cast(out[0]) != 3.25f) + ret = 1; + if (static_cast(out[1]) != -0.009765625f) + ret = 1; + + sycl::free(data, queue); + sycl::free(out, queue); + return ret; +} + +int test_boundary_round_trip_nan(sycl::queue &queue) { + const float input[2] = {std::numeric_limits::quiet_NaN(), + -std::numeric_limits::quiet_NaN()}; + auto *data = sycl::malloc_shared(1, queue); + auto *out = sycl::malloc_shared(2, queue); + data[0] = fp8_e4m3_x2(input, rounding::to_even); + + queue.single_task([=]() { + fp8_e4m3_x2 value = data[0]; + sycl::marray unpacked = static_cast>(value); + data[0] = fp8_e4m3_x2(unpacked, rounding::to_even); + sycl::marray round_tripped = + static_cast>(data[0]); + out[0] = round_tripped[0]; + out[1] = round_tripped[1]; + }); + queue.wait_and_throw(); + + int ret = !(std::isnan(out[0]) && std::isnan(out[1])); + sycl::free(data, queue); + sycl::free(out, queue); + return ret; +} + +int test_boundary_round_trip_negative_zero(sycl::queue &queue) { + const float input[2] = {-0.0f, 7.0f}; + auto *data = sycl::malloc_shared(1, queue); + auto *out = sycl::malloc_shared(2, queue); + data[0] = fp8_e4m3_x2(input, rounding::to_even); + + queue.single_task([=]() { + fp8_e4m3_x2 value = data[0]; + sycl::marray unpacked = static_cast>(value); + data[0] = fp8_e4m3_x2(unpacked, rounding::to_even); + sycl::marray round_tripped = + static_cast>(data[0]); + out[0] = round_tripped[0]; + out[1] = round_tripped[1]; + }); + queue.wait_and_throw(); + + int ret = 0; + if (!equal_with_zero_sign(out[0], -0.0f)) + ret = 1; + if (out[1] != 7.0f) + ret = 1; + + sycl::free(data, queue); + sycl::free(out, queue); + return ret; +} + +int test_boundary_round_trip_subnormals(sycl::queue &queue) { + const float input[2] = {0.01171875f, -0.009765625f}; + auto *data = sycl::malloc_shared(1, queue); + auto *out = sycl::malloc_shared(2, queue); + data[0] = fp8_e4m3_x2(input, rounding::to_even); + + queue.single_task([=]() { + fp8_e4m3_x2 value = data[0]; + sycl::marray unpacked = static_cast>(value); + data[0] = fp8_e4m3_x2(unpacked, rounding::to_even); + sycl::marray round_tripped = + static_cast>(data[0]); + out[0] = round_tripped[0]; + out[1] = round_tripped[1]; + }); + queue.wait_and_throw(); + + int ret = 0; + if (out[0] != 0.01171875f) + ret = 1; + if (out[1] != -0.009765625f) + ret = 1; + + sycl::free(data, queue); + sycl::free(out, queue); + return ret; +} + +int test_boundary_round_trip_saturation_and_infinity_clamp(sycl::queue &queue) { + const float input[2] = {600.0f, -std::numeric_limits::infinity()}; + auto *data = sycl::malloc_shared(1, queue); + auto *out = sycl::malloc_shared(2, queue); + data[0] = fp8_e4m3_x2(input, rounding::to_even); + + queue.single_task([=]() { + fp8_e4m3_x2 value = data[0]; + sycl::marray unpacked = static_cast>(value); + data[0] = fp8_e4m3_x2(unpacked, rounding::to_even); + sycl::marray round_tripped = + static_cast>(data[0]); + out[0] = round_tripped[0]; + out[1] = round_tripped[1]; + }); + queue.wait_and_throw(); + + int ret = 0; + if (out[0] != 448.0f) + ret = 1; + if (out[1] != -448.0f) + ret = 1; + + sycl::free(data, queue); + sycl::free(out, queue); + return ret; +} + +} // namespace + +template int test_fp8_simple_type_conversion(sycl::queue &queue) { + auto *data = sycl::malloc_shared(1, queue); + data[0] = fp8_e4m3_x2(static_cast(1.5f), static_cast(2.5f)); + + queue.single_task([=]() { + fp8_e4m3_x2 value = data[0]; + sycl::marray f = static_cast>(value); + f[0] += static_cast(1.0f); + f[1] += static_cast(1.0f); + data[0] = fp8_e4m3_x2(f); + }); + queue.wait_and_throw(); + + sycl::marray expected_input(static_cast(2.5f), static_cast(3.5f)); + fp8_e4m3_x2 expected(expected_input); + sycl::marray out = static_cast>(data[0]); + sycl::marray expected_out = static_cast>(expected); + + sycl::free(data, queue); + for (size_t i = 0; i < 2; ++i) { + if (std::fabs(static_cast(out[i]) - static_cast(expected_out[i])) > + 0.0f) + return 1; + } + + return 0; +} + +template int test_marray_conversion(sycl::queue &queue) { + sycl::marray input(static_cast(1.25f), static_cast(2.5f)); + auto *data = sycl::malloc_shared(1, queue); + data[0] = fp8_e4m3_x2(input); + + queue.single_task([=]() { + fp8_e4m3_x2 value = data[0]; + sycl::marray f = static_cast>(value); + f[0] += static_cast(1.0f); + f[1] += static_cast(2.0f); + data[0] = fp8_e4m3_x2(f); + }); + queue.wait_and_throw(); + sycl::marray expected_input(static_cast(2.25f), static_cast(4.5f)); + fp8_e4m3_x2 expected(expected_input); + sycl::marray out = static_cast>(data[0]); + sycl::marray expected_out = static_cast>(expected); + + sycl::free(data, queue); + for (size_t i = 0; i < 2; ++i) { + if (std::fabs(static_cast(out[i]) - static_cast(expected_out[i])) > + 0.0f) + return 1; + } + return 0; +} + +template int test_carray_conversion(sycl::queue &queue) { + T input[2] = {static_cast(1.25f), static_cast(2.5f)}; + auto *data = sycl::malloc_shared(1, queue); + data[0] = fp8_e4m3_x2(input); + + queue.single_task([=]() { + fp8_e4m3_x2 value = data[0]; + sycl::marray unpacked = static_cast>(value); + T output[2] = {unpacked[0] + static_cast(1.0f), + unpacked[1] + static_cast(4.0f)}; + data[0] = fp8_e4m3_x2(output); + }); + queue.wait_and_throw(); + + T expected_input[2] = {static_cast(2.25f), static_cast(6.5f)}; + fp8_e4m3_x2 expected(expected_input); + sycl::marray out = static_cast>(data[0]); + sycl::marray expected_out = static_cast>(expected); + + sycl::free(data, queue); + for (size_t i = 0; i < 2; ++i) { + if (std::fabs(static_cast(out[i]) - + static_cast(expected_out[i])) > 0.0f) + return 1; + } + + return 0; +} + +int main() { + auto async_handler = [](sycl::exception_list exceptions) { + for (const std::exception_ptr &e : exceptions) { + try { + std::rethrow_exception(e); + } catch (const sycl::exception &ex) { + std::cerr << "Async SYCL exception: " << ex.what() << '\n'; + std::terminate(); + } + } + }; + + sycl::queue queue{async_handler}; + + // fp8_e4m3_x2 only supports packed conversions through marray, + // marray, and marray. + int ret = test_fp8_simple_type_conversion(queue); + ret |= test_fp8_simple_type_conversion(queue); + ret |= test_fp8_simple_type_conversion(queue); + + ret |= test_marray_conversion(queue); + ret |= test_marray_conversion(queue); +// ret |= test_marray_conversion(queue); + + ret |= test_carray_conversion(queue); + ret |= test_carray_conversion(queue); + // ret |= test_carray_conversion(queue); + + ret |= test_explicit_to_even_carray_constructor(queue); + ret |= test_explicit_to_even_carray_constructor(queue); + // ret |= test_explicit_to_even_carray_constructor(queue); + + ret |= test_explicit_to_even_marray_constructor(queue); + ret |= test_explicit_to_even_marray_constructor(queue); + // ret |= test_explicit_to_even_marray_constructor(queue); + + ret |= test_boundary_round_trip_nan(queue); + ret |= test_boundary_round_trip_negative_zero(queue); + ret |= test_boundary_round_trip_subnormals(queue); + ret |= test_boundary_round_trip_saturation_and_infinity_clamp(queue); + return ret; +} diff --git a/sycl/unittests/Extensions/fp8/fp8_e4m3.cpp b/sycl/unittests/Extensions/fp8/fp8_e4m3.cpp index cea92e9b9076d..b8127a712e670 100644 --- a/sycl/unittests/Extensions/fp8/fp8_e4m3.cpp +++ b/sycl/unittests/Extensions/fp8/fp8_e4m3.cpp @@ -85,6 +85,32 @@ TEST(FP8E4M3Test, VariadicNaNEncodingFloat) { EXPECT_EQ(a.vals[1], 0xFF); // -NaN -> 0b1_1111_111 } +TEST(FP8E4M3Test, ScalarInfinityClampsToMaxNormalPreservingSign) { + // Spec: non-stochastic conversion with finite saturation clamps Infinity to + // max normal while preserving sign. + fp8_e4m3 pos(std::numeric_limits::infinity()); + fp8_e4m3 neg(-std::numeric_limits::infinity()); + + EXPECT_EQ(pos.vals[0], 0x7E); // +448.0 -> 0b0_1111_110 + EXPECT_EQ(neg.vals[0], 0xFE); // -448.0 -> 0b1_1111_110 + + EXPECT_EQ(static_cast(pos), 448.0f); + EXPECT_EQ(static_cast(neg), -448.0f); +} + +TEST(FP8E4M3Test, X2InfinityClampsToMaxNormalPreservingSign) { + const float in[2] = {std::numeric_limits::infinity(), + -std::numeric_limits::infinity()}; + fp8_e4m3_x2 value(in); + + EXPECT_EQ(value.vals[0], 0x7E); // +448.0 -> 0b0_1111_110 + EXPECT_EQ(value.vals[1], 0xFE); // -448.0 -> 0b1_1111_110 + + sycl::marray out = static_cast>(value); + EXPECT_EQ(out[0], 448.0f); + EXPECT_EQ(out[1], -448.0f); +} + TEST(FP8E4M3Test, IntegerToEvenFiniteAndSize) { // Integer constructors: to_even + finite saturation (CPU). fp8_e4m3 a0(0); @@ -119,6 +145,43 @@ TEST(FP8E4M3Test, AssignmentOperatorToEvenFiniteAndSize) { EXPECT_EQ(a.vals[0], 0x08); } +TEST(FP8E4M3Test, AssignmentOperatorsAllScalarWidths) { + fp8_e4m3 value(11.0f); + + EXPECT_EQ(&(value = sycl::half(11.0f)), &value); + EXPECT_EQ(static_cast(value), 11.0f); + + EXPECT_EQ(&(value = sycl::ext::oneapi::bfloat16(-13.0f)), &value); + EXPECT_EQ(static_cast(value), -13.0f); + + EXPECT_EQ(&(value = 14.0f), &value); + EXPECT_EQ(static_cast(value), 14.0f); + + EXPECT_EQ(&(value = static_cast(9)), &value); + EXPECT_EQ(static_cast(value), 9.0f); + + EXPECT_EQ(&(value = -10), &value); + EXPECT_EQ(static_cast(value), -10.0f); + + EXPECT_EQ(&(value = 22L), &value); + EXPECT_EQ(static_cast(value), 22.0f); + + EXPECT_EQ(&(value = -26LL), &value); + EXPECT_EQ(static_cast(value), -26.0f); + + EXPECT_EQ(&(value = static_cast(30)), &value); + EXPECT_EQ(static_cast(value), 30.0f); + + EXPECT_EQ(&(value = 44U), &value); + EXPECT_EQ(static_cast(value), 44.0f); + + EXPECT_EQ(&(value = 52UL), &value); + EXPECT_EQ(static_cast(value), 52.0f); + + EXPECT_EQ(&(value = 88ULL), &value); + EXPECT_EQ(static_cast(value), 88.0f); +} + TEST(FP8E4M3Test, FloatingPointConversionOperators) { // Floating-point operators: convert stored fp8 to the respective type. fp8_e4m3 one(1.0f); @@ -199,25 +262,6 @@ TEST(FP8E4M3Test, CArrayFloatHostToEvenFinite) { EXPECT_EQ(a2.vals[1], 0x00); // 0 } -TEST(FP8E4M3Test, CArrayDoubleToEvenFinite) { - // Double c-array: to_even + finite saturation. - const double in[2] = {448.0, 449.0}; - const double in1[2] = {0.015625, 0.013671875}; - const double in2[2] = {0.001953125, std::numeric_limits::quiet_NaN()}; - fp8_e4m3_x2 a(in); - fp8_e4m3_x2 a1(in1); - fp8_e4m3_x2 a2(in2); - - EXPECT_EQ(sizeof(a.vals), 2u); - EXPECT_EQ(sizeof(a1.vals), 2u); - EXPECT_EQ(sizeof(a2.vals), 2u); - EXPECT_EQ(a.vals[0], 0x7E); // +448 - EXPECT_EQ(a.vals[1], 0x7E); // 449 -> clamp to +448 - EXPECT_EQ(a1.vals[0], 0x08); // min normal - EXPECT_EQ(a1.vals[1], 0x07); // max subnormal - EXPECT_EQ(a2.vals[0], 0x01); // min subnormal - EXPECT_EQ(a2.vals[1], 0x7F); // NaN -} TEST(FP8E4M3Test, CArrayHalfHostToEvenFinite) { // Host code supports only rounding::to_even and saturation::finite. @@ -317,11 +361,9 @@ TEST(FP8E4M3Test, FloatingPointConversionOperatorsMoreTypes) { EXPECT_EQ(sizeof(b.vals), 1u); EXPECT_EQ(sizeof(nanv.vals), 1u); - double da = static_cast(a); sycl::half ha = static_cast(a); sycl::ext::oneapi::bfloat16 ba = static_cast(a); - EXPECT_EQ(da, 1.0); EXPECT_EQ(static_cast(ha), 1.0f); EXPECT_EQ(static_cast(ba), 1.0f); @@ -331,6 +373,29 @@ TEST(FP8E4M3Test, FloatingPointConversionOperatorsMoreTypes) { EXPECT_TRUE(std::isnan(fn)); } +TEST(FP8E4M3Test, MarrayConversionOperatorsHalfNumericValues) { + fp8_e4m3_x2 a(6.5f, -0.3125f); + + EXPECT_EQ(sizeof(a.vals), 2u); + + sycl::marray out = static_cast>(a); + + EXPECT_EQ(static_cast(out[0]), 6.5f); + EXPECT_EQ(static_cast(out[1]), -0.3125f); +} + +TEST(FP8E4M3Test, MarrayConversionOperatorsBFloat16NumericValues) { + fp8_e4m3_x2 a(-12.0f, 0.1875f); + + EXPECT_EQ(sizeof(a.vals), 2u); + + sycl::marray out = + static_cast>(a); + + EXPECT_EQ(static_cast(out[0]), -12.0f); + EXPECT_EQ(static_cast(out[1]), 0.1875f); +} + TEST(FP8E4M3Test, IntegerConversionOperatorsMultipleWidthsTowardZero) { fp8_e4m3 p(1.5f); fp8_e4m3 n(-1.5f); @@ -346,6 +411,32 @@ TEST(FP8E4M3Test, IntegerConversionOperatorsMultipleWidthsTowardZero) { EXPECT_EQ(ll, -1); } +TEST(FP8E4M3Test, IntegerConversionOperatorsRemainingWidthsTowardZero) { + fp8_e4m3 pos_char(13.0f); + fp8_e4m3 neg_schar(-11.0f); + fp8_e4m3 pos_uchar(14.0f); + fp8_e4m3 pos_ushort(22.0f); + fp8_e4m3 pos_uint(30.0f); + fp8_e4m3 pos_ulong(44.0f); + fp8_e4m3 pos_ull(88.0f); + + char c = static_cast(pos_char); + signed char sc = static_cast(neg_schar); + unsigned char uc = static_cast(pos_uchar); + unsigned short us = static_cast(pos_ushort); + unsigned int ui = static_cast(pos_uint); + unsigned long ul = static_cast(pos_ulong); + unsigned long long ull = static_cast(pos_ull); + + EXPECT_EQ(c, static_cast(13)); + EXPECT_EQ(sc, static_cast(-11)); + EXPECT_EQ(uc, static_cast(14)); + EXPECT_EQ(us, static_cast(22)); + EXPECT_EQ(ui, 30u); + EXPECT_EQ(ul, 44ul); + EXPECT_EQ(ull, 88ull); +} + TEST(FP8E4M3Test, CArrayFloatRoundingToEven) { const float in[2] = {0.012f, 1000.0f}; fp8_e4m3_x2 a(in, rounding::to_even); @@ -399,13 +490,6 @@ TEST(FP8E4M3Test, MarrayFloatRoundingToEven) { EXPECT_EQ(a.vals[1], 0x38); } -TEST(FP8E4M3Test, MarrayDoubleToEven) { - const sycl::marray in = {0.012, 1.0625}; - fp8_e4m3_x2 a(in); - - EXPECT_EQ(a.vals[0], 0x06); - EXPECT_EQ(a.vals[1], 0x38); -} TEST(FP8E4M3Test, VariadicRejectsMixedTypes) { EXPECT_FALSE((std::is_constructible_v)); @@ -448,9 +532,6 @@ TEST(FP8E4M3Test, X2NotConstructibleFromSingleFloat) { EXPECT_FALSE((std::is_constructible_v)); } -TEST(FP8E4M3Test, X2NotConstructibleFromSingleDouble) { - EXPECT_FALSE((std::is_constructible_v)); -} TEST(FP8E4M3Test, X2NotConstructibleFromSingleBFloat16) { EXPECT_FALSE( @@ -482,9 +563,6 @@ TEST(FP8E4M3Test, X2NotAssignableFromSingleFloat) { EXPECT_FALSE((std::is_assignable_v)); } -TEST(FP8E4M3Test, X2NotAssignableFromSingleDouble) { - EXPECT_FALSE((std::is_assignable_v)); -} TEST(FP8E4M3Test, X2NotAssignableFromSingleChar) { EXPECT_FALSE((std::is_assignable_v)); diff --git a/sycl/unittests/Extensions/fp8/fp8_e5m2.cpp b/sycl/unittests/Extensions/fp8/fp8_e5m2.cpp index 6f67cae20189c..eefc07fa807d4 100644 --- a/sycl/unittests/Extensions/fp8/fp8_e5m2.cpp +++ b/sycl/unittests/Extensions/fp8/fp8_e5m2.cpp @@ -249,27 +249,6 @@ TEST(FP8E5M2Test, CArrayFloatHostToEvenFinite) { EXPECT_EQ(a1.vals[1], 0x7B); // finite saturation => +57344 } -TEST(FP8E5M2Test, CArrayDoubleToEvenFinite) { - // Double c-array: to_even + finite saturation. - const double in[2] = {57344.0, 60000.0}; - const double in1[2] = {0.00006103515625, 0.0000457763671875}; - const double in2[2] = {0.0000152587890625, - std::numeric_limits::quiet_NaN()}; - fp8_e5m2_x2 a(in); - fp8_e5m2_x2 a1(in1); - fp8_e5m2_x2 a2(in2); - - EXPECT_EQ(sizeof(a.vals), 2u); - EXPECT_EQ(sizeof(a1.vals), 2u); - EXPECT_EQ(sizeof(a2.vals), 2u); - EXPECT_EQ(a.vals[0], 0x7B); // +57344 - EXPECT_EQ(a.vals[1], 0x7B); // 60000 -> clamp to +57344 - EXPECT_EQ(a1.vals[0], 0x04); // min normal - EXPECT_EQ(a1.vals[1], 0x03); // max subnormal - EXPECT_EQ(a2.vals[0], 0x01); // min subnormal - EXPECT_EQ(a2.vals[1], 0x7F); // NaN -} - TEST(FP8E5M2Test, CArrayHalfHostToEvenFinite) { const sycl::half in[2] = {sycl::half(1.0f), sycl::half(2.0f)}; const sycl::half in1[2] = {sycl::half(1.125f), sycl::half(-0.0f)}; @@ -338,22 +317,6 @@ TEST(FP8E5M2Test, MarrayAndOperators) { EXPECT_EQ(out3[1], -1.5f); } -TEST(FP8E5M2Test, MarrayDouble) { - sycl::marray dvals = {1.0, 2.0}; - sycl::marray dvals1 = {57344.0, -0.0}; - - fp8_e5m2_x2 ah(dvals); - fp8_e5m2_x2 ah1(dvals1); - - EXPECT_EQ(sizeof(ah.vals), 2u); - EXPECT_EQ(sizeof(ah1.vals), 2u); - - EXPECT_EQ(ah.vals[0], 0x3C); - EXPECT_EQ(ah.vals[1], 0x40); - EXPECT_EQ(ah1.vals[0], 0x7B); - EXPECT_EQ(ah1.vals[1], 0x80); -} - TEST(FP8E5M2Test, FloatingPointConversionOperatorsMoreTypes) { fp8_e5m2 a(1.0f); fp8_e5m2 b(0.00006103515625f); @@ -363,11 +326,9 @@ TEST(FP8E5M2Test, FloatingPointConversionOperatorsMoreTypes) { EXPECT_EQ(sizeof(b.vals), 1u); EXPECT_EQ(sizeof(nanv.vals), 1u); - double da = static_cast(a); sycl::half ha = static_cast(a); sycl::ext::oneapi::bfloat16 ba = static_cast(a); - EXPECT_EQ(da, 1.0); EXPECT_EQ(static_cast(ha), 1.0f); EXPECT_EQ(static_cast(ba), 1.0f); @@ -434,9 +395,6 @@ TEST(FP8E5M2Test, AssignmentOperatorsAllTypes) { a = 3.0f; EXPECT_EQ(a.vals[0], 0x42); // 3.0 - a = 4.0; - EXPECT_EQ(a.vals[0], 0x44); // 4.0 - a = static_cast(-1); EXPECT_EQ(a.vals[0], 0xBC); @@ -482,8 +440,6 @@ TEST(FP8E5M2Test, BoolOperatorWithNaN) { TEST(FP8E5M2Test, VariadicMixedScalarTypes) { EXPECT_FALSE((std::is_constructible_v)); - EXPECT_FALSE( - (std::is_constructible_v)); } TEST(FP8E5M2Test, X2NotConstructibleFromSingleShort) { @@ -522,10 +478,6 @@ TEST(FP8E5M2Test, X2NotConstructibleFromSingleFloat) { EXPECT_FALSE((std::is_constructible_v)); } -TEST(FP8E5M2Test, X2NotConstructibleFromSingleDouble) { - EXPECT_FALSE((std::is_constructible_v)); -} - TEST(FP8E5M2Test, X2NotConstructibleFromSingleBFloat16) { EXPECT_FALSE( (std::is_constructible_v)); @@ -556,10 +508,6 @@ TEST(FP8E5M2Test, X2NotAssignableFromSingleFloat) { EXPECT_FALSE((std::is_assignable_v)); } -TEST(FP8E5M2Test, X2NotAssignableFromSingleDouble) { - EXPECT_FALSE((std::is_assignable_v)); -} - TEST(FP8E5M2Test, X2NotAssignableFromSingleChar) { EXPECT_FALSE((std::is_assignable_v)); } diff --git a/sycl/unittests/Extensions/fp8/fp8_e8m0.cpp b/sycl/unittests/Extensions/fp8/fp8_e8m0.cpp index fad814697b6e2..56537aa6ce6ab 100644 --- a/sycl/unittests/Extensions/fp8/fp8_e8m0.cpp +++ b/sycl/unittests/Extensions/fp8/fp8_e8m0.cpp @@ -56,14 +56,6 @@ TEST(FP8E8M0Test, VariadicBFloat16) { EXPECT_EQ(a.vals[1], 0x80); } -TEST(FP8E8M0Test, VariadicDouble) { - fp8_e8m0_x2 a(1.0, 3.0); - - EXPECT_EQ(sizeof(a.vals), 2u); - EXPECT_EQ(a.vals[0], 0x7F); - EXPECT_EQ(a.vals[1], 0x81); -} - TEST(FP8E8M0Test, VariadicBoundaryEncodings) { fp8_e8m0_x2 a(std::ldexp(1.0f, -127), std::numeric_limits::quiet_NaN()); @@ -133,15 +125,6 @@ TEST(FP8E8M0Test, CArrayBFloat16HostUpwardFinite) { EXPECT_EQ(a.vals[1], 0x80); } -TEST(FP8E8M0Test, CArrayDoubleDefaultUpwardFinite) { - const double in[2] = {1.0, 3.0}; - fp8_e8m0_x2 a(in); - - EXPECT_EQ(sizeof(a.vals), 2u); - EXPECT_EQ(a.vals[0], 0x7F); - EXPECT_EQ(a.vals[1], 0x81); -} - TEST(FP8E8M0Test, MarrayAndOperatorsFloat) { sycl::marray in = {1.0f, 2.0f}; sycl::marray in1 = {3.0f, 0.0f}; @@ -164,26 +147,21 @@ TEST(FP8E8M0Test, MarrayAndOperatorsFloat) { EXPECT_EQ(out1[1], std::ldexp(1.0f, -127)); } -TEST(FP8E8M0Test, MarrayHalfBFloat16Double) { +TEST(FP8E8M0Test, MarrayHalfBFloat16) { sycl::marray hvals = {sycl::half(1.0f), sycl::half(3.0f)}; sycl::marray bvals = { sycl::ext::oneapi::bfloat16(1.0f), sycl::ext::oneapi::bfloat16(2.0f)}; - sycl::marray dvals = {1.0, 3.0}; fp8_e8m0_x2 ah(hvals, rounding::upward); fp8_e8m0_x2 ab(bvals, rounding::upward); - fp8_e8m0_x2 ad(dvals); EXPECT_EQ(sizeof(ah.vals), 2u); EXPECT_EQ(sizeof(ab.vals), 2u); - EXPECT_EQ(sizeof(ad.vals), 2u); EXPECT_EQ(ah.vals[0], 0x7F); EXPECT_EQ(ah.vals[1], 0x81); EXPECT_EQ(ab.vals[0], 0x7F); EXPECT_EQ(ab.vals[1], 0x80); - EXPECT_EQ(ad.vals[0], 0x7F); - EXPECT_EQ(ad.vals[1], 0x81); } TEST(FP8E8M0Test, IntegerConstructorsAllTypes) { @@ -228,9 +206,6 @@ TEST(FP8E8M0Test, AssignmentOperatorsAllTypes) { a = 3.0f; EXPECT_EQ(a.vals[0], 0x81); - a = 4.0; - EXPECT_EQ(a.vals[0], 0x81); - a = static_cast(1); EXPECT_EQ(a.vals[0], 0x7F); @@ -267,13 +242,11 @@ TEST(FP8E8M0Test, FloatingPointConversionOperators) { EXPECT_EQ(min.vals[0], 0x00); float fo = static_cast(one); - double doo = static_cast(one); sycl::half ho = static_cast(one); sycl::ext::oneapi::bfloat16 bo = static_cast(one); EXPECT_EQ(fo, 1.0f); - EXPECT_EQ(doo, 1.0); EXPECT_EQ(static_cast(ho), 1.0f); EXPECT_EQ(static_cast(bo), 1.0f); @@ -312,8 +285,6 @@ TEST(FP8E8M0Test, MarrayConversionOperators) { TEST(FP8E8M0Test, VariadicMixedTypes) { EXPECT_FALSE((std::is_constructible_v)); - EXPECT_FALSE((std::is_constructible_v)); } TEST(FP8E8M0Test, X2NotConstructibleFromSingleShort) { @@ -352,10 +323,6 @@ TEST(FP8E8M0Test, X2NotConstructibleFromSingleFloat) { EXPECT_FALSE((std::is_constructible_v)); } -TEST(FP8E8M0Test, X2NotConstructibleFromSingleDouble) { - EXPECT_FALSE((std::is_constructible_v)); -} - TEST(FP8E8M0Test, X2NotConstructibleFromSingleBFloat16) { EXPECT_FALSE( (std::is_constructible_v)); @@ -386,10 +353,6 @@ TEST(FP8E8M0Test, X2NotAssignableFromSingleFloat) { EXPECT_FALSE((std::is_assignable_v)); } -TEST(FP8E8M0Test, X2NotAssignableFromSingleDouble) { - EXPECT_FALSE((std::is_assignable_v)); -} - TEST(FP8E8M0Test, X2NotAssignableFromSingleChar) { EXPECT_FALSE((std::is_assignable_v)); } From 113be776914b82238f6cada660bb31804ddbe5ea Mon Sep 17 00:00:00 2001 From: "Klochkov, Denis" Date: Tue, 12 May 2026 18:15:10 +0200 Subject: [PATCH 48/89] [SYCL][TESTS] FP8: fix builtin mock tests Co-authored-by: Copilot --- sycl/unittests/Extensions/fp8/builtin_call_tests.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sycl/unittests/Extensions/fp8/builtin_call_tests.cpp b/sycl/unittests/Extensions/fp8/builtin_call_tests.cpp index be8c031ffad14..ccecb5d921cd7 100644 --- a/sycl/unittests/Extensions/fp8/builtin_call_tests.cpp +++ b/sycl/unittests/Extensions/fp8/builtin_call_tests.cpp @@ -57,11 +57,11 @@ TEST_F(Fp8BuiltinCallTest, E4M3CastToBf16CallsConvertE4M3ToBF16) { EXPECT_EQ(fp8_builtin_mock::getCounters().ConvertE4M3ToBF16EXT, 1); } -TEST_F(Fp8BuiltinCallTest, E4M3CastToBoolCallsConvertE4M3ToFP16) { +TEST_F(Fp8BuiltinCallTest, E4M3CastToBoolDoesNotCallConvertE4M3ToFP16) { fp8_e4m3 Value(static_cast(1.0f)); fp8_builtin_mock::resetCounters(); - (void)static_cast(Value); - EXPECT_EQ(fp8_builtin_mock::getCounters().ConvertE4M3ToFP16EXT, 1); + EXPECT_TRUE(static_cast(Value)); + EXPECT_EQ(fp8_builtin_mock::getCounters().ConvertE4M3ToFP16EXT, 0); } TEST_F(Fp8BuiltinCallTest, E4M3MarrayCastToHalfCallsConvertE4M3ToFP16) { From 9bfe2b74c4b12f25502caceed53e0d814b3fed26 Mon Sep 17 00:00:00 2001 From: "Klochkov, Denis" Date: Wed, 13 May 2026 17:37:50 +0200 Subject: [PATCH 49/89] [SYCL][TESTE2E] make coverage of fp8 data types about 90 --- .../Experimental/fp8/e4m3_cri_conversion.cpp | 32 ++ .../fp8/e4m3_x2_cri_conversion.cpp | 66 ++- .../Experimental/fp8/e5m2_cri_conversion.cpp | 173 +++++++ .../fp8/e5m2_x2_cri_conversion.cpp | 439 ++++++++++++++++ .../Experimental/fp8/e8m0_cri_conversion.cpp | 444 ++++++++++++++++ .../fp8/e8m0_x2_cri_conversion.cpp | 479 ++++++++++++++++++ sycl/unittests/Extensions/fp8/fp8_e4m3.cpp | 28 + 7 files changed, 1658 insertions(+), 3 deletions(-) create mode 100644 sycl/test-e2e/Experimental/fp8/e5m2_cri_conversion.cpp create mode 100644 sycl/test-e2e/Experimental/fp8/e5m2_x2_cri_conversion.cpp create mode 100644 sycl/test-e2e/Experimental/fp8/e8m0_cri_conversion.cpp create mode 100644 sycl/test-e2e/Experimental/fp8/e8m0_x2_cri_conversion.cpp diff --git a/sycl/test-e2e/Experimental/fp8/e4m3_cri_conversion.cpp b/sycl/test-e2e/Experimental/fp8/e4m3_cri_conversion.cpp index ac7594d42ab8d..06381e25b18ba 100644 --- a/sycl/test-e2e/Experimental/fp8/e4m3_cri_conversion.cpp +++ b/sycl/test-e2e/Experimental/fp8/e4m3_cri_conversion.cpp @@ -8,6 +8,7 @@ // TODO need to set requirement of intel_feature_gpu_cri #include +#include #include #include #include @@ -53,6 +54,29 @@ int test_boolean_conversion(sycl::queue &queue, float test_value, return ret; } +template int test_single_element_carray_constructor(sycl::queue &queue) { + T input[1] = {static_cast(1.25f)}; + auto *data = sycl::malloc_shared(1, queue); + data[0] = fp8_e4m3(input); + + queue.single_task([=]() { + fp8_e4m3 value = data[0]; + T output[1] = {static_cast(value) + static_cast(1.0f)}; + data[0] = fp8_e4m3(output); + }); + queue.wait_and_throw(); + + fp8_e4m3 expected(static_cast(2.25f)); + T out = static_cast(data[0]); + T expected_out = static_cast(expected); + + sycl::free(data, queue); + if (std::fabs(static_cast(out) - static_cast(expected_out)) > + 0.0f) + return 1; + return 0; +} + template int test_marray_conversion(sycl::queue &queue) { sycl::marray input(static_cast(1.25f)); auto *data = sycl::malloc_shared(1, queue); @@ -134,6 +158,14 @@ int main() { ret |= test_boolean_conversion(queue, -0.0f, false); ret |= test_boolean_conversion(queue, 1.0f, true); ret |= test_boolean_conversion(queue, -1.0f, true); + ret |= test_boolean_conversion(queue, std::numeric_limits::quiet_NaN(), + true); + ret |= test_boolean_conversion(queue, 0.001953125f, true); + + ret |= test_single_element_carray_constructor(queue); + ret |= test_single_element_carray_constructor(queue); + // ret |= + // test_single_element_carray_constructor(queue); ret |= test_marray_conversion(queue); ret |= test_marray_conversion(queue); diff --git a/sycl/test-e2e/Experimental/fp8/e4m3_x2_cri_conversion.cpp b/sycl/test-e2e/Experimental/fp8/e4m3_x2_cri_conversion.cpp index 36f991296986a..dce5c9ae77ac2 100644 --- a/sycl/test-e2e/Experimental/fp8/e4m3_x2_cri_conversion.cpp +++ b/sycl/test-e2e/Experimental/fp8/e4m3_x2_cri_conversion.cpp @@ -162,6 +162,62 @@ int test_boundary_round_trip_subnormals(sycl::queue &queue) { return ret; } +int test_boundary_round_trip_exact_normals(sycl::queue &queue) { + const float input[2] = {448.0f, 0.015625f}; + auto *data = sycl::malloc_shared(1, queue); + auto *out = sycl::malloc_shared(2, queue); + data[0] = fp8_e4m3_x2(input, rounding::to_even); + + queue.single_task([=]() { + fp8_e4m3_x2 value = data[0]; + sycl::marray unpacked = static_cast>(value); + data[0] = fp8_e4m3_x2(unpacked, rounding::to_even); + sycl::marray round_tripped = + static_cast>(data[0]); + out[0] = round_tripped[0]; + out[1] = round_tripped[1]; + }); + queue.wait_and_throw(); + + int ret = 0; + if (out[0] != 448.0f) + ret = 1; + if (out[1] != 0.015625f) + ret = 1; + + sycl::free(data, queue); + sycl::free(out, queue); + return ret; +} + +int test_boundary_round_trip_exact_subnormal_limits(sycl::queue &queue) { + const float input[2] = {0.013671875f, 0.001953125f}; + auto *data = sycl::malloc_shared(1, queue); + auto *out = sycl::malloc_shared(2, queue); + data[0] = fp8_e4m3_x2(input, rounding::to_even); + + queue.single_task([=]() { + fp8_e4m3_x2 value = data[0]; + sycl::marray unpacked = static_cast>(value); + data[0] = fp8_e4m3_x2(unpacked, rounding::to_even); + sycl::marray round_tripped = + static_cast>(data[0]); + out[0] = round_tripped[0]; + out[1] = round_tripped[1]; + }); + queue.wait_and_throw(); + + int ret = 0; + if (out[0] != 0.013671875f) + ret = 1; + if (out[1] != 0.001953125f) + ret = 1; + + sycl::free(data, queue); + sycl::free(out, queue); + return ret; +} + int test_boundary_round_trip_saturation_and_infinity_clamp(sycl::queue &queue) { const float input[2] = {600.0f, -std::numeric_limits::infinity()}; auto *data = sycl::malloc_shared(1, queue); @@ -294,7 +350,7 @@ int main() { // marray, and marray. int ret = test_fp8_simple_type_conversion(queue); ret |= test_fp8_simple_type_conversion(queue); - ret |= test_fp8_simple_type_conversion(queue); + // ret |= test_fp8_simple_type_conversion(queue); ret |= test_marray_conversion(queue); ret |= test_marray_conversion(queue); @@ -306,15 +362,19 @@ int main() { ret |= test_explicit_to_even_carray_constructor(queue); ret |= test_explicit_to_even_carray_constructor(queue); - // ret |= test_explicit_to_even_carray_constructor(queue); + // ret |= + // test_explicit_to_even_carray_constructor(queue); ret |= test_explicit_to_even_marray_constructor(queue); ret |= test_explicit_to_even_marray_constructor(queue); - // ret |= test_explicit_to_even_marray_constructor(queue); + // ret |= + // test_explicit_to_even_marray_constructor(queue); ret |= test_boundary_round_trip_nan(queue); ret |= test_boundary_round_trip_negative_zero(queue); ret |= test_boundary_round_trip_subnormals(queue); + ret |= test_boundary_round_trip_exact_normals(queue); + ret |= test_boundary_round_trip_exact_subnormal_limits(queue); ret |= test_boundary_round_trip_saturation_and_infinity_clamp(queue); return ret; } diff --git a/sycl/test-e2e/Experimental/fp8/e5m2_cri_conversion.cpp b/sycl/test-e2e/Experimental/fp8/e5m2_cri_conversion.cpp new file mode 100644 index 0000000000000..8b27778787187 --- /dev/null +++ b/sycl/test-e2e/Experimental/fp8/e5m2_cri_conversion.cpp @@ -0,0 +1,173 @@ + +// RUN: %{build} -Xclang -freg-struct-return -Xspirv-translator --spirv-ext=+SPV_INTEL_fp_conversions,+SPV_EXT_float8,+SPV_KHR_bfloat16 -o %t.out +// RUN: %{run} SYCL_UR_TRACE=1 %t.out + +// Warning! This test requires CRI device or its simulator run to communicate +// via TCP socket with port 60999, or any other from config + +// TODO need to set requirement of intel_feature_gpu_cri + +#include +#include +#include +#include +#include + +using namespace sycl::ext::oneapi::experimental; + +template int test_fp8_simple_type_conversion(sycl::queue &queue) { + auto *data = sycl::malloc_shared(1, queue); + data[0] = fp8_e5m2(static_cast(1.5)); + + queue.single_task([=]() { + fp8_e5m2 value = data[0]; + T f = static_cast(value); + f += static_cast(1.0f); + data[0] = fp8_e5m2(f); + }); + queue.wait_and_throw(); + + fp8_e5m2 expected(2.5f); + T out = static_cast(data[0]); + T expected_out = static_cast(expected); + + sycl::free(data, queue); + if (std::fabs(out - expected_out) > 0.0f) + return 1; + + return 0; +} + +int test_boolean_conversion(sycl::queue &queue, float test_value, + bool expected) { + auto *data = sycl::malloc_shared(1, queue); + auto *res = sycl::malloc_shared(1, queue); + data[0] = fp8_e5m2(test_value); + queue.single_task([=]() { + fp8_e5m2 value = data[0]; + res[0] = static_cast(value); + }); + queue.wait_and_throw(); + int ret = res[0] == expected ? 0 : 1; + sycl::free(data, queue); + sycl::free(res, queue); + return ret; +} + +template +int test_single_element_carray_constructor(sycl::queue &queue) { + T input[1] = {static_cast(1.5f)}; + auto *data = sycl::malloc_shared(1, queue); + data[0] = fp8_e5m2(input); + + queue.single_task([=]() { + fp8_e5m2 value = data[0]; + T output[1] = {static_cast(value) + static_cast(1.0f)}; + data[0] = fp8_e5m2(output); + }); + queue.wait_and_throw(); + + fp8_e5m2 expected(static_cast(2.5f)); + T out = static_cast(data[0]); + T expected_out = static_cast(expected); + + sycl::free(data, queue); + if (std::fabs(static_cast(out) - static_cast(expected_out)) > + 0.0f) + return 1; + return 0; +} + +template int test_marray_conversion(sycl::queue &queue) { + sycl::marray input(static_cast(1.5f)); + auto *data = sycl::malloc_shared(1, queue); + data[0] = fp8_e5m2(input); + + queue.single_task([=]() { + fp8_e5m2 value = data[0]; + sycl::marray f = static_cast>(value); + f[0] += static_cast(1.0f); + data[0] = fp8_e5m2(f); + }); + queue.wait_and_throw(); + + return 0; +} + +template int test_carray_conversion(sycl::queue &queue) { + auto *data = sycl::malloc_shared(1, queue); + data[0] = fp8_e5m2(static_cast(1.5f)); + + queue.single_task([=]() { + fp8_e5m2 value = data[0]; + T f = {static_cast(value)}; + f += static_cast(1.0f); + data[0] = fp8_e5m2(f); + }); + queue.wait_and_throw(); + + fp8_e5m2 expected(static_cast(2.5f)); + T out = {static_cast(data[0])}; + T expected_out = {static_cast(expected)}; + + sycl::free(data, queue); + if (std::fabs(out - expected_out) > 0.0f) + return 1; + return 0; +} + +int main() { + auto async_handler = [](sycl::exception_list exceptions) { + for (const std::exception_ptr &e : exceptions) { + try { + std::rethrow_exception(e); + } catch (const sycl::exception &ex) { + std::cerr << "Async SYCL exception: " << ex.what() << '\n'; + std::terminate(); + } + } + }; + + sycl::queue queue{async_handler}; + + int ret = test_fp8_simple_type_conversion(queue); + ret |= test_fp8_simple_type_conversion(queue); + ret |= test_fp8_simple_type_conversion(queue); + ret |= test_fp8_simple_type_conversion(queue); + ret |= test_fp8_simple_type_conversion(queue); + ret |= test_fp8_simple_type_conversion(queue); + ret |= test_fp8_simple_type_conversion(queue); + ret |= test_fp8_simple_type_conversion(queue); + ret |= test_fp8_simple_type_conversion(queue); + ret |= test_fp8_simple_type_conversion(queue); + ret |= test_fp8_simple_type_conversion(queue); + ret |= test_fp8_simple_type_conversion(queue); + ret |= test_fp8_simple_type_conversion(queue); + ret |= test_fp8_simple_type_conversion(queue); + + ret |= test_boolean_conversion(queue, 0.0f, false); + ret |= test_boolean_conversion(queue, -0.0f, false); + ret |= test_boolean_conversion(queue, 1.0f, true); + ret |= test_boolean_conversion(queue, -1.0f, true); + ret |= test_boolean_conversion(queue, std::numeric_limits::quiet_NaN(), + true); + ret |= test_boolean_conversion(queue, std::numeric_limits::infinity(), + true); + ret |= test_boolean_conversion(queue, 1.52587890625e-05f, true); + + ret |= test_single_element_carray_constructor(queue); + ret |= test_single_element_carray_constructor(queue); + // ret |= + // test_single_element_carray_constructor(queue); + + ret |= test_marray_conversion(queue); + ret |= test_marray_conversion(queue); + // TODO: uncomment when bfloat16 conversion is fixed + // ret |= test_marray_conversion(queue); + + ret |= test_carray_conversion(queue); + ret |= test_carray_conversion(queue); + // TODO: uncomment when bfloat16 conversion is fixed + // ret |= test_carray_conversion(queue); + return ret; +} diff --git a/sycl/test-e2e/Experimental/fp8/e5m2_x2_cri_conversion.cpp b/sycl/test-e2e/Experimental/fp8/e5m2_x2_cri_conversion.cpp new file mode 100644 index 0000000000000..992a42fab8045 --- /dev/null +++ b/sycl/test-e2e/Experimental/fp8/e5m2_x2_cri_conversion.cpp @@ -0,0 +1,439 @@ + +// RUN: %{build} -Xclang -freg-struct-return -Xspirv-translator --spirv-ext=+SPV_INTEL_fp_conversions,+SPV_EXT_float8,+SPV_KHR_bfloat16 -o %t.out +// RUN: %{run} SYCL_UR_TRACE=1 %t.out + +// Warning! This test requires CRI device or its simulator run to communicate +// via TCP socket with port 60999, or any other from config + +// TODO need to set requirement of intel_feature_gpu_cri + +#include +#include +#include +#include +#include + +using namespace sycl::ext::oneapi::experimental; + +namespace { + +bool equal_or_both_nan(float actual, float expected) { + if (std::isnan(expected)) + return std::isnan(actual); + return actual == expected; +} + +bool equal_with_zero_sign(float actual, float expected) { + if (!equal_or_both_nan(actual, expected)) + return false; + if (expected == 0.0f) + return std::signbit(actual) == std::signbit(expected); + return true; +} + +template +int test_explicit_to_even_carray_constructor(sycl::queue &queue) { + T input[2] = {static_cast(3.0517578125e-05f), static_cast(-6.0f)}; + auto *data = sycl::malloc_shared(1, queue); + auto *out = sycl::malloc_shared(2, queue); + data[0] = fp8_e5m2_x2(input, rounding::to_even); + + queue.single_task([=]() { + sycl::marray unpacked = static_cast>(data[0]); + out[0] = unpacked[0]; + out[1] = unpacked[1]; + }); + queue.wait_and_throw(); + + int ret = 0; + if (static_cast(out[0]) != 3.0517578125e-05f) + ret = 1; + if (static_cast(out[1]) != -6.0f) + ret = 1; + + sycl::free(data, queue); + sycl::free(out, queue); + return ret; +} + +template +int test_explicit_to_even_marray_constructor(sycl::queue &queue) { + sycl::marray input(static_cast(3.0f), + static_cast(-1.52587890625e-05f)); + auto *data = sycl::malloc_shared(1, queue); + auto *out = sycl::malloc_shared(2, queue); + data[0] = fp8_e5m2_x2(input, rounding::to_even); + + queue.single_task([=]() { + sycl::marray unpacked = static_cast>(data[0]); + out[0] = unpacked[0]; + out[1] = unpacked[1]; + }); + queue.wait_and_throw(); + + int ret = 0; + if (static_cast(out[0]) != 3.0f) + ret = 1; + if (static_cast(out[1]) != -1.52587890625e-05f) + ret = 1; + + sycl::free(data, queue); + sycl::free(out, queue); + return ret; +} + +int test_boundary_round_trip_nan(sycl::queue &queue) { + const float input[2] = {std::numeric_limits::quiet_NaN(), + -std::numeric_limits::quiet_NaN()}; + auto *data = sycl::malloc_shared(1, queue); + auto *out = sycl::malloc_shared(2, queue); + data[0] = fp8_e5m2_x2(input, rounding::to_even); + + queue.single_task([=]() { + fp8_e5m2_x2 value = data[0]; + sycl::marray unpacked = + static_cast>(value); + data[0] = fp8_e5m2_x2(unpacked, rounding::to_even); + sycl::marray round_tripped = + static_cast>(data[0]); + out[0] = round_tripped[0]; + out[1] = round_tripped[1]; + }); + queue.wait_and_throw(); + + int ret = !(std::isnan(out[0]) && std::isnan(out[1])); + sycl::free(data, queue); + sycl::free(out, queue); + return ret; +} + +int test_boundary_round_trip_negative_zero(sycl::queue &queue) { + const float input[2] = {-0.0f, 7.0f}; + auto *data = sycl::malloc_shared(1, queue); + auto *out = sycl::malloc_shared(2, queue); + data[0] = fp8_e5m2_x2(input, rounding::to_even); + + queue.single_task([=]() { + fp8_e5m2_x2 value = data[0]; + sycl::marray unpacked = + static_cast>(value); + data[0] = fp8_e5m2_x2(unpacked, rounding::to_even); + sycl::marray round_tripped = + static_cast>(data[0]); + out[0] = round_tripped[0]; + out[1] = round_tripped[1]; + }); + queue.wait_and_throw(); + + int ret = 0; + if (!equal_with_zero_sign(out[0], -0.0f)) + ret = 1; + if (out[1] != 7.0f) + ret = 1; + + sycl::free(data, queue); + sycl::free(out, queue); + return ret; +} + +int test_boundary_round_trip_subnormals(sycl::queue &queue) { + const float input[2] = {3.0517578125e-05f, -4.57763671875e-05f}; + auto *data = sycl::malloc_shared(1, queue); + auto *out = sycl::malloc_shared(2, queue); + data[0] = fp8_e5m2_x2(input, rounding::to_even); + + queue.single_task([=]() { + fp8_e5m2_x2 value = data[0]; + sycl::marray unpacked = + static_cast>(value); + data[0] = fp8_e5m2_x2(unpacked, rounding::to_even); + sycl::marray round_tripped = + static_cast>(data[0]); + out[0] = round_tripped[0]; + out[1] = round_tripped[1]; + }); + queue.wait_and_throw(); + + int ret = 0; + if (out[0] != 3.0517578125e-05f) + ret = 1; + if (out[1] != -4.57763671875e-05f) + ret = 1; + + sycl::free(data, queue); + sycl::free(out, queue); + return ret; +} + +int test_boundary_round_trip_exact_normals(sycl::queue &queue) { + const float input[2] = {57344.0f, 6.103515625e-05f}; + auto *data = sycl::malloc_shared(1, queue); + auto *out = sycl::malloc_shared(2, queue); + data[0] = fp8_e5m2_x2(input, rounding::to_even); + + queue.single_task([=]() { + fp8_e5m2_x2 value = data[0]; + sycl::marray unpacked = + static_cast>(value); + data[0] = fp8_e5m2_x2(unpacked, rounding::to_even); + sycl::marray round_tripped = + static_cast>(data[0]); + out[0] = round_tripped[0]; + out[1] = round_tripped[1]; + }); + queue.wait_and_throw(); + + int ret = 0; + if (out[0] != 57344.0f) + ret = 1; + if (out[1] != 6.103515625e-05f) + ret = 1; + + sycl::free(data, queue); + sycl::free(out, queue); + return ret; +} + +int test_boundary_round_trip_exact_subnormal_limits(sycl::queue &queue) { + const float input[2] = {4.57763671875e-05f, 1.52587890625e-05f}; + auto *data = sycl::malloc_shared(1, queue); + auto *out = sycl::malloc_shared(2, queue); + data[0] = fp8_e5m2_x2(input, rounding::to_even); + + queue.single_task([=]() { + fp8_e5m2_x2 value = data[0]; + sycl::marray unpacked = + static_cast>(value); + data[0] = fp8_e5m2_x2(unpacked, rounding::to_even); + sycl::marray round_tripped = + static_cast>(data[0]); + out[0] = round_tripped[0]; + out[1] = round_tripped[1]; + }); + queue.wait_and_throw(); + + int ret = 0; + if (out[0] != 4.57763671875e-05f) + ret = 1; + if (out[1] != 1.52587890625e-05f) + ret = 1; + + sycl::free(data, queue); + sycl::free(out, queue); + return ret; +} + +int test_boundary_round_trip_saturation_and_infinity_clamp(sycl::queue &queue) { + const float input[2] = {60000.0f, -std::numeric_limits::infinity()}; + auto *data = sycl::malloc_shared(1, queue); + auto *out = sycl::malloc_shared(2, queue); + data[0] = fp8_e5m2_x2(input, rounding::to_even); + + queue.single_task([=]() { + fp8_e5m2_x2 value = data[0]; + sycl::marray unpacked = + static_cast>(value); + data[0] = fp8_e5m2_x2(unpacked, rounding::to_even); + sycl::marray round_tripped = + static_cast>(data[0]); + out[0] = round_tripped[0]; + out[1] = round_tripped[1]; + }); + queue.wait_and_throw(); + + int ret = 0; + if (out[0] != 57344.0f) + ret = 1; + if (out[1] != -57344.0f) + ret = 1; + + sycl::free(data, queue); + sycl::free(out, queue); + return ret; +} + +int test_boundary_infinity_no_saturation(sycl::queue &queue) { + const float input[2] = {std::numeric_limits::infinity(), + -std::numeric_limits::infinity()}; + auto *data = sycl::malloc_shared(1, queue); + auto *out = sycl::malloc_shared(2, queue); + data[0] = fp8_e5m2_x2(input, rounding::to_even, saturation::none); + + queue.single_task([=]() { + fp8_e5m2_x2 value = data[0]; + sycl::marray unpacked = + static_cast>(value); + out[0] = unpacked[0]; + out[1] = unpacked[1]; + }); + queue.wait_and_throw(); + + int ret = 0; + if (out[0] != std::numeric_limits::infinity()) + ret = 1; + if (out[1] != -std::numeric_limits::infinity()) + ret = 1; + + sycl::free(data, queue); + sycl::free(out, queue); + return ret; +} + +int test_boundary_overflow_no_saturation(sycl::queue &queue) { + const float input[2] = {60000.0f, -60000.0f}; + auto *data = sycl::malloc_shared(1, queue); + auto *out = sycl::malloc_shared(2, queue); + data[0] = fp8_e5m2_x2(input, rounding::to_even, saturation::none); + + queue.single_task([=]() { + fp8_e5m2_x2 value = data[0]; + sycl::marray unpacked = + static_cast>(value); + out[0] = unpacked[0]; + out[1] = unpacked[1]; + }); + queue.wait_and_throw(); + + int ret = 0; + if (out[0] != std::numeric_limits::infinity()) + ret = 1; + if (out[1] != -std::numeric_limits::infinity()) + ret = 1; + + sycl::free(data, queue); + sycl::free(out, queue); + return ret; +} + +} // namespace + +template int test_fp8_simple_type_conversion(sycl::queue &queue) { + auto *data = sycl::malloc_shared(1, queue); + data[0] = fp8_e5m2_x2(static_cast(1.5f), static_cast(2.5f)); + + queue.single_task([=]() { + fp8_e5m2_x2 value = data[0]; + sycl::marray f = static_cast>(value); + f[0] += static_cast(1.0f); + f[1] += static_cast(1.0f); + data[0] = fp8_e5m2_x2(f); + }); + queue.wait_and_throw(); + + sycl::marray expected_input(static_cast(2.5f), static_cast(3.5f)); + fp8_e5m2_x2 expected(expected_input); + sycl::marray out = static_cast>(data[0]); + sycl::marray expected_out = static_cast>(expected); + + sycl::free(data, queue); + for (size_t i = 0; i < 2; ++i) { + if (std::fabs(static_cast(out[i]) - + static_cast(expected_out[i])) > 0.0f) + return 1; + } + + return 0; +} + +template int test_marray_conversion(sycl::queue &queue) { + sycl::marray input(static_cast(1.0f), static_cast(2.0f)); + auto *data = sycl::malloc_shared(1, queue); + data[0] = fp8_e5m2_x2(input); + + queue.single_task([=]() { + fp8_e5m2_x2 value = data[0]; + sycl::marray f = static_cast>(value); + f[0] += static_cast(1.0f); + f[1] += static_cast(2.0f); + data[0] = fp8_e5m2_x2(f); + }); + queue.wait_and_throw(); + sycl::marray expected_input(static_cast(2.0f), static_cast(4.0f)); + fp8_e5m2_x2 expected(expected_input); + sycl::marray out = static_cast>(data[0]); + sycl::marray expected_out = static_cast>(expected); + + sycl::free(data, queue); + for (size_t i = 0; i < 2; ++i) { + if (std::fabs(static_cast(out[i]) - + static_cast(expected_out[i])) > 0.0f) + return 1; + } + return 0; +} + +template int test_carray_conversion(sycl::queue &queue) { + T input[2] = {static_cast(1.0f), static_cast(3.0f)}; + auto *data = sycl::malloc_shared(1, queue); + data[0] = fp8_e5m2_x2(input); + + queue.single_task([=]() { + fp8_e5m2_x2 value = data[0]; + sycl::marray unpacked = static_cast>(value); + T output[2] = {unpacked[0] + static_cast(1.0f), + unpacked[1] + static_cast(4.0f)}; + data[0] = fp8_e5m2_x2(output); + }); + queue.wait_and_throw(); + + T expected_input[2] = {static_cast(2.0f), static_cast(7.0f)}; + fp8_e5m2_x2 expected(expected_input); + sycl::marray out = static_cast>(data[0]); + sycl::marray expected_out = static_cast>(expected); + + sycl::free(data, queue); + for (size_t i = 0; i < 2; ++i) { + if (std::fabs(static_cast(out[i]) - + static_cast(expected_out[i])) > 0.0f) + return 1; + } + + return 0; +} + +int main() { + auto async_handler = [](sycl::exception_list exceptions) { + for (const std::exception_ptr &e : exceptions) { + try { + std::rethrow_exception(e); + } catch (const sycl::exception &ex) { + std::cerr << "Async SYCL exception: " << ex.what() << '\n'; + std::terminate(); + } + } + }; + + sycl::queue queue{async_handler}; + + int ret = test_fp8_simple_type_conversion(queue); + ret |= test_fp8_simple_type_conversion(queue); + // ret |= test_fp8_simple_type_conversion(queue); + + ret |= test_marray_conversion(queue); + ret |= test_marray_conversion(queue); + // ret |= test_marray_conversion(queue); + + ret |= test_carray_conversion(queue); + ret |= test_carray_conversion(queue); + // ret |= test_carray_conversion(queue); + + ret |= test_explicit_to_even_carray_constructor(queue); + ret |= test_explicit_to_even_carray_constructor(queue); + // ret |= + // test_explicit_to_even_carray_constructor(queue); + + ret |= test_explicit_to_even_marray_constructor(queue); + ret |= test_explicit_to_even_marray_constructor(queue); + // ret |= + // test_explicit_to_even_marray_constructor(queue); + + ret |= test_boundary_round_trip_nan(queue); + ret |= test_boundary_round_trip_negative_zero(queue); + ret |= test_boundary_round_trip_subnormals(queue); + ret |= test_boundary_round_trip_exact_normals(queue); + ret |= test_boundary_round_trip_exact_subnormal_limits(queue); + ret |= test_boundary_round_trip_saturation_and_infinity_clamp(queue); + ret |= test_boundary_infinity_no_saturation(queue); + ret |= test_boundary_overflow_no_saturation(queue); + return ret; +} diff --git a/sycl/test-e2e/Experimental/fp8/e8m0_cri_conversion.cpp b/sycl/test-e2e/Experimental/fp8/e8m0_cri_conversion.cpp new file mode 100644 index 0000000000000..515e0369eee03 --- /dev/null +++ b/sycl/test-e2e/Experimental/fp8/e8m0_cri_conversion.cpp @@ -0,0 +1,444 @@ + +// RUN: %{build} -Xclang -freg-struct-return -Xspirv-translator --spirv-ext=+SPV_INTEL_fp_conversions,+SPV_EXT_float8,+SPV_KHR_bfloat16 -o %t.out +// RUN: %{run} SYCL_UR_TRACE=1 %t.out + +// Warning! This test requires CRI device or its simulator run to communicate +// via TCP socket with port 60999, or any other from config + +// TODO need to set requirement of intel_feature_gpu_cri + +#include +#include +#include +#include +#include + +using namespace sycl::ext::oneapi::experimental; + +template int test_fp8_simple_type_conversion(sycl::queue &queue) { + auto *data = sycl::malloc_shared(1, queue); + data[0] = fp8_e8m0(static_cast(4.0f)); + + queue.single_task([=]() { + fp8_e8m0 value = data[0]; + T f = static_cast(value); + f *= static_cast(2.0f); + data[0] = fp8_e8m0(f); + }); + queue.wait_and_throw(); + + fp8_e8m0 expected(8.0f); + T out = static_cast(data[0]); + T expected_out = static_cast(expected); + + sycl::free(data, queue); + if (std::fabs(static_cast(out) - static_cast(expected_out)) > + 0.0f) + return 1; + + return 0; +} + +int test_boolean_conversion(sycl::queue &queue) { + auto *data = sycl::malloc_shared(1, queue); + auto *res = sycl::malloc_shared(1, queue); + data[0] = fp8_e8m0(1.0f); + queue.single_task([=]() { + fp8_e8m0 value = data[0]; + res[0] = static_cast(value); + }); + queue.wait_and_throw(); + int ret = res[0] == true ? 0 : 1; + sycl::free(data, queue); + sycl::free(res, queue); + return ret; +} + +int test_boolean_conversion_large(sycl::queue &queue) { + auto *data = sycl::malloc_shared(1, queue); + auto *res = sycl::malloc_shared(1, queue); + data[0] = fp8_e8m0(128.0f); + queue.single_task([=]() { + fp8_e8m0 value = data[0]; + res[0] = static_cast(value); + }); + queue.wait_and_throw(); + int ret = res[0] == true ? 0 : 1; + sycl::free(data, queue); + sycl::free(res, queue); + return ret; +} + +int test_boolean_conversion_nan(sycl::queue &queue) { + auto *data = sycl::malloc_shared(1, queue); + auto *res = sycl::malloc_shared(1, queue); + data[0] = fp8_e8m0(std::numeric_limits::quiet_NaN()); + queue.single_task([=]() { + fp8_e8m0 value = data[0]; + res[0] = static_cast(value); + }); + queue.wait_and_throw(); + int ret = res[0] == true ? 0 : 1; + sycl::free(data, queue); + sycl::free(res, queue); + return ret; +} + +template +int test_single_element_carray_constructor(sycl::queue &queue) { + T input[1] = {static_cast(4.0f)}; + auto *data = sycl::malloc_shared(1, queue); + data[0] = fp8_e8m0(input); + + queue.single_task([=]() { + fp8_e8m0 value = data[0]; + T output[1] = {static_cast(value) * static_cast(2.0f)}; + data[0] = fp8_e8m0(output); + }); + queue.wait_and_throw(); + + fp8_e8m0 expected(static_cast(8.0f)); + T out = static_cast(data[0]); + T expected_out = static_cast(expected); + + sycl::free(data, queue); + if (std::fabs(static_cast(out) - static_cast(expected_out)) > + 0.0f) + return 1; + return 0; +} + +template int test_marray_conversion(sycl::queue &queue) { + sycl::marray input(static_cast(4.0f)); + auto *data = sycl::malloc_shared(1, queue); + data[0] = fp8_e8m0(input); + + queue.single_task([=]() { + fp8_e8m0 value = data[0]; + sycl::marray f = static_cast>(value); + f[0] *= static_cast(2.0f); + data[0] = fp8_e8m0(f); + }); + queue.wait_and_throw(); + + sycl::marray expected_input(static_cast(8.0f)); + fp8_e8m0 expected(expected_input); + sycl::marray out = static_cast>(data[0]); + sycl::marray expected_out = static_cast>(expected); + + sycl::free(data, queue); + if (std::fabs(static_cast(out[0]) - + static_cast(expected_out[0])) > 0.0f) + return 1; + return 0; +} + +template int test_carray_conversion(sycl::queue &queue) { + auto *data = sycl::malloc_shared(1, queue); + data[0] = fp8_e8m0(static_cast(4.0f)); + + queue.single_task([=]() { + fp8_e8m0 value = data[0]; + T f = {static_cast(value)}; + f *= static_cast(2.0f); + data[0] = fp8_e8m0(f); + }); + queue.wait_and_throw(); + + fp8_e8m0 expected(static_cast(8.0f)); + T out = {static_cast(data[0])}; + T expected_out = {static_cast(expected)}; + + sycl::free(data, queue); + if (std::fabs(static_cast(out) - static_cast(expected_out)) > + 0.0f) + return 1; + return 0; +} + +int test_rounding_upward(sycl::queue &queue) { + float input[1] = {3.0f}; + auto *data = sycl::malloc_shared(1, queue); + data[0] = fp8_e8m0(input, rounding::upward); + + queue.single_task([=]() { + fp8_e8m0 value = data[0]; + float out = static_cast(value); + float expected[1] = {out}; + data[0] = fp8_e8m0(expected, rounding::upward); + }); + queue.wait_and_throw(); + + float out = static_cast(data[0]); + sycl::free(data, queue); + if (out != 4.0f) + return 1; + return 0; +} + +int test_rounding_toward_zero(sycl::queue &queue) { + float input[1] = {3.0f}; + auto *data = sycl::malloc_shared(1, queue); + data[0] = fp8_e8m0(input, rounding::toward_zero); + + queue.single_task([=]() { + fp8_e8m0 value = data[0]; + float out = static_cast(value); + float expected[1] = {out}; + data[0] = fp8_e8m0(expected, rounding::toward_zero); + }); + queue.wait_and_throw(); + + float out = static_cast(data[0]); + sycl::free(data, queue); + if (out != 2.0f) + return 1; + return 0; +} + +int test_rounding_upward_marray(sycl::queue &queue) { + sycl::marray input(5.0f); + auto *data = sycl::malloc_shared(1, queue); + data[0] = fp8_e8m0(input, rounding::upward); + + queue.single_task([=]() { + fp8_e8m0 value = data[0]; + sycl::marray f = static_cast>(value); + data[0] = fp8_e8m0(f, rounding::upward); + }); + queue.wait_and_throw(); + + float out = static_cast(data[0]); + sycl::free(data, queue); + if (out != 8.0f) + return 1; + return 0; +} + +int test_rounding_toward_zero_marray(sycl::queue &queue) { + sycl::marray input(5.0f); + auto *data = sycl::malloc_shared(1, queue); + data[0] = fp8_e8m0(input, rounding::toward_zero); + + queue.single_task([=]() { + fp8_e8m0 value = data[0]; + sycl::marray f = static_cast>(value); + data[0] = fp8_e8m0(f, rounding::toward_zero); + }); + queue.wait_and_throw(); + + float out = static_cast(data[0]); + sycl::free(data, queue); + if (out != 4.0f) + return 1; + return 0; +} + +int test_power_of_two_round_trip(sycl::queue &queue) { + auto *data = sycl::malloc_shared(1, queue); + auto *out = sycl::malloc_shared(1, queue); + data[0] = fp8_e8m0(16.0f); + + queue.single_task([=]() { + fp8_e8m0 value = data[0]; + out[0] = static_cast(value); + }); + queue.wait_and_throw(); + + int ret = (out[0] != 16.0f) ? 1 : 0; + sycl::free(data, queue); + sycl::free(out, queue); + return ret; +} + +int test_max_normal_round_trip(sycl::queue &queue) { + float max_val = std::ldexp(1.0f, 127); + auto *data = sycl::malloc_shared(1, queue); + auto *out = sycl::malloc_shared(1, queue); + float input[1] = {max_val}; + data[0] = fp8_e8m0(input, rounding::toward_zero); + + queue.single_task([=]() { + fp8_e8m0 value = data[0]; + out[0] = static_cast(value); + }); + queue.wait_and_throw(); + + int ret = (out[0] != max_val) ? 1 : 0; + sycl::free(data, queue); + sycl::free(out, queue); + return ret; +} + +int test_min_normal_round_trip(sycl::queue &queue) { + float min_val = std::ldexp(1.0f, -127); + auto *data = sycl::malloc_shared(1, queue); + auto *out = sycl::malloc_shared(1, queue); + float input[1] = {min_val}; + data[0] = fp8_e8m0(input, rounding::toward_zero); + + queue.single_task([=]() { + fp8_e8m0 value = data[0]; + out[0] = static_cast(value); + }); + queue.wait_and_throw(); + + int ret = (out[0] != min_val) ? 1 : 0; + sycl::free(data, queue); + sycl::free(out, queue); + return ret; +} + +int test_nan_round_trip(sycl::queue &queue) { + auto *data = sycl::malloc_shared(1, queue); + auto *out = sycl::malloc_shared(1, queue); + float input[1] = {std::numeric_limits::quiet_NaN()}; + data[0] = fp8_e8m0(input, rounding::upward); + + queue.single_task([=]() { + fp8_e8m0 value = data[0]; + out[0] = static_cast(value); + }); + queue.wait_and_throw(); + + int ret = std::isnan(out[0]) ? 0 : 1; + sycl::free(data, queue); + sycl::free(out, queue); + return ret; +} + +int test_saturation_large_value(sycl::queue &queue) { + float large = std::numeric_limits::infinity(); + auto *data = sycl::malloc_shared(1, queue); + auto *out = sycl::malloc_shared(1, queue); + float input[1] = {large}; + data[0] = fp8_e8m0(input, rounding::upward); + + queue.single_task([=]() { + fp8_e8m0 value = data[0]; + out[0] = static_cast(value); + }); + queue.wait_and_throw(); + + float max_e8m0 = std::ldexp(1.0f, 127); + int ret = (out[0] != max_e8m0) ? 1 : 0; + sycl::free(data, queue); + sycl::free(out, queue); + return ret; +} + +int test_saturation_overflow(sycl::queue &queue) { + float large = std::ldexp(1.0f, 128); + auto *data = sycl::malloc_shared(1, queue); + auto *out = sycl::malloc_shared(1, queue); + float input[1] = {large}; + data[0] = fp8_e8m0(input, rounding::upward); + + queue.single_task([=]() { + fp8_e8m0 value = data[0]; + out[0] = static_cast(value); + }); + queue.wait_and_throw(); + + float max_e8m0 = std::ldexp(1.0f, 127); + int ret = (out[0] != max_e8m0) ? 1 : 0; + sycl::free(data, queue); + sycl::free(out, queue); + return ret; +} + +int test_raw_vals_access(sycl::queue &queue) { + auto *data = sycl::malloc_shared(1, queue); + auto *out = sycl::malloc_shared(1, queue); + data[0] = fp8_e8m0(1.0f); + + queue.single_task([=]() { + out[0] = data[0].vals[0]; + }); + queue.wait_and_throw(); + + int ret = (out[0] != 127) ? 1 : 0; + sycl::free(data, queue); + sycl::free(out, queue); + return ret; +} + +int test_negative_input_drops_sign(sycl::queue &queue) { + float input[1] = {-8.0f}; + auto *data = sycl::malloc_shared(1, queue); + auto *out = sycl::malloc_shared(1, queue); + data[0] = fp8_e8m0(input, rounding::upward); + + queue.single_task([=]() { + fp8_e8m0 value = data[0]; + out[0] = static_cast(value); + }); + queue.wait_and_throw(); + + int ret = (out[0] != 8.0f) ? 1 : 0; + sycl::free(data, queue); + sycl::free(out, queue); + return ret; +} + +int main() { + auto async_handler = [](sycl::exception_list exceptions) { + for (const std::exception_ptr &e : exceptions) { + try { + std::rethrow_exception(e); + } catch (const sycl::exception &ex) { + std::cerr << "Async SYCL exception: " << ex.what() << '\n'; + std::terminate(); + } + } + }; + + sycl::queue queue{async_handler}; + + int ret = test_fp8_simple_type_conversion(queue); + ret |= test_fp8_simple_type_conversion(queue); + ret |= test_fp8_simple_type_conversion(queue); + ret |= test_fp8_simple_type_conversion(queue); + ret |= test_fp8_simple_type_conversion(queue); + ret |= test_fp8_simple_type_conversion(queue); + ret |= test_fp8_simple_type_conversion(queue); + ret |= test_fp8_simple_type_conversion(queue); + ret |= test_fp8_simple_type_conversion(queue); + ret |= test_fp8_simple_type_conversion(queue); + ret |= test_fp8_simple_type_conversion(queue); + ret |= test_fp8_simple_type_conversion(queue); + ret |= test_fp8_simple_type_conversion(queue); + ret |= test_fp8_simple_type_conversion(queue); + + ret |= test_boolean_conversion(queue); + ret |= test_boolean_conversion_large(queue); + ret |= test_boolean_conversion_nan(queue); + + ret |= test_single_element_carray_constructor(queue); + ret |= test_single_element_carray_constructor(queue); + ret |= test_single_element_carray_constructor( + queue); + + ret |= test_marray_conversion(queue); + ret |= test_marray_conversion(queue); + ret |= test_marray_conversion(queue); + + ret |= test_carray_conversion(queue); + ret |= test_carray_conversion(queue); + ret |= test_carray_conversion(queue); + + ret |= test_rounding_upward(queue); + ret |= test_rounding_toward_zero(queue); + ret |= test_rounding_upward_marray(queue); + ret |= test_rounding_toward_zero_marray(queue); + + ret |= test_power_of_two_round_trip(queue); + ret |= test_max_normal_round_trip(queue); + ret |= test_min_normal_round_trip(queue); + ret |= test_nan_round_trip(queue); + ret |= test_saturation_large_value(queue); + ret |= test_saturation_overflow(queue); + ret |= test_raw_vals_access(queue); + ret |= test_negative_input_drops_sign(queue); + return ret; +} diff --git a/sycl/test-e2e/Experimental/fp8/e8m0_x2_cri_conversion.cpp b/sycl/test-e2e/Experimental/fp8/e8m0_x2_cri_conversion.cpp new file mode 100644 index 0000000000000..3fc689a867e42 --- /dev/null +++ b/sycl/test-e2e/Experimental/fp8/e8m0_x2_cri_conversion.cpp @@ -0,0 +1,479 @@ + +// RUN: %{build} -Xclang -freg-struct-return -Xspirv-translator --spirv-ext=+SPV_INTEL_fp_conversions,+SPV_EXT_float8,+SPV_KHR_bfloat16 -o %t.out +// RUN: %{run} SYCL_UR_TRACE=1 %t.out + +// Warning! This test requires CRI device or its simulator run to communicate +// via TCP socket with port 60999, or any other from config + +// TODO need to set requirement of intel_feature_gpu_cri + +#include +#include +#include +#include +#include + +using namespace sycl::ext::oneapi::experimental; + +namespace { + +template +int test_explicit_upward_carray_constructor(sycl::queue &queue) { + T input[2] = {static_cast(4.0f), static_cast(16.0f)}; + auto *data = sycl::malloc_shared(1, queue); + auto *out = sycl::malloc_shared(2, queue); + data[0] = fp8_e8m0_x2(input, rounding::upward); + + queue.single_task([=]() { + sycl::marray unpacked = static_cast>(data[0]); + out[0] = unpacked[0]; + out[1] = unpacked[1]; + }); + queue.wait_and_throw(); + + int ret = 0; + if (static_cast(out[0]) != 4.0f) + ret = 1; + if (static_cast(out[1]) != 16.0f) + ret = 1; + + sycl::free(data, queue); + sycl::free(out, queue); + return ret; +} + +template +int test_explicit_toward_zero_carray_constructor(sycl::queue &queue) { + T input[2] = {static_cast(5.0f), static_cast(12.0f)}; + auto *data = sycl::malloc_shared(1, queue); + auto *out = sycl::malloc_shared(2, queue); + data[0] = fp8_e8m0_x2(input, rounding::toward_zero); + + queue.single_task([=]() { + sycl::marray unpacked = static_cast>(data[0]); + out[0] = unpacked[0]; + out[1] = unpacked[1]; + }); + queue.wait_and_throw(); + + int ret = 0; + if (static_cast(out[0]) != 4.0f) + ret = 1; + if (static_cast(out[1]) != 8.0f) + ret = 1; + + sycl::free(data, queue); + sycl::free(out, queue); + return ret; +} + +template +int test_explicit_upward_marray_constructor(sycl::queue &queue) { + sycl::marray input(static_cast(2.0f), static_cast(64.0f)); + auto *data = sycl::malloc_shared(1, queue); + auto *out = sycl::malloc_shared(2, queue); + data[0] = fp8_e8m0_x2(input, rounding::upward); + + queue.single_task([=]() { + sycl::marray unpacked = static_cast>(data[0]); + out[0] = unpacked[0]; + out[1] = unpacked[1]; + }); + queue.wait_and_throw(); + + int ret = 0; + if (static_cast(out[0]) != 2.0f) + ret = 1; + if (static_cast(out[1]) != 64.0f) + ret = 1; + + sycl::free(data, queue); + sycl::free(out, queue); + return ret; +} + +template +int test_explicit_toward_zero_marray_constructor(sycl::queue &queue) { + sycl::marray input(static_cast(3.0f), static_cast(10.0f)); + auto *data = sycl::malloc_shared(1, queue); + auto *out = sycl::malloc_shared(2, queue); + data[0] = fp8_e8m0_x2(input, rounding::toward_zero); + + queue.single_task([=]() { + sycl::marray unpacked = static_cast>(data[0]); + out[0] = unpacked[0]; + out[1] = unpacked[1]; + }); + queue.wait_and_throw(); + + int ret = 0; + if (static_cast(out[0]) != 2.0f) + ret = 1; + if (static_cast(out[1]) != 8.0f) + ret = 1; + + sycl::free(data, queue); + sycl::free(out, queue); + return ret; +} + +int test_boundary_round_trip_nan(sycl::queue &queue) { + const float input[2] = {std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN()}; + auto *data = sycl::malloc_shared(1, queue); + auto *out = sycl::malloc_shared(2, queue); + data[0] = fp8_e8m0_x2(input, rounding::upward); + + queue.single_task([=]() { + fp8_e8m0_x2 value = data[0]; + sycl::marray unpacked = + static_cast>(value); + data[0] = fp8_e8m0_x2(unpacked, rounding::upward); + sycl::marray round_tripped = + static_cast>(data[0]); + out[0] = round_tripped[0]; + out[1] = round_tripped[1]; + }); + queue.wait_and_throw(); + + int ret = !(std::isnan(out[0]) && std::isnan(out[1])); + sycl::free(data, queue); + sycl::free(out, queue); + return ret; +} + +int test_boundary_round_trip_exact_powers_of_two(sycl::queue &queue) { + const float input[2] = {32.0f, 0.25f}; + auto *data = sycl::malloc_shared(1, queue); + auto *out = sycl::malloc_shared(2, queue); + data[0] = fp8_e8m0_x2(input, rounding::upward); + + queue.single_task([=]() { + fp8_e8m0_x2 value = data[0]; + sycl::marray unpacked = + static_cast>(value); + data[0] = fp8_e8m0_x2(unpacked, rounding::upward); + sycl::marray round_tripped = + static_cast>(data[0]); + out[0] = round_tripped[0]; + out[1] = round_tripped[1]; + }); + queue.wait_and_throw(); + + int ret = 0; + if (out[0] != 32.0f) + ret = 1; + if (out[1] != 0.25f) + ret = 1; + + sycl::free(data, queue); + sycl::free(out, queue); + return ret; +} + +int test_boundary_round_trip_max_min_normal(sycl::queue &queue) { + float max_val = std::ldexp(1.0f, 127); + float min_val = std::ldexp(1.0f, -127); + const float input[2] = {max_val, min_val}; + auto *data = sycl::malloc_shared(1, queue); + auto *out = sycl::malloc_shared(2, queue); + data[0] = fp8_e8m0_x2(input, rounding::toward_zero); + + queue.single_task([=]() { + fp8_e8m0_x2 value = data[0]; + sycl::marray unpacked = + static_cast>(value); + data[0] = fp8_e8m0_x2(unpacked, rounding::toward_zero); + sycl::marray round_tripped = + static_cast>(data[0]); + out[0] = round_tripped[0]; + out[1] = round_tripped[1]; + }); + queue.wait_and_throw(); + + int ret = 0; + if (out[0] != max_val) + ret = 1; + if (out[1] != min_val) + ret = 1; + + sycl::free(data, queue); + sycl::free(out, queue); + return ret; +} + +int test_boundary_saturation_infinity_clamp(sycl::queue &queue) { + const float input[2] = {std::numeric_limits::infinity(), + std::ldexp(1.0f, 128)}; + auto *data = sycl::malloc_shared(1, queue); + auto *out = sycl::malloc_shared(2, queue); + data[0] = fp8_e8m0_x2(input, rounding::upward); + + queue.single_task([=]() { + fp8_e8m0_x2 value = data[0]; + sycl::marray unpacked = + static_cast>(value); + out[0] = unpacked[0]; + out[1] = unpacked[1]; + }); + queue.wait_and_throw(); + + float max_e8m0 = std::ldexp(1.0f, 127); + int ret = 0; + if (out[0] != max_e8m0) + ret = 1; + if (out[1] != max_e8m0) + ret = 1; + + sycl::free(data, queue); + sycl::free(out, queue); + return ret; +} + +int test_boundary_negative_input_drops_sign(sycl::queue &queue) { + const float input[2] = {-4.0f, -32.0f}; + auto *data = sycl::malloc_shared(1, queue); + auto *out = sycl::malloc_shared(2, queue); + data[0] = fp8_e8m0_x2(input, rounding::upward); + + queue.single_task([=]() { + fp8_e8m0_x2 value = data[0]; + sycl::marray unpacked = + static_cast>(value); + out[0] = unpacked[0]; + out[1] = unpacked[1]; + }); + queue.wait_and_throw(); + + int ret = 0; + if (out[0] != 4.0f) + ret = 1; + if (out[1] != 32.0f) + ret = 1; + + sycl::free(data, queue); + sycl::free(out, queue); + return ret; +} + +int test_rounding_upward_non_power_of_two(sycl::queue &queue) { + const float input[2] = {3.0f, 6.0f}; + auto *data = sycl::malloc_shared(1, queue); + auto *out = sycl::malloc_shared(2, queue); + data[0] = fp8_e8m0_x2(input, rounding::upward); + + queue.single_task([=]() { + fp8_e8m0_x2 value = data[0]; + sycl::marray unpacked = + static_cast>(value); + out[0] = unpacked[0]; + out[1] = unpacked[1]; + }); + queue.wait_and_throw(); + + int ret = 0; + if (out[0] != 4.0f) + ret = 1; + if (out[1] != 8.0f) + ret = 1; + + sycl::free(data, queue); + sycl::free(out, queue); + return ret; +} + +int test_rounding_toward_zero_non_power_of_two(sycl::queue &queue) { + const float input[2] = {3.0f, 6.0f}; + auto *data = sycl::malloc_shared(1, queue); + auto *out = sycl::malloc_shared(2, queue); + data[0] = fp8_e8m0_x2(input, rounding::toward_zero); + + queue.single_task([=]() { + fp8_e8m0_x2 value = data[0]; + sycl::marray unpacked = + static_cast>(value); + out[0] = unpacked[0]; + out[1] = unpacked[1]; + }); + queue.wait_and_throw(); + + int ret = 0; + if (out[0] != 2.0f) + ret = 1; + if (out[1] != 4.0f) + ret = 1; + + sycl::free(data, queue); + sycl::free(out, queue); + return ret; +} + +int test_raw_vals_access(sycl::queue &queue) { + auto *data = sycl::malloc_shared(1, queue); + auto *out = sycl::malloc_shared(2, queue); + float input[2] = {1.0f, 2.0f}; + data[0] = fp8_e8m0_x2(input, rounding::upward); + + queue.single_task([=]() { + out[0] = data[0].vals[0]; + out[1] = data[0].vals[1]; + }); + queue.wait_and_throw(); + + int ret = 0; + if (out[0] != 127) + ret = 1; + if (out[1] != 128) + ret = 1; + + sycl::free(data, queue); + sycl::free(out, queue); + return ret; +} + +} // namespace + +template int test_fp8_simple_type_conversion(sycl::queue &queue) { + auto *data = sycl::malloc_shared(1, queue); + data[0] = fp8_e8m0_x2(static_cast(4.0f), static_cast(16.0f)); + + queue.single_task([=]() { + fp8_e8m0_x2 value = data[0]; + sycl::marray f = static_cast>(value); + f[0] *= static_cast(2.0f); + f[1] *= static_cast(2.0f); + data[0] = fp8_e8m0_x2(f); + }); + queue.wait_and_throw(); + + sycl::marray expected_input(static_cast(8.0f), + static_cast(32.0f)); + fp8_e8m0_x2 expected(expected_input); + sycl::marray out = static_cast>(data[0]); + sycl::marray expected_out = static_cast>(expected); + + sycl::free(data, queue); + for (size_t i = 0; i < 2; ++i) { + if (std::fabs(static_cast(out[i]) - + static_cast(expected_out[i])) > 0.0f) + return 1; + } + + return 0; +} + +template int test_marray_conversion(sycl::queue &queue) { + sycl::marray input(static_cast(4.0f), static_cast(16.0f)); + auto *data = sycl::malloc_shared(1, queue); + data[0] = fp8_e8m0_x2(input); + + queue.single_task([=]() { + fp8_e8m0_x2 value = data[0]; + sycl::marray f = static_cast>(value); + f[0] *= static_cast(2.0f); + f[1] *= static_cast(4.0f); + data[0] = fp8_e8m0_x2(f); + }); + queue.wait_and_throw(); + sycl::marray expected_input(static_cast(8.0f), + static_cast(64.0f)); + fp8_e8m0_x2 expected(expected_input); + sycl::marray out = static_cast>(data[0]); + sycl::marray expected_out = static_cast>(expected); + + sycl::free(data, queue); + for (size_t i = 0; i < 2; ++i) { + if (std::fabs(static_cast(out[i]) - + static_cast(expected_out[i])) > 0.0f) + return 1; + } + return 0; +} + +template int test_carray_conversion(sycl::queue &queue) { + T input[2] = {static_cast(4.0f), static_cast(16.0f)}; + auto *data = sycl::malloc_shared(1, queue); + data[0] = fp8_e8m0_x2(input); + + queue.single_task([=]() { + fp8_e8m0_x2 value = data[0]; + sycl::marray unpacked = static_cast>(value); + T output[2] = {unpacked[0] * static_cast(2.0f), + unpacked[1] * static_cast(4.0f)}; + data[0] = fp8_e8m0_x2(output); + }); + queue.wait_and_throw(); + + T expected_input[2] = {static_cast(8.0f), static_cast(64.0f)}; + fp8_e8m0_x2 expected(expected_input); + sycl::marray out = static_cast>(data[0]); + sycl::marray expected_out = static_cast>(expected); + + sycl::free(data, queue); + for (size_t i = 0; i < 2; ++i) { + if (std::fabs(static_cast(out[i]) - + static_cast(expected_out[i])) > 0.0f) + return 1; + } + + return 0; +} + +int main() { + auto async_handler = [](sycl::exception_list exceptions) { + for (const std::exception_ptr &e : exceptions) { + try { + std::rethrow_exception(e); + } catch (const sycl::exception &ex) { + std::cerr << "Async SYCL exception: " << ex.what() << '\n'; + std::terminate(); + } + } + }; + + sycl::queue queue{async_handler}; + + int ret = test_fp8_simple_type_conversion(queue); + ret |= test_fp8_simple_type_conversion(queue); + ret |= test_fp8_simple_type_conversion(queue); + + ret |= test_marray_conversion(queue); + ret |= test_marray_conversion(queue); + ret |= test_marray_conversion(queue); + + ret |= test_carray_conversion(queue); + ret |= test_carray_conversion(queue); + ret |= test_carray_conversion(queue); + + ret |= test_explicit_upward_carray_constructor(queue); + ret |= test_explicit_upward_carray_constructor(queue); + ret |= test_explicit_upward_carray_constructor( + queue); + + ret |= test_explicit_toward_zero_carray_constructor(queue); + ret |= test_explicit_toward_zero_carray_constructor(queue); + ret |= + test_explicit_toward_zero_carray_constructor( + queue); + + ret |= test_explicit_upward_marray_constructor(queue); + ret |= test_explicit_upward_marray_constructor(queue); + ret |= test_explicit_upward_marray_constructor( + queue); + + ret |= test_explicit_toward_zero_marray_constructor(queue); + ret |= test_explicit_toward_zero_marray_constructor(queue); + ret |= + test_explicit_toward_zero_marray_constructor( + queue); + + ret |= test_boundary_round_trip_nan(queue); + ret |= test_boundary_round_trip_exact_powers_of_two(queue); + ret |= test_boundary_round_trip_max_min_normal(queue); + ret |= test_boundary_saturation_infinity_clamp(queue); + ret |= test_boundary_negative_input_drops_sign(queue); + ret |= test_rounding_upward_non_power_of_two(queue); + ret |= test_rounding_toward_zero_non_power_of_two(queue); + ret |= test_raw_vals_access(queue); + return ret; +} diff --git a/sycl/unittests/Extensions/fp8/fp8_e4m3.cpp b/sycl/unittests/Extensions/fp8/fp8_e4m3.cpp index b8127a712e670..1b4851ca5a579 100644 --- a/sycl/unittests/Extensions/fp8/fp8_e4m3.cpp +++ b/sycl/unittests/Extensions/fp8/fp8_e4m3.cpp @@ -13,6 +13,26 @@ code thus unit tests check only API using namespace sycl::ext::oneapi::experimental; +TEST(FP8E4M3Test, TrivialSpecialMembers) { + EXPECT_TRUE((std::is_trivially_default_constructible_v)); + EXPECT_TRUE((std::is_trivially_copy_constructible_v)); + EXPECT_TRUE((std::is_trivially_destructible_v)); + EXPECT_TRUE((std::is_trivially_copy_assignable_v)); + + EXPECT_TRUE((std::is_trivially_default_constructible_v)); + EXPECT_TRUE((std::is_trivially_copy_constructible_v)); + EXPECT_TRUE((std::is_trivially_destructible_v)); + EXPECT_TRUE((std::is_trivially_copy_assignable_v)); + + fp8_e4m3 source(1.0f); + fp8_e4m3 copy(source); + fp8_e4m3 assigned; + assigned = source; + + EXPECT_EQ(copy.vals[0], source.vals[0]); + EXPECT_EQ(assigned.vals[0], source.vals[0]); +} + TEST(FP8E4M3Test, VariadicHalf) { fp8_e4m3_x2 a(sycl::half(1.0f), sycl::half(2.0f)); @@ -242,6 +262,14 @@ TEST(FP8E4M3Test, BoolOperatorZeroRules) { EXPECT_TRUE(static_cast(sub)); } +TEST(FP8E4M3Test, BoolOperatorTreatsNaNAsTrue) { + fp8_e4m3 nanv(std::numeric_limits::quiet_NaN()); + + EXPECT_EQ(sizeof(nanv.vals), 1u); + EXPECT_EQ(nanv.vals[0], 0x7F); + EXPECT_TRUE(static_cast(nanv)); +} + TEST(FP8E4M3Test, CArrayFloatHostToEvenFinite) { // Host code supports only rounding::to_even and saturation::finite. const float in[2] = {1.0f, 1.1f}; From 4b44878e54410f341b7515b35cd76662d46b7cd1 Mon Sep 17 00:00:00 2001 From: "Klochkov, Denis" Date: Mon, 18 May 2026 09:44:47 +0200 Subject: [PATCH 50/89] [SYCL][E2E] add more FP8 e2e tests --- .../Experimental/fp8/e4m3_cri_conversion.cpp | 36 +++++++------- .../fp8/e4m3_x2_cri_conversion.cpp | 43 ++++++++++------- .../Experimental/fp8/e5m2_cri_conversion.cpp | 8 ++-- .../fp8/e5m2_x2_cri_conversion.cpp | 48 ++++++++++++++----- .../fp8/e8m0_x2_cri_conversion.cpp | 10 ++-- 5 files changed, 87 insertions(+), 58 deletions(-) diff --git a/sycl/test-e2e/Experimental/fp8/e4m3_cri_conversion.cpp b/sycl/test-e2e/Experimental/fp8/e4m3_cri_conversion.cpp index 06381e25b18ba..881cded8be3cf 100644 --- a/sycl/test-e2e/Experimental/fp8/e4m3_cri_conversion.cpp +++ b/sycl/test-e2e/Experimental/fp8/e4m3_cri_conversion.cpp @@ -54,7 +54,8 @@ int test_boolean_conversion(sycl::queue &queue, float test_value, return ret; } -template int test_single_element_carray_constructor(sycl::queue &queue) { +template +int test_single_element_carray_constructor(sycl::queue &queue) { T input[1] = {static_cast(1.25f)}; auto *data = sycl::malloc_shared(1, queue); data[0] = fp8_e4m3(input); @@ -89,16 +90,16 @@ template int test_marray_conversion(sycl::queue &queue) { data[0] = fp8_e4m3(f); }); queue.wait_and_throw(); - /* - sycl::marray expected_input(static_cast(2.25f)); - fp8_e4m3 expected(expected_input); - sycl::marray out = static_cast>(data[0]); - sycl::marray expected_out = static_cast>(expected); - - sycl::free(data, queue); - if (std::fabs(out[0] - expected_out[0]) > 0.0f) - return 1; - */ + + sycl::marray expected_input(static_cast(2.25f)); + fp8_e4m3 expected(expected_input); + sycl::marray out = static_cast>(data[0]); + sycl::marray expected_out = static_cast>(expected); + + sycl::free(data, queue); + if (std::fabs(out[0] - expected_out[0]) > 0.0f) + return 1; + return 0; } @@ -153,8 +154,9 @@ int main() { ret |= test_fp8_simple_type_conversion(queue); ret |= test_fp8_simple_type_conversion(queue); // check special requirement for boolean conversion - only +0.0 and -0.0 - // should be converted to false, all other values should be converted to true - ret |= test_boolean_conversion(queue, 0.0f, false); + should be converted to false, + all other values should be converted to true ret |= + test_boolean_conversion(queue, 0.0f, false); ret |= test_boolean_conversion(queue, -0.0f, false); ret |= test_boolean_conversion(queue, 1.0f, true); ret |= test_boolean_conversion(queue, -1.0f, true); @@ -164,17 +166,17 @@ int main() { ret |= test_single_element_carray_constructor(queue); ret |= test_single_element_carray_constructor(queue); - // ret |= - // test_single_element_carray_constructor(queue); + ret |= test_single_element_carray_constructor( + queue); ret |= test_marray_conversion(queue); ret |= test_marray_conversion(queue); // TODO: uncomment when bfloat16 conversion is fixed - //ret |= test_marray_conversion(queue); + ret |= test_marray_conversion(queue); ret |= test_carray_conversion(queue); ret |= test_carray_conversion(queue); // TODO: uncomment when bfloat16 conversion is fixed - //ret |= test_carray_conversion(queue); + ret |= test_carray_conversion(queue); return ret; } diff --git a/sycl/test-e2e/Experimental/fp8/e4m3_x2_cri_conversion.cpp b/sycl/test-e2e/Experimental/fp8/e4m3_x2_cri_conversion.cpp index dce5c9ae77ac2..9700f21989a26 100644 --- a/sycl/test-e2e/Experimental/fp8/e4m3_x2_cri_conversion.cpp +++ b/sycl/test-e2e/Experimental/fp8/e4m3_x2_cri_conversion.cpp @@ -91,7 +91,8 @@ int test_boundary_round_trip_nan(sycl::queue &queue) { queue.single_task([=]() { fp8_e4m3_x2 value = data[0]; - sycl::marray unpacked = static_cast>(value); + sycl::marray unpacked = + static_cast>(value); data[0] = fp8_e4m3_x2(unpacked, rounding::to_even); sycl::marray round_tripped = static_cast>(data[0]); @@ -114,7 +115,8 @@ int test_boundary_round_trip_negative_zero(sycl::queue &queue) { queue.single_task([=]() { fp8_e4m3_x2 value = data[0]; - sycl::marray unpacked = static_cast>(value); + sycl::marray unpacked = + static_cast>(value); data[0] = fp8_e4m3_x2(unpacked, rounding::to_even); sycl::marray round_tripped = static_cast>(data[0]); @@ -142,7 +144,8 @@ int test_boundary_round_trip_subnormals(sycl::queue &queue) { queue.single_task([=]() { fp8_e4m3_x2 value = data[0]; - sycl::marray unpacked = static_cast>(value); + sycl::marray unpacked = + static_cast>(value); data[0] = fp8_e4m3_x2(unpacked, rounding::to_even); sycl::marray round_tripped = static_cast>(data[0]); @@ -170,7 +173,8 @@ int test_boundary_round_trip_exact_normals(sycl::queue &queue) { queue.single_task([=]() { fp8_e4m3_x2 value = data[0]; - sycl::marray unpacked = static_cast>(value); + sycl::marray unpacked = + static_cast>(value); data[0] = fp8_e4m3_x2(unpacked, rounding::to_even); sycl::marray round_tripped = static_cast>(data[0]); @@ -198,7 +202,8 @@ int test_boundary_round_trip_exact_subnormal_limits(sycl::queue &queue) { queue.single_task([=]() { fp8_e4m3_x2 value = data[0]; - sycl::marray unpacked = static_cast>(value); + sycl::marray unpacked = + static_cast>(value); data[0] = fp8_e4m3_x2(unpacked, rounding::to_even); sycl::marray round_tripped = static_cast>(data[0]); @@ -226,7 +231,8 @@ int test_boundary_round_trip_saturation_and_infinity_clamp(sycl::queue &queue) { queue.single_task([=]() { fp8_e4m3_x2 value = data[0]; - sycl::marray unpacked = static_cast>(value); + sycl::marray unpacked = + static_cast>(value); data[0] = fp8_e4m3_x2(unpacked, rounding::to_even); sycl::marray round_tripped = static_cast>(data[0]); @@ -268,8 +274,8 @@ template int test_fp8_simple_type_conversion(sycl::queue &queue) { sycl::free(data, queue); for (size_t i = 0; i < 2; ++i) { - if (std::fabs(static_cast(out[i]) - static_cast(expected_out[i])) > - 0.0f) + if (std::fabs(static_cast(out[i]) - + static_cast(expected_out[i])) > 0.0f) return 1; } @@ -289,15 +295,16 @@ template int test_marray_conversion(sycl::queue &queue) { data[0] = fp8_e4m3_x2(f); }); queue.wait_and_throw(); - sycl::marray expected_input(static_cast(2.25f), static_cast(4.5f)); + sycl::marray expected_input(static_cast(2.25f), + static_cast(4.5f)); fp8_e4m3_x2 expected(expected_input); sycl::marray out = static_cast>(data[0]); sycl::marray expected_out = static_cast>(expected); sycl::free(data, queue); for (size_t i = 0; i < 2; ++i) { - if (std::fabs(static_cast(out[i]) - static_cast(expected_out[i])) > - 0.0f) + if (std::fabs(static_cast(out[i]) - + static_cast(expected_out[i])) > 0.0f) return 1; } return 0; @@ -350,25 +357,25 @@ int main() { // marray, and marray. int ret = test_fp8_simple_type_conversion(queue); ret |= test_fp8_simple_type_conversion(queue); - // ret |= test_fp8_simple_type_conversion(queue); + ret |= test_fp8_simple_type_conversion(queue); ret |= test_marray_conversion(queue); ret |= test_marray_conversion(queue); -// ret |= test_marray_conversion(queue); + ret |= test_marray_conversion(queue); ret |= test_carray_conversion(queue); ret |= test_carray_conversion(queue); - // ret |= test_carray_conversion(queue); + ret |= test_carray_conversion(queue); ret |= test_explicit_to_even_carray_constructor(queue); ret |= test_explicit_to_even_carray_constructor(queue); - // ret |= - // test_explicit_to_even_carray_constructor(queue); + ret |= test_explicit_to_even_carray_constructor( + queue); ret |= test_explicit_to_even_marray_constructor(queue); ret |= test_explicit_to_even_marray_constructor(queue); - // ret |= - // test_explicit_to_even_marray_constructor(queue); + ret |= test_explicit_to_even_marray_constructor( + queue); ret |= test_boundary_round_trip_nan(queue); ret |= test_boundary_round_trip_negative_zero(queue); diff --git a/sycl/test-e2e/Experimental/fp8/e5m2_cri_conversion.cpp b/sycl/test-e2e/Experimental/fp8/e5m2_cri_conversion.cpp index 8b27778787187..ad22ca7d73ff0 100644 --- a/sycl/test-e2e/Experimental/fp8/e5m2_cri_conversion.cpp +++ b/sycl/test-e2e/Experimental/fp8/e5m2_cri_conversion.cpp @@ -157,17 +157,17 @@ int main() { ret |= test_single_element_carray_constructor(queue); ret |= test_single_element_carray_constructor(queue); - // ret |= - // test_single_element_carray_constructor(queue); + ret |= test_single_element_carray_constructor( + queue); ret |= test_marray_conversion(queue); ret |= test_marray_conversion(queue); // TODO: uncomment when bfloat16 conversion is fixed - // ret |= test_marray_conversion(queue); + ret |= test_marray_conversion(queue); ret |= test_carray_conversion(queue); ret |= test_carray_conversion(queue); // TODO: uncomment when bfloat16 conversion is fixed - // ret |= test_carray_conversion(queue); + ret |= test_carray_conversion(queue); return ret; } diff --git a/sycl/test-e2e/Experimental/fp8/e5m2_x2_cri_conversion.cpp b/sycl/test-e2e/Experimental/fp8/e5m2_x2_cri_conversion.cpp index 992a42fab8045..b2e518ef93918 100644 --- a/sycl/test-e2e/Experimental/fp8/e5m2_x2_cri_conversion.cpp +++ b/sycl/test-e2e/Experimental/fp8/e5m2_x2_cri_conversion.cpp @@ -9,12 +9,23 @@ #include #include +#include #include #include #include using namespace sycl::ext::oneapi::experimental; +#ifdef __SYCL_DEVICE_ONLY__ +#define CONSTANT __attribute__((opencl_constant)) +#else +#define CONSTANT +#endif + +static const CONSTANT char kKernelStartFmt[] = "kernel: enter vals=(%u,%u)\n"; +static const CONSTANT char kKernelUnpackedFmt[] = "kernel: unpacked=(%f,%f)\n"; +static const CONSTANT char kKernelStoreFmt[] = "kernel: store vals=(%u,%u)\n"; + namespace { bool equal_or_both_nan(float actual, float expected) { @@ -311,20 +322,32 @@ template int test_fp8_simple_type_conversion(sycl::queue &queue) { auto *data = sycl::malloc_shared(1, queue); data[0] = fp8_e5m2_x2(static_cast(1.5f), static_cast(2.5f)); + std::cout << "KErnel start\n"; queue.single_task([=]() { - fp8_e5m2_x2 value = data[0]; - sycl::marray f = static_cast>(value); - f[0] += static_cast(1.0f); - f[1] += static_cast(1.0f); - data[0] = fp8_e5m2_x2(f); + // sycl::ext::oneapi::experimental::printf("1\n"); + // fp8_e5m2_x2 value = data[0]; + // sycl::ext::oneapi::experimental::printf(kKernelStartFmt, + // (unsigned int)value.vals[0], + // (unsigned int)value.vals[1]); + // sycl::marray f = static_cast>(value); + // sycl::ext::oneapi::experimental::printf(kKernelUnpackedFmt, + // (float)f[0], (float)f[1]); + // f[0] += static_cast(1.0f); + // f[1] += static_cast(1.0f); + // data[0] = fp8_e5m2_x2(f); + // sycl::ext::oneapi::experimental::printf(kKernelStoreFmt, + // (unsigned int)data[0].vals[0], + // (unsigned int)data[0].vals[1]); }); queue.wait_and_throw(); + std::cout << "KErnel finish\n"; sycl::marray expected_input(static_cast(2.5f), static_cast(3.5f)); fp8_e5m2_x2 expected(expected_input); sycl::marray out = static_cast>(data[0]); sycl::marray expected_out = static_cast>(expected); + std::cout << "free data\n"; sycl::free(data, queue); for (size_t i = 0; i < 2; ++i) { if (std::fabs(static_cast(out[i]) - @@ -332,6 +355,7 @@ template int test_fp8_simple_type_conversion(sycl::queue &queue) { return 1; } + std::cout << "success\n"; return 0; } @@ -407,25 +431,25 @@ int main() { int ret = test_fp8_simple_type_conversion(queue); ret |= test_fp8_simple_type_conversion(queue); - // ret |= test_fp8_simple_type_conversion(queue); + ret |= test_fp8_simple_type_conversion(queue); ret |= test_marray_conversion(queue); ret |= test_marray_conversion(queue); - // ret |= test_marray_conversion(queue); + ret |= test_marray_conversion(queue); ret |= test_carray_conversion(queue); ret |= test_carray_conversion(queue); - // ret |= test_carray_conversion(queue); + ret |= test_carray_conversion(queue); ret |= test_explicit_to_even_carray_constructor(queue); ret |= test_explicit_to_even_carray_constructor(queue); - // ret |= - // test_explicit_to_even_carray_constructor(queue); + ret |= test_explicit_to_even_carray_constructor( + queue); ret |= test_explicit_to_even_marray_constructor(queue); ret |= test_explicit_to_even_marray_constructor(queue); - // ret |= - // test_explicit_to_even_marray_constructor(queue); + ret |= test_explicit_to_even_marray_constructor( + queue); ret |= test_boundary_round_trip_nan(queue); ret |= test_boundary_round_trip_negative_zero(queue); diff --git a/sycl/test-e2e/Experimental/fp8/e8m0_x2_cri_conversion.cpp b/sycl/test-e2e/Experimental/fp8/e8m0_x2_cri_conversion.cpp index 3fc689a867e42..ff6844b1fd594 100644 --- a/sycl/test-e2e/Experimental/fp8/e8m0_x2_cri_conversion.cpp +++ b/sycl/test-e2e/Experimental/fp8/e8m0_x2_cri_conversion.cpp @@ -334,9 +334,11 @@ int test_raw_vals_access(sycl::queue &queue) { } // namespace template int test_fp8_simple_type_conversion(sycl::queue &queue) { + auto *data = sycl::malloc_shared(1, queue); data[0] = fp8_e8m0_x2(static_cast(4.0f), static_cast(16.0f)); + std::cout << "kernel\n"; queue.single_task([=]() { fp8_e8m0_x2 value = data[0]; sycl::marray f = static_cast>(value); @@ -346,6 +348,7 @@ template int test_fp8_simple_type_conversion(sycl::queue &queue) { }); queue.wait_and_throw(); + std::cout << "kernel finished\n"; sycl::marray expected_input(static_cast(8.0f), static_cast(32.0f)); fp8_e8m0_x2 expected(expected_input); @@ -436,37 +439,30 @@ int main() { int ret = test_fp8_simple_type_conversion(queue); ret |= test_fp8_simple_type_conversion(queue); ret |= test_fp8_simple_type_conversion(queue); - ret |= test_marray_conversion(queue); ret |= test_marray_conversion(queue); ret |= test_marray_conversion(queue); - ret |= test_carray_conversion(queue); ret |= test_carray_conversion(queue); ret |= test_carray_conversion(queue); - ret |= test_explicit_upward_carray_constructor(queue); ret |= test_explicit_upward_carray_constructor(queue); ret |= test_explicit_upward_carray_constructor( queue); - ret |= test_explicit_toward_zero_carray_constructor(queue); ret |= test_explicit_toward_zero_carray_constructor(queue); ret |= test_explicit_toward_zero_carray_constructor( queue); - ret |= test_explicit_upward_marray_constructor(queue); ret |= test_explicit_upward_marray_constructor(queue); ret |= test_explicit_upward_marray_constructor( queue); - ret |= test_explicit_toward_zero_marray_constructor(queue); ret |= test_explicit_toward_zero_marray_constructor(queue); ret |= test_explicit_toward_zero_marray_constructor( queue); - ret |= test_boundary_round_trip_nan(queue); ret |= test_boundary_round_trip_exact_powers_of_two(queue); ret |= test_boundary_round_trip_max_min_normal(queue); From 0bbf809ba48698109b43b3ae6372c4c7c04f0ae2 Mon Sep 17 00:00:00 2001 From: "Klochkov, Denis" Date: Mon, 18 May 2026 10:19:45 +0200 Subject: [PATCH 51/89] [SYCL] do not set port for communication with sumilator --- .../Experimental/fp8/e4m3_cri_conversion.cpp | 4 +--- .../fp8/e5m2_x2_cri_conversion.cpp | 23 ++++--------------- .../Experimental/fp8/lit.local.cfg.py | 1 - 3 files changed, 6 insertions(+), 22 deletions(-) diff --git a/sycl/test-e2e/Experimental/fp8/e4m3_cri_conversion.cpp b/sycl/test-e2e/Experimental/fp8/e4m3_cri_conversion.cpp index 881cded8be3cf..687ba88793d5b 100644 --- a/sycl/test-e2e/Experimental/fp8/e4m3_cri_conversion.cpp +++ b/sycl/test-e2e/Experimental/fp8/e4m3_cri_conversion.cpp @@ -154,9 +154,7 @@ int main() { ret |= test_fp8_simple_type_conversion(queue); ret |= test_fp8_simple_type_conversion(queue); // check special requirement for boolean conversion - only +0.0 and -0.0 - should be converted to false, - all other values should be converted to true ret |= - test_boolean_conversion(queue, 0.0f, false); + ret |= test_boolean_conversion(queue, 0.0f, false); ret |= test_boolean_conversion(queue, -0.0f, false); ret |= test_boolean_conversion(queue, 1.0f, true); ret |= test_boolean_conversion(queue, -1.0f, true); diff --git a/sycl/test-e2e/Experimental/fp8/e5m2_x2_cri_conversion.cpp b/sycl/test-e2e/Experimental/fp8/e5m2_x2_cri_conversion.cpp index b2e518ef93918..33d1329fc70bb 100644 --- a/sycl/test-e2e/Experimental/fp8/e5m2_x2_cri_conversion.cpp +++ b/sycl/test-e2e/Experimental/fp8/e5m2_x2_cri_conversion.cpp @@ -322,32 +322,20 @@ template int test_fp8_simple_type_conversion(sycl::queue &queue) { auto *data = sycl::malloc_shared(1, queue); data[0] = fp8_e5m2_x2(static_cast(1.5f), static_cast(2.5f)); - std::cout << "KErnel start\n"; queue.single_task([=]() { - // sycl::ext::oneapi::experimental::printf("1\n"); - // fp8_e5m2_x2 value = data[0]; - // sycl::ext::oneapi::experimental::printf(kKernelStartFmt, - // (unsigned int)value.vals[0], - // (unsigned int)value.vals[1]); - // sycl::marray f = static_cast>(value); - // sycl::ext::oneapi::experimental::printf(kKernelUnpackedFmt, - // (float)f[0], (float)f[1]); - // f[0] += static_cast(1.0f); - // f[1] += static_cast(1.0f); - // data[0] = fp8_e5m2_x2(f); - // sycl::ext::oneapi::experimental::printf(kKernelStoreFmt, - // (unsigned int)data[0].vals[0], - // (unsigned int)data[0].vals[1]); + fp8_e5m2_x2 value = data[0]; + sycl::marray f = static_cast>(value); + f[0] += static_cast(1.0f); + f[1] += static_cast(1.0f); + data[0] = fp8_e5m2_x2(f); }); queue.wait_and_throw(); - std::cout << "KErnel finish\n"; sycl::marray expected_input(static_cast(2.5f), static_cast(3.5f)); fp8_e5m2_x2 expected(expected_input); sycl::marray out = static_cast>(data[0]); sycl::marray expected_out = static_cast>(expected); - std::cout << "free data\n"; sycl::free(data, queue); for (size_t i = 0; i < 2; ++i) { if (std::fabs(static_cast(out[i]) - @@ -355,7 +343,6 @@ template int test_fp8_simple_type_conversion(sycl::queue &queue) { return 1; } - std::cout << "success\n"; return 0; } diff --git a/sycl/test-e2e/Experimental/fp8/lit.local.cfg.py b/sycl/test-e2e/Experimental/fp8/lit.local.cfg.py index 605551f377933..edd1afe5db831 100644 --- a/sycl/test-e2e/Experimental/fp8/lit.local.cfg.py +++ b/sycl/test-e2e/Experimental/fp8/lit.local.cfg.py @@ -2,7 +2,6 @@ config.environment["ProductFamilyOverride"] = "cri" config.environment["HardwareInfoOverride"] = "1x8x8" config.environment["SetCommandStreamReceiver"] = "2" -config.environment["TbxPort"] = "60999" config.environment["RebuildPrecompiledKernels"] = "1" config.environment["EnableDirectSubmission"] = "0" config.environment["EnableBlitterOperationsSupport"] = "1" From 30cf7c6bf8ce4340ae67d95656938ab5f23b4fbe Mon Sep 17 00:00:00 2001 From: "Klochkov, Denis" Date: Mon, 18 May 2026 10:25:31 +0200 Subject: [PATCH 52/89] [SYCL][E2E] add requirement of cri gpu Co-authored-by: Copilot --- sycl/test-e2e/Experimental/fp8/e4m3_cri_conversion.cpp | 6 +----- sycl/test-e2e/Experimental/fp8/e4m3_x2_cri_conversion.cpp | 5 +---- sycl/test-e2e/Experimental/fp8/e5m2_cri_conversion.cpp | 6 +----- sycl/test-e2e/Experimental/fp8/e5m2_x2_cri_conversion.cpp | 7 +------ sycl/test-e2e/Experimental/fp8/e8m0_cri_conversion.cpp | 6 +----- sycl/test-e2e/Experimental/fp8/e8m0_x2_cri_conversion.cpp | 5 +---- 6 files changed, 6 insertions(+), 29 deletions(-) diff --git a/sycl/test-e2e/Experimental/fp8/e4m3_cri_conversion.cpp b/sycl/test-e2e/Experimental/fp8/e4m3_cri_conversion.cpp index 687ba88793d5b..0a2e3bb7d135a 100644 --- a/sycl/test-e2e/Experimental/fp8/e4m3_cri_conversion.cpp +++ b/sycl/test-e2e/Experimental/fp8/e4m3_cri_conversion.cpp @@ -1,12 +1,8 @@ +// REQUIRES: intel_feature_gpu_cri // RUN: %{build} -Xclang -freg-struct-return -Xspirv-translator --spirv-ext=+SPV_INTEL_fp_conversions,+SPV_EXT_float8,+SPV_KHR_bfloat16 -o %t.out // RUN: %{run} SYCL_UR_TRACE=1 %t.out -// Warning! This test requires CRI device or its simulator run to communicate -// via TCP socket with port 60999, or any other from config - -// TODO need to set requirement of intel_feature_gpu_cri - #include #include #include diff --git a/sycl/test-e2e/Experimental/fp8/e4m3_x2_cri_conversion.cpp b/sycl/test-e2e/Experimental/fp8/e4m3_x2_cri_conversion.cpp index 9700f21989a26..07a948772972e 100644 --- a/sycl/test-e2e/Experimental/fp8/e4m3_x2_cri_conversion.cpp +++ b/sycl/test-e2e/Experimental/fp8/e4m3_x2_cri_conversion.cpp @@ -1,11 +1,8 @@ +// REQUIRES: intel_feature_gpu_cri // RUN: %{build} -Xclang -freg-struct-return -Xspirv-translator --spirv-ext=+SPV_INTEL_fp_conversions,+SPV_EXT_float8,+SPV_KHR_bfloat16 -o %t.out // RUN: %{run} SYCL_UR_TRACE=1 %t.out -// Warning! This test requires CRI device or its simulator run to communicate -// via TCP socket with port 60999, or any other from config - -// TODO need to set requirement of intel_feature_gpu_cri #include #include diff --git a/sycl/test-e2e/Experimental/fp8/e5m2_cri_conversion.cpp b/sycl/test-e2e/Experimental/fp8/e5m2_cri_conversion.cpp index ad22ca7d73ff0..cbc779f86bcb7 100644 --- a/sycl/test-e2e/Experimental/fp8/e5m2_cri_conversion.cpp +++ b/sycl/test-e2e/Experimental/fp8/e5m2_cri_conversion.cpp @@ -1,12 +1,8 @@ +// REQUIRES: intel_feature_gpu_cri // RUN: %{build} -Xclang -freg-struct-return -Xspirv-translator --spirv-ext=+SPV_INTEL_fp_conversions,+SPV_EXT_float8,+SPV_KHR_bfloat16 -o %t.out // RUN: %{run} SYCL_UR_TRACE=1 %t.out -// Warning! This test requires CRI device or its simulator run to communicate -// via TCP socket with port 60999, or any other from config - -// TODO need to set requirement of intel_feature_gpu_cri - #include #include #include diff --git a/sycl/test-e2e/Experimental/fp8/e5m2_x2_cri_conversion.cpp b/sycl/test-e2e/Experimental/fp8/e5m2_x2_cri_conversion.cpp index 33d1329fc70bb..433c454357f2a 100644 --- a/sycl/test-e2e/Experimental/fp8/e5m2_x2_cri_conversion.cpp +++ b/sycl/test-e2e/Experimental/fp8/e5m2_x2_cri_conversion.cpp @@ -1,12 +1,7 @@ - +// REQUIRES: intel_feature_gpu_cri // RUN: %{build} -Xclang -freg-struct-return -Xspirv-translator --spirv-ext=+SPV_INTEL_fp_conversions,+SPV_EXT_float8,+SPV_KHR_bfloat16 -o %t.out // RUN: %{run} SYCL_UR_TRACE=1 %t.out -// Warning! This test requires CRI device or its simulator run to communicate -// via TCP socket with port 60999, or any other from config - -// TODO need to set requirement of intel_feature_gpu_cri - #include #include #include diff --git a/sycl/test-e2e/Experimental/fp8/e8m0_cri_conversion.cpp b/sycl/test-e2e/Experimental/fp8/e8m0_cri_conversion.cpp index 515e0369eee03..cdcae2a853635 100644 --- a/sycl/test-e2e/Experimental/fp8/e8m0_cri_conversion.cpp +++ b/sycl/test-e2e/Experimental/fp8/e8m0_cri_conversion.cpp @@ -1,12 +1,8 @@ +// REQUIRES: intel_feature_gpu_cri // RUN: %{build} -Xclang -freg-struct-return -Xspirv-translator --spirv-ext=+SPV_INTEL_fp_conversions,+SPV_EXT_float8,+SPV_KHR_bfloat16 -o %t.out // RUN: %{run} SYCL_UR_TRACE=1 %t.out -// Warning! This test requires CRI device or its simulator run to communicate -// via TCP socket with port 60999, or any other from config - -// TODO need to set requirement of intel_feature_gpu_cri - #include #include #include diff --git a/sycl/test-e2e/Experimental/fp8/e8m0_x2_cri_conversion.cpp b/sycl/test-e2e/Experimental/fp8/e8m0_x2_cri_conversion.cpp index ff6844b1fd594..2d36c9eca52a1 100644 --- a/sycl/test-e2e/Experimental/fp8/e8m0_x2_cri_conversion.cpp +++ b/sycl/test-e2e/Experimental/fp8/e8m0_x2_cri_conversion.cpp @@ -1,11 +1,8 @@ +// REQUIRES: intel_feature_gpu_cri // RUN: %{build} -Xclang -freg-struct-return -Xspirv-translator --spirv-ext=+SPV_INTEL_fp_conversions,+SPV_EXT_float8,+SPV_KHR_bfloat16 -o %t.out // RUN: %{run} SYCL_UR_TRACE=1 %t.out -// Warning! This test requires CRI device or its simulator run to communicate -// via TCP socket with port 60999, or any other from config - -// TODO need to set requirement of intel_feature_gpu_cri #include #include From 655e8ad0aae4673a2fea65606522909e2160d807 Mon Sep 17 00:00:00 2001 From: "Klochkov, Denis" Date: Mon, 18 May 2026 11:00:31 +0200 Subject: [PATCH 53/89] [SYCL] fix formatting --- .../sycl/ext/oneapi/experimental/float_8bit/types.hpp | 2 -- sycl/unittests/Extensions/fp8/fp8_e4m3.cpp | 6 +----- 2 files changed, 1 insertion(+), 7 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp index 7cd35a62a6ec6..0affdd152ae41 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp @@ -1458,7 +1458,6 @@ template class fp8_e5m2_x { return *this; } - template > fp8_e5m2_x &operator=(short val) { vals[0] = ConvertToFP8(val, saturation::finite); @@ -1525,7 +1524,6 @@ template class fp8_e5m2_x { return ConvertFromFP8(vals[0]); } - // Convert to integer types. // Available only when N==1. diff --git a/sycl/unittests/Extensions/fp8/fp8_e4m3.cpp b/sycl/unittests/Extensions/fp8/fp8_e4m3.cpp index 1b4851ca5a579..58659c8f1a029 100644 --- a/sycl/unittests/Extensions/fp8/fp8_e4m3.cpp +++ b/sycl/unittests/Extensions/fp8/fp8_e4m3.cpp @@ -290,7 +290,6 @@ TEST(FP8E4M3Test, CArrayFloatHostToEvenFinite) { EXPECT_EQ(a2.vals[1], 0x00); // 0 } - TEST(FP8E4M3Test, CArrayHalfHostToEvenFinite) { // Host code supports only rounding::to_even and saturation::finite. const sycl::half in[2] = {sycl::half(448.0f), sycl::half(449.0f)}; @@ -518,7 +517,6 @@ TEST(FP8E4M3Test, MarrayFloatRoundingToEven) { EXPECT_EQ(a.vals[1], 0x38); } - TEST(FP8E4M3Test, VariadicRejectsMixedTypes) { EXPECT_FALSE((std::is_constructible_v)); EXPECT_FALSE((std::is_constructible_v)); @@ -560,7 +558,6 @@ TEST(FP8E4M3Test, X2NotConstructibleFromSingleFloat) { EXPECT_FALSE((std::is_constructible_v)); } - TEST(FP8E4M3Test, X2NotConstructibleFromSingleBFloat16) { EXPECT_FALSE( (std::is_constructible_v)); @@ -591,7 +588,6 @@ TEST(FP8E4M3Test, X2NotAssignableFromSingleFloat) { EXPECT_FALSE((std::is_assignable_v)); } - TEST(FP8E4M3Test, X2NotAssignableFromSingleChar) { EXPECT_FALSE((std::is_assignable_v)); } @@ -771,4 +767,4 @@ TEST(FP8E4M3Test, VariadicFloatReferences) { EXPECT_EQ(sizeof(a.vals), 2u); EXPECT_EQ(a.vals[0], 0x38); EXPECT_EQ(a.vals[1], 0x40); -} \ No newline at end of file +} From f338a2f13304e0eab2ec9cbe8322459720375064 Mon Sep 17 00:00:00 2001 From: "Klochkov, Denis" Date: Mon, 18 May 2026 11:43:33 +0200 Subject: [PATCH 54/89] [SYCL] add FP8 feature macro Co-authored-by: Copilot --- sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp index 0affdd152ae41..83030a7534e0a 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp @@ -19,6 +19,8 @@ #include #include +#define SYCL_EXT_ONEAPI_FP8 1 + #ifdef __SYCL_DEVICE_ONLY__ // FP8 builtins From e9eb83f4315c04b28b54727cfabd364865392ed6 Mon Sep 17 00:00:00 2001 From: "Klochkov, Denis" Date: Mon, 18 May 2026 11:47:47 +0200 Subject: [PATCH 55/89] [SYCL] fix formating --- sycl/test-e2e/Experimental/fp8/e4m3_x2_cri_conversion.cpp | 1 - sycl/test-e2e/Experimental/fp8/e8m0_cri_conversion.cpp | 4 +--- sycl/test-e2e/Experimental/fp8/e8m0_x2_cri_conversion.cpp | 1 - 3 files changed, 1 insertion(+), 5 deletions(-) diff --git a/sycl/test-e2e/Experimental/fp8/e4m3_x2_cri_conversion.cpp b/sycl/test-e2e/Experimental/fp8/e4m3_x2_cri_conversion.cpp index 07a948772972e..968314b088161 100644 --- a/sycl/test-e2e/Experimental/fp8/e4m3_x2_cri_conversion.cpp +++ b/sycl/test-e2e/Experimental/fp8/e4m3_x2_cri_conversion.cpp @@ -3,7 +3,6 @@ // RUN: %{build} -Xclang -freg-struct-return -Xspirv-translator --spirv-ext=+SPV_INTEL_fp_conversions,+SPV_EXT_float8,+SPV_KHR_bfloat16 -o %t.out // RUN: %{run} SYCL_UR_TRACE=1 %t.out - #include #include #include diff --git a/sycl/test-e2e/Experimental/fp8/e8m0_cri_conversion.cpp b/sycl/test-e2e/Experimental/fp8/e8m0_cri_conversion.cpp index cdcae2a853635..c554cc209536f 100644 --- a/sycl/test-e2e/Experimental/fp8/e8m0_cri_conversion.cpp +++ b/sycl/test-e2e/Experimental/fp8/e8m0_cri_conversion.cpp @@ -348,9 +348,7 @@ int test_raw_vals_access(sycl::queue &queue) { auto *out = sycl::malloc_shared(1, queue); data[0] = fp8_e8m0(1.0f); - queue.single_task([=]() { - out[0] = data[0].vals[0]; - }); + queue.single_task([=]() { out[0] = data[0].vals[0]; }); queue.wait_and_throw(); int ret = (out[0] != 127) ? 1 : 0; diff --git a/sycl/test-e2e/Experimental/fp8/e8m0_x2_cri_conversion.cpp b/sycl/test-e2e/Experimental/fp8/e8m0_x2_cri_conversion.cpp index 2d36c9eca52a1..41af44d2b91a8 100644 --- a/sycl/test-e2e/Experimental/fp8/e8m0_x2_cri_conversion.cpp +++ b/sycl/test-e2e/Experimental/fp8/e8m0_x2_cri_conversion.cpp @@ -3,7 +3,6 @@ // RUN: %{build} -Xclang -freg-struct-return -Xspirv-translator --spirv-ext=+SPV_INTEL_fp_conversions,+SPV_EXT_float8,+SPV_KHR_bfloat16 -o %t.out // RUN: %{run} SYCL_UR_TRACE=1 %t.out - #include #include #include From 0ed3ca268472f8de2f5554ed58fe8b6b0f458dd7 Mon Sep 17 00:00:00 2001 From: "Klochkov, Denis" Date: Mon, 18 May 2026 16:09:08 +0200 Subject: [PATCH 56/89] [SYCL][TEST] update test of post link drop known builtins Co-authored-by: Copilot --- .../sycl-external-funcs/drop-known-builtins.ll | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/llvm/test/tools/sycl-post-link/sycl-external-funcs/drop-known-builtins.ll b/llvm/test/tools/sycl-post-link/sycl-external-funcs/drop-known-builtins.ll index 88ce25ccc82b8..9153a653b1880 100644 --- a/llvm/test/tools/sycl-post-link/sycl-external-funcs/drop-known-builtins.ll +++ b/llvm/test/tools/sycl-post-link/sycl-external-funcs/drop-known-builtins.ll @@ -24,10 +24,19 @@ define dso_local spir_func void @_Z33__sXcl_getScalarSpecConstantValue() #0 { ret void } +define dso_local spir_func void @_Z29__builtin_spirv_foo() #0 { + ret void +} +define dso_local spir_func void @_Z29__builtin_spXrv_foo() #0 { + ret void +} + attributes #0 = { "sycl-module-id"="a.cpp" } ; CHECK-NOT: define dso_local spir_func void @_Z28__spirv_GlobalInvocationId_xv() ; CHECK-NOT: define dso_local spir_func void @_Z33__sycl_getScalarSpecConstantValue() +; CHECK-NOT: define dso_local spir_func void @_Z29__builtin_spirv_foo() ; CHECK-DAG: define dso_local spir_func void @_Z28__spXrv_GlobalInvocationId_xv() ; CHECK-DAG: define dso_local spir_func void @_Z33__sXcl_getScalarSpecConstantValue() +; CHECK-DAG: define dso_local spir_func void @_Z29__builtin_spXrv_foo() From f8f895200aab1ea254b57d49d542e8918daac318 Mon Sep 17 00:00:00 2001 From: "Klochkov, Denis" Date: Mon, 18 May 2026 16:53:46 +0200 Subject: [PATCH 57/89] [SYCL][E2E] add spirv translator pattern to build command Co-authored-by: Copilot --- sycl/test-e2e/Experimental/fp8/e4m3_cri_conversion.cpp | 3 +-- sycl/test-e2e/Experimental/fp8/e4m3_x2_cri_conversion.cpp | 2 +- sycl/test-e2e/Experimental/fp8/e5m2_cri_conversion.cpp | 2 +- sycl/test-e2e/Experimental/fp8/e5m2_x2_cri_conversion.cpp | 2 +- sycl/test-e2e/Experimental/fp8/e8m0_cri_conversion.cpp | 2 +- sycl/test-e2e/Experimental/fp8/e8m0_x2_cri_conversion.cpp | 2 +- 6 files changed, 6 insertions(+), 7 deletions(-) diff --git a/sycl/test-e2e/Experimental/fp8/e4m3_cri_conversion.cpp b/sycl/test-e2e/Experimental/fp8/e4m3_cri_conversion.cpp index 0a2e3bb7d135a..75cf1105df5d3 100644 --- a/sycl/test-e2e/Experimental/fp8/e4m3_cri_conversion.cpp +++ b/sycl/test-e2e/Experimental/fp8/e4m3_cri_conversion.cpp @@ -1,6 +1,5 @@ - // REQUIRES: intel_feature_gpu_cri -// RUN: %{build} -Xclang -freg-struct-return -Xspirv-translator --spirv-ext=+SPV_INTEL_fp_conversions,+SPV_EXT_float8,+SPV_KHR_bfloat16 -o %t.out +// RUN: %{build} -Xclang -freg-struct-return -Xspirv-translator=spir64 --spirv-ext=+SPV_INTEL_fp_conversions,+SPV_EXT_float8,+SPV_KHR_bfloat16 -o %t.out // RUN: %{run} SYCL_UR_TRACE=1 %t.out #include diff --git a/sycl/test-e2e/Experimental/fp8/e4m3_x2_cri_conversion.cpp b/sycl/test-e2e/Experimental/fp8/e4m3_x2_cri_conversion.cpp index 968314b088161..52bd239c81363 100644 --- a/sycl/test-e2e/Experimental/fp8/e4m3_x2_cri_conversion.cpp +++ b/sycl/test-e2e/Experimental/fp8/e4m3_x2_cri_conversion.cpp @@ -1,6 +1,6 @@ // REQUIRES: intel_feature_gpu_cri -// RUN: %{build} -Xclang -freg-struct-return -Xspirv-translator --spirv-ext=+SPV_INTEL_fp_conversions,+SPV_EXT_float8,+SPV_KHR_bfloat16 -o %t.out +// RUN: %{build} -Xclang -freg-struct-return -Xspirv-translator=spir64 --spirv-ext=+SPV_INTEL_fp_conversions,+SPV_EXT_float8,+SPV_KHR_bfloat16 -o %t.out // RUN: %{run} SYCL_UR_TRACE=1 %t.out #include diff --git a/sycl/test-e2e/Experimental/fp8/e5m2_cri_conversion.cpp b/sycl/test-e2e/Experimental/fp8/e5m2_cri_conversion.cpp index cbc779f86bcb7..c6e0438a97b5b 100644 --- a/sycl/test-e2e/Experimental/fp8/e5m2_cri_conversion.cpp +++ b/sycl/test-e2e/Experimental/fp8/e5m2_cri_conversion.cpp @@ -1,6 +1,6 @@ // REQUIRES: intel_feature_gpu_cri -// RUN: %{build} -Xclang -freg-struct-return -Xspirv-translator --spirv-ext=+SPV_INTEL_fp_conversions,+SPV_EXT_float8,+SPV_KHR_bfloat16 -o %t.out +// RUN: %{build} -Xclang -freg-struct-return -Xspirv-translator=spir64 --spirv-ext=+SPV_INTEL_fp_conversions,+SPV_EXT_float8,+SPV_KHR_bfloat16 -o %t.out // RUN: %{run} SYCL_UR_TRACE=1 %t.out #include diff --git a/sycl/test-e2e/Experimental/fp8/e5m2_x2_cri_conversion.cpp b/sycl/test-e2e/Experimental/fp8/e5m2_x2_cri_conversion.cpp index 433c454357f2a..012542cda1104 100644 --- a/sycl/test-e2e/Experimental/fp8/e5m2_x2_cri_conversion.cpp +++ b/sycl/test-e2e/Experimental/fp8/e5m2_x2_cri_conversion.cpp @@ -1,5 +1,5 @@ // REQUIRES: intel_feature_gpu_cri -// RUN: %{build} -Xclang -freg-struct-return -Xspirv-translator --spirv-ext=+SPV_INTEL_fp_conversions,+SPV_EXT_float8,+SPV_KHR_bfloat16 -o %t.out +// RUN: %{build} -Xclang -freg-struct-return -Xspirv-translator=spir64 --spirv-ext=+SPV_INTEL_fp_conversions,+SPV_EXT_float8,+SPV_KHR_bfloat16 -o %t.out // RUN: %{run} SYCL_UR_TRACE=1 %t.out #include diff --git a/sycl/test-e2e/Experimental/fp8/e8m0_cri_conversion.cpp b/sycl/test-e2e/Experimental/fp8/e8m0_cri_conversion.cpp index c554cc209536f..a6fd8ecc453e4 100644 --- a/sycl/test-e2e/Experimental/fp8/e8m0_cri_conversion.cpp +++ b/sycl/test-e2e/Experimental/fp8/e8m0_cri_conversion.cpp @@ -1,6 +1,6 @@ // REQUIRES: intel_feature_gpu_cri -// RUN: %{build} -Xclang -freg-struct-return -Xspirv-translator --spirv-ext=+SPV_INTEL_fp_conversions,+SPV_EXT_float8,+SPV_KHR_bfloat16 -o %t.out +// RUN: %{build} -Xclang -freg-struct-return -Xspirv-translator=spir64 --spirv-ext=+SPV_INTEL_fp_conversions,+SPV_EXT_float8,+SPV_KHR_bfloat16 -o %t.out // RUN: %{run} SYCL_UR_TRACE=1 %t.out #include diff --git a/sycl/test-e2e/Experimental/fp8/e8m0_x2_cri_conversion.cpp b/sycl/test-e2e/Experimental/fp8/e8m0_x2_cri_conversion.cpp index 41af44d2b91a8..3a1eb49720861 100644 --- a/sycl/test-e2e/Experimental/fp8/e8m0_x2_cri_conversion.cpp +++ b/sycl/test-e2e/Experimental/fp8/e8m0_x2_cri_conversion.cpp @@ -1,6 +1,6 @@ // REQUIRES: intel_feature_gpu_cri -// RUN: %{build} -Xclang -freg-struct-return -Xspirv-translator --spirv-ext=+SPV_INTEL_fp_conversions,+SPV_EXT_float8,+SPV_KHR_bfloat16 -o %t.out +// RUN: %{build} -Xclang -freg-struct-return -Xspirv-translator=spir64 --spirv-ext=+SPV_INTEL_fp_conversions,+SPV_EXT_float8,+SPV_KHR_bfloat16 -o %t.out // RUN: %{run} SYCL_UR_TRACE=1 %t.out #include From c82aa03730e81ec7c66011a311a3d8d271434db2 Mon Sep 17 00:00:00 2001 From: "Klochkov, Denis" Date: Tue, 19 May 2026 14:30:59 +0200 Subject: [PATCH 58/89] [SYCL][E2E] make tests XFAIL until driver is installed on CI machines Co-authored-by: Copilot --- .../Experimental/fp8/e4m3_cri_conversion.cpp | 5 +++++ .../Experimental/fp8/e4m3_x2_cri_conversion.cpp | 4 ++++ .../Experimental/fp8/e5m2_cri_conversion.cpp | 4 ++++ .../Experimental/fp8/e5m2_x2_cri_conversion.cpp | 14 ++++---------- .../Experimental/fp8/e8m0_cri_conversion.cpp | 4 ++++ .../Experimental/fp8/e8m0_x2_cri_conversion.cpp | 6 ++++-- 6 files changed, 25 insertions(+), 12 deletions(-) diff --git a/sycl/test-e2e/Experimental/fp8/e4m3_cri_conversion.cpp b/sycl/test-e2e/Experimental/fp8/e4m3_cri_conversion.cpp index 75cf1105df5d3..5311741d8ad1e 100644 --- a/sycl/test-e2e/Experimental/fp8/e4m3_cri_conversion.cpp +++ b/sycl/test-e2e/Experimental/fp8/e4m3_cri_conversion.cpp @@ -2,6 +2,11 @@ // RUN: %{build} -Xclang -freg-struct-return -Xspirv-translator=spir64 --spirv-ext=+SPV_INTEL_fp_conversions,+SPV_EXT_float8,+SPV_KHR_bfloat16 -o %t.out // RUN: %{run} SYCL_UR_TRACE=1 %t.out + +// make it XFAIL until driver will be installed on CI machines and the test will be enabled in the test suite +// XFAIL: * +// XFAIL-TRACKER: CMPLRLLVM-69851 + #include #include #include diff --git a/sycl/test-e2e/Experimental/fp8/e4m3_x2_cri_conversion.cpp b/sycl/test-e2e/Experimental/fp8/e4m3_x2_cri_conversion.cpp index 52bd239c81363..453f3bb9cf420 100644 --- a/sycl/test-e2e/Experimental/fp8/e4m3_x2_cri_conversion.cpp +++ b/sycl/test-e2e/Experimental/fp8/e4m3_x2_cri_conversion.cpp @@ -3,6 +3,10 @@ // RUN: %{build} -Xclang -freg-struct-return -Xspirv-translator=spir64 --spirv-ext=+SPV_INTEL_fp_conversions,+SPV_EXT_float8,+SPV_KHR_bfloat16 -o %t.out // RUN: %{run} SYCL_UR_TRACE=1 %t.out +// make it XFAIL until driver will be installed on CI machines and the test will be enabled in the test suite +// XFAIL: * +// XFAIL-TRACKER: CMPLRLLVM-69851 + #include #include #include diff --git a/sycl/test-e2e/Experimental/fp8/e5m2_cri_conversion.cpp b/sycl/test-e2e/Experimental/fp8/e5m2_cri_conversion.cpp index c6e0438a97b5b..a516897611d01 100644 --- a/sycl/test-e2e/Experimental/fp8/e5m2_cri_conversion.cpp +++ b/sycl/test-e2e/Experimental/fp8/e5m2_cri_conversion.cpp @@ -3,6 +3,10 @@ // RUN: %{build} -Xclang -freg-struct-return -Xspirv-translator=spir64 --spirv-ext=+SPV_INTEL_fp_conversions,+SPV_EXT_float8,+SPV_KHR_bfloat16 -o %t.out // RUN: %{run} SYCL_UR_TRACE=1 %t.out +// make it XFAIL until driver will be installed on CI machines and the test will be enabled in the test suite +// XFAIL: * +// XFAIL-TRACKER: CMPLRLLVM-69851 + #include #include #include diff --git a/sycl/test-e2e/Experimental/fp8/e5m2_x2_cri_conversion.cpp b/sycl/test-e2e/Experimental/fp8/e5m2_x2_cri_conversion.cpp index 012542cda1104..8dab1b92862c6 100644 --- a/sycl/test-e2e/Experimental/fp8/e5m2_x2_cri_conversion.cpp +++ b/sycl/test-e2e/Experimental/fp8/e5m2_x2_cri_conversion.cpp @@ -2,6 +2,10 @@ // RUN: %{build} -Xclang -freg-struct-return -Xspirv-translator=spir64 --spirv-ext=+SPV_INTEL_fp_conversions,+SPV_EXT_float8,+SPV_KHR_bfloat16 -o %t.out // RUN: %{run} SYCL_UR_TRACE=1 %t.out +// make it XFAIL until driver will be installed on CI machines and the test will be enabled in the test suite +// XFAIL: * +// XFAIL-TRACKER: CMPLRLLVM-69851 + #include #include #include @@ -11,16 +15,6 @@ using namespace sycl::ext::oneapi::experimental; -#ifdef __SYCL_DEVICE_ONLY__ -#define CONSTANT __attribute__((opencl_constant)) -#else -#define CONSTANT -#endif - -static const CONSTANT char kKernelStartFmt[] = "kernel: enter vals=(%u,%u)\n"; -static const CONSTANT char kKernelUnpackedFmt[] = "kernel: unpacked=(%f,%f)\n"; -static const CONSTANT char kKernelStoreFmt[] = "kernel: store vals=(%u,%u)\n"; - namespace { bool equal_or_both_nan(float actual, float expected) { diff --git a/sycl/test-e2e/Experimental/fp8/e8m0_cri_conversion.cpp b/sycl/test-e2e/Experimental/fp8/e8m0_cri_conversion.cpp index a6fd8ecc453e4..5ff940f99fdd6 100644 --- a/sycl/test-e2e/Experimental/fp8/e8m0_cri_conversion.cpp +++ b/sycl/test-e2e/Experimental/fp8/e8m0_cri_conversion.cpp @@ -3,6 +3,10 @@ // RUN: %{build} -Xclang -freg-struct-return -Xspirv-translator=spir64 --spirv-ext=+SPV_INTEL_fp_conversions,+SPV_EXT_float8,+SPV_KHR_bfloat16 -o %t.out // RUN: %{run} SYCL_UR_TRACE=1 %t.out +// make it XFAIL until driver will be installed on CI machines and the test will be enabled in the test suite +// XFAIL: * +// XFAIL-TRACKER: CMPLRLLVM-69851 + #include #include #include diff --git a/sycl/test-e2e/Experimental/fp8/e8m0_x2_cri_conversion.cpp b/sycl/test-e2e/Experimental/fp8/e8m0_x2_cri_conversion.cpp index 3a1eb49720861..b4d83e84e2e79 100644 --- a/sycl/test-e2e/Experimental/fp8/e8m0_x2_cri_conversion.cpp +++ b/sycl/test-e2e/Experimental/fp8/e8m0_x2_cri_conversion.cpp @@ -3,6 +3,10 @@ // RUN: %{build} -Xclang -freg-struct-return -Xspirv-translator=spir64 --spirv-ext=+SPV_INTEL_fp_conversions,+SPV_EXT_float8,+SPV_KHR_bfloat16 -o %t.out // RUN: %{run} SYCL_UR_TRACE=1 %t.out +// make it XFAIL until driver will be installed on CI machines and the test will be enabled in the test suite +// XFAIL: * +// XFAIL-TRACKER: CMPLRLLVM-69851 + #include #include #include @@ -334,7 +338,6 @@ template int test_fp8_simple_type_conversion(sycl::queue &queue) { auto *data = sycl::malloc_shared(1, queue); data[0] = fp8_e8m0_x2(static_cast(4.0f), static_cast(16.0f)); - std::cout << "kernel\n"; queue.single_task([=]() { fp8_e8m0_x2 value = data[0]; sycl::marray f = static_cast>(value); @@ -344,7 +347,6 @@ template int test_fp8_simple_type_conversion(sycl::queue &queue) { }); queue.wait_and_throw(); - std::cout << "kernel finished\n"; sycl::marray expected_input(static_cast(8.0f), static_cast(32.0f)); fp8_e8m0_x2 expected(expected_input); From 880c7258cdc40b711b19bc5bff0d89a090f0f2a0 Mon Sep 17 00:00:00 2001 From: "Klochkov, Denis" Date: Tue, 19 May 2026 14:59:37 +0200 Subject: [PATCH 59/89] [SYCL][TESTE2E] fix formatting --- sycl/test-e2e/Experimental/fp8/e4m3_cri_conversion.cpp | 4 ++-- sycl/test-e2e/Experimental/fp8/e4m3_x2_cri_conversion.cpp | 3 ++- sycl/test-e2e/Experimental/fp8/e5m2_cri_conversion.cpp | 3 ++- sycl/test-e2e/Experimental/fp8/e5m2_x2_cri_conversion.cpp | 3 ++- sycl/test-e2e/Experimental/fp8/e8m0_cri_conversion.cpp | 3 ++- sycl/test-e2e/Experimental/fp8/e8m0_x2_cri_conversion.cpp | 3 ++- 6 files changed, 12 insertions(+), 7 deletions(-) diff --git a/sycl/test-e2e/Experimental/fp8/e4m3_cri_conversion.cpp b/sycl/test-e2e/Experimental/fp8/e4m3_cri_conversion.cpp index 5311741d8ad1e..497969db244f0 100644 --- a/sycl/test-e2e/Experimental/fp8/e4m3_cri_conversion.cpp +++ b/sycl/test-e2e/Experimental/fp8/e4m3_cri_conversion.cpp @@ -2,8 +2,8 @@ // RUN: %{build} -Xclang -freg-struct-return -Xspirv-translator=spir64 --spirv-ext=+SPV_INTEL_fp_conversions,+SPV_EXT_float8,+SPV_KHR_bfloat16 -o %t.out // RUN: %{run} SYCL_UR_TRACE=1 %t.out - -// make it XFAIL until driver will be installed on CI machines and the test will be enabled in the test suite +// make it XFAIL until driver will be installed on CI machines and the test will +// be enabled in the test suite // XFAIL: * // XFAIL-TRACKER: CMPLRLLVM-69851 diff --git a/sycl/test-e2e/Experimental/fp8/e4m3_x2_cri_conversion.cpp b/sycl/test-e2e/Experimental/fp8/e4m3_x2_cri_conversion.cpp index 453f3bb9cf420..b28101eadd1d7 100644 --- a/sycl/test-e2e/Experimental/fp8/e4m3_x2_cri_conversion.cpp +++ b/sycl/test-e2e/Experimental/fp8/e4m3_x2_cri_conversion.cpp @@ -3,7 +3,8 @@ // RUN: %{build} -Xclang -freg-struct-return -Xspirv-translator=spir64 --spirv-ext=+SPV_INTEL_fp_conversions,+SPV_EXT_float8,+SPV_KHR_bfloat16 -o %t.out // RUN: %{run} SYCL_UR_TRACE=1 %t.out -// make it XFAIL until driver will be installed on CI machines and the test will be enabled in the test suite +// make it XFAIL until driver will be installed on CI machines and the test will +// be enabled in the test suite // XFAIL: * // XFAIL-TRACKER: CMPLRLLVM-69851 diff --git a/sycl/test-e2e/Experimental/fp8/e5m2_cri_conversion.cpp b/sycl/test-e2e/Experimental/fp8/e5m2_cri_conversion.cpp index a516897611d01..21a4340ac3977 100644 --- a/sycl/test-e2e/Experimental/fp8/e5m2_cri_conversion.cpp +++ b/sycl/test-e2e/Experimental/fp8/e5m2_cri_conversion.cpp @@ -3,7 +3,8 @@ // RUN: %{build} -Xclang -freg-struct-return -Xspirv-translator=spir64 --spirv-ext=+SPV_INTEL_fp_conversions,+SPV_EXT_float8,+SPV_KHR_bfloat16 -o %t.out // RUN: %{run} SYCL_UR_TRACE=1 %t.out -// make it XFAIL until driver will be installed on CI machines and the test will be enabled in the test suite +// make it XFAIL until driver will be installed on CI machines and the test will +// be enabled in the test suite // XFAIL: * // XFAIL-TRACKER: CMPLRLLVM-69851 diff --git a/sycl/test-e2e/Experimental/fp8/e5m2_x2_cri_conversion.cpp b/sycl/test-e2e/Experimental/fp8/e5m2_x2_cri_conversion.cpp index 8dab1b92862c6..6991e0bbe2936 100644 --- a/sycl/test-e2e/Experimental/fp8/e5m2_x2_cri_conversion.cpp +++ b/sycl/test-e2e/Experimental/fp8/e5m2_x2_cri_conversion.cpp @@ -2,7 +2,8 @@ // RUN: %{build} -Xclang -freg-struct-return -Xspirv-translator=spir64 --spirv-ext=+SPV_INTEL_fp_conversions,+SPV_EXT_float8,+SPV_KHR_bfloat16 -o %t.out // RUN: %{run} SYCL_UR_TRACE=1 %t.out -// make it XFAIL until driver will be installed on CI machines and the test will be enabled in the test suite +// make it XFAIL until driver will be installed on CI machines and the test will +// be enabled in the test suite // XFAIL: * // XFAIL-TRACKER: CMPLRLLVM-69851 diff --git a/sycl/test-e2e/Experimental/fp8/e8m0_cri_conversion.cpp b/sycl/test-e2e/Experimental/fp8/e8m0_cri_conversion.cpp index 5ff940f99fdd6..98a58009d4586 100644 --- a/sycl/test-e2e/Experimental/fp8/e8m0_cri_conversion.cpp +++ b/sycl/test-e2e/Experimental/fp8/e8m0_cri_conversion.cpp @@ -3,7 +3,8 @@ // RUN: %{build} -Xclang -freg-struct-return -Xspirv-translator=spir64 --spirv-ext=+SPV_INTEL_fp_conversions,+SPV_EXT_float8,+SPV_KHR_bfloat16 -o %t.out // RUN: %{run} SYCL_UR_TRACE=1 %t.out -// make it XFAIL until driver will be installed on CI machines and the test will be enabled in the test suite +// make it XFAIL until driver will be installed on CI machines and the test will +// be enabled in the test suite // XFAIL: * // XFAIL-TRACKER: CMPLRLLVM-69851 diff --git a/sycl/test-e2e/Experimental/fp8/e8m0_x2_cri_conversion.cpp b/sycl/test-e2e/Experimental/fp8/e8m0_x2_cri_conversion.cpp index b4d83e84e2e79..571fe42e2858a 100644 --- a/sycl/test-e2e/Experimental/fp8/e8m0_x2_cri_conversion.cpp +++ b/sycl/test-e2e/Experimental/fp8/e8m0_x2_cri_conversion.cpp @@ -3,7 +3,8 @@ // RUN: %{build} -Xclang -freg-struct-return -Xspirv-translator=spir64 --spirv-ext=+SPV_INTEL_fp_conversions,+SPV_EXT_float8,+SPV_KHR_bfloat16 -o %t.out // RUN: %{run} SYCL_UR_TRACE=1 %t.out -// make it XFAIL until driver will be installed on CI machines and the test will be enabled in the test suite +// make it XFAIL until driver will be installed on CI machines and the test will +// be enabled in the test suite // XFAIL: * // XFAIL-TRACKER: CMPLRLLVM-69851 From b0f212c6b9be7330f6f96a0eec9de85287989b24 Mon Sep 17 00:00:00 2001 From: "Klochkov, Denis" Date: Tue, 26 May 2026 14:48:38 +0200 Subject: [PATCH 60/89] [SYCL] fix build issues on win --- .../sycl/ext/oneapi/experimental/float_8bit/types.hpp | 4 ++-- sycl/unittests/Extensions/fp8/builtin_mocks.hpp | 5 +++++ 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp index 83030a7534e0a..036faca72c599 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp @@ -409,7 +409,7 @@ static inline uint8_t ConvertFloatToFP8_CPU(T f, rounding R, }; UInt bits; - __builtin_memcpy(&bits, &f, sizeof(bits)); + std::memcpy(&bits, &f, sizeof(bits)); const uint8_t sign = (bits & SignMask) ? 0x80u : 0x00u; bits &= ~SignMask; @@ -541,7 +541,7 @@ static inline uint8_t ConvertFloatToE8M0_CPU(T f, rounding R, constexpr int TargetEmax = 127; UInt h; - __builtin_memcpy(&h, &f, sizeof(h)); + std::memcpy(&h, &f, sizeof(h)); h &= ~SignMask; UInt exp = (h & ExpMask) >> Traits::FracBits; diff --git a/sycl/unittests/Extensions/fp8/builtin_mocks.hpp b/sycl/unittests/Extensions/fp8/builtin_mocks.hpp index 0d5f9cee1f7c7..f5822744ca7f5 100644 --- a/sycl/unittests/Extensions/fp8/builtin_mocks.hpp +++ b/sycl/unittests/Extensions/fp8/builtin_mocks.hpp @@ -7,6 +7,11 @@ #include #include +#if defined(_MSC_VER) +#define _Float16 sycl::half +#define __bf16 sycl::ext::oneapi::bfloat16 +#endif + // Force code path that uses helpers.hpp wrappers. #ifndef __SYCL_DEVICE_ONLY__ #define __SYCL_DEVICE_ONLY__ 1 From d24737e0307caf38d29f1a389cfe6ecb283437a3 Mon Sep 17 00:00:00 2001 From: "Klochkov, Denis" Date: Wed, 27 May 2026 12:18:45 +0200 Subject: [PATCH 61/89] [SYCL] use safer method to cast to _Float16 --- .../include/sycl/ext/oneapi/experimental/float_8bit/types.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp index 036faca72c599..27b4da2e6408d 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp @@ -821,7 +821,7 @@ template class fp8_e4m3_x { #ifdef __SYCL_DEVICE_ONLY__ _Float16 v{0}; if constexpr (std::is_same_v, sycl::half>) - v = static_cast<_Float16>(static_cast(h)); + v = sycl::bit_cast<_Float16>(h); else v = static_cast<_Float16>(h); return __builtin_spirv_ClampConvertFP16ToE4M3INTEL(v); @@ -1177,7 +1177,7 @@ template class fp8_e5m2_x { #ifdef __SYCL_DEVICE_ONLY__ _Float16 v{0}; if constexpr (std::is_same_v, sycl::half>) - v = static_cast<_Float16>(static_cast(h)); + v = sycl::bit_cast<_Float16>(h); else v = static_cast<_Float16>(h); return s == saturation::finite From b64f10791240b3cd7d476d08efd6fbf27e23863c Mon Sep 17 00:00:00 2001 From: "Klochkov, Denis" Date: Thu, 28 May 2026 19:06:23 +0200 Subject: [PATCH 62/89] [SYCL] throw runtime errors when stochastic constructor is used on hos side [SYCL] add deduction guides --- .../oneapi/experimental/float_8bit/types.hpp | 16 +++++ sycl/unittests/Extensions/fp8/fp8_e4m3.cpp | 8 +++ sycl/unittests/Extensions/fp8/fp8_e5m2.cpp | 67 ++++++++++++++++++- sycl/unittests/Extensions/fp8/fp8_e8m0.cpp | 8 +++ 4 files changed, 98 insertions(+), 1 deletion(-) diff --git a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp index 27b4da2e6408d..0eaf16cf5138e 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp @@ -1334,6 +1334,9 @@ template class fp8_e5m2_x { } current_seed = *seed.pseed; } +#else + throw std::runtime_error( + "stochastic rounding constructors are not supported on host"); #endif } @@ -1352,6 +1355,9 @@ template class fp8_e5m2_x { } current_seed = *seed.pseed; } +#else + throw std::runtime_error( + "stochastic rounding constructors are not supported on host"); #endif } @@ -1375,6 +1381,9 @@ template class fp8_e5m2_x { } current_seed = *seed.pseed; } +#else + throw std::runtime_error( + "stochastic rounding constructors are not supported on host"); #endif } @@ -1393,6 +1402,9 @@ template class fp8_e5m2_x { } current_seed = *seed.pseed; } +#else + throw std::runtime_error( + "stochastic rounding constructors are not supported on host"); #endif } @@ -1899,6 +1911,10 @@ template class fp8_e8m0_x { uint8_t vals[N]; }; +template fp8_e4m3_x(Ts...) -> fp8_e4m3_x; +template fp8_e5m2_x(Ts...) -> fp8_e5m2_x; +template fp8_e8m0_x(Ts...) -> fp8_e8m0_x; + using fp8_e4m3 = fp8_e4m3_x<1>; using fp8_e4m3_x2 = fp8_e4m3_x<2>; using fp8_e5m2 = fp8_e5m2_x<1>; diff --git a/sycl/unittests/Extensions/fp8/fp8_e4m3.cpp b/sycl/unittests/Extensions/fp8/fp8_e4m3.cpp index 58659c8f1a029..2c8ae12125178 100644 --- a/sycl/unittests/Extensions/fp8/fp8_e4m3.cpp +++ b/sycl/unittests/Extensions/fp8/fp8_e4m3.cpp @@ -13,6 +13,14 @@ code thus unit tests check only API using namespace sycl::ext::oneapi::experimental; +TEST(FP8E4M3Test, DeductionGuide) { + fp8_e4m3_x one(1.0f); + fp8_e4m3_x pair(1.0f, 2.0f); + + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); +} + TEST(FP8E4M3Test, TrivialSpecialMembers) { EXPECT_TRUE((std::is_trivially_default_constructible_v)); EXPECT_TRUE((std::is_trivially_copy_constructible_v)); diff --git a/sycl/unittests/Extensions/fp8/fp8_e5m2.cpp b/sycl/unittests/Extensions/fp8/fp8_e5m2.cpp index eefc07fa807d4..fd394afc093c4 100644 --- a/sycl/unittests/Extensions/fp8/fp8_e5m2.cpp +++ b/sycl/unittests/Extensions/fp8/fp8_e5m2.cpp @@ -13,6 +13,14 @@ code thus unit tests check only API using namespace sycl::ext::oneapi::experimental; +TEST(FP8E5M2Test, DeductionGuide) { + fp8_e5m2_x one(1.0f); + fp8_e5m2_x pair(1.0f, 2.0f); + + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); +} + TEST(FP8E5M2Test, VariadicHalf) { fp8_e5m2_x2 a(sycl::half(1.0f), sycl::half(2.0f)); @@ -317,6 +325,62 @@ TEST(FP8E5M2Test, MarrayAndOperators) { EXPECT_EQ(out3[1], -1.5f); } +TEST(FP8E5M2Test, StochasticMarrayHalfConstructorThrowsOnHost) { + sycl::marray in = {sycl::half(1.0f), sycl::half(2.0f)}; + uint32_t seed_value = 1234; + stochastic_seed seed(&seed_value); + + EXPECT_THROW( + { + fp8_e5m2_x2 value(in, seed); + (void)value; + }, + std::runtime_error); +} + +TEST(FP8E5M2Test, StochasticCArrayHalfConstructorThrowsOnHost) { + const sycl::half in[2] = {sycl::half(1.0f), sycl::half(2.0f)}; + uint32_t seed_value = 2345; + stochastic_seed seed(&seed_value); + + EXPECT_THROW( + { + fp8_e5m2_x2 value(in, seed); + (void)value; + }, + std::runtime_error); +} + +TEST(FP8E5M2Test, StochasticMarrayBFloat16ConstructorThrowsOnHost) { + sycl::marray in = { + sycl::ext::oneapi::bfloat16(1.0f), + sycl::ext::oneapi::bfloat16(2.0f)}; + uint32_t seed_value = 5678; + stochastic_seed seed(&seed_value); + + EXPECT_THROW( + { + fp8_e5m2_x2 value(in, seed); + (void)value; + }, + std::runtime_error); +} + +TEST(FP8E5M2Test, StochasticCArrayBFloat16ConstructorThrowsOnHost) { + const sycl::ext::oneapi::bfloat16 in[2] = { + sycl::ext::oneapi::bfloat16(1.0f), + sycl::ext::oneapi::bfloat16(2.0f)}; + uint32_t seed_value = 6789; + stochastic_seed seed(&seed_value); + + EXPECT_THROW( + { + fp8_e5m2_x2 value(in, seed); + (void)value; + }, + std::runtime_error); +} + TEST(FP8E5M2Test, FloatingPointConversionOperatorsMoreTypes) { fp8_e5m2 a(1.0f); fp8_e5m2 b(0.00006103515625f); @@ -687,4 +751,5 @@ TEST(FP8E5M2Test, VariadicFloatReferences) { EXPECT_EQ(sizeof(a.vals), 2u); EXPECT_EQ(a.vals[0], 0x3C); EXPECT_EQ(a.vals[1], 0x40); -} \ No newline at end of file +} + diff --git a/sycl/unittests/Extensions/fp8/fp8_e8m0.cpp b/sycl/unittests/Extensions/fp8/fp8_e8m0.cpp index 56537aa6ce6ab..69bc25431d99f 100644 --- a/sycl/unittests/Extensions/fp8/fp8_e8m0.cpp +++ b/sycl/unittests/Extensions/fp8/fp8_e8m0.cpp @@ -27,6 +27,14 @@ bool checkCode(float Input, rounding Mode, uint8_t Expected) { } // namespace +TEST(FP8E8M0Test, DeductionGuide) { + fp8_e8m0_x one(1.0f); + fp8_e8m0_x pair(1.0f, 2.0f); + + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); +} + TEST(FP8E8M0Test, VariadicFloat) { fp8_e8m0_x2 a(1.0f, 2.0f); fp8_e8m0_x2 a1(1.1f, 0.0f); From e399a24606c834435ff198fc5e1b582fdac37546 Mon Sep 17 00:00:00 2001 From: "Klochkov, Denis" Date: Thu, 28 May 2026 19:11:38 +0200 Subject: [PATCH 63/89] [SYCL] move FP8 feature macro to header --- sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp | 2 -- sycl/source/feature_test.hpp.in | 1 + 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp index 0eaf16cf5138e..2490611670da3 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp @@ -19,8 +19,6 @@ #include #include -#define SYCL_EXT_ONEAPI_FP8 1 - #ifdef __SYCL_DEVICE_ONLY__ // FP8 builtins diff --git a/sycl/source/feature_test.hpp.in b/sycl/source/feature_test.hpp.in index d87264fb6c593..c8bb06d8034ec 100644 --- a/sycl/source/feature_test.hpp.in +++ b/sycl/source/feature_test.hpp.in @@ -124,6 +124,7 @@ inline namespace _V1 { #define SYCL_KHR_STATIC_ADDRSPACE_CAST 1 #define SYCL_KHR_WORK_ITEM_QUERIES 1 #define SYCL_KHR_GROUP_INTERFACE 1 +#define SYCL_EXT_ONEAPI_FP8 1 // Unfinished KHR extensions. These extensions are only available if the // __DPCPP_ENABLE_UNFINISHED_KHR_EXTENSIONS macro is defined. From 956410027b5f4a25c24d524b8dd9ab7b226214f7 Mon Sep 17 00:00:00 2001 From: "Klochkov, Denis" Date: Thu, 28 May 2026 19:16:19 +0200 Subject: [PATCH 64/89] [SYCL][TESTE2E] remove TODOs from tests --- sycl/test-e2e/Experimental/fp8/e4m3_cri_conversion.cpp | 2 -- sycl/test-e2e/Experimental/fp8/e5m2_cri_conversion.cpp | 2 -- 2 files changed, 4 deletions(-) diff --git a/sycl/test-e2e/Experimental/fp8/e4m3_cri_conversion.cpp b/sycl/test-e2e/Experimental/fp8/e4m3_cri_conversion.cpp index 497969db244f0..789a88b1cf1fe 100644 --- a/sycl/test-e2e/Experimental/fp8/e4m3_cri_conversion.cpp +++ b/sycl/test-e2e/Experimental/fp8/e4m3_cri_conversion.cpp @@ -169,12 +169,10 @@ int main() { ret |= test_marray_conversion(queue); ret |= test_marray_conversion(queue); - // TODO: uncomment when bfloat16 conversion is fixed ret |= test_marray_conversion(queue); ret |= test_carray_conversion(queue); ret |= test_carray_conversion(queue); - // TODO: uncomment when bfloat16 conversion is fixed ret |= test_carray_conversion(queue); return ret; } diff --git a/sycl/test-e2e/Experimental/fp8/e5m2_cri_conversion.cpp b/sycl/test-e2e/Experimental/fp8/e5m2_cri_conversion.cpp index 21a4340ac3977..13f809c0cd062 100644 --- a/sycl/test-e2e/Experimental/fp8/e5m2_cri_conversion.cpp +++ b/sycl/test-e2e/Experimental/fp8/e5m2_cri_conversion.cpp @@ -163,12 +163,10 @@ int main() { ret |= test_marray_conversion(queue); ret |= test_marray_conversion(queue); - // TODO: uncomment when bfloat16 conversion is fixed ret |= test_marray_conversion(queue); ret |= test_carray_conversion(queue); ret |= test_carray_conversion(queue); - // TODO: uncomment when bfloat16 conversion is fixed ret |= test_carray_conversion(queue); return ret; } From a4a6ec564fda0653199f51cc63ae123e50058276 Mon Sep 17 00:00:00 2001 From: "Klochkov, Denis" Date: Thu, 28 May 2026 19:36:46 +0200 Subject: [PATCH 65/89] [SYCL][TESTE2E] remove extra carray tests --- .../Experimental/fp8/e4m3_cri_conversion.cpp | 26 ------------------- .../Experimental/fp8/e5m2_cri_conversion.cpp | 26 ------------------- 2 files changed, 52 deletions(-) diff --git a/sycl/test-e2e/Experimental/fp8/e4m3_cri_conversion.cpp b/sycl/test-e2e/Experimental/fp8/e4m3_cri_conversion.cpp index 789a88b1cf1fe..81f453a586113 100644 --- a/sycl/test-e2e/Experimental/fp8/e4m3_cri_conversion.cpp +++ b/sycl/test-e2e/Experimental/fp8/e4m3_cri_conversion.cpp @@ -103,28 +103,6 @@ template int test_marray_conversion(sycl::queue &queue) { return 0; } -template int test_carray_conversion(sycl::queue &queue) { - auto *data = sycl::malloc_shared(1, queue); - data[0] = fp8_e4m3(static_cast(1.25f)); - - queue.single_task([=]() { - fp8_e4m3 value = data[0]; - T f = {static_cast(value)}; - f += static_cast(1.0f); - data[0] = fp8_e4m3(f); - }); - queue.wait_and_throw(); - - fp8_e4m3 expected(static_cast(2.25f)); - T out = {static_cast(data[0])}; - T expected_out = {static_cast(expected)}; - - sycl::free(data, queue); - if (std::fabs(out - expected_out) > 0.0f) - return 1; - return 0; -} - int main() { auto async_handler = [](sycl::exception_list exceptions) { for (const std::exception_ptr &e : exceptions) { @@ -170,9 +148,5 @@ int main() { ret |= test_marray_conversion(queue); ret |= test_marray_conversion(queue); ret |= test_marray_conversion(queue); - - ret |= test_carray_conversion(queue); - ret |= test_carray_conversion(queue); - ret |= test_carray_conversion(queue); return ret; } diff --git a/sycl/test-e2e/Experimental/fp8/e5m2_cri_conversion.cpp b/sycl/test-e2e/Experimental/fp8/e5m2_cri_conversion.cpp index 13f809c0cd062..6a2d2b87e68a8 100644 --- a/sycl/test-e2e/Experimental/fp8/e5m2_cri_conversion.cpp +++ b/sycl/test-e2e/Experimental/fp8/e5m2_cri_conversion.cpp @@ -95,28 +95,6 @@ template int test_marray_conversion(sycl::queue &queue) { return 0; } -template int test_carray_conversion(sycl::queue &queue) { - auto *data = sycl::malloc_shared(1, queue); - data[0] = fp8_e5m2(static_cast(1.5f)); - - queue.single_task([=]() { - fp8_e5m2 value = data[0]; - T f = {static_cast(value)}; - f += static_cast(1.0f); - data[0] = fp8_e5m2(f); - }); - queue.wait_and_throw(); - - fp8_e5m2 expected(static_cast(2.5f)); - T out = {static_cast(data[0])}; - T expected_out = {static_cast(expected)}; - - sycl::free(data, queue); - if (std::fabs(out - expected_out) > 0.0f) - return 1; - return 0; -} - int main() { auto async_handler = [](sycl::exception_list exceptions) { for (const std::exception_ptr &e : exceptions) { @@ -164,9 +142,5 @@ int main() { ret |= test_marray_conversion(queue); ret |= test_marray_conversion(queue); ret |= test_marray_conversion(queue); - - ret |= test_carray_conversion(queue); - ret |= test_carray_conversion(queue); - ret |= test_carray_conversion(queue); return ret; } From e960f8b1cbf5742f054b198c039a767d72d67d6d Mon Sep 17 00:00:00 2001 From: "Klochkov, Denis" Date: Mon, 1 Jun 2026 11:06:26 +0200 Subject: [PATCH 66/89] [SYCL][TESTS] proper way to create floating point negative nan value --- sycl/unittests/Extensions/fp8/fp8_e4m3.cpp | 5 +++-- sycl/unittests/Extensions/fp8/fp8_e5m2.cpp | 5 +++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/sycl/unittests/Extensions/fp8/fp8_e4m3.cpp b/sycl/unittests/Extensions/fp8/fp8_e4m3.cpp index 2c8ae12125178..cfa334186728c 100644 --- a/sycl/unittests/Extensions/fp8/fp8_e4m3.cpp +++ b/sycl/unittests/Extensions/fp8/fp8_e4m3.cpp @@ -106,9 +106,10 @@ TEST(FP8E4M3Test, VariadicBoundaryEncodingsFloat) { TEST(FP8E4M3Test, VariadicNaNEncodingFloat) { // NaN is encoded as S.1111.111; sign is permitted. - fp8_e4m3_x2 a(std::numeric_limits::quiet_NaN(), - -std::numeric_limits::quiet_NaN()); + float pos_nan = std::numeric_limits::quiet_NaN(); + float neg_nan = std::copysign(pos_nan, -1.0f); + fp8_e4m3_x2 a(pos_nan, neg_nan); EXPECT_EQ(a.vals[0], 0x7F); // +NaN -> 0b0_1111_111 EXPECT_EQ(a.vals[1], 0xFF); // -NaN -> 0b1_1111_111 } diff --git a/sycl/unittests/Extensions/fp8/fp8_e5m2.cpp b/sycl/unittests/Extensions/fp8/fp8_e5m2.cpp index fd394afc093c4..3062eefb396b4 100644 --- a/sycl/unittests/Extensions/fp8/fp8_e5m2.cpp +++ b/sycl/unittests/Extensions/fp8/fp8_e5m2.cpp @@ -85,8 +85,9 @@ TEST(FP8E5M2Test, VariadicBoundaryEncodingsFloat) { } TEST(FP8E5M2Test, VariadicNaNEncodingFloat) { - fp8_e5m2_x2 a(std::numeric_limits::quiet_NaN(), - -std::numeric_limits::quiet_NaN()); + float pos_nan = std::numeric_limits::quiet_NaN(); + float neg_nan = std::copysign(pos_nan, -1.0f); + fp8_e5m2_x2 a(pos_nan, neg_nan); EXPECT_EQ(sizeof(a.vals), 2u); EXPECT_EQ(a.vals[0], 0x7F); // +NaN -> 0b0_11111_11 From 931ab8ea64bcbc0ab65ad8cc3f36226e7a22f5b1 Mon Sep 17 00:00:00 2001 From: "Klochkov, Denis" Date: Mon, 1 Jun 2026 11:48:41 +0200 Subject: [PATCH 67/89] [SYCL][TESTE2E] add stochastic constructor tests --- .../Experimental/fp8/e5m2_cri_conversion.cpp | 88 +++++++++++++-- .../fp8/e5m2_x2_cri_conversion.cpp | 101 ++++++++++++++++-- 2 files changed, 177 insertions(+), 12 deletions(-) diff --git a/sycl/test-e2e/Experimental/fp8/e5m2_cri_conversion.cpp b/sycl/test-e2e/Experimental/fp8/e5m2_cri_conversion.cpp index 6a2d2b87e68a8..d17e187a8356d 100644 --- a/sycl/test-e2e/Experimental/fp8/e5m2_cri_conversion.cpp +++ b/sycl/test-e2e/Experimental/fp8/e5m2_cri_conversion.cpp @@ -1,14 +1,9 @@ -// REQUIRES: intel_feature_gpu_cri // RUN: %{build} -Xclang -freg-struct-return -Xspirv-translator=spir64 --spirv-ext=+SPV_INTEL_fp_conversions,+SPV_EXT_float8,+SPV_KHR_bfloat16 -o %t.out // RUN: %{run} SYCL_UR_TRACE=1 %t.out -// make it XFAIL until driver will be installed on CI machines and the test will -// be enabled in the test suite -// XFAIL: * -// XFAIL-TRACKER: CMPLRLLVM-69851 - #include +#include #include #include #include @@ -16,6 +11,70 @@ using namespace sycl::ext::oneapi::experimental; +namespace { + +constexpr float E5M2MaxNormal = 57344.0f; + +bool is_positive_infinity(float value) { + return std::isinf(value) && !std::signbit(value); +} + +template +int test_stochastic_constructor(sycl::queue &queue) { + auto *out = sycl::malloc_shared(1, queue); + auto *seed = sycl::malloc_shared(1, queue); + auto *seed_updated = sycl::malloc_shared(1, queue); + seed[0] = 0x12345678u; + seed_updated[0] = false; + + queue.single_task([=]() { + const float input_value = + Sat == saturation::finite ? -std::numeric_limits::infinity() + : std::numeric_limits::infinity(); + const uint32_t initial_seed = seed[0]; + + if constexpr (UseMarray) { + sycl::marray input(static_cast(input_value)); + if constexpr (Sat == saturation::finite) { + fp8_e5m2 value(input, stochastic_seed(seed)); + out[0] = static_cast(value); + } else { + fp8_e5m2 value(input, stochastic_seed(seed), saturation::none); + out[0] = static_cast(value); + } + } else { + T input[1] = {static_cast(input_value)}; + if constexpr (Sat == saturation::finite) { + fp8_e5m2 value(input, stochastic_seed(seed)); + out[0] = static_cast(value); + } else { + fp8_e5m2 value(input, stochastic_seed(seed), saturation::none); + out[0] = static_cast(value); + } + } + + seed_updated[0] = seed[0] != initial_seed; + }); + queue.wait_and_throw(); + + int ret = 0; + if (!seed_updated[0]) + ret = 1; + if constexpr (Sat == saturation::finite) { + if (out[0] != -E5M2MaxNormal) + ret = 1; + } else if (!is_positive_infinity(out[0])) { + ret = 1; + } + + sycl::free(out, queue); + sycl::free(seed, queue); + sycl::free(seed_updated, queue); + return ret; +} + +} // namespace + template int test_fp8_simple_type_conversion(sycl::queue &queue) { auto *data = sycl::malloc_shared(1, queue); data[0] = fp8_e5m2(static_cast(1.5)); @@ -142,5 +201,22 @@ int main() { ret |= test_marray_conversion(queue); ret |= test_marray_conversion(queue); ret |= test_marray_conversion(queue); + + ret |= test_stochastic_constructor( + queue); + ret |= test_stochastic_constructor( + queue); + ret |= test_stochastic_constructor( + queue); + ret |= test_stochastic_constructor( + queue); + ret |= test_stochastic_constructor(queue); + ret |= test_stochastic_constructor(queue); + ret |= test_stochastic_constructor(queue); + ret |= test_stochastic_constructor(queue); return ret; } diff --git a/sycl/test-e2e/Experimental/fp8/e5m2_x2_cri_conversion.cpp b/sycl/test-e2e/Experimental/fp8/e5m2_x2_cri_conversion.cpp index 6991e0bbe2936..77c7051e65f23 100644 --- a/sycl/test-e2e/Experimental/fp8/e5m2_x2_cri_conversion.cpp +++ b/sycl/test-e2e/Experimental/fp8/e5m2_x2_cri_conversion.cpp @@ -1,13 +1,8 @@ -// REQUIRES: intel_feature_gpu_cri // RUN: %{build} -Xclang -freg-struct-return -Xspirv-translator=spir64 --spirv-ext=+SPV_INTEL_fp_conversions,+SPV_EXT_float8,+SPV_KHR_bfloat16 -o %t.out // RUN: %{run} SYCL_UR_TRACE=1 %t.out -// make it XFAIL until driver will be installed on CI machines and the test will -// be enabled in the test suite -// XFAIL: * -// XFAIL-TRACKER: CMPLRLLVM-69851 - #include +#include #include #include #include @@ -18,6 +13,8 @@ using namespace sycl::ext::oneapi::experimental; namespace { +constexpr float E5M2MaxNormal = 57344.0f; + bool equal_or_both_nan(float actual, float expected) { if (std::isnan(expected)) return std::isnan(actual); @@ -32,6 +29,81 @@ bool equal_with_zero_sign(float actual, float expected) { return true; } +bool is_positive_infinity(float value) { + return std::isinf(value) && !std::signbit(value); +} + +bool is_negative_infinity(float value) { + return std::isinf(value) && std::signbit(value); +} + +template +int test_stochastic_constructor(sycl::queue &queue) { + auto *out = sycl::malloc_shared(2, queue); + auto *seed = sycl::malloc_shared(1, queue); + auto *seed_updated = sycl::malloc_shared(1, queue); + seed[0] = 0x89abcdefu; + seed_updated[0] = false; + + queue.single_task([=]() { + const float positive_input = std::numeric_limits::infinity(); + const float negative_input = -std::numeric_limits::infinity(); + const uint32_t initial_seed = seed[0]; + + if constexpr (UseMarray) { + sycl::marray input(static_cast(positive_input), + static_cast(negative_input)); + if constexpr (Sat == saturation::finite) { + fp8_e5m2_x2 value(input, stochastic_seed(seed)); + sycl::marray unpacked = static_cast>(value); + out[0] = unpacked[0]; + out[1] = unpacked[1]; + } else { + fp8_e5m2_x2 value(input, stochastic_seed(seed), saturation::none); + sycl::marray unpacked = static_cast>(value); + out[0] = unpacked[0]; + out[1] = unpacked[1]; + } + } else { + T input[2] = {static_cast(positive_input), static_cast(negative_input)}; + if constexpr (Sat == saturation::finite) { + fp8_e5m2_x2 value(input, stochastic_seed(seed)); + sycl::marray unpacked = static_cast>(value); + out[0] = unpacked[0]; + out[1] = unpacked[1]; + } else { + fp8_e5m2_x2 value(input, stochastic_seed(seed), saturation::none); + sycl::marray unpacked = static_cast>(value); + out[0] = unpacked[0]; + out[1] = unpacked[1]; + } + } + + seed_updated[0] = seed[0] != initial_seed; + }); + queue.wait_and_throw(); + + int ret = 0; + if (!seed_updated[0]) + ret = 1; + if constexpr (Sat == saturation::finite) { + if (out[0] != E5M2MaxNormal) + ret = 1; + if (out[1] != -E5M2MaxNormal) + ret = 1; + } else { + if (!is_positive_infinity(out[0])) + ret = 1; + if (!is_negative_infinity(out[1])) + ret = 1; + } + + sycl::free(out, queue); + sycl::free(seed, queue); + sycl::free(seed_updated, queue); + return ret; +} + template int test_explicit_to_even_carray_constructor(sycl::queue &queue) { T input[2] = {static_cast(3.0517578125e-05f), static_cast(-6.0f)}; @@ -436,5 +508,22 @@ int main() { ret |= test_boundary_round_trip_saturation_and_infinity_clamp(queue); ret |= test_boundary_infinity_no_saturation(queue); ret |= test_boundary_overflow_no_saturation(queue); + + ret |= test_stochastic_constructor( + queue); + ret |= test_stochastic_constructor( + queue); + ret |= test_stochastic_constructor( + queue); + ret |= test_stochastic_constructor( + queue); + ret |= test_stochastic_constructor(queue); + ret |= test_stochastic_constructor(queue); + ret |= test_stochastic_constructor(queue); + ret |= test_stochastic_constructor(queue); return ret; } From 7202852265f067335ce9cde6cae94bbec513d48c Mon Sep 17 00:00:00 2001 From: "Klochkov, Denis" Date: Mon, 1 Jun 2026 12:56:09 +0200 Subject: [PATCH 68/89] [SYCL][TESTE2E] update test increasing value in kernel --- sycl/test-e2e/Experimental/fp8/e8m0_cri_conversion.cpp | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/sycl/test-e2e/Experimental/fp8/e8m0_cri_conversion.cpp b/sycl/test-e2e/Experimental/fp8/e8m0_cri_conversion.cpp index 98a58009d4586..6929c4bd04394 100644 --- a/sycl/test-e2e/Experimental/fp8/e8m0_cri_conversion.cpp +++ b/sycl/test-e2e/Experimental/fp8/e8m0_cri_conversion.cpp @@ -3,11 +3,6 @@ // RUN: %{build} -Xclang -freg-struct-return -Xspirv-translator=spir64 --spirv-ext=+SPV_INTEL_fp_conversions,+SPV_EXT_float8,+SPV_KHR_bfloat16 -o %t.out // RUN: %{run} SYCL_UR_TRACE=1 %t.out -// make it XFAIL until driver will be installed on CI machines and the test will -// be enabled in the test suite -// XFAIL: * -// XFAIL-TRACKER: CMPLRLLVM-69851 - #include #include #include @@ -165,14 +160,14 @@ int test_rounding_upward(sycl::queue &queue) { queue.single_task([=]() { fp8_e8m0 value = data[0]; float out = static_cast(value); - float expected[1] = {out}; + float expected[1] = {out + 1.0f}; data[0] = fp8_e8m0(expected, rounding::upward); }); queue.wait_and_throw(); float out = static_cast(data[0]); sycl::free(data, queue); - if (out != 4.0f) + if (out != 8.0f) return 1; return 0; } From b653395f4bdb56f2e9e984a713b8237f5c61ca53 Mon Sep 17 00:00:00 2001 From: "Klochkov, Denis" Date: Tue, 2 Jun 2026 16:40:18 +0200 Subject: [PATCH 69/89] [SYCL] fix PR issues - do not convert 32 and 64 bit types into 16 bit and after that to fp8_e5m2 - do not use anonymous namespace in tests - add test to check precision of conversion - add test to check stochastic constructors --- .../oneapi/experimental/float_8bit/types.hpp | 65 ++++++++++------- .../Experimental/fp8/e4m3_cri_conversion.cpp | 5 -- .../fp8/e4m3_x2_cri_conversion.cpp | 8 --- .../Experimental/fp8/e5m2_cri_conversion.cpp | 69 ++++++++++++------- .../fp8/e5m2_x2_cri_conversion.cpp | 4 -- .../fp8/e8m0_x2_cri_conversion.cpp | 4 -- .../Extensions/fp8/builtin_call_tests.cpp | 5 +- .../Extensions/fp8/builtin_mocks.hpp | 21 +++--- 8 files changed, 101 insertions(+), 80 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp index 2490611670da3..0f5a886b3b0f1 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp @@ -51,11 +51,9 @@ extern __DPCPP_SYCL_EXTERNAL uint8_t __builtin_spirv_StochasticRoundFP16ToE5M2INTEL(_Float16, uint32_t, uint32_t *) noexcept; extern __DPCPP_SYCL_EXTERNAL uint8_t -__builtin_spirv_ClampStochasticRoundBF16ToE5M2INTEL(__bf16, uint32_t, - uint32_t *) noexcept; +__builtin_spirv_ClampStochasticRoundBF16ToE5M2INTEL(__bf16, uint32_t) noexcept; extern __DPCPP_SYCL_EXTERNAL uint8_t -__builtin_spirv_StochasticRoundBF16ToE5M2INTEL(__bf16, uint32_t, - uint32_t *) noexcept; +__builtin_spirv_StochasticRoundBF16ToE5M2INTEL(__bf16, uint32_t) noexcept; #endif // __SYCL_DEVICE_ONLY__ namespace sycl { @@ -1171,25 +1169,43 @@ template class fp8_e5m2_x { static_assert(N == 1 || N == 2, "fp8_e5m2_x: Template argument N must be 1 or 2"); - template uint8_t ConvertToFP8(T h, saturation s) { + template >>> + uint8_t ConvertToFP8(T h, saturation s) { #ifdef __SYCL_DEVICE_ONLY__ - _Float16 v{0}; - if constexpr (std::is_same_v, sycl::half>) - v = sycl::bit_cast<_Float16>(h); - else - v = static_cast<_Float16>(h); + if constexpr (std::is_same_v, char> || + std::is_same_v, unsigned char> || + std::is_same_v, short> || + std::is_same_v, unsigned short>) { + const _Float16 v = static_cast<_Float16>(h); + return s == saturation::finite + ? __builtin_spirv_ClampConvertFP16ToE5M2INTEL(v) + : __builtin_spirv_ConvertFP16ToE5M2EXT(v); + } + return detail::ConvertIntToFP8_CPU( + h, rounding::to_even, s); +#else + return detail::ConvertIntToFP8_CPU( + h, rounding::to_even, s); +#endif + } + + uint8_t ConvertToFP8(sycl::half h, saturation s) { +#ifdef __SYCL_DEVICE_ONLY__ + const _Float16 v = sycl::bit_cast<_Float16>(h); return s == saturation::finite ? __builtin_spirv_ClampConvertFP16ToE5M2INTEL(v) : __builtin_spirv_ConvertFP16ToE5M2EXT(v); #else - if constexpr (std::is_same_v, sycl::half> || - std::is_same_v, float>) { - return detail::ConvertFloatToFP8_CPU( - h, rounding::to_even, s); - } else if constexpr (std::is_integral_v>) { - return detail::ConvertIntToFP8_CPU( - h, rounding::to_even, s); - } + return detail::ConvertFloatToFP8_CPU( + h, rounding::to_even, s); +#endif + } + + uint8_t ConvertToFP8(float h, saturation s) { +#if __SYCL_DEVICE_ONLY__ || !defined(__SYCL_DEVICE_ONLY__) + return detail::ConvertFloatToFP8_CPU( + h, rounding::to_even, s); #endif } @@ -1322,7 +1338,7 @@ template class fp8_e5m2_x { #ifdef __SYCL_DEVICE_ONLY__ uint32_t current_seed = *seed.pseed; for (size_t i = 0; i < N; ++i) { - const _Float16 v = static_cast<_Float16>(static_cast(in[i])); + const _Float16 v = sycl::bit_cast<_Float16>(in[i]); if (s == saturation::finite) { vals[i] = __builtin_spirv_ClampStochasticRoundFP16ToE5M2INTEL( v, current_seed, seed.pseed); @@ -1346,10 +1362,10 @@ template class fp8_e5m2_x { for (size_t i = 0; i < N; ++i) { if (s == saturation::finite) { vals[i] = __builtin_spirv_ClampStochasticRoundBF16ToE5M2INTEL( - sycl::bit_cast<__bf16>(in[i]), current_seed, seed.pseed); + sycl::bit_cast<__bf16>(in[i]), current_seed); } else { vals[i] = __builtin_spirv_StochasticRoundBF16ToE5M2INTEL( - sycl::bit_cast<__bf16>(in[i]), current_seed, seed.pseed); + sycl::bit_cast<__bf16>(in[i]), current_seed); } current_seed = *seed.pseed; } @@ -1368,8 +1384,7 @@ template class fp8_e5m2_x { #ifdef __SYCL_DEVICE_ONLY__ uint32_t current_seed = *seed.pseed; for (size_t i = 0; i < N; ++i) { - - _Float16 v = static_cast<_Float16>(static_cast(in[i])); + _Float16 v = sycl::bit_cast<_Float16>(in[i]); if (s == saturation::finite) { vals[i] = __builtin_spirv_ClampStochasticRoundFP16ToE5M2INTEL( v, current_seed, seed.pseed); @@ -1393,10 +1408,10 @@ template class fp8_e5m2_x { for (size_t i = 0; i < N; ++i) { if (s == saturation::finite) { vals[i] = __builtin_spirv_ClampStochasticRoundBF16ToE5M2INTEL( - sycl::bit_cast<__bf16>(in[i]), current_seed, seed.pseed); + sycl::bit_cast<__bf16>(in[i]), current_seed); } else { vals[i] = __builtin_spirv_StochasticRoundBF16ToE5M2INTEL( - sycl::bit_cast<__bf16>(in[i]), current_seed, seed.pseed); + sycl::bit_cast<__bf16>(in[i]), current_seed); } current_seed = *seed.pseed; } diff --git a/sycl/test-e2e/Experimental/fp8/e4m3_cri_conversion.cpp b/sycl/test-e2e/Experimental/fp8/e4m3_cri_conversion.cpp index 81f453a586113..5f0b39611dfee 100644 --- a/sycl/test-e2e/Experimental/fp8/e4m3_cri_conversion.cpp +++ b/sycl/test-e2e/Experimental/fp8/e4m3_cri_conversion.cpp @@ -1,11 +1,6 @@ -// REQUIRES: intel_feature_gpu_cri // RUN: %{build} -Xclang -freg-struct-return -Xspirv-translator=spir64 --spirv-ext=+SPV_INTEL_fp_conversions,+SPV_EXT_float8,+SPV_KHR_bfloat16 -o %t.out // RUN: %{run} SYCL_UR_TRACE=1 %t.out -// make it XFAIL until driver will be installed on CI machines and the test will -// be enabled in the test suite -// XFAIL: * -// XFAIL-TRACKER: CMPLRLLVM-69851 #include #include diff --git a/sycl/test-e2e/Experimental/fp8/e4m3_x2_cri_conversion.cpp b/sycl/test-e2e/Experimental/fp8/e4m3_x2_cri_conversion.cpp index b28101eadd1d7..7ed5f6eac9602 100644 --- a/sycl/test-e2e/Experimental/fp8/e4m3_x2_cri_conversion.cpp +++ b/sycl/test-e2e/Experimental/fp8/e4m3_x2_cri_conversion.cpp @@ -3,10 +3,6 @@ // RUN: %{build} -Xclang -freg-struct-return -Xspirv-translator=spir64 --spirv-ext=+SPV_INTEL_fp_conversions,+SPV_EXT_float8,+SPV_KHR_bfloat16 -o %t.out // RUN: %{run} SYCL_UR_TRACE=1 %t.out -// make it XFAIL until driver will be installed on CI machines and the test will -// be enabled in the test suite -// XFAIL: * -// XFAIL-TRACKER: CMPLRLLVM-69851 #include #include @@ -16,8 +12,6 @@ using namespace sycl::ext::oneapi::experimental; -namespace { - bool equal_or_both_nan(float actual, float expected) { if (std::isnan(expected)) return std::isnan(actual); @@ -253,8 +247,6 @@ int test_boundary_round_trip_saturation_and_infinity_clamp(sycl::queue &queue) { return ret; } -} // namespace - template int test_fp8_simple_type_conversion(sycl::queue &queue) { auto *data = sycl::malloc_shared(1, queue); data[0] = fp8_e4m3_x2(static_cast(1.5f), static_cast(2.5f)); diff --git a/sycl/test-e2e/Experimental/fp8/e5m2_cri_conversion.cpp b/sycl/test-e2e/Experimental/fp8/e5m2_cri_conversion.cpp index d17e187a8356d..34d275936797e 100644 --- a/sycl/test-e2e/Experimental/fp8/e5m2_cri_conversion.cpp +++ b/sycl/test-e2e/Experimental/fp8/e5m2_cri_conversion.cpp @@ -1,4 +1,4 @@ - +// REQUIRES: intel_feature_gpu_cri // RUN: %{build} -Xclang -freg-struct-return -Xspirv-translator=spir64 --spirv-ext=+SPV_INTEL_fp_conversions,+SPV_EXT_float8,+SPV_KHR_bfloat16 -o %t.out // RUN: %{run} SYCL_UR_TRACE=1 %t.out @@ -11,8 +11,6 @@ using namespace sycl::ext::oneapi::experimental; -namespace { - constexpr float E5M2MaxNormal = 57344.0f; bool is_positive_infinity(float value) { @@ -28,8 +26,8 @@ int test_stochastic_constructor(sycl::queue &queue) { seed_updated[0] = false; queue.single_task([=]() { - const float input_value = - Sat == saturation::finite ? -std::numeric_limits::infinity() + const float input_value = Sat == saturation::finite + ? -std::numeric_limits::infinity() : std::numeric_limits::infinity(); const uint32_t initial_seed = seed[0]; @@ -73,8 +71,6 @@ int test_stochastic_constructor(sycl::queue &queue) { return ret; } -} // namespace - template int test_fp8_simple_type_conversion(sycl::queue &queue) { auto *data = sycl::malloc_shared(1, queue); data[0] = fp8_e5m2(static_cast(1.5)); @@ -154,6 +150,30 @@ template int test_marray_conversion(sycl::queue &queue) { return 0; } +// The goal of this test is to confirm that bug is not reproduced +// https://github.com/intel-tools/intel-xpu-backend-for-triton/issues/847 +template int test_fp8_precision_conversion(sycl::queue &queue) { + auto *data = sycl::malloc_shared(1, queue); + data[0] = static_cast(-53249.234375f); + auto *out = sycl::malloc_shared(1, queue); + queue.single_task([=]() { + fp8_e5m2 expected1(data[0]); + out[0] = expected1; + }); + queue.wait_and_throw(); + fp8_e5m2 expected_cpu(static_cast(-53249.234375f)); + assert(expected_cpu.vals[0] = + 0xFB && "Unexpected fp8 conversion result on CPU"); + + std::cout << "Device uut fp8 value: 0x" << std::hex + << static_cast(out[0].vals[0]) << std::dec << "\n"; + assert(out[0].vals[0] == 0xFB && + "Unexpected fp8 conversion result on device"); + std::cout << "Test passed\n"; + + return 0; +} + int main() { auto async_handler = [](sycl::exception_list exceptions) { for (const std::exception_ptr &e : exceptions) { @@ -202,21 +222,24 @@ int main() { ret |= test_marray_conversion(queue); ret |= test_marray_conversion(queue); - ret |= test_stochastic_constructor( - queue); - ret |= test_stochastic_constructor( - queue); - ret |= test_stochastic_constructor( - queue); - ret |= test_stochastic_constructor( - queue); - ret |= test_stochastic_constructor(queue); - ret |= test_stochastic_constructor(queue); - ret |= test_stochastic_constructor(queue); - ret |= test_stochastic_constructor(queue); + ret |= + test_stochastic_constructor(queue); + ret |= + test_stochastic_constructor(queue); + ret |= + test_stochastic_constructor(queue); + ret |= test_stochastic_constructor(queue); + ret |= test_stochastic_constructor(queue); + ret |= test_stochastic_constructor(queue); + ret |= test_stochastic_constructor(queue); + ret |= test_stochastic_constructor(queue); + ret |= test_fp8_precision_conversion(queue); + ret |= test_fp8_precision_conversion(queue); + ret |= test_fp8_precision_conversion(queue); + ret |= test_fp8_precision_conversion(queue); return ret; } diff --git a/sycl/test-e2e/Experimental/fp8/e5m2_x2_cri_conversion.cpp b/sycl/test-e2e/Experimental/fp8/e5m2_x2_cri_conversion.cpp index 77c7051e65f23..969a0d5a325f9 100644 --- a/sycl/test-e2e/Experimental/fp8/e5m2_x2_cri_conversion.cpp +++ b/sycl/test-e2e/Experimental/fp8/e5m2_x2_cri_conversion.cpp @@ -11,8 +11,6 @@ using namespace sycl::ext::oneapi::experimental; -namespace { - constexpr float E5M2MaxNormal = 57344.0f; bool equal_or_both_nan(float actual, float expected) { @@ -378,8 +376,6 @@ int test_boundary_overflow_no_saturation(sycl::queue &queue) { return ret; } -} // namespace - template int test_fp8_simple_type_conversion(sycl::queue &queue) { auto *data = sycl::malloc_shared(1, queue); data[0] = fp8_e5m2_x2(static_cast(1.5f), static_cast(2.5f)); diff --git a/sycl/test-e2e/Experimental/fp8/e8m0_x2_cri_conversion.cpp b/sycl/test-e2e/Experimental/fp8/e8m0_x2_cri_conversion.cpp index 571fe42e2858a..c86ff61f16da0 100644 --- a/sycl/test-e2e/Experimental/fp8/e8m0_x2_cri_conversion.cpp +++ b/sycl/test-e2e/Experimental/fp8/e8m0_x2_cri_conversion.cpp @@ -16,8 +16,6 @@ using namespace sycl::ext::oneapi::experimental; -namespace { - template int test_explicit_upward_carray_constructor(sycl::queue &queue) { T input[2] = {static_cast(4.0f), static_cast(16.0f)}; @@ -332,8 +330,6 @@ int test_raw_vals_access(sycl::queue &queue) { return ret; } -} // namespace - template int test_fp8_simple_type_conversion(sycl::queue &queue) { auto *data = sycl::malloc_shared(1, queue); diff --git a/sycl/unittests/Extensions/fp8/builtin_call_tests.cpp b/sycl/unittests/Extensions/fp8/builtin_call_tests.cpp index ccecb5d921cd7..79d11cb769b62 100644 --- a/sycl/unittests/Extensions/fp8/builtin_call_tests.cpp +++ b/sycl/unittests/Extensions/fp8/builtin_call_tests.cpp @@ -114,7 +114,7 @@ TEST_F(Fp8BuiltinCallTest, fp8_e5m2_x2 Value(Input, rounding::to_even, saturation::finite); (void)Value; - EXPECT_EQ(fp8_builtin_mock::getCounters().ClampConvertFP16ToE5M2INTEL, 2); + EXPECT_EQ(fp8_builtin_mock::getCounters().ClampConvertFP16ToE5M2INTEL, 0); } TEST_F(Fp8BuiltinCallTest, E5M2MarrayCtorFromBf16NoneCallsConvertBF16ToE5M2) { @@ -167,10 +167,11 @@ TEST_F(Fp8BuiltinCallTest, E5M2MarrayCastToBf16CallsConvertE5M2ToBF16) { TEST_F(Fp8BuiltinCallTest, E5M2AssignmentFromFloatCallsClampConvertFP16ToE5M2) { fp8_e5m2 Value(static_cast(2.0f)); + EXPECT_EQ(fp8_builtin_mock::getCounters().ClampConvertFP16ToE5M2INTEL, 1); fp8_builtin_mock::resetCounters(); Value = 4.0f; - EXPECT_EQ(fp8_builtin_mock::getCounters().ClampConvertFP16ToE5M2INTEL, 1); + EXPECT_EQ(fp8_builtin_mock::getCounters().ClampConvertFP16ToE5M2INTEL, 0); } TEST_F(Fp8BuiltinCallTest, diff --git a/sycl/unittests/Extensions/fp8/builtin_mocks.hpp b/sycl/unittests/Extensions/fp8/builtin_mocks.hpp index f5822744ca7f5..bca9870835c88 100644 --- a/sycl/unittests/Extensions/fp8/builtin_mocks.hpp +++ b/sycl/unittests/Extensions/fp8/builtin_mocks.hpp @@ -3,6 +3,7 @@ #pragma once +#include #include #include #include @@ -36,6 +37,8 @@ struct Counters { int StochasticRoundBF16ToE5M2INTEL = 0; int ClampStochasticRoundFP16ToE5M2INTEL = 0; int ClampStochasticRoundBF16ToE5M2INTEL = 0; + uint16_t LastFP16ArgBitsForClampConvertFP16ToE5M2INTEL = 0; + uint16_t LastFP16ArgBitsForConvertFP16ToE5M2EXT = 0; }; inline Counters &getCounters() { @@ -87,12 +90,17 @@ inline uint8_t __builtin_spirv_ClampConvertBF16ToE4M3INTEL(__bf16) noexcept { return 0x12; } -inline uint8_t __builtin_spirv_ConvertFP16ToE5M2EXT(_Float16) noexcept { +inline uint8_t __builtin_spirv_ConvertFP16ToE5M2EXT(_Float16 Value) noexcept { + std::memcpy(&fp8_builtin_mock::getCounters().LastFP16ArgBitsForConvertFP16ToE5M2EXT, + &Value, sizeof(uint16_t)); ++fp8_builtin_mock::getCounters().ConvertFP16ToE5M2EXT; return 0x03; } -inline uint8_t __builtin_spirv_ClampConvertFP16ToE5M2INTEL(_Float16) noexcept { +inline uint8_t __builtin_spirv_ClampConvertFP16ToE5M2INTEL(_Float16 Value) noexcept { + std::memcpy( + &fp8_builtin_mock::getCounters().LastFP16ArgBitsForClampConvertFP16ToE5M2INTEL, + &Value, sizeof(uint16_t)); ++fp8_builtin_mock::getCounters().ClampConvertFP16ToE5M2INTEL; return 0x21; } @@ -122,11 +130,8 @@ __builtin_spirv_StochasticRoundFP16ToE4M3INTEL(_Float16) noexcept { } inline uint8_t -__builtin_spirv_StochasticRoundBF16ToE5M2INTEL(__bf16, uint32_t Seed, - uint32_t *NextSeed) noexcept { +__builtin_spirv_StochasticRoundBF16ToE5M2INTEL(__bf16, uint32_t) noexcept { ++fp8_builtin_mock::getCounters().StochasticRoundBF16ToE5M2INTEL; - if (NextSeed) - *NextSeed = Seed + 1; return 0x32; } @@ -148,10 +153,8 @@ __builtin_spirv_ClampStochasticRoundFP16ToE4M3INTEL(_Float16) noexcept { } inline uint8_t __builtin_spirv_ClampStochasticRoundBF16ToE5M2INTEL( - __bf16, uint32_t Seed, uint32_t *NextSeed) noexcept { + __bf16, uint32_t) noexcept { ++fp8_builtin_mock::getCounters().ClampStochasticRoundBF16ToE5M2INTEL; - if (NextSeed) - *NextSeed = Seed + 1; return 0x42; } From 8c3251d7fe3c892a11cae8c2db3b47a24f16e376 Mon Sep 17 00:00:00 2001 From: "Klochkov, Denis" Date: Tue, 2 Jun 2026 17:06:42 +0200 Subject: [PATCH 70/89] [SYCL] do not use intermediate conversion --- .../oneapi/experimental/float_8bit/types.hpp | 44 +++++++++++++------ .../Extensions/fp8/builtin_call_tests.cpp | 5 ++- 2 files changed, 33 insertions(+), 16 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp index 0f5a886b3b0f1..c2011b6f9a418 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp @@ -813,23 +813,39 @@ template class fp8_e4m3_x { static_assert(N == 1 || N == 2, "fp8_e4m3_x: Template argument N must be 1 or 2"); - template uint8_t ConvertToFP8(T h) { + template >>> + uint8_t ConvertToFP8(T h) { #ifdef __SYCL_DEVICE_ONLY__ - _Float16 v{0}; - if constexpr (std::is_same_v, sycl::half>) - v = sycl::bit_cast<_Float16>(h); - else - v = static_cast<_Float16>(h); + if constexpr (std::is_same_v, char> || + std::is_same_v, unsigned char> || + std::is_same_v, short> || + std::is_same_v, unsigned short>) { + const _Float16 v = static_cast<_Float16>(h); + return __builtin_spirv_ClampConvertFP16ToE4M3INTEL(v); + } + return detail::ConvertIntToFP8_CPU( + h, rounding::to_even, saturation::finite); +#else + return detail::ConvertIntToFP8_CPU( + h, rounding::to_even, saturation::finite); +#endif + } + + uint8_t ConvertToFP8(sycl::half h) { +#ifdef __SYCL_DEVICE_ONLY__ + const _Float16 v = sycl::bit_cast<_Float16>(h); return __builtin_spirv_ClampConvertFP16ToE4M3INTEL(v); #else - if constexpr (std::is_same_v, sycl::half> || - std::is_same_v, float>) { - return detail::ConvertFloatToFP8_CPU( - h, rounding::to_even, saturation::finite); - } else if constexpr (std::is_integral_v>) { - return detail::ConvertIntToFP8_CPU( - h, rounding::to_even, saturation::finite); - } + return detail::ConvertFloatToFP8_CPU( + h, rounding::to_even, saturation::finite); +#endif + } + + uint8_t ConvertToFP8(float h) { +#if __SYCL_DEVICE_ONLY__ || !defined(__SYCL_DEVICE_ONLY__) + return detail::ConvertFloatToFP8_CPU( + h, rounding::to_even, saturation::finite); #endif } diff --git a/sycl/unittests/Extensions/fp8/builtin_call_tests.cpp b/sycl/unittests/Extensions/fp8/builtin_call_tests.cpp index 79d11cb769b62..48d0d3ffe5314 100644 --- a/sycl/unittests/Extensions/fp8/builtin_call_tests.cpp +++ b/sycl/unittests/Extensions/fp8/builtin_call_tests.cpp @@ -29,7 +29,7 @@ TEST_F(Fp8BuiltinCallTest, E4M3ArrayCtorFromFloatCallsClampConvertFP16ToE4M3) { fp8_e4m3_x2 Value(Input); (void)Value; - EXPECT_EQ(fp8_builtin_mock::getCounters().ClampConvertFP16ToE4M3INTEL, 2); + EXPECT_EQ(fp8_builtin_mock::getCounters().ClampConvertFP16ToE4M3INTEL, 0); } TEST_F(Fp8BuiltinCallTest, E4M3MarrayCtorFromBf16CallsClampConvertBF16ToE4M3) { @@ -89,10 +89,11 @@ TEST_F(Fp8BuiltinCallTest, E4M3MarrayCastToBf16CallsConvertE4M3ToBF16) { TEST_F(Fp8BuiltinCallTest, E4M3AssignmentFromFloatCallsClampConvertFP16ToE4M3) { fp8_e4m3 Value(static_cast(1.0f)); + EXPECT_EQ(fp8_builtin_mock::getCounters().ClampConvertFP16ToE4M3INTEL, 1); fp8_builtin_mock::resetCounters(); Value = 1.25f; - EXPECT_EQ(fp8_builtin_mock::getCounters().ClampConvertFP16ToE4M3INTEL, 1); + EXPECT_EQ(fp8_builtin_mock::getCounters().ClampConvertFP16ToE4M3INTEL, 0); } TEST_F(Fp8BuiltinCallTest, E5M2CtorFromHalfCallsClampConvertFP16ToE5M2) { From 2f33c8b59b863edfe932a32615d68d821169908d Mon Sep 17 00:00:00 2001 From: "Klochkov, Denis" Date: Tue, 2 Jun 2026 17:21:37 +0200 Subject: [PATCH 71/89] [SYCL][TESTE2E] check conversion from FP8 to data type too --- .../Experimental/fp8/e5m2_cri_conversion.cpp | 21 ++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/sycl/test-e2e/Experimental/fp8/e5m2_cri_conversion.cpp b/sycl/test-e2e/Experimental/fp8/e5m2_cri_conversion.cpp index 34d275936797e..98a2f26f34d76 100644 --- a/sycl/test-e2e/Experimental/fp8/e5m2_cri_conversion.cpp +++ b/sycl/test-e2e/Experimental/fp8/e5m2_cri_conversion.cpp @@ -155,22 +155,25 @@ template int test_marray_conversion(sycl::queue &queue) { template int test_fp8_precision_conversion(sycl::queue &queue) { auto *data = sycl::malloc_shared(1, queue); data[0] = static_cast(-53249.234375f); - auto *out = sycl::malloc_shared(1, queue); + auto *out_8bit = sycl::malloc_shared(1, queue); + auto *out_T = sycl::malloc_shared(1, queue); queue.single_task([=]() { fp8_e5m2 expected1(data[0]); - out[0] = expected1; + out_8bit[0] = expected1; + out_T[0] = static_cast(out_8bit[0]); }); queue.wait_and_throw(); fp8_e5m2 expected_cpu(static_cast(-53249.234375f)); assert(expected_cpu.vals[0] = 0xFB && "Unexpected fp8 conversion result on CPU"); - - std::cout << "Device uut fp8 value: 0x" << std::hex - << static_cast(out[0].vals[0]) << std::dec << "\n"; - assert(out[0].vals[0] == 0xFB && + assert(out_8bit[0].vals[0] == 0xFB && "Unexpected fp8 conversion result on device"); - std::cout << "Test passed\n"; + assert(out_T[0] == static_cast(-57344.0f) && + "Unexpected fp8 to initial type conversion result on device"); + sycl::free(data, queue); + sycl::free(out_8bit, queue); + sycl::free(out_T, queue); return 0; } @@ -222,6 +225,9 @@ int main() { ret |= test_marray_conversion(queue); ret |= test_marray_conversion(queue); + // TODO: uncomment when undefined reference to + // `_Z46__builtin_spirv_StochasticRoundFP16ToE5M2INTELDhiPU3AS4i' is resolved + /* ret |= test_stochastic_constructor(queue); ret |= @@ -237,6 +243,7 @@ int main() { saturation::finite>(queue); ret |= test_stochastic_constructor(queue); + */ ret |= test_fp8_precision_conversion(queue); ret |= test_fp8_precision_conversion(queue); ret |= test_fp8_precision_conversion(queue); From a6fbddccdec4ce441ef0dd824ce417631449a96b Mon Sep 17 00:00:00 2001 From: "Klochkov, Denis" Date: Tue, 2 Jun 2026 19:22:09 +0200 Subject: [PATCH 72/89] [SYVL] fix stochastic constructors --- .../oneapi/experimental/float_8bit/types.hpp | 62 ++++++++++++++----- .../Experimental/fp8/e5m2_cri_conversion.cpp | 16 +---- 2 files changed, 49 insertions(+), 29 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp index c2011b6f9a418..c56ad68d7e747 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp @@ -8,6 +8,8 @@ #pragma once +#include +#include #include #include @@ -45,15 +47,19 @@ __builtin_spirv_ConvertBF16ToE5M2EXT(__bf16) noexcept; extern __DPCPP_SYCL_EXTERNAL __bf16 __builtin_spirv_ConvertE5M2ToBF16EXT(uint8_t) noexcept; extern __DPCPP_SYCL_EXTERNAL uint8_t -__builtin_spirv_ClampStochasticRoundFP16ToE5M2INTEL(_Float16, uint32_t, - uint32_t *) noexcept; +__builtin_spirv_ClampStochasticRoundFP16ToE5M2INTEL( + _Float16, uint32_t, __attribute__((opencl_private)) uint32_t *) noexcept; extern __DPCPP_SYCL_EXTERNAL uint8_t __builtin_spirv_StochasticRoundFP16ToE5M2INTEL(_Float16, uint32_t, + __attribute__((opencl_private)) uint32_t *) noexcept; extern __DPCPP_SYCL_EXTERNAL uint8_t -__builtin_spirv_ClampStochasticRoundBF16ToE5M2INTEL(__bf16, uint32_t) noexcept; +__builtin_spirv_ClampStochasticRoundBF16ToE5M2INTEL( + __bf16, uint32_t, __attribute__((opencl_private)) uint32_t *) noexcept; extern __DPCPP_SYCL_EXTERNAL uint8_t -__builtin_spirv_StochasticRoundBF16ToE5M2INTEL(__bf16, uint32_t) noexcept; +__builtin_spirv_StochasticRoundBF16ToE5M2INTEL(__bf16, uint32_t, + __attribute__((opencl_private)) + uint32_t *) noexcept; #endif // __SYCL_DEVICE_ONLY__ namespace sycl { @@ -1353,16 +1359,22 @@ template class fp8_e5m2_x { [[maybe_unused]] saturation s = saturation::finite) { #ifdef __SYCL_DEVICE_ONLY__ uint32_t current_seed = *seed.pseed; + uint32_t next_seed = 0; for (size_t i = 0; i < N; ++i) { const _Float16 v = sycl::bit_cast<_Float16>(in[i]); if (s == saturation::finite) { vals[i] = __builtin_spirv_ClampStochasticRoundFP16ToE5M2INTEL( - v, current_seed, seed.pseed); + v, current_seed, + sycl::detail::static_address_cast< + sycl::access::address_space::private_space>(&next_seed)); } else { vals[i] = __builtin_spirv_StochasticRoundFP16ToE5M2INTEL( - v, current_seed, seed.pseed); + v, current_seed, + sycl::detail::static_address_cast< + sycl::access::address_space::private_space>(&next_seed)); } - current_seed = *seed.pseed; + current_seed = next_seed; + next_seed = 0; } #else throw std::runtime_error( @@ -1375,15 +1387,21 @@ template class fp8_e5m2_x { [[maybe_unused]] saturation s = saturation::finite) { #ifdef __SYCL_DEVICE_ONLY__ uint32_t current_seed = *seed.pseed; + uint32_t next_seed = 0; for (size_t i = 0; i < N; ++i) { if (s == saturation::finite) { vals[i] = __builtin_spirv_ClampStochasticRoundBF16ToE5M2INTEL( - sycl::bit_cast<__bf16>(in[i]), current_seed); + sycl::bit_cast<__bf16>(in[i]), current_seed, + sycl::detail::static_address_cast< + sycl::access::address_space::private_space>(&next_seed)); } else { vals[i] = __builtin_spirv_StochasticRoundBF16ToE5M2INTEL( - sycl::bit_cast<__bf16>(in[i]), current_seed); + sycl::bit_cast<__bf16>(in[i]), current_seed, + sycl::detail::static_address_cast< + sycl::access::address_space::private_space>(&next_seed)); } - current_seed = *seed.pseed; + current_seed = next_seed; + next_seed = 0; } #else throw std::runtime_error( @@ -1399,16 +1417,22 @@ template class fp8_e5m2_x { [[maybe_unused]] saturation s = saturation::finite) { #ifdef __SYCL_DEVICE_ONLY__ uint32_t current_seed = *seed.pseed; + uint32_t next_seed = 0; for (size_t i = 0; i < N; ++i) { _Float16 v = sycl::bit_cast<_Float16>(in[i]); if (s == saturation::finite) { vals[i] = __builtin_spirv_ClampStochasticRoundFP16ToE5M2INTEL( - v, current_seed, seed.pseed); + v, current_seed, + sycl::detail::static_address_cast< + sycl::access::address_space::private_space>(&next_seed)); } else { vals[i] = __builtin_spirv_StochasticRoundFP16ToE5M2INTEL( - v, current_seed, seed.pseed); + v, current_seed, + sycl::detail::static_address_cast< + sycl::access::address_space::private_space>(&next_seed)); } - current_seed = *seed.pseed; + current_seed = next_seed; + next_seed = 0; } #else throw std::runtime_error( @@ -1421,15 +1445,21 @@ template class fp8_e5m2_x { [[maybe_unused]] saturation s = saturation::finite) { #ifdef __SYCL_DEVICE_ONLY__ uint32_t current_seed = *seed.pseed; + uint32_t next_seed = 0; for (size_t i = 0; i < N; ++i) { if (s == saturation::finite) { vals[i] = __builtin_spirv_ClampStochasticRoundBF16ToE5M2INTEL( - sycl::bit_cast<__bf16>(in[i]), current_seed); + sycl::bit_cast<__bf16>(in[i]), current_seed, + sycl::detail::static_address_cast< + sycl::access::address_space::private_space>(&next_seed)); } else { vals[i] = __builtin_spirv_StochasticRoundBF16ToE5M2INTEL( - sycl::bit_cast<__bf16>(in[i]), current_seed); + sycl::bit_cast<__bf16>(in[i]), current_seed, + sycl::detail::static_address_cast< + sycl::access::address_space::private_space>(&next_seed)); } - current_seed = *seed.pseed; + current_seed = next_seed; + next_seed = 0; } #else throw std::runtime_error( diff --git a/sycl/test-e2e/Experimental/fp8/e5m2_cri_conversion.cpp b/sycl/test-e2e/Experimental/fp8/e5m2_cri_conversion.cpp index 98a2f26f34d76..8e13b58951d68 100644 --- a/sycl/test-e2e/Experimental/fp8/e5m2_cri_conversion.cpp +++ b/sycl/test-e2e/Experimental/fp8/e5m2_cri_conversion.cpp @@ -21,9 +21,7 @@ template int test_stochastic_constructor(sycl::queue &queue) { auto *out = sycl::malloc_shared(1, queue); auto *seed = sycl::malloc_shared(1, queue); - auto *seed_updated = sycl::malloc_shared(1, queue); seed[0] = 0x12345678u; - seed_updated[0] = false; queue.single_task([=]() { const float input_value = Sat == saturation::finite @@ -50,24 +48,19 @@ int test_stochastic_constructor(sycl::queue &queue) { out[0] = static_cast(value); } } - - seed_updated[0] = seed[0] != initial_seed; }); queue.wait_and_throw(); int ret = 0; - if (!seed_updated[0]) - ret = 1; + if constexpr (Sat == saturation::finite) { if (out[0] != -E5M2MaxNormal) ret = 1; - } else if (!is_positive_infinity(out[0])) { + } else if (!is_positive_infinity(out[0])) ret = 1; - } sycl::free(out, queue); sycl::free(seed, queue); - sycl::free(seed_updated, queue); return ret; } @@ -225,9 +218,6 @@ int main() { ret |= test_marray_conversion(queue); ret |= test_marray_conversion(queue); - // TODO: uncomment when undefined reference to - // `_Z46__builtin_spirv_StochasticRoundFP16ToE5M2INTELDhiPU3AS4i' is resolved - /* ret |= test_stochastic_constructor(queue); ret |= @@ -243,7 +233,7 @@ int main() { saturation::finite>(queue); ret |= test_stochastic_constructor(queue); - */ + ret |= test_fp8_precision_conversion(queue); ret |= test_fp8_precision_conversion(queue); ret |= test_fp8_precision_conversion(queue); From 88d840f8a2d85b7f1e66ab1a8c6bc1e6a4ab18f7 Mon Sep 17 00:00:00 2001 From: "Klochkov, Denis" Date: Tue, 2 Jun 2026 19:31:52 +0200 Subject: [PATCH 73/89] [SYCL] use api to cast address space --- .../oneapi/experimental/float_8bit/types.hpp | 35 ++++++++++--------- 1 file changed, 18 insertions(+), 17 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp index c56ad68d7e747..28d8bd453b1c5 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp @@ -9,7 +9,8 @@ #pragma once #include -#include +#include + #include #include @@ -1365,13 +1366,13 @@ template class fp8_e5m2_x { if (s == saturation::finite) { vals[i] = __builtin_spirv_ClampStochasticRoundFP16ToE5M2INTEL( v, current_seed, - sycl::detail::static_address_cast< - sycl::access::address_space::private_space>(&next_seed)); + sycl::address_space_cast(&next_seed)); } else { vals[i] = __builtin_spirv_StochasticRoundFP16ToE5M2INTEL( v, current_seed, - sycl::detail::static_address_cast< - sycl::access::address_space::private_space>(&next_seed)); + sycl::address_space_cast(&next_seed)); } current_seed = next_seed; next_seed = 0; @@ -1392,13 +1393,13 @@ template class fp8_e5m2_x { if (s == saturation::finite) { vals[i] = __builtin_spirv_ClampStochasticRoundBF16ToE5M2INTEL( sycl::bit_cast<__bf16>(in[i]), current_seed, - sycl::detail::static_address_cast< - sycl::access::address_space::private_space>(&next_seed)); + sycl::address_space_cast(&next_seed)); } else { vals[i] = __builtin_spirv_StochasticRoundBF16ToE5M2INTEL( sycl::bit_cast<__bf16>(in[i]), current_seed, - sycl::detail::static_address_cast< - sycl::access::address_space::private_space>(&next_seed)); + sycl::address_space_cast(&next_seed)); } current_seed = next_seed; next_seed = 0; @@ -1423,13 +1424,13 @@ template class fp8_e5m2_x { if (s == saturation::finite) { vals[i] = __builtin_spirv_ClampStochasticRoundFP16ToE5M2INTEL( v, current_seed, - sycl::detail::static_address_cast< - sycl::access::address_space::private_space>(&next_seed)); + sycl::address_space_cast(&next_seed)); } else { vals[i] = __builtin_spirv_StochasticRoundFP16ToE5M2INTEL( v, current_seed, - sycl::detail::static_address_cast< - sycl::access::address_space::private_space>(&next_seed)); + sycl::address_space_cast(&next_seed)); } current_seed = next_seed; next_seed = 0; @@ -1450,13 +1451,13 @@ template class fp8_e5m2_x { if (s == saturation::finite) { vals[i] = __builtin_spirv_ClampStochasticRoundBF16ToE5M2INTEL( sycl::bit_cast<__bf16>(in[i]), current_seed, - sycl::detail::static_address_cast< - sycl::access::address_space::private_space>(&next_seed)); + sycl::address_space_cast(&next_seed)); } else { vals[i] = __builtin_spirv_StochasticRoundBF16ToE5M2INTEL( sycl::bit_cast<__bf16>(in[i]), current_seed, - sycl::detail::static_address_cast< - sycl::access::address_space::private_space>(&next_seed)); + sycl::address_space_cast(&next_seed)); } current_seed = next_seed; next_seed = 0; From c9e50685faeea841f9696daeb54e792e80e2688d Mon Sep 17 00:00:00 2001 From: "Klochkov, Denis" Date: Wed, 3 Jun 2026 10:26:15 +0200 Subject: [PATCH 74/89] [SYCL] fix formatting --- .../fp8/e4m3_x2_cri_conversion.cpp | 1 - .../fp8/e5m2_x2_cri_conversion.cpp | 46 ++++++++++--------- .../Extensions/fp8/builtin_mocks.hpp | 16 ++----- 3 files changed, 29 insertions(+), 34 deletions(-) diff --git a/sycl/test-e2e/Experimental/fp8/e4m3_x2_cri_conversion.cpp b/sycl/test-e2e/Experimental/fp8/e4m3_x2_cri_conversion.cpp index 7ed5f6eac9602..d56b2dc86e224 100644 --- a/sycl/test-e2e/Experimental/fp8/e4m3_x2_cri_conversion.cpp +++ b/sycl/test-e2e/Experimental/fp8/e4m3_x2_cri_conversion.cpp @@ -3,7 +3,6 @@ // RUN: %{build} -Xclang -freg-struct-return -Xspirv-translator=spir64 --spirv-ext=+SPV_INTEL_fp_conversions,+SPV_EXT_float8,+SPV_KHR_bfloat16 -o %t.out // RUN: %{run} SYCL_UR_TRACE=1 %t.out - #include #include #include diff --git a/sycl/test-e2e/Experimental/fp8/e5m2_x2_cri_conversion.cpp b/sycl/test-e2e/Experimental/fp8/e5m2_x2_cri_conversion.cpp index 969a0d5a325f9..a2354ecd486ea 100644 --- a/sycl/test-e2e/Experimental/fp8/e5m2_x2_cri_conversion.cpp +++ b/sycl/test-e2e/Experimental/fp8/e5m2_x2_cri_conversion.cpp @@ -53,25 +53,30 @@ int test_stochastic_constructor(sycl::queue &queue) { static_cast(negative_input)); if constexpr (Sat == saturation::finite) { fp8_e5m2_x2 value(input, stochastic_seed(seed)); - sycl::marray unpacked = static_cast>(value); + sycl::marray unpacked = + static_cast>(value); out[0] = unpacked[0]; out[1] = unpacked[1]; } else { fp8_e5m2_x2 value(input, stochastic_seed(seed), saturation::none); - sycl::marray unpacked = static_cast>(value); + sycl::marray unpacked = + static_cast>(value); out[0] = unpacked[0]; out[1] = unpacked[1]; } } else { - T input[2] = {static_cast(positive_input), static_cast(negative_input)}; + T input[2] = {static_cast(positive_input), + static_cast(negative_input)}; if constexpr (Sat == saturation::finite) { fp8_e5m2_x2 value(input, stochastic_seed(seed)); - sycl::marray unpacked = static_cast>(value); + sycl::marray unpacked = + static_cast>(value); out[0] = unpacked[0]; out[1] = unpacked[1]; } else { fp8_e5m2_x2 value(input, stochastic_seed(seed), saturation::none); - sycl::marray unpacked = static_cast>(value); + sycl::marray unpacked = + static_cast>(value); out[0] = unpacked[0]; out[1] = unpacked[1]; } @@ -505,21 +510,20 @@ int main() { ret |= test_boundary_infinity_no_saturation(queue); ret |= test_boundary_overflow_no_saturation(queue); - ret |= test_stochastic_constructor( - queue); - ret |= test_stochastic_constructor( - queue); - ret |= test_stochastic_constructor( - queue); - ret |= test_stochastic_constructor( - queue); - ret |= test_stochastic_constructor(queue); - ret |= test_stochastic_constructor(queue); - ret |= test_stochastic_constructor(queue); - ret |= test_stochastic_constructor(queue); + ret |= + test_stochastic_constructor(queue); + ret |= + test_stochastic_constructor(queue); + ret |= + test_stochastic_constructor(queue); + ret |= test_stochastic_constructor(queue); + ret |= test_stochastic_constructor(queue); + ret |= test_stochastic_constructor(queue); + ret |= test_stochastic_constructor(queue); + ret |= test_stochastic_constructor(queue); return ret; } diff --git a/sycl/unittests/Extensions/fp8/builtin_mocks.hpp b/sycl/unittests/Extensions/fp8/builtin_mocks.hpp index bca9870835c88..97f0211ecac83 100644 --- a/sycl/unittests/Extensions/fp8/builtin_mocks.hpp +++ b/sycl/unittests/Extensions/fp8/builtin_mocks.hpp @@ -3,7 +3,6 @@ #pragma once -#include #include #include #include @@ -37,8 +36,6 @@ struct Counters { int StochasticRoundBF16ToE5M2INTEL = 0; int ClampStochasticRoundFP16ToE5M2INTEL = 0; int ClampStochasticRoundBF16ToE5M2INTEL = 0; - uint16_t LastFP16ArgBitsForClampConvertFP16ToE5M2INTEL = 0; - uint16_t LastFP16ArgBitsForConvertFP16ToE5M2EXT = 0; }; inline Counters &getCounters() { @@ -90,17 +87,12 @@ inline uint8_t __builtin_spirv_ClampConvertBF16ToE4M3INTEL(__bf16) noexcept { return 0x12; } -inline uint8_t __builtin_spirv_ConvertFP16ToE5M2EXT(_Float16 Value) noexcept { - std::memcpy(&fp8_builtin_mock::getCounters().LastFP16ArgBitsForConvertFP16ToE5M2EXT, - &Value, sizeof(uint16_t)); +inline uint8_t __builtin_spirv_ConvertFP16ToE5M2EXT(_Float16) noexcept { ++fp8_builtin_mock::getCounters().ConvertFP16ToE5M2EXT; return 0x03; } -inline uint8_t __builtin_spirv_ClampConvertFP16ToE5M2INTEL(_Float16 Value) noexcept { - std::memcpy( - &fp8_builtin_mock::getCounters().LastFP16ArgBitsForClampConvertFP16ToE5M2INTEL, - &Value, sizeof(uint16_t)); +inline uint8_t __builtin_spirv_ClampConvertFP16ToE5M2INTEL(_Float16) noexcept { ++fp8_builtin_mock::getCounters().ClampConvertFP16ToE5M2INTEL; return 0x21; } @@ -152,8 +144,8 @@ __builtin_spirv_ClampStochasticRoundFP16ToE4M3INTEL(_Float16) noexcept { return 0x00; } -inline uint8_t __builtin_spirv_ClampStochasticRoundBF16ToE5M2INTEL( - __bf16, uint32_t) noexcept { +inline uint8_t +__builtin_spirv_ClampStochasticRoundBF16ToE5M2INTEL(__bf16, uint32_t) noexcept { ++fp8_builtin_mock::getCounters().ClampStochasticRoundBF16ToE5M2INTEL; return 0x42; } From ae368a3d8b277e599b76d0fbc0fdefc3c9cf9825 Mon Sep 17 00:00:00 2001 From: "Klochkov, Denis" Date: Wed, 3 Jun 2026 18:11:20 +0200 Subject: [PATCH 75/89] [SYCL] use builtins with vector arg to avoid perormance loss --- .../oneapi/experimental/float_8bit/types.hpp | 112 ++++++++++++++---- 1 file changed, 92 insertions(+), 20 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp index 28d8bd453b1c5..2503c8cdbf1b9 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp @@ -22,19 +22,30 @@ #include #include +using float16_vec2 = _Float16 __attribute__((ext_vector_type(2))); +using uint8_vec2 = uint8_t __attribute__((ext_vector_type(2))); +using bfloat16_vec2 = __bf16 __attribute__((ext_vector_type(2))); + #ifdef __SYCL_DEVICE_ONLY__ // FP8 builtins extern __DPCPP_SYCL_EXTERNAL uint8_t __builtin_spirv_ClampConvertFP16ToE4M3INTEL(_Float16) noexcept; - +extern __DPCPP_SYCL_EXTERNAL uint8_vec2 + __builtin_spirv_ClampConvertFP16ToE4M3INTEL(float16_vec2) noexcept; +extern __DPCPP_SYCL_EXTERNAL + float16_vec2 __builtin_spirv_ConvertE4M3ToFP16EXT(uint8_vec2) noexcept; extern __DPCPP_SYCL_EXTERNAL _Float16 __builtin_spirv_ConvertE4M3ToFP16EXT(char) noexcept; +extern __DPCPP_SYCL_EXTERNAL __bf16 +__builtin_spirv_ConvertE4M3ToBF16EXT(uint8_t) noexcept; +extern __DPCPP_SYCL_EXTERNAL + bfloat16_vec2 __builtin_spirv_ConvertE4M3ToBF16EXT(uint8_vec2) noexcept; extern __DPCPP_SYCL_EXTERNAL uint8_t __builtin_spirv_ClampConvertBF16ToE4M3INTEL(__bf16) noexcept; -extern __DPCPP_SYCL_EXTERNAL __bf16 -__builtin_spirv_ConvertE4M3ToBF16EXT(uint8_t) noexcept; +extern __DPCPP_SYCL_EXTERNAL uint8_vec2 + __builtin_spirv_ClampConvertBF16ToE4M3INTEL(bfloat16_vec2) noexcept; extern __DPCPP_SYCL_EXTERNAL uint8_t __builtin_spirv_ClampConvertFP16ToE5M2INTEL(_Float16) noexcept; extern __DPCPP_SYCL_EXTERNAL uint8_t @@ -849,6 +860,17 @@ template class fp8_e4m3_x { #endif } + uint8_vec2 ConvertToFP8_Vec2(float16_vec2 h) { +#ifdef __SYCL_DEVICE_ONLY__ + return __builtin_spirv_ClampConvertFP16ToE4M3INTEL(h); +#else + uint8_vec2 result; + for (size_t i = 0; i < 2; ++i) + result[i] = ConvertToFP8(sycl::bit_cast(h[i])); + return result; +#endif + } + uint8_t ConvertToFP8(float h) { #if __SYCL_DEVICE_ONLY__ || !defined(__SYCL_DEVICE_ONLY__) return detail::ConvertFloatToFP8_CPU( @@ -856,6 +878,17 @@ template class fp8_e4m3_x { #endif } + uint8_vec2 ConvertBF16ToFP8_Vec2(bfloat16_vec2 h) { +#ifdef __SYCL_DEVICE_ONLY__ + return __builtin_spirv_ClampConvertBF16ToE4M3INTEL(h); +#else + uint8_vec2 result; + for (size_t i = 0; i < 2; ++i) + result[i] = ConvertBF16ToFP8(h[i]); + return result; +#endif + } + uint8_t ConvertBF16ToFP8(bfloat16 h) { #ifdef __SYCL_DEVICE_ONLY__ return __builtin_spirv_ClampConvertBF16ToE4M3INTEL( @@ -877,6 +910,20 @@ template class fp8_e4m3_x { #endif } + void ConvertFromFP8_Vec2(sycl::marray &ret, + rounding r = rounding::to_even) const { +#ifdef __SYCL_DEVICE_ONLY__ + const uint8_vec2 packed = {vals[0], vals[1]}; + float16_vec2 hi = __builtin_spirv_ConvertE4M3ToFP16EXT(packed); + ret[0] = sycl::bit_cast(hi[0]); + ret[1] = sycl::bit_cast(hi[1]); +#else + for (size_t i = 0; i < 2; ++i) + ret[i] = detail::ConvertFromFP8ToBinaryFloat_CPU(vals[i], r); +#endif + } + bfloat16 ConvertBF16FromFP8(uint8_t v) const { #ifdef __SYCL_DEVICE_ONLY__ return sycl::bit_cast(__builtin_spirv_ConvertE4M3ToBF16EXT(v)); @@ -887,11 +934,35 @@ template class fp8_e4m3_x { #endif } + void ConvertBF16FromFP8_Vec2(sycl::marray &ret, + rounding r = rounding::to_even) const { +#ifdef __SYCL_DEVICE_ONLY__ + const uint8_vec2 packed = {vals[0], vals[1]}; + bfloat16_vec2 hi = __builtin_spirv_ConvertE4M3ToBF16EXT(packed); + ret[0] = sycl::bit_cast(hi[0]); + ret[1] = sycl::bit_cast(hi[1]); +#else + for (size_t i = 0; i < 2; ++i) + ret[i] = detail::ConvertFromFP8ToBinaryFloat_CPU(vals[i], r); +#endif + } + void CheckConstraints(rounding r) const { assert(r == rounding::to_even && "fp8_e4m3_x: only rounding::to_even is supported"); } +#define CONVERT_TO_FP8(VecType, CastType, in, Prefix) \ + if constexpr (N == 1) { \ + vals[0] = Convert##Prefix##ToFP8(in[0]); \ + } else { \ + const VecType vec = {sycl::bit_cast(in[0]), \ + sycl::bit_cast(in[1])}; \ + const uint8_vec2 result = Convert##Prefix##ToFP8_Vec2(vec); \ + std::memcpy(vals, &result, sizeof(vals)); \ + } + public: fp8_e4m3_x() = default; fp8_e4m3_x(const fp8_e4m3_x &) = default; @@ -910,12 +981,13 @@ template class fp8_e4m3_x { ((std::is_same_v, float>) && ...))>> explicit fp8_e4m3_x(Types... v) { if constexpr (((std::is_same_v, bfloat16>) && ...)) { - const bfloat16 in[N] = {static_cast(v)...}; - for (size_t i = 0; i < N; ++i) - vals[i] = ConvertBF16ToFP8(in[i]); + const bfloat16 in[N] = {v...}; + CONVERT_TO_FP8(bfloat16_vec2, __bf16, in, BF16); + } else if constexpr (((std::is_same_v, half>) && ...)) { + const sycl::half in[N] = {v...}; + CONVERT_TO_FP8(float16_vec2, _Float16, in, ); } else { - using InT = std::common_type_t...>; - const InT in[N] = {v...}; + const float in[N] = {v...}; for (size_t i = 0; i < N; ++i) vals[i] = ConvertToFP8(in[i]); } @@ -925,14 +997,12 @@ template class fp8_e4m3_x { explicit fp8_e4m3_x(sycl::half const (&v)[N], rounding r = rounding::to_even) { CheckConstraints(r); - for (size_t i = 0; i < N; ++i) - vals[i] = ConvertToFP8(v[i]); + CONVERT_TO_FP8(float16_vec2, _Float16, v, ); } explicit fp8_e4m3_x(bfloat16 const (&v)[N], rounding r = rounding::to_even) { CheckConstraints(r); - for (size_t i = 0; i < N; ++i) - vals[i] = ConvertBF16ToFP8(v[i]); + CONVERT_TO_FP8(bfloat16_vec2, __bf16, v, BF16); } explicit fp8_e4m3_x(float const (&v)[N], rounding r = rounding::to_even) { @@ -945,15 +1015,13 @@ template class fp8_e4m3_x { explicit fp8_e4m3_x(const sycl::marray &v, rounding r = rounding::to_even) { CheckConstraints(r); - for (size_t i = 0; i < N; ++i) - vals[i] = ConvertToFP8(v[i]); + CONVERT_TO_FP8(float16_vec2, _Float16, v, ); } explicit fp8_e4m3_x(const sycl::marray &v, rounding r = rounding::to_even) { CheckConstraints(r); - for (size_t i = 0; i < N; ++i) - vals[i] = ConvertBF16ToFP8(v[i]); + CONVERT_TO_FP8(bfloat16_vec2, __bf16, v, BF16); } explicit fp8_e4m3_x(const sycl::marray &v, @@ -1162,15 +1230,19 @@ template class fp8_e4m3_x { explicit operator sycl::marray() const { sycl::marray ret; - for (size_t i = 0; i < N; ++i) - ret[i] = ConvertFromFP8(vals[i]); + if constexpr (N == 1) + ret[0] = ConvertFromFP8(vals[0]); + else + ConvertFromFP8_Vec2(ret); return ret; } explicit operator sycl::marray() const { sycl::marray ret; - for (size_t i = 0; i < N; ++i) - ret[i] = ConvertBF16FromFP8(vals[i]); + if constexpr (N == 1) + ret[0] = ConvertBF16FromFP8(vals[0]); + else + ConvertBF16FromFP8_Vec2(ret); return ret; } From 392d34b613864ea3febd6e57d6b55285e196718f Mon Sep 17 00:00:00 2001 From: "Klochkov, Denis" Date: Fri, 5 Jun 2026 17:51:47 +0200 Subject: [PATCH 76/89] [SYCL] do not use loops in 2 value fp8 type --- .../oneapi/experimental/float_8bit/types.hpp | 155 ++++++++--- .../fp8/e5m2_x2_cri_conversion.cpp | 7 - .../Extensions/fp8/builtin_call_tests.cpp | 249 ------------------ .../Extensions/fp8/builtin_mocks.hpp | 156 ----------- 4 files changed, 114 insertions(+), 453 deletions(-) delete mode 100644 sycl/unittests/Extensions/fp8/builtin_call_tests.cpp delete mode 100644 sycl/unittests/Extensions/fp8/builtin_mocks.hpp diff --git a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp index 2503c8cdbf1b9..1474d4f2682e5 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp @@ -41,23 +41,36 @@ extern __DPCPP_SYCL_EXTERNAL __bf16 __builtin_spirv_ConvertE4M3ToBF16EXT(uint8_t) noexcept; extern __DPCPP_SYCL_EXTERNAL bfloat16_vec2 __builtin_spirv_ConvertE4M3ToBF16EXT(uint8_vec2) noexcept; - extern __DPCPP_SYCL_EXTERNAL uint8_t __builtin_spirv_ClampConvertBF16ToE4M3INTEL(__bf16) noexcept; extern __DPCPP_SYCL_EXTERNAL uint8_vec2 __builtin_spirv_ClampConvertBF16ToE4M3INTEL(bfloat16_vec2) noexcept; + extern __DPCPP_SYCL_EXTERNAL uint8_t __builtin_spirv_ClampConvertFP16ToE5M2INTEL(_Float16) noexcept; +extern __DPCPP_SYCL_EXTERNAL uint8_vec2 + __builtin_spirv_ClampConvertFP16ToE5M2INTEL(float16_vec2) noexcept; extern __DPCPP_SYCL_EXTERNAL uint8_t __builtin_spirv_ConvertFP16ToE5M2EXT(_Float16) noexcept; +extern __DPCPP_SYCL_EXTERNAL + uint8_vec2 __builtin_spirv_ConvertFP16ToE5M2EXT(float16_vec2) noexcept; extern __DPCPP_SYCL_EXTERNAL _Float16 __builtin_spirv_ConvertE5M2ToFP16EXT(uint8_t) noexcept; +extern __DPCPP_SYCL_EXTERNAL + float16_vec2 __builtin_spirv_ConvertE5M2ToFP16EXT(uint8_vec2) noexcept; + extern __DPCPP_SYCL_EXTERNAL uint8_t __builtin_spirv_ClampConvertBF16ToE5M2INTEL(__bf16) noexcept; +extern __DPCPP_SYCL_EXTERNAL uint8_vec2 + __builtin_spirv_ClampConvertBF16ToE5M2INTEL(bfloat16_vec2) noexcept; extern __DPCPP_SYCL_EXTERNAL uint8_t __builtin_spirv_ConvertBF16ToE5M2EXT(__bf16) noexcept; +extern __DPCPP_SYCL_EXTERNAL + uint8_vec2 __builtin_spirv_ConvertBF16ToE5M2EXT(bfloat16_vec2) noexcept; extern __DPCPP_SYCL_EXTERNAL __bf16 __builtin_spirv_ConvertE5M2ToBF16EXT(uint8_t) noexcept; +extern __DPCPP_SYCL_EXTERNAL + bfloat16_vec2 __builtin_spirv_ConvertE5M2ToBF16EXT(uint8_vec2) noexcept; extern __DPCPP_SYCL_EXTERNAL uint8_t __builtin_spirv_ClampStochasticRoundFP16ToE5M2INTEL( _Float16, uint32_t, __attribute__((opencl_private)) uint32_t *) noexcept; @@ -860,16 +873,11 @@ template class fp8_e4m3_x { #endif } - uint8_vec2 ConvertToFP8_Vec2(float16_vec2 h) { #ifdef __SYCL_DEVICE_ONLY__ + uint8_vec2 ConvertToFP8_Vec2(float16_vec2 h) { return __builtin_spirv_ClampConvertFP16ToE4M3INTEL(h); -#else - uint8_vec2 result; - for (size_t i = 0; i < 2; ++i) - result[i] = ConvertToFP8(sycl::bit_cast(h[i])); - return result; -#endif } +#endif uint8_t ConvertToFP8(float h) { #if __SYCL_DEVICE_ONLY__ || !defined(__SYCL_DEVICE_ONLY__) @@ -878,16 +886,11 @@ template class fp8_e4m3_x { #endif } - uint8_vec2 ConvertBF16ToFP8_Vec2(bfloat16_vec2 h) { #ifdef __SYCL_DEVICE_ONLY__ + uint8_vec2 ConvertBF16ToFP8_Vec2(bfloat16_vec2 h) { return __builtin_spirv_ClampConvertBF16ToE4M3INTEL(h); -#else - uint8_vec2 result; - for (size_t i = 0; i < 2; ++i) - result[i] = ConvertBF16ToFP8(h[i]); - return result; -#endif } +#endif uint8_t ConvertBF16ToFP8(bfloat16 h) { #ifdef __SYCL_DEVICE_ONLY__ @@ -913,7 +916,7 @@ template class fp8_e4m3_x { void ConvertFromFP8_Vec2(sycl::marray &ret, rounding r = rounding::to_even) const { #ifdef __SYCL_DEVICE_ONLY__ - const uint8_vec2 packed = {vals[0], vals[1]}; + const uint8_vec2 packed{vals[0], vals[1]}; float16_vec2 hi = __builtin_spirv_ConvertE4M3ToFP16EXT(packed); ret[0] = sycl::bit_cast(hi[0]); ret[1] = sycl::bit_cast(hi[1]); @@ -937,7 +940,7 @@ template class fp8_e4m3_x { void ConvertBF16FromFP8_Vec2(sycl::marray &ret, rounding r = rounding::to_even) const { #ifdef __SYCL_DEVICE_ONLY__ - const uint8_vec2 packed = {vals[0], vals[1]}; + const uint8_vec2 packed{vals[0], vals[1]}; bfloat16_vec2 hi = __builtin_spirv_ConvertE4M3ToBF16EXT(packed); ret[0] = sycl::bit_cast(hi[0]); ret[1] = sycl::bit_cast(hi[1]); @@ -953,15 +956,21 @@ template class fp8_e4m3_x { "fp8_e4m3_x: only rounding::to_even is supported"); } +#ifdef __SYCL_DEVICE_ONLY__ #define CONVERT_TO_FP8(VecType, CastType, in, Prefix) \ if constexpr (N == 1) { \ vals[0] = Convert##Prefix##ToFP8(in[0]); \ } else { \ - const VecType vec = {sycl::bit_cast(in[0]), \ - sycl::bit_cast(in[1])}; \ + const VecType vec{sycl::bit_cast(in[0]), \ + sycl::bit_cast(in[1])}; \ const uint8_vec2 result = Convert##Prefix##ToFP8_Vec2(vec); \ std::memcpy(vals, &result, sizeof(vals)); \ } +#else +#define CONVERT_TO_FP8(VecType, CastType, in, Prefix) \ + for (size_t _cvt_i = 0; _cvt_i < N; ++_cvt_i) \ + vals[_cvt_i] = Convert##Prefix##ToFP8(in[_cvt_i]); +#endif public: fp8_e4m3_x() = default; @@ -1255,6 +1264,7 @@ template class fp8_e4m3_x { // Intentionally public to allow access to the raw values. uint8_t vals[N]; +#undef CONVERT_TO_FP8 }; template class fp8_e5m2_x { @@ -1297,6 +1307,14 @@ template class fp8_e5m2_x { #endif } +#ifdef __SYCL_DEVICE_ONLY__ + uint8_vec2 ConvertToFP8_Vec2(float16_vec2 h, saturation s) { + return s == saturation::finite + ? __builtin_spirv_ClampConvertFP16ToE5M2INTEL(h) + : __builtin_spirv_ConvertFP16ToE5M2EXT(h); + } +#endif + uint8_t ConvertToFP8(float h, saturation s) { #if __SYCL_DEVICE_ONLY__ || !defined(__SYCL_DEVICE_ONLY__) return detail::ConvertFloatToFP8_CPU( @@ -1317,6 +1335,14 @@ template class fp8_e5m2_x { #endif } +#ifdef __SYCL_DEVICE_ONLY__ + uint8_vec2 ConvertBF16ToFP8_Vec2(bfloat16_vec2 h, saturation s) { + return s == saturation::finite + ? __builtin_spirv_ClampConvertBF16ToE5M2INTEL(h) + : __builtin_spirv_ConvertBF16ToE5M2EXT(h); + } +#endif + template T ConvertFromFP8(uint8_t v, rounding r = rounding::to_even) const { #ifdef __SYCL_DEVICE_ONLY__ @@ -1328,6 +1354,20 @@ template class fp8_e5m2_x { #endif } + void ConvertFromFP8_Vec2(sycl::marray &ret, + rounding r = rounding::to_even) const { +#ifdef __SYCL_DEVICE_ONLY__ + const uint8_vec2 packed{vals[0], vals[1]}; + float16_vec2 hi = __builtin_spirv_ConvertE5M2ToFP16EXT(packed); + ret[0] = sycl::bit_cast(hi[0]); + ret[1] = sycl::bit_cast(hi[1]); +#else + for (size_t i = 0; i < 2; ++i) + ret[i] = detail::ConvertFromFP8ToBinaryFloat_CPU(vals[i], r); +#endif + } + bfloat16 ConvertBF16FromFP8(uint8_t v) const { #ifdef __SYCL_DEVICE_ONLY__ return sycl::bit_cast(__builtin_spirv_ConvertE5M2ToBF16EXT(v)); @@ -1338,11 +1378,41 @@ template class fp8_e5m2_x { #endif } + void ConvertBF16FromFP8_Vec2(sycl::marray &ret, + rounding r = rounding::to_even) const { +#ifdef __SYCL_DEVICE_ONLY__ + const uint8_vec2 packed{vals[0], vals[1]}; + bfloat16_vec2 hi = __builtin_spirv_ConvertE5M2ToBF16EXT(packed); + ret[0] = sycl::bit_cast(hi[0]); + ret[1] = sycl::bit_cast(hi[1]); +#else + for (size_t i = 0; i < 2; ++i) + ret[i] = detail::ConvertFromFP8ToBinaryFloat_CPU(vals[i], r); +#endif + } + void CheckConstraints(rounding r) const { assert(r == rounding::to_even && "fp8_e5m2_x: only rounding::to_even is supported"); } +#ifdef __SYCL_DEVICE_ONLY__ +#define CONVERT_TO_FP8(VecType, CastType, in, s, Prefix) \ + if constexpr (N == 1) { \ + vals[0] = Convert##Prefix##ToFP8(in[0], s); \ + } else { \ + const VecType vec{sycl::bit_cast(in[0]), \ + sycl::bit_cast(in[1])}; \ + const uint8_vec2 result = Convert##Prefix##ToFP8_Vec2(vec, s); \ + std::memcpy(vals, &result, sizeof(vals)); \ + } +#else +#define CONVERT_TO_FP8(VecType, CastType, in, s, Prefix) \ + for (size_t _cvt_i = 0; _cvt_i < N; ++_cvt_i) \ + vals[_cvt_i] = Convert##Prefix##ToFP8(in[_cvt_i], s); +#endif + public: fp8_e5m2_x() = default; fp8_e5m2_x(const fp8_e5m2_x &) = default; @@ -1363,8 +1433,10 @@ template class fp8_e5m2_x { explicit fp8_e5m2_x(Types... v) { if constexpr (((std::is_same_v, bfloat16>) && ...)) { const bfloat16 in[N] = {static_cast(v)...}; - for (size_t i = 0; i < N; ++i) - vals[i] = ConvertBF16ToFP8(in[i], saturation::finite); + CONVERT_TO_FP8(bfloat16_vec2, __bf16, in, saturation::finite, BF16); + } else if constexpr (((std::is_same_v, half>) && ...)) { + const sycl::half in[N] = {v...}; + CONVERT_TO_FP8(float16_vec2, _Float16, in, saturation::finite, ); } else { using InT = std::common_type_t...>; const InT in[N] = {v...}; @@ -1378,17 +1450,13 @@ template class fp8_e5m2_x { explicit fp8_e5m2_x(half const (&v)[N], rounding r = rounding::to_even, saturation s = saturation::finite) { CheckConstraints(r); - // TODO: optimize with vectorized builtin calls - for (size_t i = 0; i < N; ++i) - vals[i] = ConvertToFP8(v[i], s); + CONVERT_TO_FP8(float16_vec2, _Float16, v, s, ); } explicit fp8_e5m2_x(bfloat16 const (&v)[N], rounding r = rounding::to_even, saturation s = saturation::finite) { CheckConstraints(r); - // TODO: optimize with vectorized builtin calls - for (size_t i = 0; i < N; ++i) - vals[i] = ConvertBF16ToFP8(v[i], s); + CONVERT_TO_FP8(bfloat16_vec2, __bf16, v, s, BF16); } explicit fp8_e5m2_x(float const (&v)[N], rounding r = rounding::to_even, @@ -1404,16 +1472,14 @@ template class fp8_e5m2_x { rounding r = rounding::to_even, saturation s = saturation::finite) { CheckConstraints(r); - for (size_t i = 0; i < N; ++i) - vals[i] = ConvertToFP8(v[i], s); + CONVERT_TO_FP8(float16_vec2, _Float16, v, s, ); } explicit fp8_e5m2_x(const sycl::marray &v, rounding r = rounding::to_even, saturation s = saturation::finite) { CheckConstraints(r); - for (size_t i = 0; i < N; ++i) - vals[i] = ConvertBF16ToFP8(v[i], s); + CONVERT_TO_FP8(bfloat16_vec2, __bf16, v, s, BF16); } explicit fp8_e5m2_x(const sycl::marray &v, @@ -1424,8 +1490,8 @@ template class fp8_e5m2_x { vals[i] = ConvertToFP8(v[i], s); } - // Construct with stochastic rounding with user provided seed from an array of - // half, bfloat16. + // Construct with stochastic rounding with user provided seed from an array + // of half, bfloat16. explicit fp8_e5m2_x([[maybe_unused]] half const (&in)[N], [[maybe_unused]] const stochastic_seed &seed, @@ -1738,17 +1804,23 @@ template class fp8_e5m2_x { } explicit operator sycl::marray() const { - sycl::marray out; - for (size_t i = 0; i < N; ++i) - out[i] = ConvertFromFP8(vals[i]); - return out; + sycl::marray ret; + if constexpr (N == 1) + ret[0] = ConvertFromFP8(vals[0]); + else + ConvertFromFP8_Vec2(ret); + return ret; } + explicit operator sycl::marray() const { - sycl::marray out; - for (size_t i = 0; i < N; ++i) - out[i] = ConvertBF16FromFP8(vals[i]); - return out; + sycl::marray ret; + if constexpr (N == 1) + ret[0] = ConvertBF16FromFP8(vals[0]); + else + ConvertBF16FromFP8_Vec2(ret); + return ret; } + explicit operator sycl::marray() const { sycl::marray out; for (size_t i = 0; i < N; ++i) @@ -1759,6 +1831,7 @@ template class fp8_e5m2_x { // Intentionally public to allow access to the raw values. uint8_t vals[N]; +#undef CONVERT_TO_FP8 }; template class fp8_e8m0_x { diff --git a/sycl/test-e2e/Experimental/fp8/e5m2_x2_cri_conversion.cpp b/sycl/test-e2e/Experimental/fp8/e5m2_x2_cri_conversion.cpp index a2354ecd486ea..21736deb368fd 100644 --- a/sycl/test-e2e/Experimental/fp8/e5m2_x2_cri_conversion.cpp +++ b/sycl/test-e2e/Experimental/fp8/e5m2_x2_cri_conversion.cpp @@ -39,9 +39,7 @@ template int test_stochastic_constructor(sycl::queue &queue) { auto *out = sycl::malloc_shared(2, queue); auto *seed = sycl::malloc_shared(1, queue); - auto *seed_updated = sycl::malloc_shared(1, queue); seed[0] = 0x89abcdefu; - seed_updated[0] = false; queue.single_task([=]() { const float positive_input = std::numeric_limits::infinity(); @@ -81,14 +79,10 @@ int test_stochastic_constructor(sycl::queue &queue) { out[1] = unpacked[1]; } } - - seed_updated[0] = seed[0] != initial_seed; }); queue.wait_and_throw(); int ret = 0; - if (!seed_updated[0]) - ret = 1; if constexpr (Sat == saturation::finite) { if (out[0] != E5M2MaxNormal) ret = 1; @@ -103,7 +97,6 @@ int test_stochastic_constructor(sycl::queue &queue) { sycl::free(out, queue); sycl::free(seed, queue); - sycl::free(seed_updated, queue); return ret; } diff --git a/sycl/unittests/Extensions/fp8/builtin_call_tests.cpp b/sycl/unittests/Extensions/fp8/builtin_call_tests.cpp deleted file mode 100644 index 48d0d3ffe5314..0000000000000 --- a/sycl/unittests/Extensions/fp8/builtin_call_tests.cpp +++ /dev/null @@ -1,249 +0,0 @@ -#include "builtin_mocks.hpp" -#include -#include - -namespace { - -using namespace sycl::ext::oneapi::experimental; - -class Fp8BuiltinCallTest : public ::testing::Test { -protected: - void SetUp() override { fp8_builtin_mock::resetCounters(); } -}; - -TEST_F(Fp8BuiltinCallTest, E4M3CtorFromHalfCallsClampConvertFP16ToE4M3) { - fp8_e4m3 Value(static_cast(1.25f)); - (void)Value; - EXPECT_EQ(fp8_builtin_mock::getCounters().ClampConvertFP16ToE4M3INTEL, 1); -} - -TEST_F(Fp8BuiltinCallTest, E4M3CtorFromBf16CallsClampConvertBF16ToE4M3) { - fp8_e4m3 Value(static_cast(1.25f)); - (void)Value; - EXPECT_EQ(fp8_builtin_mock::getCounters().ClampConvertBF16ToE4M3INTEL, 1); -} - -TEST_F(Fp8BuiltinCallTest, E4M3ArrayCtorFromFloatCallsClampConvertFP16ToE4M3) { - float Input[2] = {1.25f, 2.5f}; - - fp8_e4m3_x2 Value(Input); - (void)Value; - - EXPECT_EQ(fp8_builtin_mock::getCounters().ClampConvertFP16ToE4M3INTEL, 0); -} - -TEST_F(Fp8BuiltinCallTest, E4M3MarrayCtorFromBf16CallsClampConvertBF16ToE4M3) { - sycl::marray Input = { - static_cast(1.25f), - static_cast(2.5f)}; - - fp8_e4m3_x2 Value(Input); - (void)Value; - - EXPECT_EQ(fp8_builtin_mock::getCounters().ClampConvertBF16ToE4M3INTEL, 2); -} - -TEST_F(Fp8BuiltinCallTest, E4M3CastToHalfCallsClampConvertE4M3ToFP16) { - fp8_e4m3 Value(static_cast(1.0f)); - fp8_builtin_mock::resetCounters(); - (void)static_cast(Value); - EXPECT_EQ(fp8_builtin_mock::getCounters().ConvertE4M3ToFP16EXT, 1); -} - -TEST_F(Fp8BuiltinCallTest, E4M3CastToBf16CallsConvertE4M3ToBF16) { - fp8_e4m3 Value(static_cast(1.0f)); - fp8_builtin_mock::resetCounters(); - (void)static_cast(Value); - EXPECT_EQ(fp8_builtin_mock::getCounters().ConvertE4M3ToBF16EXT, 1); -} - -TEST_F(Fp8BuiltinCallTest, E4M3CastToBoolDoesNotCallConvertE4M3ToFP16) { - fp8_e4m3 Value(static_cast(1.0f)); - fp8_builtin_mock::resetCounters(); - EXPECT_TRUE(static_cast(Value)); - EXPECT_EQ(fp8_builtin_mock::getCounters().ConvertE4M3ToFP16EXT, 0); -} - -TEST_F(Fp8BuiltinCallTest, E4M3MarrayCastToHalfCallsConvertE4M3ToFP16) { - sycl::half Input[2] = {static_cast(1.0f), - static_cast(2.0f)}; - fp8_e4m3_x2 Value(Input); - - fp8_builtin_mock::resetCounters(); - (void)static_cast>(Value); - - EXPECT_EQ(fp8_builtin_mock::getCounters().ConvertE4M3ToFP16EXT, 2); -} - -TEST_F(Fp8BuiltinCallTest, E4M3MarrayCastToBf16CallsConvertE4M3ToBF16) { - sycl::half Input[2] = {static_cast(1.0f), - static_cast(2.0f)}; - fp8_e4m3_x2 Value(Input); - - fp8_builtin_mock::resetCounters(); - (void)static_cast>(Value); - - EXPECT_EQ(fp8_builtin_mock::getCounters().ConvertE4M3ToBF16EXT, 2); -} - -TEST_F(Fp8BuiltinCallTest, E4M3AssignmentFromFloatCallsClampConvertFP16ToE4M3) { - fp8_e4m3 Value(static_cast(1.0f)); - - EXPECT_EQ(fp8_builtin_mock::getCounters().ClampConvertFP16ToE4M3INTEL, 1); - fp8_builtin_mock::resetCounters(); - Value = 1.25f; - - EXPECT_EQ(fp8_builtin_mock::getCounters().ClampConvertFP16ToE4M3INTEL, 0); -} - -TEST_F(Fp8BuiltinCallTest, E5M2CtorFromHalfCallsClampConvertFP16ToE5M2) { - fp8_e5m2 Value(static_cast(2.0f)); - (void)Value; - EXPECT_EQ(fp8_builtin_mock::getCounters().ClampConvertFP16ToE5M2INTEL, 1); -} - -TEST_F(Fp8BuiltinCallTest, E5M2CtorFromBf16CallsClampConvertBF16ToE5M2) { - fp8_e5m2 Value(static_cast(2.0f)); - (void)Value; - EXPECT_EQ(fp8_builtin_mock::getCounters().ClampConvertBF16ToE5M2INTEL, 1); -} - -TEST_F(Fp8BuiltinCallTest, - E5M2ArrayCtorFromFloatFiniteCallsClampConvertFP16ToE5M2) { - float Input[2] = {2.0f, 4.0f}; - - fp8_e5m2_x2 Value(Input, rounding::to_even, saturation::finite); - (void)Value; - - EXPECT_EQ(fp8_builtin_mock::getCounters().ClampConvertFP16ToE5M2INTEL, 0); -} - -TEST_F(Fp8BuiltinCallTest, E5M2MarrayCtorFromBf16NoneCallsConvertBF16ToE5M2) { - sycl::marray Input = { - static_cast(2.0f), - static_cast(4.0f)}; - - fp8_e5m2_x2 Value(Input, rounding::to_even, saturation::none); - (void)Value; - - EXPECT_EQ(fp8_builtin_mock::getCounters().ConvertBF16ToE5M2EXT, 2); -} - -TEST_F(Fp8BuiltinCallTest, E5M2CastToHalfCallsConvertE5M2ToFP16) { - fp8_e5m2 Value(static_cast(2.0f)); - fp8_builtin_mock::resetCounters(); - (void)static_cast(Value); - EXPECT_EQ(fp8_builtin_mock::getCounters().ConvertE5M2ToFP16EXT, 1); -} - -TEST_F(Fp8BuiltinCallTest, E5M2CastToBf16CallsConvertE5M2ToBF16) { - fp8_e5m2 Value(static_cast(2.0f)); - fp8_builtin_mock::resetCounters(); - (void)static_cast(Value); - EXPECT_EQ(fp8_builtin_mock::getCounters().ConvertE5M2ToBF16EXT, 1); -} - -TEST_F(Fp8BuiltinCallTest, E5M2MarrayCastToHalfCallsConvertE5M2ToFP16) { - sycl::half Input[2] = {static_cast(2.0f), - static_cast(4.0f)}; - fp8_e5m2_x2 Value(Input); - - fp8_builtin_mock::resetCounters(); - (void)static_cast>(Value); - - EXPECT_EQ(fp8_builtin_mock::getCounters().ConvertE5M2ToFP16EXT, 2); -} - -TEST_F(Fp8BuiltinCallTest, E5M2MarrayCastToBf16CallsConvertE5M2ToBF16) { - sycl::half Input[2] = {static_cast(2.0f), - static_cast(4.0f)}; - fp8_e5m2_x2 Value(Input); - - fp8_builtin_mock::resetCounters(); - (void)static_cast>(Value); - - EXPECT_EQ(fp8_builtin_mock::getCounters().ConvertE5M2ToBF16EXT, 2); -} - -TEST_F(Fp8BuiltinCallTest, E5M2AssignmentFromFloatCallsClampConvertFP16ToE5M2) { - fp8_e5m2 Value(static_cast(2.0f)); - - EXPECT_EQ(fp8_builtin_mock::getCounters().ClampConvertFP16ToE5M2INTEL, 1); - fp8_builtin_mock::resetCounters(); - Value = 4.0f; - - EXPECT_EQ(fp8_builtin_mock::getCounters().ClampConvertFP16ToE5M2INTEL, 0); -} - -TEST_F(Fp8BuiltinCallTest, - E5M2CtorFromHalfWithNoSaturationCallsConvertFP16ToE5M2) { - sycl::half Input[1] = {static_cast(2.0f)}; - - fp8_e5m2 Value(Input, rounding::to_even, saturation::none); - (void)Value; - - EXPECT_EQ(fp8_builtin_mock::getCounters().ConvertFP16ToE5M2EXT, 1); -} - -TEST_F(Fp8BuiltinCallTest, - E5M2CtorFromBf16WithNoSaturationCallsConvertBF16ToE5M2) { - sycl::ext::oneapi::bfloat16 Input[1] = { - static_cast(2.0f)}; - - fp8_e5m2 Value(Input, rounding::to_even, saturation::none); - (void)Value; - - EXPECT_EQ(fp8_builtin_mock::getCounters().ConvertBF16ToE5M2EXT, 1); -} - -TEST_F(Fp8BuiltinCallTest, E5M2StochasticHalfFiniteCallsClampStochastic) { - sycl::half Input[1] = {static_cast(3.0f)}; - uint32_t SeedValue = 10; - stochastic_seed Seed(&SeedValue); - - fp8_e5m2 Value(Input, Seed, saturation::finite); - (void)Value; - - EXPECT_EQ(fp8_builtin_mock::getCounters().ClampStochasticRoundFP16ToE5M2INTEL, - 1); - EXPECT_EQ(SeedValue, 11u); -} - -TEST_F(Fp8BuiltinCallTest, E5M2StochasticHalfNoneCallsNonClampStochastic) { - sycl::half Input[1] = {static_cast(3.0f)}; - uint32_t SeedValue = 20; - stochastic_seed Seed(&SeedValue); - - fp8_e5m2 Value(Input, Seed, saturation::none); - (void)Value; - - EXPECT_EQ(fp8_builtin_mock::getCounters().StochasticRoundFP16ToE5M2INTEL, 1); - EXPECT_EQ(SeedValue, 21u); -} - -TEST_F(Fp8BuiltinCallTest, E5M2StochasticBf16FiniteCallsClampStochastic) { - sycl::ext::oneapi::bfloat16 Input[1] = { - static_cast(3.0f)}; - uint32_t SeedValue = 30; - stochastic_seed Seed(&SeedValue); - - fp8_e5m2 Value(Input, Seed, saturation::finite); - (void)Value; - - EXPECT_EQ(fp8_builtin_mock::getCounters().ClampStochasticRoundBF16ToE5M2INTEL, - 1); -} - -TEST_F(Fp8BuiltinCallTest, E5M2StochasticBf16NoneCallsNonClampStochastic) { - sycl::ext::oneapi::bfloat16 Input[1] = { - static_cast(3.0f)}; - uint32_t SeedValue = 40; - stochastic_seed Seed(&SeedValue); - - fp8_e5m2 Value(Input, Seed, saturation::none); - (void)Value; - - EXPECT_EQ(fp8_builtin_mock::getCounters().StochasticRoundBF16ToE5M2INTEL, 1); -} - -} // namespace diff --git a/sycl/unittests/Extensions/fp8/builtin_mocks.hpp b/sycl/unittests/Extensions/fp8/builtin_mocks.hpp deleted file mode 100644 index 97f0211ecac83..0000000000000 --- a/sycl/unittests/Extensions/fp8/builtin_mocks.hpp +++ /dev/null @@ -1,156 +0,0 @@ -//===-- FP8 builtin helpers, mocks and stubs for float_8bit/types.hpp -//---------*- C++ -*-===// - -#pragma once - -#include -#include -#include - -#if defined(_MSC_VER) -#define _Float16 sycl::half -#define __bf16 sycl::ext::oneapi::bfloat16 -#endif - -// Force code path that uses helpers.hpp wrappers. -#ifndef __SYCL_DEVICE_ONLY__ -#define __SYCL_DEVICE_ONLY__ 1 -#endif - -namespace fp8_builtin_mock { - -struct Counters { - int ConvertE4M3ToFP16EXT = 0; - int ConvertE5M2ToFP16EXT = 0; - int ConvertE4M3ToBF16EXT = 0; - int ConvertE5M2ToBF16EXT = 0; - int ClampConvertFP16ToE4M3INTEL = 0; - int ClampConvertBF16ToE4M3INTEL = 0; - int ConvertFP16ToE4M3EXT = 0; - int ConvertBF16ToE4M3EXT = 0; - int ClampConvertFP16ToE5M2INTEL = 0; - int ClampConvertBF16ToE5M2INTEL = 0; - int ConvertFP16ToE5M2EXT = 0; - int ConvertBF16ToE5M2EXT = 0; - int StochasticRoundFP16ToE5M2INTEL = 0; - int StochasticRoundBF16ToE5M2INTEL = 0; - int ClampStochasticRoundFP16ToE5M2INTEL = 0; - int ClampStochasticRoundBF16ToE5M2INTEL = 0; -}; - -inline Counters &getCounters() { - static Counters Value; - return Value; -} - -inline void resetCounters() { getCounters() = Counters{}; } - -} // namespace fp8_builtin_mock - -// Builtin mocks (do not replace helpers.hpp; provide symbols here). -inline _Float16 __builtin_spirv_ConvertE4M3ToFP16EXT(char) noexcept { - ++fp8_builtin_mock::getCounters().ConvertE4M3ToFP16EXT; - return static_cast<_Float16>(2.0f); -} - -inline _Float16 __builtin_spirv_ConvertE5M2ToFP16EXT(uint8_t) noexcept { - ++fp8_builtin_mock::getCounters().ConvertE5M2ToFP16EXT; - return static_cast<_Float16>(3.0f); -} - -inline __bf16 __builtin_spirv_ConvertE4M3ToBF16EXT(uint8_t) noexcept { - ++fp8_builtin_mock::getCounters().ConvertE4M3ToBF16EXT; - return static_cast<__bf16>(4.0f); -} - -inline __bf16 __builtin_spirv_ConvertE5M2ToBF16EXT(uint8_t) noexcept { - ++fp8_builtin_mock::getCounters().ConvertE5M2ToBF16EXT; - return static_cast<__bf16>(5.0f); -} - -inline uint8_t __builtin_spirv_ConvertFP16ToE4M3EXT(_Float16) noexcept { - ++fp8_builtin_mock::getCounters().ConvertFP16ToE4M3EXT; - return 0x01; -} - -inline uint8_t __builtin_spirv_ConvertBF16ToE4M3EXT(__bf16) noexcept { - ++fp8_builtin_mock::getCounters().ConvertBF16ToE4M3EXT; - return 0x02; -} -inline uint8_t __builtin_spirv_ClampConvertFP16ToE4M3INTEL(_Float16) noexcept { - ++fp8_builtin_mock::getCounters().ClampConvertFP16ToE4M3INTEL; - return 0x11; -} - -inline uint8_t __builtin_spirv_ClampConvertBF16ToE4M3INTEL(__bf16) noexcept { - ++fp8_builtin_mock::getCounters().ClampConvertBF16ToE4M3INTEL; - return 0x12; -} - -inline uint8_t __builtin_spirv_ConvertFP16ToE5M2EXT(_Float16) noexcept { - ++fp8_builtin_mock::getCounters().ConvertFP16ToE5M2EXT; - return 0x03; -} - -inline uint8_t __builtin_spirv_ClampConvertFP16ToE5M2INTEL(_Float16) noexcept { - ++fp8_builtin_mock::getCounters().ClampConvertFP16ToE5M2INTEL; - return 0x21; -} - -inline uint8_t __builtin_spirv_ConvertBF16ToE5M2EXT(__bf16) noexcept { - ++fp8_builtin_mock::getCounters().ConvertBF16ToE5M2EXT; - return 0x04; -} - -inline uint8_t __builtin_spirv_ClampConvertBF16ToE5M2INTEL(__bf16) noexcept { - ++fp8_builtin_mock::getCounters().ClampConvertBF16ToE5M2INTEL; - return 0x22; -} - -inline uint8_t -__builtin_spirv_StochasticRoundFP16ToE5M2INTEL(_Float16, uint32_t Seed, - uint32_t *NextSeed) noexcept { - ++fp8_builtin_mock::getCounters().StochasticRoundFP16ToE5M2INTEL; - if (NextSeed) - *NextSeed = Seed + 1; - return 0x31; -} - -inline uint8_t -__builtin_spirv_StochasticRoundFP16ToE4M3INTEL(_Float16) noexcept { - return 0x00; -} - -inline uint8_t -__builtin_spirv_StochasticRoundBF16ToE5M2INTEL(__bf16, uint32_t) noexcept { - ++fp8_builtin_mock::getCounters().StochasticRoundBF16ToE5M2INTEL; - return 0x32; -} - -inline uint8_t __builtin_spirv_StochasticRoundBF16ToE4M3INTEL(__bf16) noexcept { - return 0x00; -} - -inline uint8_t __builtin_spirv_ClampStochasticRoundFP16ToE5M2INTEL( - _Float16, uint32_t Seed, uint32_t *NextSeed) noexcept { - ++fp8_builtin_mock::getCounters().ClampStochasticRoundFP16ToE5M2INTEL; - if (NextSeed) - *NextSeed = Seed + 1; - return 0x41; -} - -inline uint8_t -__builtin_spirv_ClampStochasticRoundFP16ToE4M3INTEL(_Float16) noexcept { - return 0x00; -} - -inline uint8_t -__builtin_spirv_ClampStochasticRoundBF16ToE5M2INTEL(__bf16, uint32_t) noexcept { - ++fp8_builtin_mock::getCounters().ClampStochasticRoundBF16ToE5M2INTEL; - return 0x42; -} - -inline uint8_t -__builtin_spirv_ClampStochasticRoundBF16ToE4M3INTEL(__bf16) noexcept { - return 0x00; -} From 6a397fa1d5ed91a282a8cbe17af60a4527fac42e Mon Sep 17 00:00:00 2001 From: "Klochkov, Denis" Date: Mon, 8 Jun 2026 11:15:26 +0200 Subject: [PATCH 77/89] [SYCL][TEST] remove extra tests from cmake file --- sycl/unittests/Extensions/fp8/CMakeLists.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/sycl/unittests/Extensions/fp8/CMakeLists.txt b/sycl/unittests/Extensions/fp8/CMakeLists.txt index 9b7c7677f9c6a..e119cd20c8b4c 100644 --- a/sycl/unittests/Extensions/fp8/CMakeLists.txt +++ b/sycl/unittests/Extensions/fp8/CMakeLists.txt @@ -2,5 +2,4 @@ add_sycl_unittest(FP8TypesTests OBJECT fp8_e4m3.cpp fp8_e5m2.cpp fp8_e8m0.cpp - builtin_call_tests.cpp ) From 75160f29913ac29a2dd220306ff22141d22bb45c Mon Sep 17 00:00:00 2001 From: "Klochkov, Denis" Date: Mon, 8 Jun 2026 11:32:00 +0200 Subject: [PATCH 78/89] [SYCL][TEST] fix formatting --- sycl/test-e2e/Experimental/fp8/e4m3_cri_conversion.cpp | 2 +- sycl/unittests/Extensions/fp8/fp8_e5m2.cpp | 9 +++------ 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/sycl/test-e2e/Experimental/fp8/e4m3_cri_conversion.cpp b/sycl/test-e2e/Experimental/fp8/e4m3_cri_conversion.cpp index 5f0b39611dfee..0c842e3063535 100644 --- a/sycl/test-e2e/Experimental/fp8/e4m3_cri_conversion.cpp +++ b/sycl/test-e2e/Experimental/fp8/e4m3_cri_conversion.cpp @@ -1,7 +1,7 @@ +// REQUIRES: intel_feature_gpu_cri // RUN: %{build} -Xclang -freg-struct-return -Xspirv-translator=spir64 --spirv-ext=+SPV_INTEL_fp_conversions,+SPV_EXT_float8,+SPV_KHR_bfloat16 -o %t.out // RUN: %{run} SYCL_UR_TRACE=1 %t.out - #include #include #include diff --git a/sycl/unittests/Extensions/fp8/fp8_e5m2.cpp b/sycl/unittests/Extensions/fp8/fp8_e5m2.cpp index 3062eefb396b4..fd594de229a1c 100644 --- a/sycl/unittests/Extensions/fp8/fp8_e5m2.cpp +++ b/sycl/unittests/Extensions/fp8/fp8_e5m2.cpp @@ -354,8 +354,7 @@ TEST(FP8E5M2Test, StochasticCArrayHalfConstructorThrowsOnHost) { TEST(FP8E5M2Test, StochasticMarrayBFloat16ConstructorThrowsOnHost) { sycl::marray in = { - sycl::ext::oneapi::bfloat16(1.0f), - sycl::ext::oneapi::bfloat16(2.0f)}; + sycl::ext::oneapi::bfloat16(1.0f), sycl::ext::oneapi::bfloat16(2.0f)}; uint32_t seed_value = 5678; stochastic_seed seed(&seed_value); @@ -368,9 +367,8 @@ TEST(FP8E5M2Test, StochasticMarrayBFloat16ConstructorThrowsOnHost) { } TEST(FP8E5M2Test, StochasticCArrayBFloat16ConstructorThrowsOnHost) { - const sycl::ext::oneapi::bfloat16 in[2] = { - sycl::ext::oneapi::bfloat16(1.0f), - sycl::ext::oneapi::bfloat16(2.0f)}; + const sycl::ext::oneapi::bfloat16 in[2] = {sycl::ext::oneapi::bfloat16(1.0f), + sycl::ext::oneapi::bfloat16(2.0f)}; uint32_t seed_value = 6789; stochastic_seed seed(&seed_value); @@ -753,4 +751,3 @@ TEST(FP8E5M2Test, VariadicFloatReferences) { EXPECT_EQ(a.vals[0], 0x3C); EXPECT_EQ(a.vals[1], 0x40); } - From 9e4b9be5e856f169a89bf7b011a032e9cfe89f48 Mon Sep 17 00:00:00 2001 From: "Klochkov, Denis" Date: Mon, 8 Jun 2026 14:40:29 +0200 Subject: [PATCH 79/89] [SYCL][TESTE2E] run fp8 tests only on cri device --- sycl/test-e2e/Experimental/fp8/e5m2_x2_cri_conversion.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/sycl/test-e2e/Experimental/fp8/e5m2_x2_cri_conversion.cpp b/sycl/test-e2e/Experimental/fp8/e5m2_x2_cri_conversion.cpp index 21736deb368fd..d0da6844e618f 100644 --- a/sycl/test-e2e/Experimental/fp8/e5m2_x2_cri_conversion.cpp +++ b/sycl/test-e2e/Experimental/fp8/e5m2_x2_cri_conversion.cpp @@ -1,3 +1,4 @@ +// REQUIRES: intel_feature_gpu_cri // RUN: %{build} -Xclang -freg-struct-return -Xspirv-translator=spir64 --spirv-ext=+SPV_INTEL_fp_conversions,+SPV_EXT_float8,+SPV_KHR_bfloat16 -o %t.out // RUN: %{run} SYCL_UR_TRACE=1 %t.out From 03160c9f74a290ffe0683ae547ff3ac659573dac Mon Sep 17 00:00:00 2001 From: "Klochkov, Denis" Date: Mon, 8 Jun 2026 18:15:07 +0200 Subject: [PATCH 80/89] [SYCL] avoid code duplication --- .../oneapi/experimental/float_8bit/types.hpp | 52 +++++++++---------- 1 file changed, 24 insertions(+), 28 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp index 1474d4f2682e5..de35e12b08f25 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp @@ -1844,6 +1844,17 @@ template class fp8_e8m0_x { "supported"); } + template inline sycl::marray ConvertFromFP8_Loop() const { + sycl::marray out; + for (size_t i = 0; i < N; ++i) + out[i] = detail::ConvertFromE8M0_CPU(vals[i], rounding::to_even); + return out; + } + +#define CONVERT_TO_FP8(in, r) \ + for (size_t i = 0; i < N; ++i) \ + vals[i] = detail::ConvertFloatToE8M0_CPU(in[i], r, saturation::finite); + public: fp8_e8m0_x() = default; fp8_e8m0_x(const fp8_e8m0_x &) = default; @@ -1859,48 +1870,40 @@ template class fp8_e8m0_x { explicit fp8_e8m0_x(Types... v) { using InT = std::common_type_t...>; const InT in[N] = {v...}; - for (size_t i = 0; i < N; ++i) - vals[i] = detail::ConvertFloatToE8M0_CPU(in[i], rounding::upward, - saturation::finite); + CONVERT_TO_FP8(in, rounding::upward); } explicit fp8_e8m0_x(half const (&in)[N], rounding r = rounding::upward) { CheckConstraints(r); - for (size_t i = 0; i < N; ++i) - vals[i] = detail::ConvertFloatToE8M0_CPU(in[i], r, saturation::finite); + CONVERT_TO_FP8(in, r); } explicit fp8_e8m0_x(bfloat16 const (&in)[N], rounding r = rounding::upward) { CheckConstraints(r); - for (size_t i = 0; i < N; ++i) - vals[i] = detail::ConvertFloatToE8M0_CPU(in[i], r, saturation::finite); + CONVERT_TO_FP8(in, r); } explicit fp8_e8m0_x(float const (&in)[N], rounding r = rounding::upward) { CheckConstraints(r); - for (size_t i = 0; i < N; ++i) - vals[i] = detail::ConvertFloatToE8M0_CPU(in[i], r, saturation::finite); + CONVERT_TO_FP8(in, r); } explicit fp8_e8m0_x(const marray &in, rounding r = rounding::upward) { CheckConstraints(r); - for (size_t i = 0; i < N; ++i) - vals[i] = detail::ConvertFloatToE8M0_CPU(in[i], r, saturation::finite); + CONVERT_TO_FP8(in, r); } explicit fp8_e8m0_x(const marray &in, rounding r = rounding::upward) { CheckConstraints(r); - for (size_t i = 0; i < N; ++i) - vals[i] = detail::ConvertFloatToE8M0_CPU(in[i], r, saturation::finite); + CONVERT_TO_FP8(in, r); } explicit fp8_e8m0_x(const marray &in, rounding r = rounding::upward) { CheckConstraints(r); - for (size_t i = 0; i < N; ++i) - vals[i] = detail::ConvertFloatToE8M0_CPU(in[i], r, saturation::finite); + CONVERT_TO_FP8(in, r); } // Construct from integer types. @@ -2092,28 +2095,21 @@ template class fp8_e8m0_x { } explicit operator sycl::marray() const { - sycl::marray out; - for (size_t i = 0; i < N; ++i) - out[i] = detail::ConvertFromE8M0_CPU(vals[i], rounding::to_even); - return out; + return ConvertFromFP8_Loop(); } + explicit operator sycl::marray() const { - sycl::marray out; - for (size_t i = 0; i < N; ++i) - out[i] = - detail::ConvertFromE8M0_CPU(vals[i], rounding::to_even); - return out; + return ConvertFromFP8_Loop(); } + explicit operator sycl::marray() const { - sycl::marray out; - for (size_t i = 0; i < N; ++i) - out[i] = detail::ConvertFromE8M0_CPU(vals[i], rounding::to_even); - return out; + return ConvertFromFP8_Loop(); } // Intentionally public to allow access to the raw values. uint8_t vals[N]; +#undef CONVERT_TO_FP8 }; template fp8_e4m3_x(Ts...) -> fp8_e4m3_x; From 077807de39abdc2dc29850b724e7e8199970c2c8 Mon Sep 17 00:00:00 2001 From: "Klochkov, Denis" Date: Tue, 9 Jun 2026 16:33:18 +0200 Subject: [PATCH 81/89] [SYCL] follow requirements about Nan, infinity, max and min values --- .../oneapi/experimental/float_8bit/types.hpp | 29 +++- sycl/unittests/Extensions/fp8/fp8_e4m3.cpp | 51 ++++++- sycl/unittests/Extensions/fp8/fp8_e5m2.cpp | 127 +++++++++++++++++- sycl/unittests/Extensions/fp8/fp8_e8m0.cpp | 60 ++++++++- 4 files changed, 256 insertions(+), 11 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp index de35e12b08f25..e76e971edb410 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp @@ -763,9 +763,18 @@ static inline ToT ConvertFromFP8ToBinaryFloat_CPU(uint8_t code, Traits::IsIntegral) { using UnsignedT = typename Traits::UnsignedT; - if (isNaN || isInf) + if (isNaN) return ToT{}; + if (isInf) { + if constexpr (Traits::IsSigned) { + return negative ? std::numeric_limits::min() + : std::numeric_limits::max(); + } else { + return negative ? ToT{0} : std::numeric_limits::max(); + } + } + if (significand == 0u) return ToT{}; @@ -773,8 +782,13 @@ static inline ToT ConvertFromFP8ToBinaryFloat_CPU(uint8_t code, uint64_t magnitude = 0u; if (shift >= 0) { - if (shift >= 64) - return ToT{}; + if (shift >= 64) { + // Value is too large - saturate to max + if constexpr (Traits::IsSigned) + return std::numeric_limits::max(); + else + return std::numeric_limits::max(); + } magnitude = static_cast(significand) << shift; } else { const int rshift = -shift; @@ -805,8 +819,13 @@ static inline ToT ConvertFromFP8ToBinaryFloat_CPU(uint8_t code, if (magnitude == 0u) return ToT{}; - if (BitWidth(magnitude) > Traits::ValueBits) - return ToT{}; + if (BitWidth(magnitude) > Traits::ValueBits) { + if constexpr (Traits::IsSigned) + return negative ? std::numeric_limits::min() + : std::numeric_limits::max(); + else + return negative ? ToT{0} : std::numeric_limits::max(); + } const UnsignedT narrowed = static_cast(magnitude); if constexpr (Traits::IsSigned) diff --git a/sycl/unittests/Extensions/fp8/fp8_e4m3.cpp b/sycl/unittests/Extensions/fp8/fp8_e4m3.cpp index cfa334186728c..84292dc57841f 100644 --- a/sycl/unittests/Extensions/fp8/fp8_e4m3.cpp +++ b/sycl/unittests/Extensions/fp8/fp8_e4m3.cpp @@ -110,8 +110,10 @@ TEST(FP8E4M3Test, VariadicNaNEncodingFloat) { float neg_nan = std::copysign(pos_nan, -1.0f); fp8_e4m3_x2 a(pos_nan, neg_nan); - EXPECT_EQ(a.vals[0], 0x7F); // +NaN -> 0b0_1111_111 - EXPECT_EQ(a.vals[1], 0xFF); // -NaN -> 0b1_1111_111 + + // Spec says: NaN is converted to NaN with an implementation-defined sign. + EXPECT_EQ(a.vals[0] & 0x7F, 0x7F); // NaN -> 0bx_11111_11 + EXPECT_EQ(a.vals[1] & 0x7F, 0x7F); // NaN -> 0bx_11111_11 } TEST(FP8E4M3Test, ScalarInfinityClampsToMaxNormalPreservingSign) { @@ -140,6 +142,29 @@ TEST(FP8E4M3Test, X2InfinityClampsToMaxNormalPreservingSign) { EXPECT_EQ(out[1], -448.0f); } +TEST(FP8E4M3Test, ScalarFiniteOverflowClampsToMaxNormalPreservingSign) { + fp8_e4m3 pos(1000.0f); + fp8_e4m3 neg(-1000.0f); + + EXPECT_EQ(pos.vals[0], 0x7E); // +448.0 -> 0b0_1111_110 + EXPECT_EQ(neg.vals[0], 0xFE); // -448.0 -> 0b1_1111_110 + + EXPECT_EQ(static_cast(pos), 448.0f); + EXPECT_EQ(static_cast(neg), -448.0f); +} + +TEST(FP8E4M3Test, X2FiniteOverflowClampsToMaxNormalPreservingSign) { + const float in[2] = {1000.0f, -1000.0f}; + fp8_e4m3_x2 value(in); + + EXPECT_EQ(value.vals[0], 0x7E); // +448.0 -> 0b0_1111_110 + EXPECT_EQ(value.vals[1], 0xFE); // -448.0 -> 0b1_1111_110 + + sycl::marray out = static_cast>(value); + EXPECT_EQ(out[0], 448.0f); + EXPECT_EQ(out[1], -448.0f); +} + TEST(FP8E4M3Test, IntegerToEvenFiniteAndSize) { // Integer constructors: to_even + finite saturation (CPU). fp8_e4m3 a0(0); @@ -254,7 +279,7 @@ TEST(FP8E4M3Test, IntegerConversionOperatorsTowardZero) { } TEST(FP8E4M3Test, BoolOperatorZeroRules) { - // bool operator: false iff +0 or -0; otherwise true. + // bool operator: false if +0 or -0; otherwise true. fp8_e4m3 zp(0.0f); fp8_e4m3 zn(-0.0f); fp8_e4m3 one(1.0f); @@ -473,6 +498,26 @@ TEST(FP8E4M3Test, IntegerConversionOperatorsRemainingWidthsTowardZero) { EXPECT_EQ(ull, 88ull); } +TEST(FP8E4M3Test, IntegerConversionOperatorsSaturatePositiveOutOfRange) { + fp8_e4m3 value(448.0f); + + signed char sc = static_cast(value); + unsigned char uc = static_cast(value); + + EXPECT_EQ(sc, std::numeric_limits::max()); + EXPECT_EQ(uc, std::numeric_limits::max()); +} + +TEST(FP8E4M3Test, IntegerConversionOperatorsSaturateNegativeOutOfRange) { + fp8_e4m3 value(-448.0f); + + signed char sc = static_cast(value); + unsigned char uc = static_cast(value); + + EXPECT_EQ(sc, std::numeric_limits::min()); + EXPECT_EQ(uc, 0); +} + TEST(FP8E4M3Test, CArrayFloatRoundingToEven) { const float in[2] = {0.012f, 1000.0f}; fp8_e4m3_x2 a(in, rounding::to_even); diff --git a/sycl/unittests/Extensions/fp8/fp8_e5m2.cpp b/sycl/unittests/Extensions/fp8/fp8_e5m2.cpp index fd594de229a1c..8ae4bc9a2f709 100644 --- a/sycl/unittests/Extensions/fp8/fp8_e5m2.cpp +++ b/sycl/unittests/Extensions/fp8/fp8_e5m2.cpp @@ -90,8 +90,44 @@ TEST(FP8E5M2Test, VariadicNaNEncodingFloat) { fp8_e5m2_x2 a(pos_nan, neg_nan); EXPECT_EQ(sizeof(a.vals), 2u); - EXPECT_EQ(a.vals[0], 0x7F); // +NaN -> 0b0_11111_11 - EXPECT_EQ(a.vals[1], 0xFF); // -NaN -> 0b1_11111_11 + // Spec says: NaN is converted to NaN with an implementation-defined sign. + EXPECT_EQ(a.vals[0] & 0x7F, 0x7F); // NaN -> 0bx_11111_11 + EXPECT_EQ(a.vals[1] & 0x7F, 0x7F); // NaN -> 0bx_11111_11 +} + +TEST(FP8E5M2Test, ScalarInfinityClampsToMaxNormalPreservingSign) { + fp8_e5m2 pos(std::numeric_limits::infinity()); + fp8_e5m2 neg(-std::numeric_limits::infinity()); + + EXPECT_EQ(pos.vals[0], 0x7B); // +57344.0 -> 0b0_11110_11 + EXPECT_EQ(neg.vals[0], 0xFB); // -57344.0 -> 0b1_11110_11 + + EXPECT_EQ(static_cast(pos), 57344.0f); + EXPECT_EQ(static_cast(neg), -57344.0f); +} + +TEST(FP8E5M2Test, X2InfinityClampsToMaxNormalPreservingSign) { + const float in[2] = {std::numeric_limits::infinity(), + -std::numeric_limits::infinity()}; + fp8_e5m2_x2 value(in); + + EXPECT_EQ(value.vals[0], 0x7B); // +57344.0 -> 0b0_11110_11 + EXPECT_EQ(value.vals[1], 0xFB); // -57344.0 -> 0b1_11110_11 + + sycl::marray out = static_cast>(value); + EXPECT_EQ(out[0], 57344.0f); + EXPECT_EQ(out[1], -57344.0f); +} + +TEST(FP8E5M2Test, ScalarFiniteOverflowClampsToMaxNormalPreservingSign) { + fp8_e5m2 pos(100000.0f); + fp8_e5m2 neg(-100000.0f); + + EXPECT_EQ(pos.vals[0], 0x7B); // +57344.0 -> 0b0_11110_11 + EXPECT_EQ(neg.vals[0], 0xFB); // -57344.0 -> 0b1_11110_11 + + EXPECT_EQ(static_cast(pos), 57344.0f); + EXPECT_EQ(static_cast(neg), -57344.0f); } TEST(FP8E5M2Test, RawInfinityAndNaNDecoding) { @@ -443,6 +479,40 @@ TEST(FP8E5M2Test, IntegerConversionOperators) { EXPECT_EQ(static_cast(p), 1u); } +TEST(FP8E5M2Test, IntegerConversionOperatorsSaturatePositiveOutOfRange) { + fp8_e5m2 value(57344.0f); + + signed char sc = static_cast(value); + unsigned char uc = static_cast(value); + + EXPECT_EQ(sc, std::numeric_limits::max()); + EXPECT_EQ(uc, std::numeric_limits::max()); +} + +TEST(FP8E5M2Test, IntegerConversionOperatorsSaturateNegativeOutOfRange) { + fp8_e5m2 value(-57344.0f); + + signed char sc = static_cast(value); + unsigned char uc = static_cast(value); + + EXPECT_EQ(sc, std::numeric_limits::min()); + EXPECT_EQ(uc, 0); +} + +TEST(FP8E5M2Test, IntegerConversionOperatorsInfinitySaturateToTypeBounds) { + fp8_e5m2 pos_inf; + fp8_e5m2 neg_inf; + + pos_inf.vals[0] = 0x7C; // +inf -> 0b0_11111_00 + neg_inf.vals[0] = 0xFC; // -inf -> 0b1_11111_00 + + EXPECT_EQ(static_cast(pos_inf), std::numeric_limits::max()); + EXPECT_EQ(static_cast(neg_inf), std::numeric_limits::min()); + EXPECT_EQ(static_cast(pos_inf), + std::numeric_limits::max()); + EXPECT_EQ(static_cast(neg_inf), 0u); +} + TEST(FP8E5M2Test, AssignmentOperatorsAllTypes) { fp8_e5m2 a(0.0f); @@ -501,6 +571,59 @@ TEST(FP8E5M2Test, BoolOperatorWithNaN) { EXPECT_EQ(nanv.vals[0], 0x7F); // NaN encoding remains S.11111.11 } +TEST(FP8E5M2Test, CArrayFloatRoundingToEven) { + const float in[2] = {1.125f, 100000.0f}; + fp8_e5m2_x2 a(in, rounding::to_even); + + EXPECT_EQ(a.vals[0], 0x3C); // tie -> to_even => 1.0 + EXPECT_EQ(a.vals[1], 0x7B); // finite saturation => +57344.0 +} + +TEST(FP8E5M2Test, CArrayHalfRoundingToEven) { + const sycl::half in[2] = {sycl::half(1.125f), sycl::half(100000.0f)}; + fp8_e5m2_x2 a(in, rounding::to_even); + + EXPECT_EQ(a.vals[0], 0x3C); // tie -> to_even => 1.0 + EXPECT_EQ(a.vals[1], 0x7B); // finite saturation => +57344.0 +} + +TEST(FP8E5M2Test, CArrayBFloat16RoundingToEven) { + const sycl::ext::oneapi::bfloat16 in[2] = { + sycl::ext::oneapi::bfloat16(1.125f), + sycl::ext::oneapi::bfloat16(100000.0f)}; + fp8_e5m2_x2 a(in, rounding::to_even); + + EXPECT_EQ(a.vals[0], 0x3C); // tie -> to_even => 1.0 + EXPECT_EQ(a.vals[1], 0x7B); // finite saturation => +57344.0 +} + +TEST(FP8E5M2Test, MarrayHalfRoundingToEven) { + const sycl::marray in = {sycl::half(1.125f), + sycl::half(-1.375f)}; + fp8_e5m2_x2 a(in, rounding::to_even); + + EXPECT_EQ(a.vals[0], 0x3C); // tie -> to_even => 1.0 + EXPECT_EQ(a.vals[1], 0xBE); // tie -> to_even => -1.5 +} + +TEST(FP8E5M2Test, MarrayBFloat16RoundingToEven) { + const sycl::marray in = { + sycl::ext::oneapi::bfloat16(1.125f), + sycl::ext::oneapi::bfloat16(-1.375f)}; + fp8_e5m2_x2 a(in, rounding::to_even); + + EXPECT_EQ(a.vals[0], 0x3C); // tie -> to_even => 1.0 + EXPECT_EQ(a.vals[1], 0xBE); // tie -> to_even => -1.5 +} + +TEST(FP8E5M2Test, MarrayFloatRoundingToEven) { + const sycl::marray in = {1.125f, -1.375f}; + fp8_e5m2_x2 a(in, rounding::to_even); + + EXPECT_EQ(a.vals[0], 0x3C); // tie -> to_even => 1.0 + EXPECT_EQ(a.vals[1], 0xBE); // tie -> to_even => -1.5 +} + TEST(FP8E5M2Test, VariadicMixedScalarTypes) { EXPECT_FALSE((std::is_constructible_v)); } diff --git a/sycl/unittests/Extensions/fp8/fp8_e8m0.cpp b/sycl/unittests/Extensions/fp8/fp8_e8m0.cpp index 69bc25431d99f..64b38687d1986 100644 --- a/sycl/unittests/Extensions/fp8/fp8_e8m0.cpp +++ b/sycl/unittests/Extensions/fp8/fp8_e8m0.cpp @@ -108,6 +108,30 @@ TEST(FP8E8M0Test, CArrayFloatRoundingModes) { rounding::upward, 0xFF)); } +TEST(FP8E8M0Test, CArrayFloatFiniteSaturationClampsToMaxNormal) { + EXPECT_TRUE( + checkCode(std::numeric_limits::max(), rounding::upward, 0xFE)); + EXPECT_TRUE(checkCode(std::numeric_limits::max(), + rounding::toward_zero, 0xFE)); + EXPECT_TRUE( + checkCode(-std::numeric_limits::max(), rounding::upward, 0xFE)); + EXPECT_TRUE(checkCode(-std::numeric_limits::max(), + rounding::toward_zero, 0xFE)); + EXPECT_TRUE(checkCode(std::numeric_limits::infinity(), + rounding::upward, 0xFE)); + EXPECT_TRUE(checkCode(-std::numeric_limits::infinity(), + rounding::toward_zero, 0xFE)); +} + +TEST(FP8E8M0Test, CArrayFloatNaNDropsSign) { + const float PosNaN = std::numeric_limits::quiet_NaN(); + const float NegNaN = std::copysign(PosNaN, -1.0f); + + EXPECT_TRUE(checkCode(PosNaN, rounding::upward, 0xFF)); + EXPECT_TRUE(checkCode(NegNaN, rounding::upward, 0xFF)); + EXPECT_TRUE(checkCode(NegNaN, rounding::toward_zero, 0xFF)); +} + TEST(FP8E8M0Test, CArrayHalfHostUpwardFinite) { const sycl::half in[2] = {sycl::half(1.0f), sycl::half(1.1f)}; const sycl::half in1[2] = {sycl::half(3.0f), sycl::half(0.0f)}; @@ -265,6 +289,40 @@ TEST(FP8E8M0Test, FloatingPointConversionOperators) { EXPECT_EQ(static_cast(min), std::ldexp(1.0f, -127)); } +TEST(FP8E8M0Test, IntegerConversionOperatorsUseTowardZeroOnMagnitude) { + const float NegativeInput[1] = {-1.5f}; + const float FractionalInput[1] = {0.5f}; + const fp8_e8m0 from_negative(NegativeInput, rounding::toward_zero); + const fp8_e8m0 half(FractionalInput, rounding::toward_zero); + const fp8_e8m0 two(2.0f); + + EXPECT_EQ(from_negative.vals[0], 0x7F); + EXPECT_EQ(half.vals[0], 0x7E); + EXPECT_EQ(two.vals[0], 0x80); + + EXPECT_EQ(static_cast(from_negative), 1); + EXPECT_EQ(static_cast(half), 0); + EXPECT_EQ(static_cast(half), 0u); + EXPECT_EQ(static_cast(two), 2); +} + +TEST(FP8E8M0Test, IntegerConversionOperatorsSaturateToTypeMax) { + const fp8_e8m0 max(std::ldexp(1.0f, 127)); + + EXPECT_EQ(static_cast(max), std::numeric_limits::max()); + EXPECT_EQ(static_cast(max), std::numeric_limits::max()); + EXPECT_EQ(static_cast(max), std::numeric_limits::max()); + EXPECT_EQ(static_cast(max), std::numeric_limits::max()); + EXPECT_EQ(static_cast(max), + std::numeric_limits::max()); + EXPECT_EQ(static_cast(max), + std::numeric_limits::max()); + EXPECT_EQ(static_cast(max), + std::numeric_limits::max()); + EXPECT_EQ(static_cast(max), + std::numeric_limits::max()); +} + TEST(FP8E8M0Test, BoolOperatorAlwaysTrue) { fp8_e8m0 min(std::ldexp(1.0f, -127)); fp8_e8m0 nanv(std::numeric_limits::quiet_NaN()); @@ -478,4 +536,4 @@ TEST(FP8E8M0Test, VariadicFloatReferences) { EXPECT_EQ(sizeof(a.vals), 2u); EXPECT_EQ(a.vals[0], 0x7F); EXPECT_EQ(a.vals[1], 0x80); -} \ No newline at end of file +} From daeb55405ccacd7c1cc79c7dd2fe5523a4dbaa32 Mon Sep 17 00:00:00 2001 From: "Klochkov, Denis" Date: Tue, 9 Jun 2026 17:33:40 +0200 Subject: [PATCH 82/89] [SYCL][TESTE2E] fix ci issue with cuda --- sycl/test-e2e/Experimental/fp8/e4m3_cri_conversion.cpp | 3 +++ sycl/test-e2e/Experimental/fp8/e4m3_x2_cri_conversion.cpp | 3 +++ sycl/test-e2e/Experimental/fp8/e5m2_cri_conversion.cpp | 3 +++ sycl/test-e2e/Experimental/fp8/e5m2_x2_cri_conversion.cpp | 3 +++ sycl/test-e2e/Experimental/fp8/e8m0_cri_conversion.cpp | 3 +++ sycl/test-e2e/Experimental/fp8/e8m0_x2_cri_conversion.cpp | 6 ++---- 6 files changed, 17 insertions(+), 4 deletions(-) diff --git a/sycl/test-e2e/Experimental/fp8/e4m3_cri_conversion.cpp b/sycl/test-e2e/Experimental/fp8/e4m3_cri_conversion.cpp index 0c842e3063535..687813800d8e8 100644 --- a/sycl/test-e2e/Experimental/fp8/e4m3_cri_conversion.cpp +++ b/sycl/test-e2e/Experimental/fp8/e4m3_cri_conversion.cpp @@ -2,6 +2,9 @@ // RUN: %{build} -Xclang -freg-struct-return -Xspirv-translator=spir64 --spirv-ext=+SPV_INTEL_fp_conversions,+SPV_EXT_float8,+SPV_KHR_bfloat16 -o %t.out // RUN: %{run} SYCL_UR_TRACE=1 %t.out +// UNSUPPORTED: cuda, hip +// UNSUPPORTED-INTENDED: only supported by backends with CRI driver + #include #include #include diff --git a/sycl/test-e2e/Experimental/fp8/e4m3_x2_cri_conversion.cpp b/sycl/test-e2e/Experimental/fp8/e4m3_x2_cri_conversion.cpp index d56b2dc86e224..8ddfc9317f6cc 100644 --- a/sycl/test-e2e/Experimental/fp8/e4m3_x2_cri_conversion.cpp +++ b/sycl/test-e2e/Experimental/fp8/e4m3_x2_cri_conversion.cpp @@ -3,6 +3,9 @@ // RUN: %{build} -Xclang -freg-struct-return -Xspirv-translator=spir64 --spirv-ext=+SPV_INTEL_fp_conversions,+SPV_EXT_float8,+SPV_KHR_bfloat16 -o %t.out // RUN: %{run} SYCL_UR_TRACE=1 %t.out +// UNSUPPORTED: cuda, hip +// UNSUPPORTED-INTENDED: only supported by backends with CRI driver + #include #include #include diff --git a/sycl/test-e2e/Experimental/fp8/e5m2_cri_conversion.cpp b/sycl/test-e2e/Experimental/fp8/e5m2_cri_conversion.cpp index 8e13b58951d68..8479f55d531ce 100644 --- a/sycl/test-e2e/Experimental/fp8/e5m2_cri_conversion.cpp +++ b/sycl/test-e2e/Experimental/fp8/e5m2_cri_conversion.cpp @@ -2,6 +2,9 @@ // RUN: %{build} -Xclang -freg-struct-return -Xspirv-translator=spir64 --spirv-ext=+SPV_INTEL_fp_conversions,+SPV_EXT_float8,+SPV_KHR_bfloat16 -o %t.out // RUN: %{run} SYCL_UR_TRACE=1 %t.out +// UNSUPPORTED: cuda, hip +// UNSUPPORTED-INTENDED: only supported by backends with CRI driver + #include #include #include diff --git a/sycl/test-e2e/Experimental/fp8/e5m2_x2_cri_conversion.cpp b/sycl/test-e2e/Experimental/fp8/e5m2_x2_cri_conversion.cpp index d0da6844e618f..e9aa950f7c0a0 100644 --- a/sycl/test-e2e/Experimental/fp8/e5m2_x2_cri_conversion.cpp +++ b/sycl/test-e2e/Experimental/fp8/e5m2_x2_cri_conversion.cpp @@ -2,6 +2,9 @@ // RUN: %{build} -Xclang -freg-struct-return -Xspirv-translator=spir64 --spirv-ext=+SPV_INTEL_fp_conversions,+SPV_EXT_float8,+SPV_KHR_bfloat16 -o %t.out // RUN: %{run} SYCL_UR_TRACE=1 %t.out +// UNSUPPORTED: cuda, hip +// UNSUPPORTED-INTENDED: only supported by backends with CRI driver + #include #include #include diff --git a/sycl/test-e2e/Experimental/fp8/e8m0_cri_conversion.cpp b/sycl/test-e2e/Experimental/fp8/e8m0_cri_conversion.cpp index 6929c4bd04394..3f94a4235a077 100644 --- a/sycl/test-e2e/Experimental/fp8/e8m0_cri_conversion.cpp +++ b/sycl/test-e2e/Experimental/fp8/e8m0_cri_conversion.cpp @@ -3,6 +3,9 @@ // RUN: %{build} -Xclang -freg-struct-return -Xspirv-translator=spir64 --spirv-ext=+SPV_INTEL_fp_conversions,+SPV_EXT_float8,+SPV_KHR_bfloat16 -o %t.out // RUN: %{run} SYCL_UR_TRACE=1 %t.out +// UNSUPPORTED: cuda, hip +// UNSUPPORTED-INTENDED: only supported by backends with CRI driver + #include #include #include diff --git a/sycl/test-e2e/Experimental/fp8/e8m0_x2_cri_conversion.cpp b/sycl/test-e2e/Experimental/fp8/e8m0_x2_cri_conversion.cpp index c86ff61f16da0..ea8a9686302a8 100644 --- a/sycl/test-e2e/Experimental/fp8/e8m0_x2_cri_conversion.cpp +++ b/sycl/test-e2e/Experimental/fp8/e8m0_x2_cri_conversion.cpp @@ -3,10 +3,8 @@ // RUN: %{build} -Xclang -freg-struct-return -Xspirv-translator=spir64 --spirv-ext=+SPV_INTEL_fp_conversions,+SPV_EXT_float8,+SPV_KHR_bfloat16 -o %t.out // RUN: %{run} SYCL_UR_TRACE=1 %t.out -// make it XFAIL until driver will be installed on CI machines and the test will -// be enabled in the test suite -// XFAIL: * -// XFAIL-TRACKER: CMPLRLLVM-69851 +// UNSUPPORTED: cuda, hip +// UNSUPPORTED-INTENDED: only supported by backends with CRI driver #include #include From ef524f8eb907469fef67c5411ca839164d8c627a Mon Sep 17 00:00:00 2001 From: "Klochkov, Denis" Date: Wed, 10 Jun 2026 12:17:09 +0200 Subject: [PATCH 83/89] [SYCL] move aliases into dev code --- sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp index e76e971edb410..eb41d8030065f 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp @@ -22,11 +22,10 @@ #include #include +#ifdef __SYCL_DEVICE_ONLY__ using float16_vec2 = _Float16 __attribute__((ext_vector_type(2))); using uint8_vec2 = uint8_t __attribute__((ext_vector_type(2))); using bfloat16_vec2 = __bf16 __attribute__((ext_vector_type(2))); - -#ifdef __SYCL_DEVICE_ONLY__ // FP8 builtins extern __DPCPP_SYCL_EXTERNAL uint8_t From e1b3cfd6da09d549189ca3ca2c73405eb166a064 Mon Sep 17 00:00:00 2001 From: "Klochkov, Denis" Date: Thu, 11 Jun 2026 14:28:17 +0200 Subject: [PATCH 84/89] [SYCL][TEST] do not pass negative values to e8m0 --- .../Experimental/fp8/e8m0_cri_conversion.cpp | 19 ------------- .../fp8/e8m0_x2_cri_conversion.cpp | 27 ------------------- sycl/unittests/Extensions/fp8/fp8_e8m0.cpp | 24 +---------------- 3 files changed, 1 insertion(+), 69 deletions(-) diff --git a/sycl/test-e2e/Experimental/fp8/e8m0_cri_conversion.cpp b/sycl/test-e2e/Experimental/fp8/e8m0_cri_conversion.cpp index 3f94a4235a077..7372ee567b0b5 100644 --- a/sycl/test-e2e/Experimental/fp8/e8m0_cri_conversion.cpp +++ b/sycl/test-e2e/Experimental/fp8/e8m0_cri_conversion.cpp @@ -360,24 +360,6 @@ int test_raw_vals_access(sycl::queue &queue) { return ret; } -int test_negative_input_drops_sign(sycl::queue &queue) { - float input[1] = {-8.0f}; - auto *data = sycl::malloc_shared(1, queue); - auto *out = sycl::malloc_shared(1, queue); - data[0] = fp8_e8m0(input, rounding::upward); - - queue.single_task([=]() { - fp8_e8m0 value = data[0]; - out[0] = static_cast(value); - }); - queue.wait_and_throw(); - - int ret = (out[0] != 8.0f) ? 1 : 0; - sycl::free(data, queue); - sycl::free(out, queue); - return ret; -} - int main() { auto async_handler = [](sycl::exception_list exceptions) { for (const std::exception_ptr &e : exceptions) { @@ -436,6 +418,5 @@ int main() { ret |= test_saturation_large_value(queue); ret |= test_saturation_overflow(queue); ret |= test_raw_vals_access(queue); - ret |= test_negative_input_drops_sign(queue); return ret; } diff --git a/sycl/test-e2e/Experimental/fp8/e8m0_x2_cri_conversion.cpp b/sycl/test-e2e/Experimental/fp8/e8m0_x2_cri_conversion.cpp index ea8a9686302a8..4d36af50b27b9 100644 --- a/sycl/test-e2e/Experimental/fp8/e8m0_x2_cri_conversion.cpp +++ b/sycl/test-e2e/Experimental/fp8/e8m0_x2_cri_conversion.cpp @@ -227,32 +227,6 @@ int test_boundary_saturation_infinity_clamp(sycl::queue &queue) { return ret; } -int test_boundary_negative_input_drops_sign(sycl::queue &queue) { - const float input[2] = {-4.0f, -32.0f}; - auto *data = sycl::malloc_shared(1, queue); - auto *out = sycl::malloc_shared(2, queue); - data[0] = fp8_e8m0_x2(input, rounding::upward); - - queue.single_task([=]() { - fp8_e8m0_x2 value = data[0]; - sycl::marray unpacked = - static_cast>(value); - out[0] = unpacked[0]; - out[1] = unpacked[1]; - }); - queue.wait_and_throw(); - - int ret = 0; - if (out[0] != 4.0f) - ret = 1; - if (out[1] != 32.0f) - ret = 1; - - sycl::free(data, queue); - sycl::free(out, queue); - return ret; -} - int test_rounding_upward_non_power_of_two(sycl::queue &queue) { const float input[2] = {3.0f, 6.0f}; auto *data = sycl::malloc_shared(1, queue); @@ -460,7 +434,6 @@ int main() { ret |= test_boundary_round_trip_exact_powers_of_two(queue); ret |= test_boundary_round_trip_max_min_normal(queue); ret |= test_boundary_saturation_infinity_clamp(queue); - ret |= test_boundary_negative_input_drops_sign(queue); ret |= test_rounding_upward_non_power_of_two(queue); ret |= test_rounding_toward_zero_non_power_of_two(queue); ret |= test_raw_vals_access(queue); diff --git a/sycl/unittests/Extensions/fp8/fp8_e8m0.cpp b/sycl/unittests/Extensions/fp8/fp8_e8m0.cpp index 64b38687d1986..36221cb7ceae0 100644 --- a/sycl/unittests/Extensions/fp8/fp8_e8m0.cpp +++ b/sycl/unittests/Extensions/fp8/fp8_e8m0.cpp @@ -91,15 +91,6 @@ TEST(FP8E8M0Test, CArrayFloatRoundingModes) { EXPECT_TRUE(checkCode(3.0f, rounding::upward, 0x81)); EXPECT_TRUE(checkCode(3.0f, rounding::toward_zero, 0x80)); - // E8M0 drops sign per the extension specification, so negative inputs are - // rounded using their magnitude. - EXPECT_TRUE(checkCode(-3.0f, rounding::upward, 0x81)); - EXPECT_TRUE(checkCode(-3.0f, rounding::toward_zero, 0x80)); - EXPECT_TRUE(checkCode(-1.5f, rounding::upward, 0x80)); - EXPECT_TRUE(checkCode(-1.5f, rounding::toward_zero, 0x7F)); - EXPECT_TRUE(checkCode(-0.5f, rounding::upward, 0x7E)); - EXPECT_TRUE(checkCode(-0.5f, rounding::toward_zero, 0x7E)); - EXPECT_TRUE(checkCode(1.0f, rounding::upward, 0x7F)); EXPECT_TRUE(checkCode(0.5f, rounding::upward, 0x7E)); EXPECT_TRUE(checkCode(0.5f, rounding::toward_zero, 0x7E)); @@ -113,23 +104,14 @@ TEST(FP8E8M0Test, CArrayFloatFiniteSaturationClampsToMaxNormal) { checkCode(std::numeric_limits::max(), rounding::upward, 0xFE)); EXPECT_TRUE(checkCode(std::numeric_limits::max(), rounding::toward_zero, 0xFE)); - EXPECT_TRUE( - checkCode(-std::numeric_limits::max(), rounding::upward, 0xFE)); - EXPECT_TRUE(checkCode(-std::numeric_limits::max(), - rounding::toward_zero, 0xFE)); EXPECT_TRUE(checkCode(std::numeric_limits::infinity(), rounding::upward, 0xFE)); - EXPECT_TRUE(checkCode(-std::numeric_limits::infinity(), - rounding::toward_zero, 0xFE)); } TEST(FP8E8M0Test, CArrayFloatNaNDropsSign) { const float PosNaN = std::numeric_limits::quiet_NaN(); - const float NegNaN = std::copysign(PosNaN, -1.0f); EXPECT_TRUE(checkCode(PosNaN, rounding::upward, 0xFF)); - EXPECT_TRUE(checkCode(NegNaN, rounding::upward, 0xFF)); - EXPECT_TRUE(checkCode(NegNaN, rounding::toward_zero, 0xFF)); } TEST(FP8E8M0Test, CArrayHalfHostUpwardFinite) { @@ -289,18 +271,14 @@ TEST(FP8E8M0Test, FloatingPointConversionOperators) { EXPECT_EQ(static_cast(min), std::ldexp(1.0f, -127)); } -TEST(FP8E8M0Test, IntegerConversionOperatorsUseTowardZeroOnMagnitude) { - const float NegativeInput[1] = {-1.5f}; +TEST(FP8E8M0Test, IntegerConversionOperatorsUseTowardZero) { const float FractionalInput[1] = {0.5f}; - const fp8_e8m0 from_negative(NegativeInput, rounding::toward_zero); const fp8_e8m0 half(FractionalInput, rounding::toward_zero); const fp8_e8m0 two(2.0f); - EXPECT_EQ(from_negative.vals[0], 0x7F); EXPECT_EQ(half.vals[0], 0x7E); EXPECT_EQ(two.vals[0], 0x80); - EXPECT_EQ(static_cast(from_negative), 1); EXPECT_EQ(static_cast(half), 0); EXPECT_EQ(static_cast(half), 0u); EXPECT_EQ(static_cast(two), 2); From be4b5bb2769ea890e5d749c3754a5debfbdf522f Mon Sep 17 00:00:00 2001 From: "Klochkov, Denis" Date: Thu, 11 Jun 2026 14:51:59 +0200 Subject: [PATCH 85/89] [SYCL] move types into details nemaspace --- .../oneapi/experimental/float_8bit/types.hpp | 100 ++++++++++-------- 1 file changed, 54 insertions(+), 46 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp index eb41d8030065f..f6d3e6ce5a229 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp @@ -23,53 +23,57 @@ #include #ifdef __SYCL_DEVICE_ONLY__ + +namespace detail { using float16_vec2 = _Float16 __attribute__((ext_vector_type(2))); using uint8_vec2 = uint8_t __attribute__((ext_vector_type(2))); using bfloat16_vec2 = __bf16 __attribute__((ext_vector_type(2))); +} // namespace detail // FP8 builtins extern __DPCPP_SYCL_EXTERNAL uint8_t __builtin_spirv_ClampConvertFP16ToE4M3INTEL(_Float16) noexcept; -extern __DPCPP_SYCL_EXTERNAL uint8_vec2 - __builtin_spirv_ClampConvertFP16ToE4M3INTEL(float16_vec2) noexcept; -extern __DPCPP_SYCL_EXTERNAL - float16_vec2 __builtin_spirv_ConvertE4M3ToFP16EXT(uint8_vec2) noexcept; +extern __DPCPP_SYCL_EXTERNAL detail::uint8_vec2 + __builtin_spirv_ClampConvertFP16ToE4M3INTEL(detail::float16_vec2) noexcept; +extern __DPCPP_SYCL_EXTERNAL detail::float16_vec2 + __builtin_spirv_ConvertE4M3ToFP16EXT(detail::uint8_vec2) noexcept; extern __DPCPP_SYCL_EXTERNAL _Float16 __builtin_spirv_ConvertE4M3ToFP16EXT(char) noexcept; extern __DPCPP_SYCL_EXTERNAL __bf16 __builtin_spirv_ConvertE4M3ToBF16EXT(uint8_t) noexcept; -extern __DPCPP_SYCL_EXTERNAL - bfloat16_vec2 __builtin_spirv_ConvertE4M3ToBF16EXT(uint8_vec2) noexcept; +extern __DPCPP_SYCL_EXTERNAL detail::bfloat16_vec2 + __builtin_spirv_ConvertE4M3ToBF16EXT(detail::uint8_vec2) noexcept; extern __DPCPP_SYCL_EXTERNAL uint8_t __builtin_spirv_ClampConvertBF16ToE4M3INTEL(__bf16) noexcept; -extern __DPCPP_SYCL_EXTERNAL uint8_vec2 - __builtin_spirv_ClampConvertBF16ToE4M3INTEL(bfloat16_vec2) noexcept; +extern __DPCPP_SYCL_EXTERNAL detail::uint8_vec2 + __builtin_spirv_ClampConvertBF16ToE4M3INTEL(detail::bfloat16_vec2) noexcept; extern __DPCPP_SYCL_EXTERNAL uint8_t __builtin_spirv_ClampConvertFP16ToE5M2INTEL(_Float16) noexcept; -extern __DPCPP_SYCL_EXTERNAL uint8_vec2 - __builtin_spirv_ClampConvertFP16ToE5M2INTEL(float16_vec2) noexcept; +extern __DPCPP_SYCL_EXTERNAL ::detail::uint8_vec2 + __builtin_spirv_ClampConvertFP16ToE5M2INTEL( + ::detail::float16_vec2) noexcept; extern __DPCPP_SYCL_EXTERNAL uint8_t __builtin_spirv_ConvertFP16ToE5M2EXT(_Float16) noexcept; -extern __DPCPP_SYCL_EXTERNAL - uint8_vec2 __builtin_spirv_ConvertFP16ToE5M2EXT(float16_vec2) noexcept; +extern __DPCPP_SYCL_EXTERNAL detail::uint8_vec2 + __builtin_spirv_ConvertFP16ToE5M2EXT(detail::float16_vec2) noexcept; extern __DPCPP_SYCL_EXTERNAL _Float16 __builtin_spirv_ConvertE5M2ToFP16EXT(uint8_t) noexcept; -extern __DPCPP_SYCL_EXTERNAL - float16_vec2 __builtin_spirv_ConvertE5M2ToFP16EXT(uint8_vec2) noexcept; +extern __DPCPP_SYCL_EXTERNAL detail::float16_vec2 + __builtin_spirv_ConvertE5M2ToFP16EXT(detail::uint8_vec2) noexcept; extern __DPCPP_SYCL_EXTERNAL uint8_t __builtin_spirv_ClampConvertBF16ToE5M2INTEL(__bf16) noexcept; -extern __DPCPP_SYCL_EXTERNAL uint8_vec2 - __builtin_spirv_ClampConvertBF16ToE5M2INTEL(bfloat16_vec2) noexcept; +extern __DPCPP_SYCL_EXTERNAL detail::uint8_vec2 + __builtin_spirv_ClampConvertBF16ToE5M2INTEL(detail::bfloat16_vec2) noexcept; extern __DPCPP_SYCL_EXTERNAL uint8_t __builtin_spirv_ConvertBF16ToE5M2EXT(__bf16) noexcept; -extern __DPCPP_SYCL_EXTERNAL - uint8_vec2 __builtin_spirv_ConvertBF16ToE5M2EXT(bfloat16_vec2) noexcept; +extern __DPCPP_SYCL_EXTERNAL detail::uint8_vec2 + __builtin_spirv_ConvertBF16ToE5M2EXT(detail::bfloat16_vec2) noexcept; extern __DPCPP_SYCL_EXTERNAL __bf16 __builtin_spirv_ConvertE5M2ToBF16EXT(uint8_t) noexcept; -extern __DPCPP_SYCL_EXTERNAL - bfloat16_vec2 __builtin_spirv_ConvertE5M2ToBF16EXT(uint8_vec2) noexcept; +extern __DPCPP_SYCL_EXTERNAL detail::bfloat16_vec2 + __builtin_spirv_ConvertE5M2ToBF16EXT(detail::uint8_vec2) noexcept; extern __DPCPP_SYCL_EXTERNAL uint8_t __builtin_spirv_ClampStochasticRoundFP16ToE5M2INTEL( _Float16, uint32_t, __attribute__((opencl_private)) uint32_t *) noexcept; @@ -892,7 +896,7 @@ template class fp8_e4m3_x { } #ifdef __SYCL_DEVICE_ONLY__ - uint8_vec2 ConvertToFP8_Vec2(float16_vec2 h) { + ::detail::uint8_vec2 ConvertToFP8_Vec2(::detail::float16_vec2 h) { return __builtin_spirv_ClampConvertFP16ToE4M3INTEL(h); } #endif @@ -905,7 +909,7 @@ template class fp8_e4m3_x { } #ifdef __SYCL_DEVICE_ONLY__ - uint8_vec2 ConvertBF16ToFP8_Vec2(bfloat16_vec2 h) { + ::detail::uint8_vec2 ConvertBF16ToFP8_Vec2(::detail::bfloat16_vec2 h) { return __builtin_spirv_ClampConvertBF16ToE4M3INTEL(h); } #endif @@ -934,8 +938,8 @@ template class fp8_e4m3_x { void ConvertFromFP8_Vec2(sycl::marray &ret, rounding r = rounding::to_even) const { #ifdef __SYCL_DEVICE_ONLY__ - const uint8_vec2 packed{vals[0], vals[1]}; - float16_vec2 hi = __builtin_spirv_ConvertE4M3ToFP16EXT(packed); + const ::detail::uint8_vec2 packed{vals[0], vals[1]}; + ::detail::float16_vec2 hi = __builtin_spirv_ConvertE4M3ToFP16EXT(packed); ret[0] = sycl::bit_cast(hi[0]); ret[1] = sycl::bit_cast(hi[1]); #else @@ -958,8 +962,8 @@ template class fp8_e4m3_x { void ConvertBF16FromFP8_Vec2(sycl::marray &ret, rounding r = rounding::to_even) const { #ifdef __SYCL_DEVICE_ONLY__ - const uint8_vec2 packed{vals[0], vals[1]}; - bfloat16_vec2 hi = __builtin_spirv_ConvertE4M3ToBF16EXT(packed); + const ::detail::uint8_vec2 packed{vals[0], vals[1]}; + ::detail::bfloat16_vec2 hi = __builtin_spirv_ConvertE4M3ToBF16EXT(packed); ret[0] = sycl::bit_cast(hi[0]); ret[1] = sycl::bit_cast(hi[1]); #else @@ -981,7 +985,7 @@ template class fp8_e4m3_x { } else { \ const VecType vec{sycl::bit_cast(in[0]), \ sycl::bit_cast(in[1])}; \ - const uint8_vec2 result = Convert##Prefix##ToFP8_Vec2(vec); \ + const ::detail::uint8_vec2 result = Convert##Prefix##ToFP8_Vec2(vec); \ std::memcpy(vals, &result, sizeof(vals)); \ } #else @@ -1009,10 +1013,10 @@ template class fp8_e4m3_x { explicit fp8_e4m3_x(Types... v) { if constexpr (((std::is_same_v, bfloat16>) && ...)) { const bfloat16 in[N] = {v...}; - CONVERT_TO_FP8(bfloat16_vec2, __bf16, in, BF16); + CONVERT_TO_FP8(::detail::bfloat16_vec2, __bf16, in, BF16); } else if constexpr (((std::is_same_v, half>) && ...)) { const sycl::half in[N] = {v...}; - CONVERT_TO_FP8(float16_vec2, _Float16, in, ); + CONVERT_TO_FP8(::detail::float16_vec2, _Float16, in, ); } else { const float in[N] = {v...}; for (size_t i = 0; i < N; ++i) @@ -1024,12 +1028,12 @@ template class fp8_e4m3_x { explicit fp8_e4m3_x(sycl::half const (&v)[N], rounding r = rounding::to_even) { CheckConstraints(r); - CONVERT_TO_FP8(float16_vec2, _Float16, v, ); + CONVERT_TO_FP8(::detail::float16_vec2, _Float16, v, ); } explicit fp8_e4m3_x(bfloat16 const (&v)[N], rounding r = rounding::to_even) { CheckConstraints(r); - CONVERT_TO_FP8(bfloat16_vec2, __bf16, v, BF16); + CONVERT_TO_FP8(::detail::bfloat16_vec2, __bf16, v, BF16); } explicit fp8_e4m3_x(float const (&v)[N], rounding r = rounding::to_even) { @@ -1042,13 +1046,13 @@ template class fp8_e4m3_x { explicit fp8_e4m3_x(const sycl::marray &v, rounding r = rounding::to_even) { CheckConstraints(r); - CONVERT_TO_FP8(float16_vec2, _Float16, v, ); + CONVERT_TO_FP8(::detail::float16_vec2, _Float16, v, ); } explicit fp8_e4m3_x(const sycl::marray &v, rounding r = rounding::to_even) { CheckConstraints(r); - CONVERT_TO_FP8(bfloat16_vec2, __bf16, v, BF16); + CONVERT_TO_FP8(::detail::bfloat16_vec2, __bf16, v, BF16); } explicit fp8_e4m3_x(const sycl::marray &v, @@ -1326,7 +1330,8 @@ template class fp8_e5m2_x { } #ifdef __SYCL_DEVICE_ONLY__ - uint8_vec2 ConvertToFP8_Vec2(float16_vec2 h, saturation s) { + ::detail::uint8_vec2 ConvertToFP8_Vec2(::detail::float16_vec2 h, + saturation s) { return s == saturation::finite ? __builtin_spirv_ClampConvertFP16ToE5M2INTEL(h) : __builtin_spirv_ConvertFP16ToE5M2EXT(h); @@ -1354,7 +1359,8 @@ template class fp8_e5m2_x { } #ifdef __SYCL_DEVICE_ONLY__ - uint8_vec2 ConvertBF16ToFP8_Vec2(bfloat16_vec2 h, saturation s) { + ::detail::uint8_vec2 ConvertBF16ToFP8_Vec2(::detail::bfloat16_vec2 h, + saturation s) { return s == saturation::finite ? __builtin_spirv_ClampConvertBF16ToE5M2INTEL(h) : __builtin_spirv_ConvertBF16ToE5M2EXT(h); @@ -1375,8 +1381,8 @@ template class fp8_e5m2_x { void ConvertFromFP8_Vec2(sycl::marray &ret, rounding r = rounding::to_even) const { #ifdef __SYCL_DEVICE_ONLY__ - const uint8_vec2 packed{vals[0], vals[1]}; - float16_vec2 hi = __builtin_spirv_ConvertE5M2ToFP16EXT(packed); + const ::detail::uint8_vec2 packed{vals[0], vals[1]}; + ::detail::float16_vec2 hi = __builtin_spirv_ConvertE5M2ToFP16EXT(packed); ret[0] = sycl::bit_cast(hi[0]); ret[1] = sycl::bit_cast(hi[1]); #else @@ -1399,8 +1405,8 @@ template class fp8_e5m2_x { void ConvertBF16FromFP8_Vec2(sycl::marray &ret, rounding r = rounding::to_even) const { #ifdef __SYCL_DEVICE_ONLY__ - const uint8_vec2 packed{vals[0], vals[1]}; - bfloat16_vec2 hi = __builtin_spirv_ConvertE5M2ToBF16EXT(packed); + const ::detail::uint8_vec2 packed{vals[0], vals[1]}; + ::detail::bfloat16_vec2 hi = __builtin_spirv_ConvertE5M2ToBF16EXT(packed); ret[0] = sycl::bit_cast(hi[0]); ret[1] = sycl::bit_cast(hi[1]); #else @@ -1422,7 +1428,7 @@ template class fp8_e5m2_x { } else { \ const VecType vec{sycl::bit_cast(in[0]), \ sycl::bit_cast(in[1])}; \ - const uint8_vec2 result = Convert##Prefix##ToFP8_Vec2(vec, s); \ + const ::detail::uint8_vec2 result = Convert##Prefix##ToFP8_Vec2(vec, s); \ std::memcpy(vals, &result, sizeof(vals)); \ } #else @@ -1451,10 +1457,12 @@ template class fp8_e5m2_x { explicit fp8_e5m2_x(Types... v) { if constexpr (((std::is_same_v, bfloat16>) && ...)) { const bfloat16 in[N] = {static_cast(v)...}; - CONVERT_TO_FP8(bfloat16_vec2, __bf16, in, saturation::finite, BF16); + CONVERT_TO_FP8(::detail::bfloat16_vec2, __bf16, in, saturation::finite, + BF16); } else if constexpr (((std::is_same_v, half>) && ...)) { const sycl::half in[N] = {v...}; - CONVERT_TO_FP8(float16_vec2, _Float16, in, saturation::finite, ); + CONVERT_TO_FP8(::detail::float16_vec2, _Float16, in, + saturation::finite, ); } else { using InT = std::common_type_t...>; const InT in[N] = {v...}; @@ -1468,13 +1476,13 @@ template class fp8_e5m2_x { explicit fp8_e5m2_x(half const (&v)[N], rounding r = rounding::to_even, saturation s = saturation::finite) { CheckConstraints(r); - CONVERT_TO_FP8(float16_vec2, _Float16, v, s, ); + CONVERT_TO_FP8(::detail::float16_vec2, _Float16, v, s, ); } explicit fp8_e5m2_x(bfloat16 const (&v)[N], rounding r = rounding::to_even, saturation s = saturation::finite) { CheckConstraints(r); - CONVERT_TO_FP8(bfloat16_vec2, __bf16, v, s, BF16); + CONVERT_TO_FP8(::detail::bfloat16_vec2, __bf16, v, s, BF16); } explicit fp8_e5m2_x(float const (&v)[N], rounding r = rounding::to_even, @@ -1490,14 +1498,14 @@ template class fp8_e5m2_x { rounding r = rounding::to_even, saturation s = saturation::finite) { CheckConstraints(r); - CONVERT_TO_FP8(float16_vec2, _Float16, v, s, ); + CONVERT_TO_FP8(::detail::float16_vec2, _Float16, v, s, ); } explicit fp8_e5m2_x(const sycl::marray &v, rounding r = rounding::to_even, saturation s = saturation::finite) { CheckConstraints(r); - CONVERT_TO_FP8(bfloat16_vec2, __bf16, v, s, BF16); + CONVERT_TO_FP8(::detail::bfloat16_vec2, __bf16, v, s, BF16); } explicit fp8_e5m2_x(const sycl::marray &v, From dc8d5986656e1667e7447aab9c0ad159c8bc1137 Mon Sep 17 00:00:00 2001 From: "Klochkov, Denis" Date: Thu, 11 Jun 2026 15:47:55 +0200 Subject: [PATCH 86/89] [SYCL][TESTE2E] fix unsupported targets --- sycl/test-e2e/Experimental/fp8/e4m3_cri_conversion.cpp | 2 +- sycl/test-e2e/Experimental/fp8/e4m3_x2_cri_conversion.cpp | 2 +- sycl/test-e2e/Experimental/fp8/e5m2_cri_conversion.cpp | 2 +- sycl/test-e2e/Experimental/fp8/e5m2_x2_cri_conversion.cpp | 2 +- sycl/test-e2e/Experimental/fp8/e8m0_cri_conversion.cpp | 2 +- sycl/test-e2e/Experimental/fp8/e8m0_x2_cri_conversion.cpp | 2 +- 6 files changed, 6 insertions(+), 6 deletions(-) diff --git a/sycl/test-e2e/Experimental/fp8/e4m3_cri_conversion.cpp b/sycl/test-e2e/Experimental/fp8/e4m3_cri_conversion.cpp index 687813800d8e8..8bee9e1582666 100644 --- a/sycl/test-e2e/Experimental/fp8/e4m3_cri_conversion.cpp +++ b/sycl/test-e2e/Experimental/fp8/e4m3_cri_conversion.cpp @@ -2,7 +2,7 @@ // RUN: %{build} -Xclang -freg-struct-return -Xspirv-translator=spir64 --spirv-ext=+SPV_INTEL_fp_conversions,+SPV_EXT_float8,+SPV_KHR_bfloat16 -o %t.out // RUN: %{run} SYCL_UR_TRACE=1 %t.out -// UNSUPPORTED: cuda, hip +// UNSUPPORTED: target-nvidia, target-amd // UNSUPPORTED-INTENDED: only supported by backends with CRI driver #include diff --git a/sycl/test-e2e/Experimental/fp8/e4m3_x2_cri_conversion.cpp b/sycl/test-e2e/Experimental/fp8/e4m3_x2_cri_conversion.cpp index 8ddfc9317f6cc..7c0cc828f93c1 100644 --- a/sycl/test-e2e/Experimental/fp8/e4m3_x2_cri_conversion.cpp +++ b/sycl/test-e2e/Experimental/fp8/e4m3_x2_cri_conversion.cpp @@ -3,7 +3,7 @@ // RUN: %{build} -Xclang -freg-struct-return -Xspirv-translator=spir64 --spirv-ext=+SPV_INTEL_fp_conversions,+SPV_EXT_float8,+SPV_KHR_bfloat16 -o %t.out // RUN: %{run} SYCL_UR_TRACE=1 %t.out -// UNSUPPORTED: cuda, hip +// UNSUPPORTED: target-nvidia, target-amd // UNSUPPORTED-INTENDED: only supported by backends with CRI driver #include diff --git a/sycl/test-e2e/Experimental/fp8/e5m2_cri_conversion.cpp b/sycl/test-e2e/Experimental/fp8/e5m2_cri_conversion.cpp index 8479f55d531ce..174b902e99a71 100644 --- a/sycl/test-e2e/Experimental/fp8/e5m2_cri_conversion.cpp +++ b/sycl/test-e2e/Experimental/fp8/e5m2_cri_conversion.cpp @@ -2,7 +2,7 @@ // RUN: %{build} -Xclang -freg-struct-return -Xspirv-translator=spir64 --spirv-ext=+SPV_INTEL_fp_conversions,+SPV_EXT_float8,+SPV_KHR_bfloat16 -o %t.out // RUN: %{run} SYCL_UR_TRACE=1 %t.out -// UNSUPPORTED: cuda, hip +// UNSUPPORTED: target-nvidia, target-amd // UNSUPPORTED-INTENDED: only supported by backends with CRI driver #include diff --git a/sycl/test-e2e/Experimental/fp8/e5m2_x2_cri_conversion.cpp b/sycl/test-e2e/Experimental/fp8/e5m2_x2_cri_conversion.cpp index e9aa950f7c0a0..5ccfd2e86a13b 100644 --- a/sycl/test-e2e/Experimental/fp8/e5m2_x2_cri_conversion.cpp +++ b/sycl/test-e2e/Experimental/fp8/e5m2_x2_cri_conversion.cpp @@ -2,7 +2,7 @@ // RUN: %{build} -Xclang -freg-struct-return -Xspirv-translator=spir64 --spirv-ext=+SPV_INTEL_fp_conversions,+SPV_EXT_float8,+SPV_KHR_bfloat16 -o %t.out // RUN: %{run} SYCL_UR_TRACE=1 %t.out -// UNSUPPORTED: cuda, hip +// UNSUPPORTED: target-nvidia, target-amd // UNSUPPORTED-INTENDED: only supported by backends with CRI driver #include diff --git a/sycl/test-e2e/Experimental/fp8/e8m0_cri_conversion.cpp b/sycl/test-e2e/Experimental/fp8/e8m0_cri_conversion.cpp index 7372ee567b0b5..26bb3f068e3e5 100644 --- a/sycl/test-e2e/Experimental/fp8/e8m0_cri_conversion.cpp +++ b/sycl/test-e2e/Experimental/fp8/e8m0_cri_conversion.cpp @@ -3,7 +3,7 @@ // RUN: %{build} -Xclang -freg-struct-return -Xspirv-translator=spir64 --spirv-ext=+SPV_INTEL_fp_conversions,+SPV_EXT_float8,+SPV_KHR_bfloat16 -o %t.out // RUN: %{run} SYCL_UR_TRACE=1 %t.out -// UNSUPPORTED: cuda, hip +// UNSUPPORTED: target-nvidia, target-amd // UNSUPPORTED-INTENDED: only supported by backends with CRI driver #include diff --git a/sycl/test-e2e/Experimental/fp8/e8m0_x2_cri_conversion.cpp b/sycl/test-e2e/Experimental/fp8/e8m0_x2_cri_conversion.cpp index 4d36af50b27b9..b8324e750638b 100644 --- a/sycl/test-e2e/Experimental/fp8/e8m0_x2_cri_conversion.cpp +++ b/sycl/test-e2e/Experimental/fp8/e8m0_x2_cri_conversion.cpp @@ -3,7 +3,7 @@ // RUN: %{build} -Xclang -freg-struct-return -Xspirv-translator=spir64 --spirv-ext=+SPV_INTEL_fp_conversions,+SPV_EXT_float8,+SPV_KHR_bfloat16 -o %t.out // RUN: %{run} SYCL_UR_TRACE=1 %t.out -// UNSUPPORTED: cuda, hip +// UNSUPPORTED: target-nvidia, target-amd // UNSUPPORTED-INTENDED: only supported by backends with CRI driver #include From 39290a6edce51b39a5a7dd5c616f864600846025 Mon Sep 17 00:00:00 2001 From: "Klochkov, Denis" Date: Fri, 12 Jun 2026 10:55:50 +0200 Subject: [PATCH 87/89] [SYCL] use detail namespace inside of sycl namespace --- .../oneapi/experimental/float_8bit/types.hpp | 111 ++++++++++-------- 1 file changed, 62 insertions(+), 49 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp index f6d3e6ce5a229..bfd75d12405bf 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp @@ -24,56 +24,62 @@ #ifdef __SYCL_DEVICE_ONLY__ +namespace sycl { namespace detail { using float16_vec2 = _Float16 __attribute__((ext_vector_type(2))); using uint8_vec2 = uint8_t __attribute__((ext_vector_type(2))); using bfloat16_vec2 = __bf16 __attribute__((ext_vector_type(2))); } // namespace detail +} // namespace sycl // FP8 builtins extern __DPCPP_SYCL_EXTERNAL uint8_t __builtin_spirv_ClampConvertFP16ToE4M3INTEL(_Float16) noexcept; -extern __DPCPP_SYCL_EXTERNAL detail::uint8_vec2 - __builtin_spirv_ClampConvertFP16ToE4M3INTEL(detail::float16_vec2) noexcept; -extern __DPCPP_SYCL_EXTERNAL detail::float16_vec2 - __builtin_spirv_ConvertE4M3ToFP16EXT(detail::uint8_vec2) noexcept; +extern __DPCPP_SYCL_EXTERNAL ::sycl::detail::uint8_vec2 + __builtin_spirv_ClampConvertFP16ToE4M3INTEL( + ::sycl::detail::float16_vec2) noexcept; +extern __DPCPP_SYCL_EXTERNAL ::sycl::detail::float16_vec2 + __builtin_spirv_ConvertE4M3ToFP16EXT(::sycl::detail::uint8_vec2) noexcept; extern __DPCPP_SYCL_EXTERNAL _Float16 __builtin_spirv_ConvertE4M3ToFP16EXT(char) noexcept; extern __DPCPP_SYCL_EXTERNAL __bf16 __builtin_spirv_ConvertE4M3ToBF16EXT(uint8_t) noexcept; -extern __DPCPP_SYCL_EXTERNAL detail::bfloat16_vec2 - __builtin_spirv_ConvertE4M3ToBF16EXT(detail::uint8_vec2) noexcept; +extern __DPCPP_SYCL_EXTERNAL ::sycl::detail::bfloat16_vec2 + __builtin_spirv_ConvertE4M3ToBF16EXT(::sycl::detail::uint8_vec2) noexcept; extern __DPCPP_SYCL_EXTERNAL uint8_t __builtin_spirv_ClampConvertBF16ToE4M3INTEL(__bf16) noexcept; -extern __DPCPP_SYCL_EXTERNAL detail::uint8_vec2 - __builtin_spirv_ClampConvertBF16ToE4M3INTEL(detail::bfloat16_vec2) noexcept; +extern __DPCPP_SYCL_EXTERNAL ::sycl::detail::uint8_vec2 + __builtin_spirv_ClampConvertBF16ToE4M3INTEL( + ::sycl::detail::bfloat16_vec2) noexcept; extern __DPCPP_SYCL_EXTERNAL uint8_t __builtin_spirv_ClampConvertFP16ToE5M2INTEL(_Float16) noexcept; -extern __DPCPP_SYCL_EXTERNAL ::detail::uint8_vec2 +extern __DPCPP_SYCL_EXTERNAL ::sycl::detail::uint8_vec2 __builtin_spirv_ClampConvertFP16ToE5M2INTEL( - ::detail::float16_vec2) noexcept; + ::sycl::detail::float16_vec2) noexcept; extern __DPCPP_SYCL_EXTERNAL uint8_t __builtin_spirv_ConvertFP16ToE5M2EXT(_Float16) noexcept; -extern __DPCPP_SYCL_EXTERNAL detail::uint8_vec2 - __builtin_spirv_ConvertFP16ToE5M2EXT(detail::float16_vec2) noexcept; +extern __DPCPP_SYCL_EXTERNAL ::sycl::detail::uint8_vec2 + __builtin_spirv_ConvertFP16ToE5M2EXT(::sycl::detail::float16_vec2) noexcept; extern __DPCPP_SYCL_EXTERNAL _Float16 __builtin_spirv_ConvertE5M2ToFP16EXT(uint8_t) noexcept; -extern __DPCPP_SYCL_EXTERNAL detail::float16_vec2 - __builtin_spirv_ConvertE5M2ToFP16EXT(detail::uint8_vec2) noexcept; +extern __DPCPP_SYCL_EXTERNAL ::sycl::detail::float16_vec2 + __builtin_spirv_ConvertE5M2ToFP16EXT(::sycl::detail::uint8_vec2) noexcept; extern __DPCPP_SYCL_EXTERNAL uint8_t __builtin_spirv_ClampConvertBF16ToE5M2INTEL(__bf16) noexcept; -extern __DPCPP_SYCL_EXTERNAL detail::uint8_vec2 - __builtin_spirv_ClampConvertBF16ToE5M2INTEL(detail::bfloat16_vec2) noexcept; +extern __DPCPP_SYCL_EXTERNAL ::sycl::detail::uint8_vec2 + __builtin_spirv_ClampConvertBF16ToE5M2INTEL( + ::sycl::detail::bfloat16_vec2) noexcept; extern __DPCPP_SYCL_EXTERNAL uint8_t __builtin_spirv_ConvertBF16ToE5M2EXT(__bf16) noexcept; -extern __DPCPP_SYCL_EXTERNAL detail::uint8_vec2 - __builtin_spirv_ConvertBF16ToE5M2EXT(detail::bfloat16_vec2) noexcept; +extern __DPCPP_SYCL_EXTERNAL ::sycl::detail::uint8_vec2 + __builtin_spirv_ConvertBF16ToE5M2EXT( + ::sycl::detail::bfloat16_vec2) noexcept; extern __DPCPP_SYCL_EXTERNAL __bf16 __builtin_spirv_ConvertE5M2ToBF16EXT(uint8_t) noexcept; -extern __DPCPP_SYCL_EXTERNAL detail::bfloat16_vec2 - __builtin_spirv_ConvertE5M2ToBF16EXT(detail::uint8_vec2) noexcept; +extern __DPCPP_SYCL_EXTERNAL ::sycl::detail::bfloat16_vec2 + __builtin_spirv_ConvertE5M2ToBF16EXT(::sycl::detail::uint8_vec2) noexcept; extern __DPCPP_SYCL_EXTERNAL uint8_t __builtin_spirv_ClampStochasticRoundFP16ToE5M2INTEL( _Float16, uint32_t, __attribute__((opencl_private)) uint32_t *) noexcept; @@ -896,7 +902,7 @@ template class fp8_e4m3_x { } #ifdef __SYCL_DEVICE_ONLY__ - ::detail::uint8_vec2 ConvertToFP8_Vec2(::detail::float16_vec2 h) { + ::sycl::detail::uint8_vec2 ConvertToFP8_Vec2(::sycl::detail::float16_vec2 h) { return __builtin_spirv_ClampConvertFP16ToE4M3INTEL(h); } #endif @@ -909,7 +915,8 @@ template class fp8_e4m3_x { } #ifdef __SYCL_DEVICE_ONLY__ - ::detail::uint8_vec2 ConvertBF16ToFP8_Vec2(::detail::bfloat16_vec2 h) { + ::sycl::detail::uint8_vec2 + ConvertBF16ToFP8_Vec2(::sycl::detail::bfloat16_vec2 h) { return __builtin_spirv_ClampConvertBF16ToE4M3INTEL(h); } #endif @@ -938,8 +945,9 @@ template class fp8_e4m3_x { void ConvertFromFP8_Vec2(sycl::marray &ret, rounding r = rounding::to_even) const { #ifdef __SYCL_DEVICE_ONLY__ - const ::detail::uint8_vec2 packed{vals[0], vals[1]}; - ::detail::float16_vec2 hi = __builtin_spirv_ConvertE4M3ToFP16EXT(packed); + const ::sycl::detail::uint8_vec2 packed{vals[0], vals[1]}; + ::sycl::detail::float16_vec2 hi = + __builtin_spirv_ConvertE4M3ToFP16EXT(packed); ret[0] = sycl::bit_cast(hi[0]); ret[1] = sycl::bit_cast(hi[1]); #else @@ -962,8 +970,9 @@ template class fp8_e4m3_x { void ConvertBF16FromFP8_Vec2(sycl::marray &ret, rounding r = rounding::to_even) const { #ifdef __SYCL_DEVICE_ONLY__ - const ::detail::uint8_vec2 packed{vals[0], vals[1]}; - ::detail::bfloat16_vec2 hi = __builtin_spirv_ConvertE4M3ToBF16EXT(packed); + const ::sycl::detail::uint8_vec2 packed{vals[0], vals[1]}; + ::sycl::detail::bfloat16_vec2 hi = + __builtin_spirv_ConvertE4M3ToBF16EXT(packed); ret[0] = sycl::bit_cast(hi[0]); ret[1] = sycl::bit_cast(hi[1]); #else @@ -985,7 +994,8 @@ template class fp8_e4m3_x { } else { \ const VecType vec{sycl::bit_cast(in[0]), \ sycl::bit_cast(in[1])}; \ - const ::detail::uint8_vec2 result = Convert##Prefix##ToFP8_Vec2(vec); \ + const ::sycl::detail::uint8_vec2 result = \ + Convert##Prefix##ToFP8_Vec2(vec); \ std::memcpy(vals, &result, sizeof(vals)); \ } #else @@ -1013,10 +1023,10 @@ template class fp8_e4m3_x { explicit fp8_e4m3_x(Types... v) { if constexpr (((std::is_same_v, bfloat16>) && ...)) { const bfloat16 in[N] = {v...}; - CONVERT_TO_FP8(::detail::bfloat16_vec2, __bf16, in, BF16); + CONVERT_TO_FP8(::sycl::detail::bfloat16_vec2, __bf16, in, BF16); } else if constexpr (((std::is_same_v, half>) && ...)) { const sycl::half in[N] = {v...}; - CONVERT_TO_FP8(::detail::float16_vec2, _Float16, in, ); + CONVERT_TO_FP8(::sycl::detail::float16_vec2, _Float16, in, ); } else { const float in[N] = {v...}; for (size_t i = 0; i < N; ++i) @@ -1028,12 +1038,12 @@ template class fp8_e4m3_x { explicit fp8_e4m3_x(sycl::half const (&v)[N], rounding r = rounding::to_even) { CheckConstraints(r); - CONVERT_TO_FP8(::detail::float16_vec2, _Float16, v, ); + CONVERT_TO_FP8(::sycl::detail::float16_vec2, _Float16, v, ); } explicit fp8_e4m3_x(bfloat16 const (&v)[N], rounding r = rounding::to_even) { CheckConstraints(r); - CONVERT_TO_FP8(::detail::bfloat16_vec2, __bf16, v, BF16); + CONVERT_TO_FP8(::sycl::detail::bfloat16_vec2, __bf16, v, BF16); } explicit fp8_e4m3_x(float const (&v)[N], rounding r = rounding::to_even) { @@ -1046,13 +1056,13 @@ template class fp8_e4m3_x { explicit fp8_e4m3_x(const sycl::marray &v, rounding r = rounding::to_even) { CheckConstraints(r); - CONVERT_TO_FP8(::detail::float16_vec2, _Float16, v, ); + CONVERT_TO_FP8(::sycl::detail::float16_vec2, _Float16, v, ); } explicit fp8_e4m3_x(const sycl::marray &v, rounding r = rounding::to_even) { CheckConstraints(r); - CONVERT_TO_FP8(::detail::bfloat16_vec2, __bf16, v, BF16); + CONVERT_TO_FP8(::sycl::detail::bfloat16_vec2, __bf16, v, BF16); } explicit fp8_e4m3_x(const sycl::marray &v, @@ -1330,8 +1340,8 @@ template class fp8_e5m2_x { } #ifdef __SYCL_DEVICE_ONLY__ - ::detail::uint8_vec2 ConvertToFP8_Vec2(::detail::float16_vec2 h, - saturation s) { + ::sycl::detail::uint8_vec2 ConvertToFP8_Vec2(::sycl::detail::float16_vec2 h, + saturation s) { return s == saturation::finite ? __builtin_spirv_ClampConvertFP16ToE5M2INTEL(h) : __builtin_spirv_ConvertFP16ToE5M2EXT(h); @@ -1359,8 +1369,8 @@ template class fp8_e5m2_x { } #ifdef __SYCL_DEVICE_ONLY__ - ::detail::uint8_vec2 ConvertBF16ToFP8_Vec2(::detail::bfloat16_vec2 h, - saturation s) { + ::sycl::detail::uint8_vec2 + ConvertBF16ToFP8_Vec2(::sycl::detail::bfloat16_vec2 h, saturation s) { return s == saturation::finite ? __builtin_spirv_ClampConvertBF16ToE5M2INTEL(h) : __builtin_spirv_ConvertBF16ToE5M2EXT(h); @@ -1381,8 +1391,9 @@ template class fp8_e5m2_x { void ConvertFromFP8_Vec2(sycl::marray &ret, rounding r = rounding::to_even) const { #ifdef __SYCL_DEVICE_ONLY__ - const ::detail::uint8_vec2 packed{vals[0], vals[1]}; - ::detail::float16_vec2 hi = __builtin_spirv_ConvertE5M2ToFP16EXT(packed); + const ::sycl::detail::uint8_vec2 packed{vals[0], vals[1]}; + ::sycl::detail::float16_vec2 hi = + __builtin_spirv_ConvertE5M2ToFP16EXT(packed); ret[0] = sycl::bit_cast(hi[0]); ret[1] = sycl::bit_cast(hi[1]); #else @@ -1405,8 +1416,9 @@ template class fp8_e5m2_x { void ConvertBF16FromFP8_Vec2(sycl::marray &ret, rounding r = rounding::to_even) const { #ifdef __SYCL_DEVICE_ONLY__ - const ::detail::uint8_vec2 packed{vals[0], vals[1]}; - ::detail::bfloat16_vec2 hi = __builtin_spirv_ConvertE5M2ToBF16EXT(packed); + const ::sycl::detail::uint8_vec2 packed{vals[0], vals[1]}; + ::sycl::detail::bfloat16_vec2 hi = + __builtin_spirv_ConvertE5M2ToBF16EXT(packed); ret[0] = sycl::bit_cast(hi[0]); ret[1] = sycl::bit_cast(hi[1]); #else @@ -1428,7 +1440,8 @@ template class fp8_e5m2_x { } else { \ const VecType vec{sycl::bit_cast(in[0]), \ sycl::bit_cast(in[1])}; \ - const ::detail::uint8_vec2 result = Convert##Prefix##ToFP8_Vec2(vec, s); \ + const ::sycl::detail::uint8_vec2 result = \ + Convert##Prefix##ToFP8_Vec2(vec, s); \ std::memcpy(vals, &result, sizeof(vals)); \ } #else @@ -1457,11 +1470,11 @@ template class fp8_e5m2_x { explicit fp8_e5m2_x(Types... v) { if constexpr (((std::is_same_v, bfloat16>) && ...)) { const bfloat16 in[N] = {static_cast(v)...}; - CONVERT_TO_FP8(::detail::bfloat16_vec2, __bf16, in, saturation::finite, - BF16); + CONVERT_TO_FP8(::sycl::detail::bfloat16_vec2, __bf16, in, + saturation::finite, BF16); } else if constexpr (((std::is_same_v, half>) && ...)) { const sycl::half in[N] = {v...}; - CONVERT_TO_FP8(::detail::float16_vec2, _Float16, in, + CONVERT_TO_FP8(::sycl::detail::float16_vec2, _Float16, in, saturation::finite, ); } else { using InT = std::common_type_t...>; @@ -1476,13 +1489,13 @@ template class fp8_e5m2_x { explicit fp8_e5m2_x(half const (&v)[N], rounding r = rounding::to_even, saturation s = saturation::finite) { CheckConstraints(r); - CONVERT_TO_FP8(::detail::float16_vec2, _Float16, v, s, ); + CONVERT_TO_FP8(::sycl::detail::float16_vec2, _Float16, v, s, ); } explicit fp8_e5m2_x(bfloat16 const (&v)[N], rounding r = rounding::to_even, saturation s = saturation::finite) { CheckConstraints(r); - CONVERT_TO_FP8(::detail::bfloat16_vec2, __bf16, v, s, BF16); + CONVERT_TO_FP8(::sycl::detail::bfloat16_vec2, __bf16, v, s, BF16); } explicit fp8_e5m2_x(float const (&v)[N], rounding r = rounding::to_even, @@ -1498,14 +1511,14 @@ template class fp8_e5m2_x { rounding r = rounding::to_even, saturation s = saturation::finite) { CheckConstraints(r); - CONVERT_TO_FP8(::detail::float16_vec2, _Float16, v, s, ); + CONVERT_TO_FP8(::sycl::detail::float16_vec2, _Float16, v, s, ); } explicit fp8_e5m2_x(const sycl::marray &v, rounding r = rounding::to_even, saturation s = saturation::finite) { CheckConstraints(r); - CONVERT_TO_FP8(::detail::bfloat16_vec2, __bf16, v, s, BF16); + CONVERT_TO_FP8(::sycl::detail::bfloat16_vec2, __bf16, v, s, BF16); } explicit fp8_e5m2_x(const sycl::marray &v, From aa9a1a990abea7072feb80f80be43eec318af49c Mon Sep 17 00:00:00 2001 From: "Klochkov, Denis" Date: Fri, 12 Jun 2026 11:32:32 +0200 Subject: [PATCH 88/89] [SYCL] use khr address space cast --- .../oneapi/experimental/float_8bit/types.hpp | 41 +++++++++++-------- 1 file changed, 25 insertions(+), 16 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp index bfd75d12405bf..0981ca720a9b5 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp @@ -12,6 +12,7 @@ #include #include +#include #include #include @@ -1543,13 +1544,15 @@ template class fp8_e5m2_x { if (s == saturation::finite) { vals[i] = __builtin_spirv_ClampStochasticRoundFP16ToE5M2INTEL( v, current_seed, - sycl::address_space_cast(&next_seed)); + sycl::khr::static_addrspace_cast< + sycl::access::address_space::private_space>(&next_seed) + .get_decorated()); } else { vals[i] = __builtin_spirv_StochasticRoundFP16ToE5M2INTEL( v, current_seed, - sycl::address_space_cast(&next_seed)); + sycl::khr::static_addrspace_cast< + sycl::access::address_space::private_space>(&next_seed) + .get_decorated()); } current_seed = next_seed; next_seed = 0; @@ -1570,13 +1573,15 @@ template class fp8_e5m2_x { if (s == saturation::finite) { vals[i] = __builtin_spirv_ClampStochasticRoundBF16ToE5M2INTEL( sycl::bit_cast<__bf16>(in[i]), current_seed, - sycl::address_space_cast(&next_seed)); + sycl::khr::static_addrspace_cast< + sycl::access::address_space::private_space>(&next_seed) + .get_decorated()); } else { vals[i] = __builtin_spirv_StochasticRoundBF16ToE5M2INTEL( sycl::bit_cast<__bf16>(in[i]), current_seed, - sycl::address_space_cast(&next_seed)); + sycl::khr::static_addrspace_cast< + sycl::access::address_space::private_space>(&next_seed) + .get_decorated()); } current_seed = next_seed; next_seed = 0; @@ -1601,13 +1606,15 @@ template class fp8_e5m2_x { if (s == saturation::finite) { vals[i] = __builtin_spirv_ClampStochasticRoundFP16ToE5M2INTEL( v, current_seed, - sycl::address_space_cast(&next_seed)); + sycl::khr::static_addrspace_cast< + sycl::access::address_space::private_space>(&next_seed) + .get_decorated()); } else { vals[i] = __builtin_spirv_StochasticRoundFP16ToE5M2INTEL( v, current_seed, - sycl::address_space_cast(&next_seed)); + sycl::khr::static_addrspace_cast< + sycl::access::address_space::private_space>(&next_seed) + .get_decorated()); } current_seed = next_seed; next_seed = 0; @@ -1628,13 +1635,15 @@ template class fp8_e5m2_x { if (s == saturation::finite) { vals[i] = __builtin_spirv_ClampStochasticRoundBF16ToE5M2INTEL( sycl::bit_cast<__bf16>(in[i]), current_seed, - sycl::address_space_cast(&next_seed)); + sycl::khr::static_addrspace_cast< + sycl::access::address_space::private_space>(&next_seed) + .get_decorated()); } else { vals[i] = __builtin_spirv_StochasticRoundBF16ToE5M2INTEL( sycl::bit_cast<__bf16>(in[i]), current_seed, - sycl::address_space_cast(&next_seed)); + sycl::khr::static_addrspace_cast< + sycl::access::address_space::private_space>(&next_seed) + .get_decorated()); } current_seed = next_seed; next_seed = 0; From f529ff419e51c73cbebfc634bbc4d4a77544ee61 Mon Sep 17 00:00:00 2001 From: "Klochkov, Denis" Date: Fri, 12 Jun 2026 12:09:21 +0200 Subject: [PATCH 89/89] [SYCL] do not convert twice 16-bit integers --- .../sycl/ext/oneapi/experimental/float_8bit/types.hpp | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp index 0981ca720a9b5..8b97c30f7f836 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/float_8bit/types.hpp @@ -878,9 +878,8 @@ template class fp8_e4m3_x { uint8_t ConvertToFP8(T h) { #ifdef __SYCL_DEVICE_ONLY__ if constexpr (std::is_same_v, char> || - std::is_same_v, unsigned char> || - std::is_same_v, short> || - std::is_same_v, unsigned short>) { + std::is_same_v, signed char> || + std::is_same_v, unsigned char>) { const _Float16 v = static_cast<_Float16>(h); return __builtin_spirv_ClampConvertFP16ToE4M3INTEL(v); } @@ -1312,9 +1311,8 @@ template class fp8_e5m2_x { uint8_t ConvertToFP8(T h, saturation s) { #ifdef __SYCL_DEVICE_ONLY__ if constexpr (std::is_same_v, char> || - std::is_same_v, unsigned char> || - std::is_same_v, short> || - std::is_same_v, unsigned short>) { + std::is_same_v, signed char> || + std::is_same_v, unsigned char>) { const _Float16 v = static_cast<_Float16>(h); return s == saturation::finite ? __builtin_spirv_ClampConvertFP16ToE5M2INTEL(v)