Skip to content

Commit 44dab3b

Browse files
committed
update correctness check
1 parent f1a5a7d commit 44dab3b

3 files changed

Lines changed: 405 additions & 182 deletions

File tree

python/mscclpp_benchmark/bench_collective.py

Lines changed: 26 additions & 172 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@
44
from __future__ import annotations
55

66
import argparse
7-
import math
8-
import struct
97
from dataclasses import dataclass
108
from typing import Any
119

@@ -15,6 +13,11 @@
1513
_mscclpp_module = None
1614

1715
from mscclpp_benchmark.comm import Comm
16+
from mscclpp_benchmark.correctness import (
17+
CorrectnessStats,
18+
check_correctness as _check_correctness,
19+
fill_case_for_benchmark as _fill_case_for_benchmark,
20+
)
1821
from mscclpp_benchmark.gpu import capture_graph, init_runtime
1922
from mscclpp_benchmark.tuner import OfflineTuner
2023
from mscclpp_benchmark.tuning_config import HardwareProfile, TunedConfig, TunedConfigStore, normalize_sku
@@ -94,10 +97,6 @@ class BenchmarkCase:
9497
symmetric_memory: bool = False
9598

9699

97-
def config_accum_dtype(case: BenchmarkCase) -> Any:
98-
return case.dtype_spec.accum_dtype or case.dtype_spec.mscclpp_dtype
99-
100-
101100
def _device_name() -> str:
102101
props = cp.cuda.runtime.getDeviceProperties(cp.cuda.Device().id)
103102
name = props.get("name", "UNKNOWN")
@@ -176,10 +175,6 @@ def _with_accum_type(dtype_spec: DTypeSpec, accum_type: str | None) -> DTypeSpec
176175
)
177176

178177

179-
def _dtype_is_float(dtype: Any) -> bool:
180-
return dtype in (cp.float16, cp.float32)
181-
182-
183178
def _human_size(size: int) -> str:
184179
value = float(size)
185180
for unit in ("B", "KiB", "MiB", "GiB", "TiB"):
@@ -333,167 +328,6 @@ def _make_case(
333328
)
334329

335330

336-
def _fill_case_for_benchmark(case: BenchmarkCase, rank: int) -> None:
337-
if case.collective == _ALLREDUCE:
338-
case.input.fill(0)
339-
return
340-
_fill_allgather_input(case, rank)
341-
case.output.fill(0)
342-
if case.allgather_mode == "in-place":
343-
_fill_allgather_input(case, rank)
344-
345-
346-
def _fill_case_for_correctness(case: BenchmarkCase, rank: int, iteration: int) -> None:
347-
value = iteration * MPI.COMM_WORLD.size + rank
348-
if case.collective == _ALLREDUCE:
349-
case.input.fill(_dtype_value(case, value))
350-
return
351-
case.output.fill(0)
352-
case.input.fill(_dtype_value(case, value))
353-
354-
355-
def _fill_allgather_input(case: BenchmarkCase, rank: int) -> None:
356-
value = rank + 1
357-
case.input.fill(_dtype_value(case, value))
358-
359-
360-
def _dtype_value(case: BenchmarkCase, value: int) -> int:
361-
if case.dtype_spec.fp8_format is not None:
362-
return _encode_fp8_scalar(case.dtype_spec.fp8_format, _correctness_numeric_value(case, value))
363-
if case.dtype_spec.cupy_dtype == cp.uint8:
364-
return value % 256
365-
return value
366-
367-
368-
def _correctness_numeric_value(case: BenchmarkCase, value: int) -> float:
369-
if case.dtype_spec.fp8_format is None:
370-
return float(value)
371-
scale = max(64, MPI.COMM_WORLD.size * MPI.COMM_WORLD.size)
372-
return float(value + 1) / float(scale)
373-
374-
375-
def _check_correctness(
376-
comm: Comm,
377-
case: BenchmarkCase,
378-
config: TunedConfig,
379-
*,
380-
raise_on_unsupported: bool = True,
381-
niter: int = 1,
382-
) -> bool:
383-
if case.collective == _ALLREDUCE and not case.dtype_spec.supports_reduction_correctness:
384-
if not raise_on_unsupported:
385-
return True
386-
raise ValueError(
387-
f"Correctness checking for {case.collective} with {case.dtype_spec.name} is not implemented; "
388-
"use --skip-correctness or a numeric dtype"
389-
)
390-
391-
all_ok = True
392-
for iteration in range(niter):
393-
_fill_case_for_correctness(case, comm.rank, iteration)
394-
comm.comm_group.barrier()
395-
ret = comm.run(case, config)
396-
cp.cuda.runtime.deviceSynchronize()
397-
comm.comm_group.barrier()
398-
if ret != 0:
399-
return False
400-
401-
expected = _expected_output(case, comm.nranks, iteration)
402-
local_ok = _compare_output(case.output, expected)
403-
all_ok = all_ok and local_ok
404-
405-
if not local_ok:
406-
mismatch = _mismatch_mask(case.output, expected)
407-
print(
408-
"not close: "
409-
f"iter={iteration}, rank={comm.rank}, output={case.output[mismatch][0]}, "
410-
f"expected={expected[mismatch][0]}",
411-
flush=True,
412-
)
413-
414-
return bool(MPI.COMM_WORLD.allreduce(all_ok, op=MPI.LAND))
415-
416-
417-
def _expected_output(case: BenchmarkCase, nranks: int, iteration: int):
418-
if case.collective == _ALLREDUCE:
419-
if case.dtype_spec.fp8_format is not None:
420-
expected_numeric = sum(
421-
_decode_fp8_positive(
422-
case.dtype_spec.fp8_format,
423-
_dtype_value(case, iteration * MPI.COMM_WORLD.size + rank),
424-
)
425-
for rank in range(nranks)
426-
)
427-
return cp.full_like(case.output, _encode_fp8_scalar(case.dtype_spec.fp8_format, expected_numeric))
428-
expected_value = sum(iteration * MPI.COMM_WORLD.size + rank for rank in range(nranks))
429-
return cp.full_like(case.output, _dtype_value(case, expected_value))
430-
431-
expected = cp.empty_like(case.output)
432-
chunk = case.input.size
433-
for rank in range(nranks):
434-
expected[rank * chunk : (rank + 1) * chunk].fill(_dtype_value(case, iteration * MPI.COMM_WORLD.size + rank))
435-
return expected
436-
437-
438-
def _compare_output(output, expected) -> bool:
439-
if _dtype_is_float(output.dtype.type):
440-
return bool(cp.allclose(output, expected, rtol=1.0e-2, atol=2).item())
441-
return bool(cp.all(output == expected).item())
442-
443-
444-
def _mismatch_mask(output, expected):
445-
if _dtype_is_float(output.dtype.type):
446-
return ~cp.isclose(output, expected, rtol=1.0e-2, atol=2)
447-
return output != expected
448-
449-
450-
_FP8_POSITIVE_TABLES: dict[str, list[tuple[int, float]]] = {}
451-
452-
453-
def _encode_fp8_scalar(fp8_format: str, value: float) -> int:
454-
if value < 0:
455-
raise ValueError("FP8 correctness values are expected to be non-negative")
456-
if fp8_format == "e4m3b15":
457-
return _encode_e4m3b15_scalar(value)
458-
table = _FP8_POSITIVE_TABLES.setdefault(fp8_format, _build_fp8_positive_table(fp8_format))
459-
return min(table, key=lambda item: abs(item[1] - value))[0]
460-
461-
462-
def _encode_e4m3b15_scalar(value: float) -> int:
463-
fp16_bits = struct.unpack("<H", struct.pack("<e", float(value)))[0]
464-
abs_fp16 = fp16_bits & 0x7FFF
465-
if abs_fp16 > 0x3F00:
466-
abs_fp16 = 0x3F00
467-
sign16 = fp16_bits & 0x8000
468-
adjusted = abs_fp16 * 2 + 0x0080
469-
return ((sign16 | adjusted) >> 8) & 0xFF
470-
471-
472-
def _build_fp8_positive_table(fp8_format: str) -> list[tuple[int, float]]:
473-
table = []
474-
for byte in range(128):
475-
value = _decode_fp8_positive(fp8_format, byte)
476-
if not math.isnan(value):
477-
table.append((byte, value))
478-
return table
479-
480-
481-
def _decode_fp8_positive(fp8_format: str, byte: int) -> float:
482-
exp = (byte >> 3) & 0xF
483-
mant = byte & 0x7
484-
if fp8_format == "e4m3fn" and exp == 0xF and mant == 0x7:
485-
return float("nan")
486-
if exp == 0 and mant == 0:
487-
return 0.0
488-
if fp8_format == "e4m3fn":
489-
return math.ldexp(mant / 8.0, -6) if exp == 0 else math.ldexp(1.0 + mant / 8.0, exp - 7)
490-
if fp8_format == "e4m3fnuz":
491-
return math.ldexp(mant / 8.0, -7) if exp == 0 else math.ldexp(1.0 + mant / 8.0, exp - 8)
492-
if fp8_format == "e4m3b15":
493-
return math.ldexp(mant / 8.0, -14) if exp == 0 else math.ldexp(1.0 + mant / 8.0, exp - 15)
494-
raise ValueError(f"Unknown FP8 format: {fp8_format}")
495-
496-
497331
def _try_measure_case(
498332
comm: Comm,
499333
case: BenchmarkCase,
@@ -595,6 +429,18 @@ def _format_table(headers: list[str], rows: list[list[str]]) -> str:
595429
return "\n".join([header_line, sep_line, *row_lines])
596430

597431

432+
def _format_stat(value: float | None) -> str:
433+
if value is None:
434+
return "-"
435+
return f"{value:.6g}"
436+
437+
438+
def _format_mismatches(stats: CorrectnessStats | None) -> str:
439+
if stats is None or stats.total == 0:
440+
return "-"
441+
return f"{stats.mismatches}/{stats.total}"
442+
443+
598444
def _build_parser() -> argparse.ArgumentParser:
599445
parser = argparse.ArgumentParser(description="Benchmark MSCCL++ collectives without PyTorch dependencies")
600446
parser.add_argument("--collective", choices=(_ALLREDUCE, _ALLGATHER), default=_ALLREDUCE)
@@ -705,8 +551,10 @@ def main(argv: list[str] | None = None) -> None:
705551
config_store.upsert(hardware_profile, args.collective, case.message_size, config)
706552

707553
correctness = "SKIP"
554+
correctness_stats: CorrectnessStats | None = None
708555
if not args.skip_correctness:
709-
correctness = "PASS" if _check_correctness(comm, case, config, niter=args.correctness_iters) else "FAIL"
556+
correctness_stats = _check_correctness(comm, case, config, niter=args.correctness_iters)
557+
correctness = "PASS" if correctness_stats else "FAIL"
710558
comm.reset(config)
711559
if correctness != "PASS":
712560
raise RuntimeError(
@@ -738,6 +586,9 @@ def main(argv: list[str] | None = None) -> None:
738586
f"{algbw:.2f}",
739587
f"{busbw:.2f}",
740588
correctness,
589+
_format_stat(None if correctness_stats is None else correctness_stats.max_abs_diff),
590+
_format_stat(None if correctness_stats is None else correctness_stats.mean_abs_diff),
591+
_format_mismatches(correctness_stats),
741592
]
742593
)
743594
if comm.rank == 0:
@@ -762,6 +613,9 @@ def main(argv: list[str] | None = None) -> None:
762613
"algBW_GB/s",
763614
"busBW_GB/s",
764615
"check",
616+
"max_diff",
617+
"mean_diff",
618+
"mismatch",
765619
],
766620
rows,
767621
),

0 commit comments

Comments
 (0)