Skip to content

Commit f830639

Browse files
committed
WIP
1 parent ce03bae commit f830639

1 file changed

Lines changed: 39 additions & 21 deletions

File tree

python/mscclpp_benchmark/correctness.py

Lines changed: 39 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -263,35 +263,48 @@ def _comparison_tolerance(case: Any, nranks: int) -> tuple[float, float] | None:
263263

264264

265265
_FP8_TABLES: dict[str, list[tuple[int, float]]] = {}
266-
_FP8_DEVICE_TABLES: dict[str, tuple] = {}
266+
_FP8_LOOKUP_CACHE: dict[str, tuple[Any, Any]] = {}
267267
_FP8_SPACING_CACHE: dict[tuple[str, float], float] = {}
268268

269269

270-
def _fp8_device_table(fp8_format: str):
271-
cached = _FP8_DEVICE_TABLES.get(fp8_format)
272-
if cached is None:
273-
table = _FP8_TABLES.setdefault(fp8_format, _build_fp8_table(fp8_format))
274-
table_bytes = cp.asarray([byte for byte, _ in table], dtype=cp.uint8)
275-
table_values = cp.asarray([value for _, value in table], dtype=cp.float32)
276-
cached = (table_bytes, table_values)
277-
_FP8_DEVICE_TABLES[fp8_format] = cached
278-
return cached
279-
280-
281270
def _encode_fp8_values(fp8_format: str, values):
282271
values = values.astype(cp.float32)
283272
if fp8_format == "e4m3b15":
284273
return _encode_e4m3b15_values(values)
285274

286-
table_bytes, table_values = _fp8_device_table(fp8_format)
275+
# Round each value to the nearest representable FP8 value (ties to even).
276+
table_values, table_bytes = _fp8_lookup_arrays(fp8_format)
287277
flat_values = values.ravel()
288-
flat_encoded = cp.empty(flat_values.shape, dtype=cp.uint8)
289-
chunk_size = 65536
290-
for start in range(0, flat_values.size, chunk_size):
291-
end = min(start + chunk_size, flat_values.size)
292-
diff = cp.abs(flat_values[start:end, None] - table_values[None, :])
293-
flat_encoded[start:end] = table_bytes[cp.argmin(diff, axis=1)]
294-
return flat_encoded.reshape(values.shape)
278+
279+
# For each value find its two surrounding table entries: lower <= value <= upper.
280+
upper = cp.clip(cp.searchsorted(table_values, flat_values), 1, table_values.size - 1)
281+
lower = upper - 1
282+
283+
# Pick the closer neighbor; on an exact tie pick the one with an even byte.
284+
dist_to_upper = table_values[upper] - flat_values
285+
dist_to_lower = flat_values - table_values[lower]
286+
upper_is_even = (table_bytes[upper] & cp.uint8(1)) == 0
287+
pick_upper = (dist_to_upper < dist_to_lower) | ((dist_to_upper == dist_to_lower) & upper_is_even)
288+
289+
return cp.where(pick_upper, table_bytes[upper], table_bytes[lower]).reshape(values.shape)
290+
291+
292+
def _fp8_lookup_arrays(fp8_format: str):
293+
# Cache a sorted (value -> byte) table per format for fast nearest-value lookup.
294+
if fp8_format in _FP8_LOOKUP_CACHE:
295+
return _FP8_LOOKUP_CACHE[fp8_format]
296+
297+
# Different bytes can decode to the same value (e.g. +0 and -0); keep one byte per value.
298+
byte_for_value: dict[float, int] = {}
299+
for byte, value in _FP8_TABLES.setdefault(fp8_format, _build_fp8_table(fp8_format)):
300+
if value not in byte_for_value or byte < byte_for_value[value]:
301+
byte_for_value[value] = byte
302+
303+
table = sorted(byte_for_value.items())
304+
table_values = cp.asarray([value for value, _ in table], dtype=cp.float32)
305+
table_bytes = cp.asarray([byte for _, byte in table], dtype=cp.uint8)
306+
_FP8_LOOKUP_CACHE[fp8_format] = (table_values, table_bytes)
307+
return _FP8_LOOKUP_CACHE[fp8_format]
295308

296309

297310
def _max_fp8_spacing(fp8_format: str, max_abs_value: float) -> float:
@@ -337,6 +350,8 @@ def _build_fp8_table(fp8_format: str) -> list[tuple[int, float]]:
337350

338351

339352
def _decode_fp8_scalar(fp8_format: str, byte: int) -> float:
353+
if fp8_format == "e4m3fnuz" and byte == 0x80:
354+
return float("nan")
340355
sign = -1.0 if byte & 0x80 else 1.0
341356
return sign * _decode_fp8_positive(fp8_format, byte & 0x7F)
342357

@@ -379,4 +394,7 @@ def _decode_fp8_array(fp8_format: str, values):
379394
else:
380395
raise ValueError(f"Unknown FP8 format: {fp8_format}")
381396

382-
return cp.where(sign == 1, -decoded, decoded)
397+
result = cp.where(sign == 1, -decoded, decoded)
398+
if fp8_format == "e4m3fnuz":
399+
result = cp.where(bits == 0x80, cp.float32(float("nan")), result)
400+
return result

0 commit comments

Comments
 (0)