|
| 1 | +import time |
| 2 | + |
| 3 | +import matplotlib.pyplot as plt |
| 4 | +import torch |
| 5 | +import torch.distributed as dist |
| 6 | +from global_vars import ( |
| 7 | + ITERATION, |
| 8 | + LOCAL_RANK, |
| 9 | + LOCAL_WORLD_SIZE, |
| 10 | + RANK, |
| 11 | + WARMUP, |
| 12 | + WORLD_SIZE, |
| 13 | + get_hostnames, |
| 14 | +) |
| 15 | +from utility import create_dir, extract_first_middle_last, extract_number, log |
| 16 | + |
| 17 | + |
| 18 | +def run_inter_node_comm(args): |
| 19 | + device = torch.device(f"cuda:{LOCAL_RANK}") |
| 20 | + sizes = [2**i * 1024 * 1024 for i in range(1, 11)] |
| 21 | + # sizes = [2**i * 1024 * 1024 for i in range(1, 5)] |
| 22 | + assert WORLD_SIZE % LOCAL_WORLD_SIZE == 0 |
| 23 | + num_nodes = WORLD_SIZE // LOCAL_WORLD_SIZE |
| 24 | + RANK // LOCAL_WORLD_SIZE |
| 25 | + |
| 26 | + if num_nodes <= 1: |
| 27 | + log(f"Skip inter node comm benchmark, {num_nodes=}") |
| 28 | + return |
| 29 | + |
| 30 | + # N-node allreduce & alltoall (adjacent pairs) |
| 31 | + # 2-node allreduce, pair nodes: [0, 1], [2, 3], ... |
| 32 | + # 4-node allreduce, pair nodes: [0, 1, 2, 3], [4, 5, 6, 7]... |
| 33 | + cases = { |
| 34 | + "allreduce": list(set([2, 4] + [num_nodes])), |
| 35 | + "alltoall": list(set([2, 4] + [num_nodes])), |
| 36 | + } |
| 37 | + |
| 38 | + if RANK == 0: |
| 39 | + with open(args.markdown_file, "a", encoding="utf-8") as f: |
| 40 | + f.write(f"# InterNode Comm\n") |
| 41 | + |
| 42 | + for comm, adjacent_node_list in cases.items(): |
| 43 | + if RANK == 0: |
| 44 | + with open(args.markdown_file, "a", encoding="utf-8") as f: |
| 45 | + f.write(f"## InterNode - {comm}\n") |
| 46 | + for adjacent_nodes in adjacent_node_list: |
| 47 | + if adjacent_nodes > num_nodes: |
| 48 | + continue |
| 49 | + |
| 50 | + case_name = f"{comm}-{adjacent_nodes}nodes" |
| 51 | + latency_results = {} |
| 52 | + bandwidth_results = {} |
| 53 | + |
| 54 | + num_procs = adjacent_nodes * LOCAL_WORLD_SIZE |
| 55 | + num_adjacent_groups = num_nodes // adjacent_nodes |
| 56 | + adjacent_group = None |
| 57 | + for i_group in range(num_adjacent_groups): |
| 58 | + group_ranks = [ |
| 59 | + i_group * adjacent_nodes * LOCAL_WORLD_SIZE + r |
| 60 | + for r in range(adjacent_nodes * LOCAL_WORLD_SIZE) |
| 61 | + ] |
| 62 | + tmp_group = dist.new_group(ranks=group_ranks) |
| 63 | + if RANK in group_ranks: |
| 64 | + assert adjacent_group is None |
| 65 | + adjacent_group = tmp_group |
| 66 | + if RANK < num_adjacent_groups * adjacent_nodes * LOCAL_WORLD_SIZE: |
| 67 | + assert adjacent_group is not None |
| 68 | + |
| 69 | + for size in sizes: |
| 70 | + if adjacent_group is None: |
| 71 | + break |
| 72 | + |
| 73 | + tensor = torch.rand(size // 2, dtype=torch.bfloat16, device=device) |
| 74 | + dist.barrier(group=adjacent_group, device_ids=[torch.cuda.current_device()]) |
| 75 | + for _ in range(WARMUP): |
| 76 | + if "allreduce" == comm: |
| 77 | + dist.all_reduce(tensor, group=adjacent_group) |
| 78 | + elif "alltoall" == comm: |
| 79 | + dist.all_to_all_single(tensor, tensor, group=adjacent_group) |
| 80 | + else: |
| 81 | + assert False |
| 82 | + torch.cuda.synchronize() |
| 83 | + start = time.time() |
| 84 | + for _ in range(ITERATION): |
| 85 | + if "allreduce" == comm: |
| 86 | + dist.all_reduce(tensor, group=adjacent_group) |
| 87 | + elif "alltoall" == comm: |
| 88 | + dist.all_to_all_single(tensor, tensor, group=adjacent_group) |
| 89 | + else: |
| 90 | + assert False |
| 91 | + torch.cuda.synchronize() |
| 92 | + elapsed = (time.time() - start) / ITERATION |
| 93 | + scale = 2 if comm == "allreduce" else 1 |
| 94 | + comm_size = scale * size * (num_procs - 1) / num_procs |
| 95 | + gb_per_sec = comm_size / elapsed / 1e9 |
| 96 | + latency_results[f"{size//1024//1024}MB"] = elapsed * 1e6 |
| 97 | + bandwidth_results[f"{size//1024//1024}MB"] = gb_per_sec |
| 98 | + |
| 99 | + dist.barrier(device_ids=[torch.cuda.current_device()]) |
| 100 | + if adjacent_group is not None: |
| 101 | + dist.destroy_process_group(adjacent_group) |
| 102 | + |
| 103 | + all_latency_results = [None for _ in range(WORLD_SIZE)] |
| 104 | + all_bandwidth_results = [None for _ in range(WORLD_SIZE)] |
| 105 | + dist.gather_object(latency_results, all_latency_results if RANK == 0 else None, dst=0) |
| 106 | + dist.gather_object(bandwidth_results, all_bandwidth_results if RANK == 0 else None, dst=0) |
| 107 | + |
| 108 | + if RANK == 0: |
| 109 | + keys = sorted( |
| 110 | + list({k for r in all_bandwidth_results for k in (r or {}).keys()}), key=extract_number |
| 111 | + ) |
| 112 | + max_len = max(len(s) for s in get_hostnames()) + 2 |
| 113 | + |
| 114 | + with open(args.markdown_file, "a", encoding="utf-8") as f: |
| 115 | + f.write(f"=======InterNodeComm - {case_name} (us)=======\n") |
| 116 | + log(f"=======InterNodeComm - {case_name} (us)=======") |
| 117 | + |
| 118 | + f.write(f"| Hostname | Node | Rank | {' | '.join(keys)}|\n") |
| 119 | + f.write(f"|----------|----------|----------{'|----------' * len(keys)}|\n") |
| 120 | + |
| 121 | + formatted_keys = [f"{key:<6}" for key in keys] |
| 122 | + log(f"{'Hostname':<{max_len}} {'Node':<5} {'Rank':<5} {' '.join(formatted_keys)}") |
| 123 | + for rank, r in enumerate(all_latency_results): |
| 124 | + hostname = get_hostnames()[rank] |
| 125 | + if rank % num_procs != 0: |
| 126 | + continue |
| 127 | + node_id = rank // LOCAL_WORLD_SIZE |
| 128 | + |
| 129 | + formatted_values = [f"{r.get(key, 0):<6.2f}" for key in keys] |
| 130 | + log(f"{hostname:<{max_len}} {node_id:<5} {rank:<5} {' '.join(formatted_values)}") |
| 131 | + f.write(f"| {hostname} | {node_id} | {rank} | {' | '.join(formatted_values)}|\n") |
| 132 | + f.write(f"\n") |
| 133 | + |
| 134 | + f.write(f"=======InterNodeComm - {case_name} (GB/s)=======\n") |
| 135 | + log(f"=======InterNodeComm - {case_name} (GB/s)=======") |
| 136 | + |
| 137 | + f.write(f"| Hostname | Node | Rank | {' | '.join(keys)}|\n") |
| 138 | + f.write(f"|----------|----------|----------{'|----------' * len(keys)}|\n") |
| 139 | + formatted_keys = [f"{key:<6}" for key in keys] |
| 140 | + log(f"{'Hostname':<{max_len}} {'Node':<5} {'Rank':<5} {' '.join(formatted_keys)}") |
| 141 | + for rank, r in enumerate(all_bandwidth_results): |
| 142 | + hostname = get_hostnames()[rank] |
| 143 | + if rank % num_procs != 0: |
| 144 | + continue |
| 145 | + node_id = rank // LOCAL_WORLD_SIZE |
| 146 | + |
| 147 | + formatted_values = [f"{r.get(key, 0):<6.2f}" for key in keys] |
| 148 | + log(f"{hostname:<{max_len}} {node_id:<5} {rank:<5} {' '.join(formatted_values)}") |
| 149 | + f.write(f"| {hostname} | {node_id} | {rank} | {' | '.join(formatted_values)}|\n") |
| 150 | + f.write(f"\n") |
| 151 | + |
| 152 | + if not args.plot: |
| 153 | + continue |
| 154 | + |
| 155 | + log(f"=======Plot IntraNode {case_name} Bandwidth=======") |
| 156 | + with open(args.markdown_file, "a", encoding="utf-8") as f: |
| 157 | + f.write(f"=======Plot InterNode {case_name} Bandwidth=======\n") |
| 158 | + plot_case = f"inter_node_comm/{comm}" |
| 159 | + dump_path = f"{args.dump_path}/{plot_case}" |
| 160 | + create_dir(dump_path) |
| 161 | + print_keys = extract_first_middle_last(keys) |
| 162 | + first_rank_bandwidth_results = [ |
| 163 | + all_bandwidth_results[i] for i in range(len(all_bandwidth_results)) if i % num_procs == 0 |
| 164 | + ] |
| 165 | + num_print_ranks = len(first_rank_bandwidth_results) |
| 166 | + for size_key in print_keys: |
| 167 | + values = [r[size_key] for r in first_rank_bandwidth_results] |
| 168 | + plt.figure(figsize=(10, 4)) |
| 169 | + bars = plt.bar(range(num_print_ranks), values) |
| 170 | + plt.xlabel(f"RankPair ({num_procs} ranks)") |
| 171 | + plt.ylabel("Bandwidth") |
| 172 | + plt.title(f"Inter Node {case_name} Bandwidth for {size_key}") |
| 173 | + xtick_labels = [f"{i*num_procs}" for i in range(num_print_ranks)] |
| 174 | + plt.xticks(range(num_print_ranks), xtick_labels) |
| 175 | + plt.grid(True, axis="y") |
| 176 | + |
| 177 | + # Add roofline |
| 178 | + roofline_bandwidth = args.ib_bw |
| 179 | + plt.axhline( |
| 180 | + y=roofline_bandwidth, |
| 181 | + color="red", |
| 182 | + linestyle="--", |
| 183 | + linewidth=2, |
| 184 | + label=f"IB Unidirectional BW Roofline: {roofline_bandwidth} GB/s", |
| 185 | + ) |
| 186 | + plt.legend() |
| 187 | + |
| 188 | + # plt value |
| 189 | + for bar in bars: |
| 190 | + height = bar.get_height() |
| 191 | + plt.text( |
| 192 | + bar.get_x() + bar.get_width() / 2, |
| 193 | + height, |
| 194 | + f"{height:.2f}", |
| 195 | + ha="center", |
| 196 | + va="bottom", |
| 197 | + ) |
| 198 | + |
| 199 | + png_file = f"inter_node_{case_name}_bandwidth_{size_key.replace('x', '_')}.png" |
| 200 | + plt.tight_layout() |
| 201 | + plt.savefig(f"{dump_path}/{png_file}") |
| 202 | + plt.close() |
| 203 | + with open(args.markdown_file, "a", encoding="utf-8") as f: |
| 204 | + f.write(f"\n") |
| 205 | + |
| 206 | + # Bar chart visualization for rank 0 |
| 207 | + rank_0_values = [all_bandwidth_results[0][size_key] for size_key in keys] |
| 208 | + plt.figure(figsize=(10, 4)) |
| 209 | + bars = plt.bar(keys, rank_0_values) |
| 210 | + plt.xlabel("Size") |
| 211 | + plt.ylabel("Bandwidth") |
| 212 | + plt.title(f"Inter Node {case_name} Bandwidth for Rank 0") |
| 213 | + plt.grid(True, axis="y") |
| 214 | + # Add roofline |
| 215 | + roofline_bandwidth = args.ib_bw |
| 216 | + plt.axhline( |
| 217 | + y=roofline_bandwidth, |
| 218 | + color="red", |
| 219 | + linestyle="--", |
| 220 | + linewidth=2, |
| 221 | + label=f"IB Unidirectional BW Roofline: {roofline_bandwidth} GB/s", |
| 222 | + ) |
| 223 | + plt.legend() |
| 224 | + |
| 225 | + # plt value |
| 226 | + for bar in bars: |
| 227 | + height = bar.get_height() |
| 228 | + plt.text( |
| 229 | + bar.get_x() + bar.get_width() / 2, height, f"{height:.2f}", ha="center", va="bottom" |
| 230 | + ) |
| 231 | + |
| 232 | + png_file = f"inter_node_{case_name}_bandwidth_rank_0.png" |
| 233 | + plt.tight_layout() |
| 234 | + plt.savefig(f"{dump_path}/{png_file}") |
| 235 | + plt.close() |
| 236 | + with open(args.markdown_file, "a", encoding="utf-8") as f: |
| 237 | + f.write(f"\n") |
| 238 | + f.write(f"\n") |
| 239 | + log(f"") |
0 commit comments