|
19 | 19 | import time
|
20 | 20 | import signal
|
21 | 21 | import psutil
|
| 22 | +import distutils |
22 | 23 | from collections import defaultdict
|
23 | 24 | from typing import Dict
|
24 | 25 | from argparse import ArgumentParser, REMAINDER
|
@@ -89,6 +90,18 @@ def parse_args():
|
89 | 90 | type=str,
|
90 | 91 | help="redirect the stdout and stderr from each rank into different log files")
|
91 | 92 |
|
| 93 | + parser.add_argument("--bind_cores_to_rank", |
| 94 | + action="store_true", |
| 95 | + help="Bind each rank to different cores of the host. " |
| 96 | + "This improves host efficiency especially for CPU backend") |
| 97 | + |
| 98 | + parser.add_argument("--bind_core_list", |
| 99 | + type=str, |
| 100 | + default=None, |
| 101 | + help="List of cores to bind to with comma separated list of " |
| 102 | + "numbers and range. i.e. 1,3-5,7 => [1,3,4,5,7]. When not " |
| 103 | + "specified, all cores on system would be used rank binding") |
| 104 | + |
92 | 105 | # positional
|
93 | 106 | parser.add_argument("training_script",
|
94 | 107 | type=str,
|
@@ -117,6 +130,89 @@ def terminate_process_tree(pid):
|
117 | 130 | p.kill()
|
118 | 131 |
|
119 | 132 |
|
| 133 | +def parse_range(rng): |
| 134 | + try: |
| 135 | + value = int(rng) |
| 136 | + return range(value, value + 1) |
| 137 | + except ValueError: |
| 138 | + # value is not a single number |
| 139 | + parts = rng.split('-') |
| 140 | + if len(parts) != 2: |
| 141 | + raise ValueError("Bad range: '%s', range must be either a number or two number separated by dash" % |
| 142 | + (rng, )) |
| 143 | + start = int(parts[0]) |
| 144 | + end = int(parts[1]) |
| 145 | + if start > end: |
| 146 | + raise ValueError("Bad range: '%s', range end must larger than or equal to start" % (rng, )) |
| 147 | + return range(start, end + 1) |
| 148 | + |
| 149 | + |
| 150 | +# parse comma and dash separated range list into list |
| 151 | +# i.e. "0,2-4,6" --> [0, 2, 3, 4, 6] |
| 152 | +# rules: |
| 153 | +# 1. Range list numser be comma sepeaated, each item are either a single number, |
| 154 | +# or a range marked by two numbers (both number are included in the range) |
| 155 | +# 2. Sub ranges must be in ascend order and not overlap with each other |
| 156 | +# 3. No space in the range expression |
| 157 | +def parse_range_list(range_str): |
| 158 | + number_list = [] |
| 159 | + last = -1 |
| 160 | + range_list = range_str.split(',') |
| 161 | + for sub_range in range_list: |
| 162 | + sub_number_list = parse_range(sub_range) |
| 163 | + if sub_number_list[0] <= last: |
| 164 | + raise ValueError( |
| 165 | + "Bad range: '%s', sub ranges must not overlap with each other and should be in ascend order" % |
| 166 | + (range_str, )) |
| 167 | + last = sub_number_list[-1] |
| 168 | + number_list.extend(sub_number_list) |
| 169 | + return number_list |
| 170 | + |
| 171 | + |
| 172 | +# return a list of list for cores to numa mapping |
| 173 | +# [ |
| 174 | +# [ cores for numa 0 ] |
| 175 | +# [ cores belong to numa 1 ] |
| 176 | +# ... |
| 177 | +# ] |
| 178 | +def get_numa_cores(): |
| 179 | + ret = [] |
| 180 | + output = subprocess.check_output(['numactl', '--hardware']).decode("utf-8") |
| 181 | + lines = output.split('\n') |
| 182 | + for line in lines: |
| 183 | + if line.startswith('available:'): |
| 184 | + num_numas = int(line.split(' ')[1]) |
| 185 | + break |
| 186 | + for numa in range(num_numas): |
| 187 | + for line in lines: |
| 188 | + if line.startswith(f'node {numa} cpus:'): |
| 189 | + cores = line.split(' ')[3:] |
| 190 | + ret.append([int(core) for core in cores]) |
| 191 | + return ret |
| 192 | + |
| 193 | + |
| 194 | +def check_for_numactl_pkg(): |
| 195 | + libs = dict( |
| 196 | + dpkg=["-l", "numactl", "apt"], |
| 197 | + pacman=["-Q", "numactl", "pacman"], |
| 198 | + rpm=["-q", "numactl", "yum"], |
| 199 | + ) |
| 200 | + |
| 201 | + found = False |
| 202 | + for pkgmgr, data in libs.items(): |
| 203 | + flag, lib, tool = data |
| 204 | + path = distutils.spawn.find_executable(pkgmgr) |
| 205 | + if path is not None: |
| 206 | + cmd = f"{pkgmgr} {flag} {lib}" |
| 207 | + result = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True) |
| 208 | + if result.wait() == 0: |
| 209 | + found = True |
| 210 | + else: |
| 211 | + print(f"please install the {lib} package with {tool}") |
| 212 | + break |
| 213 | + return found |
| 214 | + |
| 215 | + |
120 | 216 | def main():
|
121 | 217 | args = parse_args()
|
122 | 218 | current_env = os.environ.copy()
|
@@ -211,8 +307,43 @@ def main():
|
211 | 307 |
|
212 | 308 | # spawn the processes
|
213 | 309 | cmd = []
|
| 310 | + if args.bind_cores_to_rank: |
| 311 | + check_for_numactl_pkg() |
| 312 | + if 'KMP_AFFINITY' in os.environ.keys(): |
| 313 | + raise ValueError("Environment variable KMP_AFFINITY conflicts with numactl " |
| 314 | + "because it interfere with how many CPU cores numactl can set. " |
| 315 | + "Unset KMP_AFFINITY before launching deepspeed.\n\n" |
| 316 | + "\t$ unset KMP_AFFINITY\n" |
| 317 | + "\t$ deepspeed <deepspeed command parameters>") |
| 318 | + if args.bind_core_list != None: |
| 319 | + core_list = parse_range_list(args.bind_core_list) |
| 320 | + total_cores = len(core_list) |
| 321 | + else: |
| 322 | + total_cores = psutil.cpu_count(logical=False) |
| 323 | + core_list = range(total_cores) |
| 324 | + cores_per_rank = total_cores // num_local_procs |
| 325 | + assert cores_per_rank >= 1, "At least one core needs to be assigned to each rank" |
| 326 | + core_list_for_rank = core_list[cores_per_rank * local_rank:cores_per_rank * (local_rank + 1)] |
| 327 | + current_env["OMP_NUM_THREADS"] = f"{cores_per_rank}" |
| 328 | + cmd.append("numactl") |
| 329 | + |
| 330 | + # check if all cores belong to same numa, if true, bind process to that numa domain with -m parameter |
| 331 | + numa_cores = get_numa_cores() |
| 332 | + num_numas = len(numa_cores) |
| 333 | + for i in range(num_numas): |
| 334 | + if set(core_list_for_rank) <= set(numa_cores[i]): |
| 335 | + cmd.append("-m") |
| 336 | + cmd.append(f"{i}") |
| 337 | + break |
| 338 | + |
| 339 | + cmd.append("-C") |
| 340 | + core_list_str = f"{core_list_for_rank[0]}" |
| 341 | + for core_id in core_list_for_rank[1:]: |
| 342 | + core_list_str = f"{core_list_str},{core_id}" |
| 343 | + cmd.append(f"{core_list_str}") |
214 | 344 | if not args.no_python:
|
215 |
| - cmd = [sys.executable, "-u"] |
| 345 | + cmd.append(sys.executable) |
| 346 | + cmd.append("-u") |
216 | 347 | if args.module:
|
217 | 348 | cmd.append("-m")
|
218 | 349 | else:
|
|
0 commit comments