Skip to content

Commit eb6d0d2

Browse files
authored
[nsys-jax] Add ratio of hidden communication time to total communication time (#1241)
1 parent 7036e87 commit eb6d0d2

File tree

3 files changed

+141
-22
lines changed

3 files changed

+141
-22
lines changed

.github/container/nsys_jax/nsys_jax/analyses/communication.py

100644100755
Lines changed: 140 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,21 @@
11
#!/usr/bin/env python
22
import argparse
3+
import csv
34
from collections import defaultdict
5+
46
from nsys_jax import (
57
align_profiler_data_timestamps,
68
apply_warmup_heuristics,
79
ensure_compiled_protos_are_importable,
810
load_profiler_data,
911
)
1012
from math import sqrt
13+
from prettytable import PrettyTable
1114
import pathlib
1215
from uncertainties import ufloat # type: ignore
1316

1417

15-
def main():
16-
parser = argparse.ArgumentParser(
17-
description="Summarise communication in an nsys-jax report"
18-
)
19-
parser.add_argument("prefix", type=pathlib.Path)
20-
args = parser.parse_args()
21-
# Make sure that the .proto files under protos/ have been compiled to .py, and
22-
# that those generated .py files are importable.
23-
ensure_compiled_protos_are_importable(prefix=args.prefix)
24-
# Load the profiler data; the compilation part is needed for the warmup heuristics
25-
all_data = load_profiler_data(args.prefix, frames={"communication", "compile"})
26-
# Align timestamps
27-
all_data, alignment_metadata = align_profiler_data_timestamps(all_data)
28-
# TODO: make this pretty
29-
# print(alignment_metadata)
30-
# Partition the profile data into initialisation and steady-state running
31-
_, steady_state = apply_warmup_heuristics(all_data)
32-
assert len(steady_state.communication), (
33-
"Communication summary was requested but no steady-state communication was "
34-
"identified."
35-
)
18+
def process_communication_data(steady_state):
3619
collective_types = set()
3720
summary_data = defaultdict(dict)
3821
for (collective, message_size), df in steady_state.communication.groupby(
@@ -52,7 +35,10 @@ def main():
5235
summary_data[message_size][collective] = ufloat(
5336
bandwidth.mean(), bandwidth.std() / sqrt(len(bandwidth))
5437
)
55-
collective_types = sorted(collective_types)
38+
return sorted(collective_types), summary_data
39+
40+
41+
def print_bandwidth_table(collective_types, summary_data):
5642
collective_widths = {
5743
collective: max(
5844
len(collective),
@@ -96,5 +82,137 @@ def format_bandwidth(data, collective):
9682
)
9783

9884

85+
def process_hidden_ms_to_total_ms(steady_state):
86+
if steady_state.communication["ProjDurHiddenMs"].sum() == 0:
87+
return None, None
88+
89+
collective_types = set()
90+
summary_data = defaultdict(dict)
91+
for collective, df in steady_state.communication.groupby(["Collective"]):
92+
collective_types.add(collective)
93+
mean_dur_hidden_ms_to_total_ms = (
94+
df["ProjDurHiddenMs"] / (df["ProjDurMs"] + df["ProjDurHiddenMs"])
95+
).mean()
96+
summary_data[collective] = mean_dur_hidden_ms_to_total_ms
97+
return collective_types, summary_data
98+
99+
100+
def print_hidden_ms_to_total_ms_table(
101+
collective_types, summary_data, overall_hidden_ms_to_total_ms
102+
):
103+
table = PrettyTable()
104+
table.field_names = ["Collective", "Mean HiddenToTotalMs"]
105+
106+
for collective in collective_types:
107+
mean_value = summary_data[collective]
108+
table.add_row([collective[0], mean_value])
109+
110+
print(table)
111+
print("Overall HiddenMs to TotalMs:", overall_hidden_ms_to_total_ms)
112+
113+
114+
def calculate_overall_hidden_ms_to_total_ms(steady_state):
115+
if steady_state.communication["ProjDurHiddenMs"].sum() == 0:
116+
return None
117+
118+
overall_hidden_ms_to_total_ms = (
119+
steady_state.communication["ProjDurHiddenMs"].sum()
120+
/ (
121+
steady_state.communication["ProjDurMs"]
122+
+ steady_state.communication["ProjDurHiddenMs"]
123+
).sum()
124+
)
125+
return overall_hidden_ms_to_total_ms
126+
127+
128+
def write_to_csv(
129+
collective_types,
130+
bandwidth_summary,
131+
hidden_to_total_summary,
132+
overall_hidden_ms_to_total_ms,
133+
output_file,
134+
):
135+
with open(output_file, "w", newline="") as csvfile:
136+
writer = csv.writer(csvfile)
137+
138+
# Write bandwidth table
139+
writer.writerow(["Bandwidth Table"])
140+
writer.writerow(["Size [B]"] + list(collective_types))
141+
for message_size in sorted(bandwidth_summary.keys()):
142+
row = [message_size]
143+
for collective in collective_types:
144+
if collective in bandwidth_summary[message_size]:
145+
row.append(f"{bandwidth_summary[message_size][collective]:S}")
146+
else:
147+
row.append("-")
148+
writer.writerow(row)
149+
150+
writer.writerow([]) # Empty row for separation
151+
152+
# Write hidden to total table if data is available
153+
if hidden_to_total_summary is not None:
154+
writer.writerow(["HiddenMs to TotalMs Table"])
155+
writer.writerow(["Collective", "Mean HiddenToTotalMs"])
156+
for collective in hidden_to_total_summary:
157+
writer.writerow([collective[0], hidden_to_total_summary[collective]])
158+
159+
writer.writerow([]) # Empty row for separation
160+
161+
if overall_hidden_ms_to_total_ms is not None:
162+
writer.writerow(
163+
["Overall HiddenMs to TotalMs", overall_hidden_ms_to_total_ms]
164+
)
165+
166+
167+
def main():
168+
parser = argparse.ArgumentParser(
169+
description="Summarise communication in an nsys-jax report"
170+
)
171+
parser.add_argument("prefix", type=pathlib.Path)
172+
args = parser.parse_args()
173+
174+
# Make sure that the .proto files under protos/ have been compiled to .py, and
175+
# that those generated .py files are importable.
176+
ensure_compiled_protos_are_importable(prefix=args.prefix)
177+
# Load the profiler data; the compilation part is needed for the warmup heuristics
178+
all_data = load_profiler_data(args.prefix, frames={"communication", "compile"})
179+
# Align timestamps
180+
all_data, alignment_metadata = align_profiler_data_timestamps(all_data)
181+
# TODO: make this pretty
182+
# print(alignment_metadata)
183+
# Partition the profile data into initialisation and steady-state running
184+
_, steady_state = apply_warmup_heuristics(all_data)
185+
186+
assert len(steady_state.communication), (
187+
"Communication summary was requested but no steady-state communication was "
188+
"identified."
189+
)
190+
191+
collective_types, bandwidth_summary = process_communication_data(steady_state)
192+
print_bandwidth_table(collective_types, bandwidth_summary)
193+
194+
hidden_to_total_collective_types, hidden_to_total_summary = (
195+
process_hidden_ms_to_total_ms(steady_state)
196+
)
197+
if hidden_to_total_summary is not None:
198+
overall_hidden_ms_to_total_ms = calculate_overall_hidden_ms_to_total_ms(
199+
steady_state
200+
)
201+
print_hidden_ms_to_total_ms_table(
202+
hidden_to_total_collective_types,
203+
hidden_to_total_summary,
204+
overall_hidden_ms_to_total_ms,
205+
)
206+
207+
# Write all tables to a single CSV file
208+
write_to_csv(
209+
collective_types,
210+
bandwidth_summary,
211+
hidden_to_total_summary,
212+
overall_hidden_ms_to_total_ms,
213+
"communication_summary.csv",
214+
)
215+
216+
99217
if __name__ == "__main__":
100218
main()

.github/container/nsys_jax/nsys_jax/analysis.py

100644100755
File mode changed.

.github/container/nsys_jax/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ dependencies = [
99
"pyarrow",
1010
"requests", # for install-protoc
1111
"uncertainties", # communication analysis recipe
12+
"prettytable",
1213
]
1314
requires-python = ">= 3.10"
1415

0 commit comments

Comments
 (0)