Skip to content

Commit ce03bae

Browse files
committed
update
1 parent 493e3b3 commit ce03bae

2 files changed

Lines changed: 25 additions & 26 deletions

File tree

python/mscclpp_benchmark/correctness.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -263,17 +263,27 @@ def _comparison_tolerance(case: Any, nranks: int) -> tuple[float, float] | None:
263263

264264

265265
_FP8_TABLES: dict[str, list[tuple[int, float]]] = {}
266+
_FP8_DEVICE_TABLES: dict[str, tuple] = {}
266267
_FP8_SPACING_CACHE: dict[tuple[str, float], float] = {}
267268

268269

270+
def _fp8_device_table(fp8_format: str):
271+
cached = _FP8_DEVICE_TABLES.get(fp8_format)
272+
if cached is None:
273+
table = _FP8_TABLES.setdefault(fp8_format, _build_fp8_table(fp8_format))
274+
table_bytes = cp.asarray([byte for byte, _ in table], dtype=cp.uint8)
275+
table_values = cp.asarray([value for _, value in table], dtype=cp.float32)
276+
cached = (table_bytes, table_values)
277+
_FP8_DEVICE_TABLES[fp8_format] = cached
278+
return cached
279+
280+
269281
def _encode_fp8_values(fp8_format: str, values):
270282
values = values.astype(cp.float32)
271283
if fp8_format == "e4m3b15":
272284
return _encode_e4m3b15_values(values)
273285

274-
table = _FP8_TABLES.setdefault(fp8_format, _build_fp8_table(fp8_format))
275-
table_bytes = cp.asarray([byte for byte, _ in table], dtype=cp.uint8)
276-
table_values = cp.asarray([value for _, value in table], dtype=cp.float32)
286+
table_bytes, table_values = _fp8_device_table(fp8_format)
277287
flat_values = values.ravel()
278288
flat_encoded = cp.empty(flat_values.shape, dtype=cp.uint8)
279289
chunk_size = 65536

python/mscclpp_benchmark/tuning_config.py

Lines changed: 12 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ class TunedConfigBySize:
3535

3636

3737
class TunedConfigStore:
38-
def __init__(self, profiles: dict[HardwareProfile | None, dict[str, list[TunedConfigBySize]]]) -> None:
38+
def __init__(self, profiles: dict[HardwareProfile, dict[str, list[TunedConfigBySize]]]) -> None:
3939
self._profiles = profiles
4040

4141
@classmethod
@@ -49,24 +49,15 @@ def load_path(cls, path: str | Path) -> "TunedConfigStore":
4949

5050
@classmethod
5151
def from_payload(cls, payload: Any) -> "TunedConfigStore":
52-
profiles: dict[HardwareProfile | None, dict[str, list[TunedConfigBySize]]] = {}
53-
if isinstance(payload, list):
54-
profiles[None] = _configs_by_collective_from_payload({"allreduce": payload})
55-
return cls(profiles)
56-
5752
if not isinstance(payload, dict):
58-
raise ValueError("MSCCL++ tuned config must be a JSON object or list")
59-
60-
if "profiles" in payload:
61-
raw_profiles = payload["profiles"]
62-
if not isinstance(raw_profiles, list):
63-
raise ValueError("MSCCL++ tuned config field 'profiles' must be a list")
64-
for raw_profile in raw_profiles:
65-
profile = _profile_from_payload(raw_profile)
66-
profiles[profile] = _configs_by_collective_from_payload(raw_profile.get("collectives", {}))
67-
return cls(profiles)
68-
69-
profiles[None] = _configs_by_collective_from_payload(payload.get("collectives", payload))
53+
raise ValueError("MSCCL++ tuned config must be a JSON object")
54+
raw_profiles = payload.get("profiles")
55+
if not isinstance(raw_profiles, list):
56+
raise ValueError("MSCCL++ tuned config must contain a 'profiles' list")
57+
profiles: dict[HardwareProfile, dict[str, list[TunedConfigBySize]]] = {}
58+
for raw_profile in raw_profiles:
59+
profile = _profile_from_payload(raw_profile)
60+
profiles[profile] = _configs_by_collective_from_payload(raw_profile.get("collectives", {}))
7061
return cls(profiles)
7162

7263
def select(self, profile: HardwareProfile, collective: str, message_size: int) -> TunedConfig | None:
@@ -89,7 +80,7 @@ def upsert(self, profile: HardwareProfile, collective: str, message_size: int, c
8980
def write_path(self, path: str | Path) -> None:
9081
profiles_payload: list[dict[str, Any]] = []
9182
for profile, configs_by_collective in sorted(
92-
((profile, configs) for profile, configs in self._profiles.items() if profile is not None),
83+
self._profiles.items(),
9384
key=lambda item: (item[0].sku is None, item[0].sku or "", item[0].scale is None, item[0].scale or 0),
9485
):
9586
collectives: dict[str, list[dict[str, Any]]] = {}
@@ -127,7 +118,7 @@ def _profile_from_payload(raw_profile: Any) -> HardwareProfile:
127118

128119

129120
def _matching_profiles(
130-
profiles: dict[HardwareProfile | None, dict[str, list[TunedConfigBySize]]],
121+
profiles: dict[HardwareProfile, dict[str, list[TunedConfigBySize]]],
131122
runtime_profile: HardwareProfile,
132123
) -> list[tuple[int, dict[str, list[TunedConfigBySize]]]]:
133124
matches: list[tuple[int, dict[str, list[TunedConfigBySize]]]] = []
@@ -138,9 +129,7 @@ def _matching_profiles(
138129
return sorted(matches, key=lambda item: item[0], reverse=True)
139130

140131

141-
def _profile_match_specificity(profile: HardwareProfile | None, runtime_profile: HardwareProfile) -> int | None:
142-
if profile is None:
143-
return -1
132+
def _profile_match_specificity(profile: HardwareProfile, runtime_profile: HardwareProfile) -> int | None:
144133
specificity = 0
145134
if profile.sku is not None:
146135
if profile.sku != runtime_profile.sku:

0 commit comments

Comments
 (0)