@@ -263,35 +263,48 @@ def _comparison_tolerance(case: Any, nranks: int) -> tuple[float, float] | None:
263263
264264
265265_FP8_TABLES : dict [str , list [tuple [int , float ]]] = {}
266- _FP8_DEVICE_TABLES : dict [str , tuple ] = {}
266+ _FP8_LOOKUP_CACHE : dict [str , tuple [ Any , Any ] ] = {}
267267_FP8_SPACING_CACHE : dict [tuple [str , float ], float ] = {}
268268
269269
270- def _fp8_device_table (fp8_format : str ):
271- cached = _FP8_DEVICE_TABLES .get (fp8_format )
272- if cached is None :
273- table = _FP8_TABLES .setdefault (fp8_format , _build_fp8_table (fp8_format ))
274- table_bytes = cp .asarray ([byte for byte , _ in table ], dtype = cp .uint8 )
275- table_values = cp .asarray ([value for _ , value in table ], dtype = cp .float32 )
276- cached = (table_bytes , table_values )
277- _FP8_DEVICE_TABLES [fp8_format ] = cached
278- return cached
279-
280-
281270def _encode_fp8_values (fp8_format : str , values ):
282271 values = values .astype (cp .float32 )
283272 if fp8_format == "e4m3b15" :
284273 return _encode_e4m3b15_values (values )
285274
286- table_bytes , table_values = _fp8_device_table (fp8_format )
275+ # Round each value to the nearest representable FP8 value (ties to even).
276+ table_values , table_bytes = _fp8_lookup_arrays (fp8_format )
287277 flat_values = values .ravel ()
288- flat_encoded = cp .empty (flat_values .shape , dtype = cp .uint8 )
289- chunk_size = 65536
290- for start in range (0 , flat_values .size , chunk_size ):
291- end = min (start + chunk_size , flat_values .size )
292- diff = cp .abs (flat_values [start :end , None ] - table_values [None , :])
293- flat_encoded [start :end ] = table_bytes [cp .argmin (diff , axis = 1 )]
294- return flat_encoded .reshape (values .shape )
278+
279+ # For each value find its two surrounding table entries: lower <= value <= upper.
280+ upper = cp .clip (cp .searchsorted (table_values , flat_values ), 1 , table_values .size - 1 )
281+ lower = upper - 1
282+
283+ # Pick the closer neighbor; on an exact tie pick the one with an even byte.
284+ dist_to_upper = table_values [upper ] - flat_values
285+ dist_to_lower = flat_values - table_values [lower ]
286+ upper_is_even = (table_bytes [upper ] & cp .uint8 (1 )) == 0
287+ pick_upper = (dist_to_upper < dist_to_lower ) | ((dist_to_upper == dist_to_lower ) & upper_is_even )
288+
289+ return cp .where (pick_upper , table_bytes [upper ], table_bytes [lower ]).reshape (values .shape )
290+
291+
292+ def _fp8_lookup_arrays (fp8_format : str ):
293+ # Cache a sorted (value -> byte) table per format for fast nearest-value lookup.
294+ if fp8_format in _FP8_LOOKUP_CACHE :
295+ return _FP8_LOOKUP_CACHE [fp8_format ]
296+
297+ # Different bytes can decode to the same value (e.g. +0 and -0); keep one byte per value.
298+ byte_for_value : dict [float , int ] = {}
299+ for byte , value in _FP8_TABLES .setdefault (fp8_format , _build_fp8_table (fp8_format )):
300+ if value not in byte_for_value or byte < byte_for_value [value ]:
301+ byte_for_value [value ] = byte
302+
303+ table = sorted (byte_for_value .items ())
304+ table_values = cp .asarray ([value for value , _ in table ], dtype = cp .float32 )
305+ table_bytes = cp .asarray ([byte for _ , byte in table ], dtype = cp .uint8 )
306+ _FP8_LOOKUP_CACHE [fp8_format ] = (table_values , table_bytes )
307+ return _FP8_LOOKUP_CACHE [fp8_format ]
295308
296309
297310def _max_fp8_spacing (fp8_format : str , max_abs_value : float ) -> float :
@@ -337,6 +350,8 @@ def _build_fp8_table(fp8_format: str) -> list[tuple[int, float]]:
337350
338351
339352def _decode_fp8_scalar (fp8_format : str , byte : int ) -> float :
353+ if fp8_format == "e4m3fnuz" and byte == 0x80 :
354+ return float ("nan" )
340355 sign = - 1.0 if byte & 0x80 else 1.0
341356 return sign * _decode_fp8_positive (fp8_format , byte & 0x7F )
342357
@@ -379,4 +394,7 @@ def _decode_fp8_array(fp8_format: str, values):
379394 else :
380395 raise ValueError (f"Unknown FP8 format: { fp8_format } " )
381396
382- return cp .where (sign == 1 , - decoded , decoded )
397+ result = cp .where (sign == 1 , - decoded , decoded )
398+ if fp8_format == "e4m3fnuz" :
399+ result = cp .where (bits == 0x80 , cp .float32 (float ("nan" )), result )
400+ return result
0 commit comments