Skip to content

Commit 8dd4054

Browse files
Stebossolupton
andauthored
Clean up the analysis script and fix the error (#1274)
This Pr introduces a fix to this [error](https://github.com/NVIDIA/JAX-Toolbox/actions/runs/12986857460/job/36216005203#step:7:4387) in the CI and clean up the `communication.py` script --------- Co-authored-by: Olli Lupton <[email protected]>
1 parent 6fe699b commit 8dd4054

File tree

1 file changed

+128
-48
lines changed

1 file changed

+128
-48
lines changed

.github/container/nsys_jax/nsys_jax/analyses/communication.py

Lines changed: 128 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,40 @@
11
#!/usr/bin/env python
22
import argparse
33
import csv
4+
import pathlib
45
from collections import defaultdict
6+
from math import sqrt
57

68
from nsys_jax import (
79
align_profiler_data_timestamps,
810
apply_warmup_heuristics,
911
ensure_compiled_protos_are_importable,
1012
load_profiler_data,
1113
)
12-
from math import sqrt
1314
from prettytable import PrettyTable
14-
import pathlib
1515
from uncertainties import ufloat # type: ignore
1616

1717

1818
def process_communication_data(steady_state):
19+
"""
20+
Process communication data from a steady state, to compute bandwith summaries.
21+
22+
Args:
23+
steady_state: A steady state data frame.
24+
25+
Return:
26+
A tuple of (collective_types, summary_data), where:
27+
collective_types (List[str]): sorted list of collective operation types
28+
summary_data (Dict[int, Dict[str, ufloat]]): Dictionary wiht summaries for bandwith data
29+
"""
1930
collective_types = set()
2031
summary_data = defaultdict(dict)
21-
for (collective, message_size), df in steady_state.communication.groupby(
32+
33+
communication_grouped_by = steady_state.communication.groupby(
2234
["Collective", "MessageSize"]
23-
):
35+
)
36+
37+
for (collective, message_size), df in communication_grouped_by:
2438
collective_types.add(collective)
2539
# This grouped data frame will have a row for each device that is participating
2640
# in this instance of the collective.
@@ -31,14 +45,43 @@ def process_communication_data(steady_state):
3145
# therefore the last one to join the collective and its bandwidth estimate does
3246
# not contain a wait time component. The .mean() is over the different
3347
# (ProgramId, ProgramExecution, ThunkIndex) values.
34-
bandwidth = devices["BusBandwidthGBPerSec"].agg("max")
48+
bandwidth_of_fastest_device = devices["BusBandwidthGBPerSec"].agg("max")
49+
mean_bandwidth = bandwidth_of_fastest_device.mean()
50+
stderr_bandwidth = bandwidth_of_fastest_device.std() / sqrt(
51+
len(bandwidth_of_fastest_device)
52+
)
53+
3554
summary_data[message_size][collective] = ufloat(
36-
bandwidth.mean(), bandwidth.std() / sqrt(len(bandwidth))
55+
mean_bandwidth, stderr_bandwidth
3756
)
57+
3858
return sorted(collective_types), summary_data
3959

4060

4161
def print_bandwidth_table(collective_types, summary_data):
62+
"""
63+
This function prints a table for summarizing the bandwidth for each collective operation
64+
65+
Args:
66+
collective_types (List[str]): sorted list of collective operation types
67+
summary_data (Dict[int, Dict[str, ufloat]]): Dictionary wiht summaries for bandwith data
68+
"""
69+
70+
def format_message_size(message_size):
71+
"""
72+
Function to format the message size
73+
"""
74+
return f"{message_size:<{size_width},}"
75+
76+
def format_bandwidth(data, collective):
77+
"""
78+
Function to format the bandwidth
79+
"""
80+
width = collective_widths[collective]
81+
if collective not in data:
82+
return "-" * width
83+
return f"{data[collective]:>{width}S}"
84+
4285
collective_widths = {
4386
collective: max(
4487
len(collective),
@@ -52,54 +95,62 @@ def print_bandwidth_table(collective_types, summary_data):
5295
}
5396
size_heading = "Size [B]"
5497
size_width = max(len(size_heading), max(len(f"{s:,}") for s in summary_data.keys()))
55-
print(f"{'':<{size_width}} | Bus bandwidth [GB/s]")
56-
print(
57-
" | ".join(
58-
[f"{size_heading:<{size_width}}"]
59-
+ [f"{coll:<{collective_widths[coll]}}" for coll in collective_types]
60-
)
61-
)
62-
63-
def format_message_size(message_size):
64-
return f"{message_size:<{size_width},}"
6598

66-
def format_bandwidth(data, collective):
67-
width = collective_widths[collective]
68-
if collective not in data:
69-
return "-" * width
70-
return f"{data[collective]:>{width}S}"
99+
header_log = f"{'':<{size_width}} | Bus bandwidth [GB/s]"
100+
print(header_log)
101+
log_specs = " | ".join(
102+
[f"{size_heading:<{size_width}}"]
103+
+ [f"{coll:<{collective_widths[coll]}}" for coll in collective_types]
104+
)
105+
print(log_specs)
71106

72107
for message_size in sorted(summary_data.keys()):
73108
data = summary_data[message_size]
74-
print(
75-
" | ".join(
76-
[format_message_size(message_size)]
77-
+ [
78-
format_bandwidth(data, collective)
79-
for collective in collective_types
80-
]
81-
)
109+
log_row = " | ".join(
110+
[format_message_size(message_size)]
111+
+ [format_bandwidth(data, collective) for collective in collective_types]
82112
)
113+
print(log_row)
83114

84115

85116
def process_hidden_ms_to_total_ms(steady_state):
117+
"""
118+
Function to compute the fraction of communication time that is hidden behind computations.
119+
120+
Args:
121+
steady_state: The steady state data
122+
123+
Returns:
124+
collective_types (Set[str]): set of collective operation types
125+
summary_data (Dict[str, float]): dictionary with mean hidden-to-total milliseconds ratio
126+
"""
86127
if steady_state.communication["ProjDurHiddenMs"].sum() == 0:
87128
return None, None
88129

89130
collective_types = set()
90131
summary_data = defaultdict(dict)
91-
for collective, df in steady_state.communication.groupby(["Collective"]):
132+
grouped_data = steady_state.communication.groupby(["Collective"])
133+
134+
for collective, df in grouped_data:
92135
collective_types.add(collective)
93-
mean_dur_hidden_ms_to_total_ms = (
94-
df["ProjDurHiddenMs"] / (df["ProjDurMs"] + df["ProjDurHiddenMs"])
95-
).mean()
136+
total_ms = df["ProjDurMs"] + df["ProjDurHiddenMs"]
137+
mean_dur_hidden_ms_to_total_ms = (df["ProjDurHiddenMs"] / total_ms).mean()
96138
summary_data[collective] = mean_dur_hidden_ms_to_total_ms
139+
97140
return collective_types, summary_data
98141

99142

100143
def print_hidden_ms_to_total_ms_table(
101144
collective_types, summary_data, overall_hidden_ms_to_total_ms
102145
):
146+
"""
147+
Print the hidden ms to total ms
148+
149+
Args:
150+
collective_types (Set[str]): set of collective operation types
151+
summary_data (Dict[str, float]): mean hidden-to-total milliseconds ratio
152+
overall_ratio (float): overall hidden-to-total milliseconds ratio
153+
"""
103154
table = PrettyTable()
104155
table.field_names = ["Collective", "Mean HiddenToTotalMs"]
105156

@@ -108,20 +159,31 @@ def print_hidden_ms_to_total_ms_table(
108159
table.add_row([collective[0], mean_value])
109160

110161
print(table)
111-
print("Overall HiddenMs to TotalMs:", overall_hidden_ms_to_total_ms)
162+
if overall_hidden_ms_to_total_ms is not None:
163+
print(f"Overall HiddenMs to TotalMs: {overall_hidden_ms_to_total_ms:.4f}")
112164

113165

114166
def calculate_overall_hidden_ms_to_total_ms(steady_state):
115-
if steady_state.communication["ProjDurHiddenMs"].sum() == 0:
167+
"""
168+
Function to calculate the overall hidden milliseconds to total milliseconds
169+
170+
Args:
171+
steady_state: the steady-state data extracted from the profiler
172+
173+
Returns:
174+
overall_hidden_ms_to_total_ms (float): overall hidden milliseconds to total milliseconds ratio
175+
"""
176+
total_hidden_ms = steady_state.communication["ProjDurHiddenMs"].sum()
177+
if total_hidden_ms == 0:
116178
return None
117179

118-
overall_hidden_ms_to_total_ms = (
119-
steady_state.communication["ProjDurHiddenMs"].sum()
120-
/ (
121-
steady_state.communication["ProjDurMs"]
122-
+ steady_state.communication["ProjDurHiddenMs"]
123-
).sum()
124-
)
180+
total_ms = (
181+
steady_state.communication["ProjDurMs"]
182+
+ steady_state.communication["ProjDurHiddenMs"]
183+
).sum()
184+
185+
overall_hidden_ms_to_total_ms = total_hidden_ms / total_ms
186+
125187
return overall_hidden_ms_to_total_ms
126188

127189

@@ -132,6 +194,17 @@ def write_to_csv(
132194
overall_hidden_ms_to_total_ms,
133195
output_file,
134196
):
197+
"""
198+
Function to write the summaries to a csv file
199+
200+
Args:
201+
collective_types (List[str]): list of collective operation types
202+
bandwidth_summary (Dict[int, Dict[str, ufloat]]): bandwidth summary data
203+
hidden_to_total_summary (Dict[str, float]): hidden-to-total milliseconds ratio summary
204+
overall_hidden_ms_to_total_ms (float): overall hidden-to-total milliseconds ratio
205+
output_file (str): output CSV file path
206+
207+
"""
135208
with open(output_file, "w", newline="") as csvfile:
136209
writer = csv.writer(csvfile)
137210

@@ -165,6 +238,9 @@ def write_to_csv(
165238

166239

167240
def main():
241+
"""
242+
Main entry point to process the nsys-jax report and generate communication summaries
243+
"""
168244
parser = argparse.ArgumentParser(
169245
description="Summarise communication in an nsys-jax report"
170246
)
@@ -178,22 +254,26 @@ def main():
178254
all_data = load_profiler_data(args.prefix, frames={"communication", "compile"})
179255
# Align timestamps
180256
all_data, alignment_metadata = align_profiler_data_timestamps(all_data)
181-
# TODO: make this pretty
182-
# print(alignment_metadata)
257+
print(f"Alignment metadata: {alignment_metadata}")
183258
# Partition the profile data into initialisation and steady-state running
184259
_, steady_state = apply_warmup_heuristics(all_data)
185260

186-
assert len(steady_state.communication), (
187-
"Communication summary was requested but no steady-state communication was "
188-
"identified."
189-
)
261+
if len(steady_state.communication) == 0:
262+
print(
263+
"Communication summary was requested but no steady-state communication was identified."
264+
)
265+
return
190266

191267
collective_types, bandwidth_summary = process_communication_data(steady_state)
192268
print_bandwidth_table(collective_types, bandwidth_summary)
193269

194270
hidden_to_total_collective_types, hidden_to_total_summary = (
195271
process_hidden_ms_to_total_ms(steady_state)
196272
)
273+
274+
# initailise overall_hidden_ms_to_total_ms
275+
overall_hidden_ms_to_total_ms = None
276+
197277
if hidden_to_total_summary is not None:
198278
overall_hidden_ms_to_total_ms = calculate_overall_hidden_ms_to_total_ms(
199279
steady_state

0 commit comments

Comments
 (0)