@@ -69,7 +69,6 @@ class DTypeSpec:
6969 cupy_dtype : Any
7070 mscclpp_dtype : Any
7171 accum_dtype : Any | None = None
72- supports_reduction_correctness : bool = True
7372 fp8_format : str | None = None
7473
7574
@@ -79,7 +78,6 @@ class CandidateSpec:
7978 min_message_size : int | None = None
8079 max_message_size : int | None = None
8180 max_nblocks : int | None = None
82- min_nthreads : int | None = None
8381 supported_skus : tuple [str , ...] | None = None
8482 requires_nvls : bool = False
8583 requires_symmetric_memory : bool = False
@@ -93,7 +91,6 @@ class BenchmarkCase:
9391 input : cp .ndarray
9492 output : cp .ndarray
9593 dtype_spec : DTypeSpec
96- allgather_mode : str
9794 symmetric_memory : bool = False
9895
9996
@@ -170,7 +167,6 @@ def _with_accum_type(dtype_spec: DTypeSpec, accum_type: str | None) -> DTypeSpec
170167 cupy_dtype = dtype_spec .cupy_dtype ,
171168 mscclpp_dtype = dtype_spec .mscclpp_dtype ,
172169 accum_dtype = accum_dtype ,
173- supports_reduction_correctness = dtype_spec .supports_reduction_correctness ,
174170 fp8_format = dtype_spec .fp8_format ,
175171 )
176172
@@ -193,36 +189,40 @@ def _parse_int_list(raw: str | None, default: tuple[int, ...]) -> tuple[int, ...
193189 return values
194190
195191
196- def _candidate_specs (
197- collective : str , message_size : int , * , symmetric_memory : bool = False
198- ) -> tuple [CandidateSpec , ...]:
192+ def _candidate_specs (collective : str , * , symmetric_memory : bool = False ) -> tuple [CandidateSpec , ...]:
199193 if collective == _ALLGATHER :
200194 return (CandidateSpec ("default_allgather_fullmesh2" , max_nblocks = 64 , supported_skus = ("MI300X" ,)),)
201195 if collective != _ALLREDUCE :
202196 raise ValueError (f"Unsupported collective: { collective } " )
203- if message_size <= 512 * 1024 :
204- candidates = (
205- CandidateSpec (
206- "default_allreduce_nvls_packet" ,
207- max_nblocks = 16 ,
208- supported_skus = ("H100" , "GB300" ),
209- requires_nvls = True ,
210- ),
211- CandidateSpec ("default_allreduce_packet" , max_nblocks = 56 ),
212- CandidateSpec ("default_allreduce_allpair_packet" , max_nblocks = 56 ),
213- )
214- elif message_size <= 4 * 1024 * 1024 :
215- candidates = (
216- CandidateSpec ("default_allreduce_packet" , max_nblocks = 56 ),
217- CandidateSpec ("default_allreduce_allpair_packet" , max_nblocks = 56 ),
218- CandidateSpec ("default_allreduce_rsag_zero_copy" ),
219- CandidateSpec ("default_allreduce_fullmesh" , max_nblocks = 64 , supported_skus = ("MI300X" ,)),
220- )
221- else :
222- candidates = (
223- CandidateSpec ("default_allreduce_rsag_zero_copy" ),
224- CandidateSpec ("default_allreduce_fullmesh" , max_nblocks = 64 , supported_skus = ("MI300X" ,)),
225- )
197+ candidates = (
198+ CandidateSpec (
199+ "default_allreduce_nvls_packet" ,
200+ max_message_size = 512 * 1024 ,
201+ max_nblocks = 16 ,
202+ supported_skus = ("H100" , "GB300" ),
203+ requires_nvls = True ,
204+ ),
205+ CandidateSpec (
206+ "default_allreduce_packet" ,
207+ max_message_size = 4 * 1024 * 1024 ,
208+ max_nblocks = 56 ,
209+ ),
210+ CandidateSpec (
211+ "default_allreduce_allpair_packet" ,
212+ max_message_size = 4 * 1024 * 1024 ,
213+ max_nblocks = 56 ,
214+ ),
215+ CandidateSpec (
216+ "default_allreduce_rsag_zero_copy" ,
217+ min_message_size = 512 * 1024 + 1 ,
218+ ),
219+ CandidateSpec (
220+ "default_allreduce_fullmesh" ,
221+ min_message_size = 512 * 1024 + 1 ,
222+ max_nblocks = 64 ,
223+ supported_skus = ("MI300X" ,),
224+ ),
225+ )
226226 if symmetric_memory :
227227 return (
228228 CandidateSpec (
@@ -244,20 +244,17 @@ def _candidate_algorithms(comm: Comm, case: BenchmarkCase) -> list[tuple[Any, Ca
244244 symmetric_memory = case .symmetric_memory
245245 profile = getattr (comm , "hardware_profile" , None )
246246 filtered_out = False
247- for candidate in _candidate_specs (case .collective , case . message_size , symmetric_memory = symmetric_memory ):
247+ for candidate in _candidate_specs (case .collective , symmetric_memory = symmetric_memory ):
248248 if not _candidate_supports_profile (candidate , profile ):
249249 filtered_out = True
250250 continue
251- if candidate .requires_nvls and not _mscclpp ().is_nvls_supported ():
252- filtered_out = True
253- continue
254- if candidate .requires_symmetric_memory and not symmetric_memory :
251+ if not _candidate_supports_message_size (candidate , case .message_size ):
255252 filtered_out = True
256253 continue
257- if candidate .min_message_size is not None and case . message_size < candidate . min_message_size :
254+ if candidate .requires_nvls and not _mscclpp (). is_nvls_supported () :
258255 filtered_out = True
259256 continue
260- if candidate .max_message_size is not None and case . message_size > candidate . max_message_size :
257+ if candidate .requires_symmetric_memory and not symmetric_memory :
261258 filtered_out = True
262259 continue
263260 algorithm = available .get (candidate .algorithm )
@@ -281,6 +278,14 @@ def _candidate_supports_profile(candidate: CandidateSpec, profile: HardwareProfi
281278 return sku in candidate .supported_skus
282279
283280
281+ def _candidate_supports_message_size (candidate : CandidateSpec , message_size : int ) -> bool :
282+ if candidate .min_message_size is not None and message_size < candidate .min_message_size :
283+ return False
284+ if candidate .max_message_size is not None and message_size > candidate .max_message_size :
285+ return False
286+ return True
287+
288+
284289def _make_case (
285290 * ,
286291 collective : str ,
@@ -299,7 +304,6 @@ def _make_case(
299304 input = memory ,
300305 output = memory ,
301306 dtype_spec = dtype_spec ,
302- allgather_mode = allgather_mode ,
303307 symmetric_memory = symmetric_memory ,
304308 )
305309
@@ -323,7 +327,6 @@ def _make_case(
323327 input = input_buffer ,
324328 output = output ,
325329 dtype_spec = dtype_spec ,
326- allgather_mode = allgather_mode ,
327330 symmetric_memory = symmetric_memory ,
328331 )
329332
@@ -522,7 +525,6 @@ def main(argv: list[str] | None = None) -> None:
522525 candidate_algorithms = _candidate_algorithms ,
523526 check_correctness = _check_correctness ,
524527 measure = _try_measure_case ,
525- symmetric_memory = args .symmetric_memory ,
526528 )
527529
528530 rows : list [list [str ]] = []
0 commit comments