|
3 | 3 |
|
4 | 4 | from __future__ import annotations |
5 | 5 |
|
6 | | -import argparse |
7 | | -from pathlib import Path |
8 | | -from typing import Any, Callable, Iterable, Sequence |
| 6 | +from typing import Any, Callable, Iterable |
9 | 7 |
|
10 | 8 | from mscclpp_benchmark.tuning_config import TunedConfig |
11 | 9 |
|
@@ -84,109 +82,3 @@ def tune(self, case: Any) -> TunedConfig | None: |
84 | 82 | if best_config is None: |
85 | 83 | return self.comm.resolve_config(case) |
86 | 84 | return best_config |
87 | | - |
88 | | - |
89 | | -def _normalize_name(name: str | None) -> str: |
90 | | - if not name: |
91 | | - return "native" |
92 | | - return name.strip().lower().replace("-", "_") |
93 | | - |
94 | | - |
95 | | -def _build_parser() -> argparse.ArgumentParser: |
96 | | - parser = argparse.ArgumentParser(description="Generate offline MSCCL++ tuned configs") |
97 | | - parser.add_argument("--collective", choices=("allreduce", "allgather"), default="allreduce") |
98 | | - parser.add_argument("--dim", type=int, required=True) |
99 | | - parser.add_argument("--dtype", required=True) |
100 | | - parser.add_argument("--accum-type") |
101 | | - parser.add_argument("--sku", default="runtime", help="Used only for the default output filename") |
102 | | - parser.add_argument("--scale", type=int, help="Expected MPI world size") |
103 | | - parser.add_argument("--batch-sizes") |
104 | | - parser.add_argument("--output") |
105 | | - parser.add_argument("--scratch-buffer-size", type=int, default=1 << 27) |
106 | | - parser.add_argument("--warmup", type=int, default=5, help="Warmup graph replays during tuning") |
107 | | - parser.add_argument("--graph-launches", type=int, default=10, help="Timed graph replays during tuning") |
108 | | - parser.add_argument( |
109 | | - "--ops-per-graph", type=int, default=100, help="Collective ops captured per graph during tuning" |
110 | | - ) |
111 | | - parser.add_argument("--candidate-nblocks") |
112 | | - parser.add_argument("--candidate-nthreads") |
113 | | - parser.add_argument("--symmetric-memory", action="store_true") |
114 | | - parser.add_argument("--skip-correctness", action="store_true") |
115 | | - return parser |
116 | | - |
117 | | - |
118 | | -def _default_output_path(args: argparse.Namespace) -> str: |
119 | | - accum = _normalize_name(args.accum_type) |
120 | | - return ( |
121 | | - "mscclpp_tuned_" |
122 | | - f"{_normalize_name(args.collective)}_" |
123 | | - f"{_normalize_name(args.sku)}_" |
124 | | - f"s{args.scale or 'runtime'}_" |
125 | | - f"d{args.dim}_" |
126 | | - f"dtype_{_normalize_name(args.dtype)}_" |
127 | | - f"accum_{accum}.json" |
128 | | - ) |
129 | | - |
130 | | - |
131 | | -def _bench_collective_args(args: argparse.Namespace) -> list[str]: |
132 | | - output = args.output or _default_output_path(args) |
133 | | - bench_args = [ |
134 | | - "--collective", |
135 | | - args.collective, |
136 | | - "--d-model", |
137 | | - str(args.dim), |
138 | | - "--dtype", |
139 | | - args.dtype, |
140 | | - "--autotune", |
141 | | - "--write-config", |
142 | | - output, |
143 | | - "--scratch-buffer-size", |
144 | | - str(args.scratch_buffer_size), |
145 | | - "--tune-warmup", |
146 | | - str(args.warmup), |
147 | | - "--tune-graph-launches", |
148 | | - str(args.graph_launches), |
149 | | - "--tune-iterations", |
150 | | - str(args.ops_per_graph), |
151 | | - "--warmup", |
152 | | - "0", |
153 | | - "--graph-launches", |
154 | | - "1", |
155 | | - "--iterations", |
156 | | - "1", |
157 | | - ] |
158 | | - if args.batch_sizes: |
159 | | - bench_args += ["--batch-sizes", args.batch_sizes] |
160 | | - if args.accum_type: |
161 | | - bench_args += ["--accum-type", args.accum_type] |
162 | | - if args.candidate_nblocks: |
163 | | - bench_args += ["--candidate-nblocks", args.candidate_nblocks] |
164 | | - if args.candidate_nthreads: |
165 | | - bench_args += ["--candidate-nthreads", args.candidate_nthreads] |
166 | | - if args.symmetric_memory: |
167 | | - bench_args.append("--symmetric-memory") |
168 | | - if args.skip_correctness: |
169 | | - bench_args.append("--skip-correctness") |
170 | | - return bench_args |
171 | | - |
172 | | - |
173 | | -def main(argv: Sequence[str] | None = None) -> None: |
174 | | - parser = _build_parser() |
175 | | - args = parser.parse_args(argv) |
176 | | - |
177 | | - if args.scale is not None: |
178 | | - from mpi4py import MPI |
179 | | - |
180 | | - world_size = MPI.COMM_WORLD.Get_size() |
181 | | - if world_size != args.scale: |
182 | | - raise ValueError(f"MSCCL++ tuning scale mismatch: expected MPI world size {args.scale}, got {world_size}") |
183 | | - |
184 | | - from mscclpp_benchmark.bench_collective import main as bench_collective_main |
185 | | - |
186 | | - bench_collective_main(_bench_collective_args(args)) |
187 | | - if args.output is None: |
188 | | - print(f"Wrote tuned config to {Path(_default_output_path(args)).resolve()}", flush=True) |
189 | | - |
190 | | - |
191 | | -if __name__ == "__main__": |
192 | | - main() |
0 commit comments