11#pragma once
22
3- inline float fp32_from_bits (uint32_t bits) {
4- return *(reinterpret_cast <thread float *>(&bits));
5- }
6- inline float fp32_to_bits (float x) {
7- return *(reinterpret_cast <thread uint32_t *>(&x));
8- }
9-
103struct fp8_e4m3 {
114 template <typename T>
125 fp8_e4m3 (T f) {
136 // From PyTorch
147 // https://github.com/pytorch/pytorch/blob/e3643e1e0e923f0fc063dfab6f45c956d568919d/c10/util/Float8_e4m3fn.h#L148
158 uint32_t fp8_max = 543 << 21 ;
169 uint32_t denorm_mask = 141 << 23 ;
17- uint32_t f_bits = fp32_to_bits (static_cast <float >(f));
10+ uint32_t f_bits = as_type< uint32_t > (static_cast <float >(f));
1811 uint32_t sign = f_bits & 0x80000000 ;
1912 f_bits ^= sign;
2013 if (f_bits >= fp8_max) {
2114 // Default behavior saturates to min/max
2215 bits = 0x7E ;
2316 } else {
2417 if (f_bits < (121 << 23 )) {
25- f_bits =
26- fp32_to_bits ( fp32_from_bits ( f_bits) + fp32_from_bits (denorm_mask));
18+ f_bits = as_type< uint32_t >(
19+ as_type< float >( f_bits) + as_type< float > (denorm_mask));
2720 bits = static_cast <uint8_t >(f_bits - denorm_mask);
2821 } else {
2922 // resulting mantissa is odd
@@ -53,7 +46,7 @@ struct fp8_e4m3 {
5346 ((((nonsign << renorm_shift >> 4 ) + ((0x78 - renorm_shift) << 23 )) |
5447 inf_nan_mask) &
5548 ~zero_mask);
56- return fp32_from_bits (result);
49+ return as_type< float > (result);
5750 }
5851
5952 uint8_t bits;
@@ -77,11 +70,12 @@ struct fp8_e8m0 {
7770 bits = static_cast <uint8_t >(n + 127 );
7871 }
7972
73+ operator bfloat16_t () {
74+ uint16_t out = (bits == 0 ? 0x40 : (static_cast <uint16_t >(bits) << 7 ));
75+ return as_type<bfloat16_t >(out);
76+ }
8077 operator float () {
81- if (bits == 0xFF ) {
82- return metal::numeric_limits<float >::quiet_NaN ();
83- }
84- return metal::ldexp (1 .0f , static_cast <int >(bits) - 127 );
78+ return static_cast <float >(this ->operator bfloat16_t ());
8579 }
8680
8781 uint8_t bits;
0 commit comments