Skip to content

Commit dc37dd6

Browse files
committed
remove some code
1 parent 44dab3b commit dc37dd6

4 files changed

Lines changed: 44 additions & 88 deletions

File tree

python/mscclpp_benchmark/bench_collective.py

Lines changed: 42 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,6 @@ class DTypeSpec:
6969
cupy_dtype: Any
7070
mscclpp_dtype: Any
7171
accum_dtype: Any | None = None
72-
supports_reduction_correctness: bool = True
7372
fp8_format: str | None = None
7473

7574

@@ -79,7 +78,6 @@ class CandidateSpec:
7978
min_message_size: int | None = None
8079
max_message_size: int | None = None
8180
max_nblocks: int | None = None
82-
min_nthreads: int | None = None
8381
supported_skus: tuple[str, ...] | None = None
8482
requires_nvls: bool = False
8583
requires_symmetric_memory: bool = False
@@ -93,7 +91,6 @@ class BenchmarkCase:
9391
input: cp.ndarray
9492
output: cp.ndarray
9593
dtype_spec: DTypeSpec
96-
allgather_mode: str
9794
symmetric_memory: bool = False
9895

9996

@@ -170,7 +167,6 @@ def _with_accum_type(dtype_spec: DTypeSpec, accum_type: str | None) -> DTypeSpec
170167
cupy_dtype=dtype_spec.cupy_dtype,
171168
mscclpp_dtype=dtype_spec.mscclpp_dtype,
172169
accum_dtype=accum_dtype,
173-
supports_reduction_correctness=dtype_spec.supports_reduction_correctness,
174170
fp8_format=dtype_spec.fp8_format,
175171
)
176172

@@ -193,36 +189,40 @@ def _parse_int_list(raw: str | None, default: tuple[int, ...]) -> tuple[int, ...
193189
return values
194190

195191

196-
def _candidate_specs(
197-
collective: str, message_size: int, *, symmetric_memory: bool = False
198-
) -> tuple[CandidateSpec, ...]:
192+
def _candidate_specs(collective: str, *, symmetric_memory: bool = False) -> tuple[CandidateSpec, ...]:
199193
if collective == _ALLGATHER:
200194
return (CandidateSpec("default_allgather_fullmesh2", max_nblocks=64, supported_skus=("MI300X",)),)
201195
if collective != _ALLREDUCE:
202196
raise ValueError(f"Unsupported collective: {collective}")
203-
if message_size <= 512 * 1024:
204-
candidates = (
205-
CandidateSpec(
206-
"default_allreduce_nvls_packet",
207-
max_nblocks=16,
208-
supported_skus=("H100", "GB300"),
209-
requires_nvls=True,
210-
),
211-
CandidateSpec("default_allreduce_packet", max_nblocks=56),
212-
CandidateSpec("default_allreduce_allpair_packet", max_nblocks=56),
213-
)
214-
elif message_size <= 4 * 1024 * 1024:
215-
candidates = (
216-
CandidateSpec("default_allreduce_packet", max_nblocks=56),
217-
CandidateSpec("default_allreduce_allpair_packet", max_nblocks=56),
218-
CandidateSpec("default_allreduce_rsag_zero_copy"),
219-
CandidateSpec("default_allreduce_fullmesh", max_nblocks=64, supported_skus=("MI300X",)),
220-
)
221-
else:
222-
candidates = (
223-
CandidateSpec("default_allreduce_rsag_zero_copy"),
224-
CandidateSpec("default_allreduce_fullmesh", max_nblocks=64, supported_skus=("MI300X",)),
225-
)
197+
candidates = (
198+
CandidateSpec(
199+
"default_allreduce_nvls_packet",
200+
max_message_size=512 * 1024,
201+
max_nblocks=16,
202+
supported_skus=("H100", "GB300"),
203+
requires_nvls=True,
204+
),
205+
CandidateSpec(
206+
"default_allreduce_packet",
207+
max_message_size=4 * 1024 * 1024,
208+
max_nblocks=56,
209+
),
210+
CandidateSpec(
211+
"default_allreduce_allpair_packet",
212+
max_message_size=4 * 1024 * 1024,
213+
max_nblocks=56,
214+
),
215+
CandidateSpec(
216+
"default_allreduce_rsag_zero_copy",
217+
min_message_size=512 * 1024 + 1,
218+
),
219+
CandidateSpec(
220+
"default_allreduce_fullmesh",
221+
min_message_size=512 * 1024 + 1,
222+
max_nblocks=64,
223+
supported_skus=("MI300X",),
224+
),
225+
)
226226
if symmetric_memory:
227227
return (
228228
CandidateSpec(
@@ -244,20 +244,17 @@ def _candidate_algorithms(comm: Comm, case: BenchmarkCase) -> list[tuple[Any, Ca
244244
symmetric_memory = case.symmetric_memory
245245
profile = getattr(comm, "hardware_profile", None)
246246
filtered_out = False
247-
for candidate in _candidate_specs(case.collective, case.message_size, symmetric_memory=symmetric_memory):
247+
for candidate in _candidate_specs(case.collective, symmetric_memory=symmetric_memory):
248248
if not _candidate_supports_profile(candidate, profile):
249249
filtered_out = True
250250
continue
251-
if candidate.requires_nvls and not _mscclpp().is_nvls_supported():
252-
filtered_out = True
253-
continue
254-
if candidate.requires_symmetric_memory and not symmetric_memory:
251+
if not _candidate_supports_message_size(candidate, case.message_size):
255252
filtered_out = True
256253
continue
257-
if candidate.min_message_size is not None and case.message_size < candidate.min_message_size:
254+
if candidate.requires_nvls and not _mscclpp().is_nvls_supported():
258255
filtered_out = True
259256
continue
260-
if candidate.max_message_size is not None and case.message_size > candidate.max_message_size:
257+
if candidate.requires_symmetric_memory and not symmetric_memory:
261258
filtered_out = True
262259
continue
263260
algorithm = available.get(candidate.algorithm)
@@ -281,6 +278,14 @@ def _candidate_supports_profile(candidate: CandidateSpec, profile: HardwareProfi
281278
return sku in candidate.supported_skus
282279

283280

281+
def _candidate_supports_message_size(candidate: CandidateSpec, message_size: int) -> bool:
282+
if candidate.min_message_size is not None and message_size < candidate.min_message_size:
283+
return False
284+
if candidate.max_message_size is not None and message_size > candidate.max_message_size:
285+
return False
286+
return True
287+
288+
284289
def _make_case(
285290
*,
286291
collective: str,
@@ -299,7 +304,6 @@ def _make_case(
299304
input=memory,
300305
output=memory,
301306
dtype_spec=dtype_spec,
302-
allgather_mode=allgather_mode,
303307
symmetric_memory=symmetric_memory,
304308
)
305309

@@ -323,7 +327,6 @@ def _make_case(
323327
input=input_buffer,
324328
output=output,
325329
dtype_spec=dtype_spec,
326-
allgather_mode=allgather_mode,
327330
symmetric_memory=symmetric_memory,
328331
)
329332

@@ -522,7 +525,6 @@ def main(argv: list[str] | None = None) -> None:
522525
candidate_algorithms=_candidate_algorithms,
523526
check_correctness=_check_correctness,
524527
measure=_try_measure_case,
525-
symmetric_memory=args.symmetric_memory,
526528
)
527529

528530
rows: list[list[str]] = []

python/mscclpp_benchmark/comm.py

Lines changed: 1 addition & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44
from __future__ import annotations
55

66
import logging
7-
from contextlib import contextmanager
8-
from typing import Any, Iterator
7+
from typing import Any
98

109
logger = logging.getLogger(__name__)
1110
_ALLREDUCE_COLLECTIVE = "allreduce"
@@ -60,41 +59,6 @@ def data_ptr(self) -> int:
6059
return int(self.buffer.data())
6160

6261

63-
@contextmanager
64-
def init_mscclpp_comm_group_for_ranks(ranks: list[int], *, name: str) -> Iterator[Any]:
65-
del name
66-
from mpi4py import MPI
67-
68-
world_comm = MPI.COMM_WORLD
69-
unique_ranks = [int(rank) for rank in ranks]
70-
if len(unique_ranks) != len(set(unique_ranks)):
71-
raise ValueError(f"Duplicate ranks are not allowed: {unique_ranks}")
72-
if world_comm.Get_rank() not in unique_ranks:
73-
raise ValueError(f"Rank {world_comm.Get_rank()} is not a member of subgroup {unique_ranks}")
74-
75-
sub_comm = None
76-
if unique_ranks == list(range(world_comm.Get_size())):
77-
mpi_comm = world_comm
78-
else:
79-
subgroup = world_comm.group.Incl(unique_ranks)
80-
sub_comm = world_comm.Create_group(subgroup)
81-
mpi_comm = sub_comm
82-
83-
_ensure_device()
84-
comm_group = _mscclpp().CommGroup(mpi_comm)
85-
setattr(comm_group, "_mpi_comm", mpi_comm)
86-
try:
87-
yield comm_group
88-
finally:
89-
destroy = getattr(comm_group, "destroy", None)
90-
if callable(destroy):
91-
destroy()
92-
if sub_comm is not None:
93-
free = getattr(sub_comm, "Free", None)
94-
if callable(free):
95-
free()
96-
97-
9862
class _AllReduceOp:
9963
def __init__(self, comm: "Comm", x: Any, *, symmetric_memory: bool = False) -> None:
10064
self._comm = comm

python/mscclpp_benchmark/correctness.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -55,12 +55,6 @@ def check_correctness(
5555
*,
5656
niter: int = 1,
5757
) -> CorrectnessStats:
58-
if case.collective == "allreduce" and not case.dtype_spec.supports_reduction_correctness:
59-
raise ValueError(
60-
f"Correctness checking for {case.collective} with {case.dtype_spec.name} is not implemented; "
61-
"use --skip-correctness or a numeric dtype"
62-
)
63-
6458
all_ok = True
6559
local_max_abs_diff = 0.0
6660
local_sum_abs_diff = 0.0

python/mscclpp_benchmark/tuner.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ def __init__(
2323
candidate_algorithms: Callable[[Any, Any], list[tuple[Any, Any]]],
2424
check_correctness: Callable[..., bool],
2525
measure: Callable[..., float | None],
26-
symmetric_memory: bool,
2726
) -> None:
2827
self.comm = comm
2928
self.candidate_nblocks = tuple(candidate_nblocks)
@@ -34,12 +33,11 @@ def __init__(
3433
self._candidate_algorithms = candidate_algorithms
3534
self._check_correctness = check_correctness
3635
self._measure = measure
37-
self._symmetric_memory = symmetric_memory
3836

3937
def tune(self, case: Any) -> TunedConfig | None:
4038
best_config: TunedConfig | None = None
4139
best_time_us = float("inf")
42-
symmetric_memory = bool(getattr(case, "symmetric_memory", self._symmetric_memory))
40+
symmetric_memory = bool(getattr(case, "symmetric_memory", False))
4341
candidates = self._candidate_algorithms(self.comm, case)
4442
if not candidates:
4543
if self.comm.rank == 0:
@@ -54,8 +52,6 @@ def tune(self, case: Any) -> TunedConfig | None:
5452
if candidate_spec.max_nblocks is not None and nblocks > candidate_spec.max_nblocks:
5553
continue
5654
for nthreads in self.candidate_nthreads:
57-
if candidate_spec.min_nthreads is not None and nthreads < candidate_spec.min_nthreads:
58-
continue
5955
config = TunedConfig(
6056
algorithm=algorithm.name,
6157
nblocks=nblocks,

0 commit comments

Comments
 (0)