44from __future__ import annotations
55
66import argparse
7- import math
8- import struct
97from dataclasses import dataclass
108from typing import Any
119
1513_mscclpp_module = None
1614
1715from mscclpp_benchmark .comm import Comm
16+ from mscclpp_benchmark .correctness import (
17+ CorrectnessStats ,
18+ check_correctness as _check_correctness ,
19+ fill_case_for_benchmark as _fill_case_for_benchmark ,
20+ )
1821from mscclpp_benchmark .gpu import capture_graph , init_runtime
1922from mscclpp_benchmark .tuner import OfflineTuner
2023from mscclpp_benchmark .tuning_config import HardwareProfile , TunedConfig , TunedConfigStore , normalize_sku
@@ -94,10 +97,6 @@ class BenchmarkCase:
9497 symmetric_memory : bool = False
9598
9699
97- def config_accum_dtype (case : BenchmarkCase ) -> Any :
98- return case .dtype_spec .accum_dtype or case .dtype_spec .mscclpp_dtype
99-
100-
101100def _device_name () -> str :
102101 props = cp .cuda .runtime .getDeviceProperties (cp .cuda .Device ().id )
103102 name = props .get ("name" , "UNKNOWN" )
@@ -176,10 +175,6 @@ def _with_accum_type(dtype_spec: DTypeSpec, accum_type: str | None) -> DTypeSpec
176175 )
177176
178177
179- def _dtype_is_float (dtype : Any ) -> bool :
180- return dtype in (cp .float16 , cp .float32 )
181-
182-
183178def _human_size (size : int ) -> str :
184179 value = float (size )
185180 for unit in ("B" , "KiB" , "MiB" , "GiB" , "TiB" ):
@@ -333,167 +328,6 @@ def _make_case(
333328 )
334329
335330
336- def _fill_case_for_benchmark (case : BenchmarkCase , rank : int ) -> None :
337- if case .collective == _ALLREDUCE :
338- case .input .fill (0 )
339- return
340- _fill_allgather_input (case , rank )
341- case .output .fill (0 )
342- if case .allgather_mode == "in-place" :
343- _fill_allgather_input (case , rank )
344-
345-
346- def _fill_case_for_correctness (case : BenchmarkCase , rank : int , iteration : int ) -> None :
347- value = iteration * MPI .COMM_WORLD .size + rank
348- if case .collective == _ALLREDUCE :
349- case .input .fill (_dtype_value (case , value ))
350- return
351- case .output .fill (0 )
352- case .input .fill (_dtype_value (case , value ))
353-
354-
355- def _fill_allgather_input (case : BenchmarkCase , rank : int ) -> None :
356- value = rank + 1
357- case .input .fill (_dtype_value (case , value ))
358-
359-
360- def _dtype_value (case : BenchmarkCase , value : int ) -> int :
361- if case .dtype_spec .fp8_format is not None :
362- return _encode_fp8_scalar (case .dtype_spec .fp8_format , _correctness_numeric_value (case , value ))
363- if case .dtype_spec .cupy_dtype == cp .uint8 :
364- return value % 256
365- return value
366-
367-
368- def _correctness_numeric_value (case : BenchmarkCase , value : int ) -> float :
369- if case .dtype_spec .fp8_format is None :
370- return float (value )
371- scale = max (64 , MPI .COMM_WORLD .size * MPI .COMM_WORLD .size )
372- return float (value + 1 ) / float (scale )
373-
374-
375- def _check_correctness (
376- comm : Comm ,
377- case : BenchmarkCase ,
378- config : TunedConfig ,
379- * ,
380- raise_on_unsupported : bool = True ,
381- niter : int = 1 ,
382- ) -> bool :
383- if case .collective == _ALLREDUCE and not case .dtype_spec .supports_reduction_correctness :
384- if not raise_on_unsupported :
385- return True
386- raise ValueError (
387- f"Correctness checking for { case .collective } with { case .dtype_spec .name } is not implemented; "
388- "use --skip-correctness or a numeric dtype"
389- )
390-
391- all_ok = True
392- for iteration in range (niter ):
393- _fill_case_for_correctness (case , comm .rank , iteration )
394- comm .comm_group .barrier ()
395- ret = comm .run (case , config )
396- cp .cuda .runtime .deviceSynchronize ()
397- comm .comm_group .barrier ()
398- if ret != 0 :
399- return False
400-
401- expected = _expected_output (case , comm .nranks , iteration )
402- local_ok = _compare_output (case .output , expected )
403- all_ok = all_ok and local_ok
404-
405- if not local_ok :
406- mismatch = _mismatch_mask (case .output , expected )
407- print (
408- "not close: "
409- f"iter={ iteration } , rank={ comm .rank } , output={ case .output [mismatch ][0 ]} , "
410- f"expected={ expected [mismatch ][0 ]} " ,
411- flush = True ,
412- )
413-
414- return bool (MPI .COMM_WORLD .allreduce (all_ok , op = MPI .LAND ))
415-
416-
417- def _expected_output (case : BenchmarkCase , nranks : int , iteration : int ):
418- if case .collective == _ALLREDUCE :
419- if case .dtype_spec .fp8_format is not None :
420- expected_numeric = sum (
421- _decode_fp8_positive (
422- case .dtype_spec .fp8_format ,
423- _dtype_value (case , iteration * MPI .COMM_WORLD .size + rank ),
424- )
425- for rank in range (nranks )
426- )
427- return cp .full_like (case .output , _encode_fp8_scalar (case .dtype_spec .fp8_format , expected_numeric ))
428- expected_value = sum (iteration * MPI .COMM_WORLD .size + rank for rank in range (nranks ))
429- return cp .full_like (case .output , _dtype_value (case , expected_value ))
430-
431- expected = cp .empty_like (case .output )
432- chunk = case .input .size
433- for rank in range (nranks ):
434- expected [rank * chunk : (rank + 1 ) * chunk ].fill (_dtype_value (case , iteration * MPI .COMM_WORLD .size + rank ))
435- return expected
436-
437-
438- def _compare_output (output , expected ) -> bool :
439- if _dtype_is_float (output .dtype .type ):
440- return bool (cp .allclose (output , expected , rtol = 1.0e-2 , atol = 2 ).item ())
441- return bool (cp .all (output == expected ).item ())
442-
443-
444- def _mismatch_mask (output , expected ):
445- if _dtype_is_float (output .dtype .type ):
446- return ~ cp .isclose (output , expected , rtol = 1.0e-2 , atol = 2 )
447- return output != expected
448-
449-
450- _FP8_POSITIVE_TABLES : dict [str , list [tuple [int , float ]]] = {}
451-
452-
453- def _encode_fp8_scalar (fp8_format : str , value : float ) -> int :
454- if value < 0 :
455- raise ValueError ("FP8 correctness values are expected to be non-negative" )
456- if fp8_format == "e4m3b15" :
457- return _encode_e4m3b15_scalar (value )
458- table = _FP8_POSITIVE_TABLES .setdefault (fp8_format , _build_fp8_positive_table (fp8_format ))
459- return min (table , key = lambda item : abs (item [1 ] - value ))[0 ]
460-
461-
462- def _encode_e4m3b15_scalar (value : float ) -> int :
463- fp16_bits = struct .unpack ("<H" , struct .pack ("<e" , float (value )))[0 ]
464- abs_fp16 = fp16_bits & 0x7FFF
465- if abs_fp16 > 0x3F00 :
466- abs_fp16 = 0x3F00
467- sign16 = fp16_bits & 0x8000
468- adjusted = abs_fp16 * 2 + 0x0080
469- return ((sign16 | adjusted ) >> 8 ) & 0xFF
470-
471-
472- def _build_fp8_positive_table (fp8_format : str ) -> list [tuple [int , float ]]:
473- table = []
474- for byte in range (128 ):
475- value = _decode_fp8_positive (fp8_format , byte )
476- if not math .isnan (value ):
477- table .append ((byte , value ))
478- return table
479-
480-
481- def _decode_fp8_positive (fp8_format : str , byte : int ) -> float :
482- exp = (byte >> 3 ) & 0xF
483- mant = byte & 0x7
484- if fp8_format == "e4m3fn" and exp == 0xF and mant == 0x7 :
485- return float ("nan" )
486- if exp == 0 and mant == 0 :
487- return 0.0
488- if fp8_format == "e4m3fn" :
489- return math .ldexp (mant / 8.0 , - 6 ) if exp == 0 else math .ldexp (1.0 + mant / 8.0 , exp - 7 )
490- if fp8_format == "e4m3fnuz" :
491- return math .ldexp (mant / 8.0 , - 7 ) if exp == 0 else math .ldexp (1.0 + mant / 8.0 , exp - 8 )
492- if fp8_format == "e4m3b15" :
493- return math .ldexp (mant / 8.0 , - 14 ) if exp == 0 else math .ldexp (1.0 + mant / 8.0 , exp - 15 )
494- raise ValueError (f"Unknown FP8 format: { fp8_format } " )
495-
496-
497331def _try_measure_case (
498332 comm : Comm ,
499333 case : BenchmarkCase ,
@@ -595,6 +429,18 @@ def _format_table(headers: list[str], rows: list[list[str]]) -> str:
595429 return "\n " .join ([header_line , sep_line , * row_lines ])
596430
597431
432+ def _format_stat (value : float | None ) -> str :
433+ if value is None :
434+ return "-"
435+ return f"{ value :.6g} "
436+
437+
438+ def _format_mismatches (stats : CorrectnessStats | None ) -> str :
439+ if stats is None or stats .total == 0 :
440+ return "-"
441+ return f"{ stats .mismatches } /{ stats .total } "
442+
443+
598444def _build_parser () -> argparse .ArgumentParser :
599445 parser = argparse .ArgumentParser (description = "Benchmark MSCCL++ collectives without PyTorch dependencies" )
600446 parser .add_argument ("--collective" , choices = (_ALLREDUCE , _ALLGATHER ), default = _ALLREDUCE )
@@ -705,8 +551,10 @@ def main(argv: list[str] | None = None) -> None:
705551 config_store .upsert (hardware_profile , args .collective , case .message_size , config )
706552
707553 correctness = "SKIP"
554+ correctness_stats : CorrectnessStats | None = None
708555 if not args .skip_correctness :
709- correctness = "PASS" if _check_correctness (comm , case , config , niter = args .correctness_iters ) else "FAIL"
556+ correctness_stats = _check_correctness (comm , case , config , niter = args .correctness_iters )
557+ correctness = "PASS" if correctness_stats else "FAIL"
710558 comm .reset (config )
711559 if correctness != "PASS" :
712560 raise RuntimeError (
@@ -738,6 +586,9 @@ def main(argv: list[str] | None = None) -> None:
738586 f"{ algbw :.2f} " ,
739587 f"{ busbw :.2f} " ,
740588 correctness ,
589+ _format_stat (None if correctness_stats is None else correctness_stats .max_abs_diff ),
590+ _format_stat (None if correctness_stats is None else correctness_stats .mean_abs_diff ),
591+ _format_mismatches (correctness_stats ),
741592 ]
742593 )
743594 if comm .rank == 0 :
@@ -762,6 +613,9 @@ def main(argv: list[str] | None = None) -> None:
762613 "algBW_GB/s" ,
763614 "busBW_GB/s" ,
764615 "check" ,
616+ "max_diff" ,
617+ "mean_diff" ,
618+ "mismatch" ,
765619 ],
766620 rows ,
767621 ),
0 commit comments