Skip to content

Commit 947e52f

Browse files
wenxie-amdlimou102lxgsbqylbk
authored
preflight (#25)
Co-authored-by: limou102 <[email protected]> Co-authored-by: LI MOU <[email protected]>
1 parent d0da8db commit 947e52f

File tree

10 files changed

+1244
-0
lines changed

10 files changed

+1244
-0
lines changed

requirements.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,6 @@ loguru
22
wandb
33
pre-commit
44
nltk
5+
matplotlib
6+
markdown2
7+
weasyprint

tools/preflight/global_vars.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import os
2+
3+
WORLD_SIZE = int(os.environ.get("WORLD_SIZE", 1))
4+
RANK = int(os.environ.get("RANK", 0))
5+
LOCAL_RANK = int(os.environ.get("LOCAL_RANK", 0))
6+
LOCAL_WORLD_SIZE = int(os.environ.get("LOCAL_WORLD_SIZE", 1))
7+
MASTER_ADDR = os.environ.get("MASTER_ADDR", "127.0.0.1")
8+
MASTER_PORT = os.environ.get("MASTER_PORT", "29500")
9+
10+
WARMUP = 10
11+
ITERATION = 50
12+
13+
_HOST_NAMES = None
14+
15+
16+
def set_hostnames(hostnames):
17+
global _HOST_NAMES
18+
_HOST_NAMES = [hostnames]
19+
20+
21+
def get_hostnames():
22+
assert _HOST_NAMES is not None, "_HOST_NAMES not initialized"
23+
return _HOST_NAMES[0]

tools/preflight/inter_node_comm.py

Lines changed: 239 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,239 @@
1+
import time
2+
3+
import matplotlib.pyplot as plt
4+
import torch
5+
import torch.distributed as dist
6+
from global_vars import (
7+
ITERATION,
8+
LOCAL_RANK,
9+
LOCAL_WORLD_SIZE,
10+
RANK,
11+
WARMUP,
12+
WORLD_SIZE,
13+
get_hostnames,
14+
)
15+
from utility import create_dir, extract_first_middle_last, extract_number, log
16+
17+
18+
def run_inter_node_comm(args):
19+
device = torch.device(f"cuda:{LOCAL_RANK}")
20+
sizes = [2**i * 1024 * 1024 for i in range(1, 11)]
21+
# sizes = [2**i * 1024 * 1024 for i in range(1, 5)]
22+
assert WORLD_SIZE % LOCAL_WORLD_SIZE == 0
23+
num_nodes = WORLD_SIZE // LOCAL_WORLD_SIZE
24+
RANK // LOCAL_WORLD_SIZE
25+
26+
if num_nodes <= 1:
27+
log(f"Skip inter node comm benchmark, {num_nodes=}")
28+
return
29+
30+
# N-node allreduce & alltoall (adjacent pairs)
31+
# 2-node allreduce, pair nodes: [0, 1], [2, 3], ...
32+
# 4-node allreduce, pair nodes: [0, 1, 2, 3], [4, 5, 6, 7]...
33+
cases = {
34+
"allreduce": list(set([2, 4] + [num_nodes])),
35+
"alltoall": list(set([2, 4] + [num_nodes])),
36+
}
37+
38+
if RANK == 0:
39+
with open(args.markdown_file, "a", encoding="utf-8") as f:
40+
f.write(f"# InterNode Comm\n")
41+
42+
for comm, adjacent_node_list in cases.items():
43+
if RANK == 0:
44+
with open(args.markdown_file, "a", encoding="utf-8") as f:
45+
f.write(f"## InterNode - {comm}\n")
46+
for adjacent_nodes in adjacent_node_list:
47+
if adjacent_nodes > num_nodes:
48+
continue
49+
50+
case_name = f"{comm}-{adjacent_nodes}nodes"
51+
latency_results = {}
52+
bandwidth_results = {}
53+
54+
num_procs = adjacent_nodes * LOCAL_WORLD_SIZE
55+
num_adjacent_groups = num_nodes // adjacent_nodes
56+
adjacent_group = None
57+
for i_group in range(num_adjacent_groups):
58+
group_ranks = [
59+
i_group * adjacent_nodes * LOCAL_WORLD_SIZE + r
60+
for r in range(adjacent_nodes * LOCAL_WORLD_SIZE)
61+
]
62+
tmp_group = dist.new_group(ranks=group_ranks)
63+
if RANK in group_ranks:
64+
assert adjacent_group is None
65+
adjacent_group = tmp_group
66+
if RANK < num_adjacent_groups * adjacent_nodes * LOCAL_WORLD_SIZE:
67+
assert adjacent_group is not None
68+
69+
for size in sizes:
70+
if adjacent_group is None:
71+
break
72+
73+
tensor = torch.rand(size // 2, dtype=torch.bfloat16, device=device)
74+
dist.barrier(group=adjacent_group, device_ids=[torch.cuda.current_device()])
75+
for _ in range(WARMUP):
76+
if "allreduce" == comm:
77+
dist.all_reduce(tensor, group=adjacent_group)
78+
elif "alltoall" == comm:
79+
dist.all_to_all_single(tensor, tensor, group=adjacent_group)
80+
else:
81+
assert False
82+
torch.cuda.synchronize()
83+
start = time.time()
84+
for _ in range(ITERATION):
85+
if "allreduce" == comm:
86+
dist.all_reduce(tensor, group=adjacent_group)
87+
elif "alltoall" == comm:
88+
dist.all_to_all_single(tensor, tensor, group=adjacent_group)
89+
else:
90+
assert False
91+
torch.cuda.synchronize()
92+
elapsed = (time.time() - start) / ITERATION
93+
scale = 2 if comm == "allreduce" else 1
94+
comm_size = scale * size * (num_procs - 1) / num_procs
95+
gb_per_sec = comm_size / elapsed / 1e9
96+
latency_results[f"{size//1024//1024}MB"] = elapsed * 1e6
97+
bandwidth_results[f"{size//1024//1024}MB"] = gb_per_sec
98+
99+
dist.barrier(device_ids=[torch.cuda.current_device()])
100+
if adjacent_group is not None:
101+
dist.destroy_process_group(adjacent_group)
102+
103+
all_latency_results = [None for _ in range(WORLD_SIZE)]
104+
all_bandwidth_results = [None for _ in range(WORLD_SIZE)]
105+
dist.gather_object(latency_results, all_latency_results if RANK == 0 else None, dst=0)
106+
dist.gather_object(bandwidth_results, all_bandwidth_results if RANK == 0 else None, dst=0)
107+
108+
if RANK == 0:
109+
keys = sorted(
110+
list({k for r in all_bandwidth_results for k in (r or {}).keys()}), key=extract_number
111+
)
112+
max_len = max(len(s) for s in get_hostnames()) + 2
113+
114+
with open(args.markdown_file, "a", encoding="utf-8") as f:
115+
f.write(f"=======InterNodeComm - {case_name} (us)=======\n")
116+
log(f"=======InterNodeComm - {case_name} (us)=======")
117+
118+
f.write(f"| Hostname | Node | Rank | {' | '.join(keys)}|\n")
119+
f.write(f"|----------|----------|----------{'|----------' * len(keys)}|\n")
120+
121+
formatted_keys = [f"{key:<6}" for key in keys]
122+
log(f"{'Hostname':<{max_len}} {'Node':<5} {'Rank':<5} {' '.join(formatted_keys)}")
123+
for rank, r in enumerate(all_latency_results):
124+
hostname = get_hostnames()[rank]
125+
if rank % num_procs != 0:
126+
continue
127+
node_id = rank // LOCAL_WORLD_SIZE
128+
129+
formatted_values = [f"{r.get(key, 0):<6.2f}" for key in keys]
130+
log(f"{hostname:<{max_len}} {node_id:<5} {rank:<5} {' '.join(formatted_values)}")
131+
f.write(f"| {hostname} | {node_id} | {rank} | {' | '.join(formatted_values)}|\n")
132+
f.write(f"\n")
133+
134+
f.write(f"=======InterNodeComm - {case_name} (GB/s)=======\n")
135+
log(f"=======InterNodeComm - {case_name} (GB/s)=======")
136+
137+
f.write(f"| Hostname | Node | Rank | {' | '.join(keys)}|\n")
138+
f.write(f"|----------|----------|----------{'|----------' * len(keys)}|\n")
139+
formatted_keys = [f"{key:<6}" for key in keys]
140+
log(f"{'Hostname':<{max_len}} {'Node':<5} {'Rank':<5} {' '.join(formatted_keys)}")
141+
for rank, r in enumerate(all_bandwidth_results):
142+
hostname = get_hostnames()[rank]
143+
if rank % num_procs != 0:
144+
continue
145+
node_id = rank // LOCAL_WORLD_SIZE
146+
147+
formatted_values = [f"{r.get(key, 0):<6.2f}" for key in keys]
148+
log(f"{hostname:<{max_len}} {node_id:<5} {rank:<5} {' '.join(formatted_values)}")
149+
f.write(f"| {hostname} | {node_id} | {rank} | {' | '.join(formatted_values)}|\n")
150+
f.write(f"\n")
151+
152+
if not args.plot:
153+
continue
154+
155+
log(f"=======Plot IntraNode {case_name} Bandwidth=======")
156+
with open(args.markdown_file, "a", encoding="utf-8") as f:
157+
f.write(f"=======Plot InterNode {case_name} Bandwidth=======\n")
158+
plot_case = f"inter_node_comm/{comm}"
159+
dump_path = f"{args.dump_path}/{plot_case}"
160+
create_dir(dump_path)
161+
print_keys = extract_first_middle_last(keys)
162+
first_rank_bandwidth_results = [
163+
all_bandwidth_results[i] for i in range(len(all_bandwidth_results)) if i % num_procs == 0
164+
]
165+
num_print_ranks = len(first_rank_bandwidth_results)
166+
for size_key in print_keys:
167+
values = [r[size_key] for r in first_rank_bandwidth_results]
168+
plt.figure(figsize=(10, 4))
169+
bars = plt.bar(range(num_print_ranks), values)
170+
plt.xlabel(f"RankPair ({num_procs} ranks)")
171+
plt.ylabel("Bandwidth")
172+
plt.title(f"Inter Node {case_name} Bandwidth for {size_key}")
173+
xtick_labels = [f"{i*num_procs}" for i in range(num_print_ranks)]
174+
plt.xticks(range(num_print_ranks), xtick_labels)
175+
plt.grid(True, axis="y")
176+
177+
# Add roofline
178+
roofline_bandwidth = args.ib_bw
179+
plt.axhline(
180+
y=roofline_bandwidth,
181+
color="red",
182+
linestyle="--",
183+
linewidth=2,
184+
label=f"IB Unidirectional BW Roofline: {roofline_bandwidth} GB/s",
185+
)
186+
plt.legend()
187+
188+
# plt value
189+
for bar in bars:
190+
height = bar.get_height()
191+
plt.text(
192+
bar.get_x() + bar.get_width() / 2,
193+
height,
194+
f"{height:.2f}",
195+
ha="center",
196+
va="bottom",
197+
)
198+
199+
png_file = f"inter_node_{case_name}_bandwidth_{size_key.replace('x', '_')}.png"
200+
plt.tight_layout()
201+
plt.savefig(f"{dump_path}/{png_file}")
202+
plt.close()
203+
with open(args.markdown_file, "a", encoding="utf-8") as f:
204+
f.write(f"![{plot_case}](./{plot_case}/{png_file})\n")
205+
206+
# Bar chart visualization for rank 0
207+
rank_0_values = [all_bandwidth_results[0][size_key] for size_key in keys]
208+
plt.figure(figsize=(10, 4))
209+
bars = plt.bar(keys, rank_0_values)
210+
plt.xlabel("Size")
211+
plt.ylabel("Bandwidth")
212+
plt.title(f"Inter Node {case_name} Bandwidth for Rank 0")
213+
plt.grid(True, axis="y")
214+
# Add roofline
215+
roofline_bandwidth = args.ib_bw
216+
plt.axhline(
217+
y=roofline_bandwidth,
218+
color="red",
219+
linestyle="--",
220+
linewidth=2,
221+
label=f"IB Unidirectional BW Roofline: {roofline_bandwidth} GB/s",
222+
)
223+
plt.legend()
224+
225+
# plt value
226+
for bar in bars:
227+
height = bar.get_height()
228+
plt.text(
229+
bar.get_x() + bar.get_width() / 2, height, f"{height:.2f}", ha="center", va="bottom"
230+
)
231+
232+
png_file = f"inter_node_{case_name}_bandwidth_rank_0.png"
233+
plt.tight_layout()
234+
plt.savefig(f"{dump_path}/{png_file}")
235+
plt.close()
236+
with open(args.markdown_file, "a", encoding="utf-8") as f:
237+
f.write(f"![{plot_case}](./{plot_case}/{png_file})\n")
238+
f.write(f"\n")
239+
log(f"")

0 commit comments

Comments
 (0)