Skip to content

Commit 564ec47

Browse files
authored
Merge branch 'main' into sbosisio/cuda-dl-base
2 parents bd066f1 + eb6d0d2 commit 564ec47

File tree

12 files changed

+521
-262
lines changed

12 files changed

+521
-262
lines changed

.github/container/nsys_jax/nsys_jax/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from .data_loaders import load_profiler_data
88
from .protobuf import xla_module_metadata
99
from .protobuf_utils import compile_protos, ensure_compiled_protos_are_importable
10-
from .utils import remove_autotuning_detail, remove_child_ranges
10+
from .utils import default_data_prefix, remove_autotuning_detail, remove_child_ranges
1111
from .visualization import create_flamegraph, display_flamegraph
1212

1313
__all__ = [
@@ -16,6 +16,7 @@
1616
"calculate_collective_metrics",
1717
"compile_protos",
1818
"create_flamegraph",
19+
"default_data_prefix",
1920
"display_flamegraph",
2021
"ensure_compiled_protos_are_importable",
2122
"generate_compilation_statistics",

.github/container/nsys_jax/nsys_jax/analyses/Analysis.ipynb

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
"from nsys_jax import (\n",
1313
" align_profiler_data_timestamps,\n",
1414
" apply_warmup_heuristics,\n",
15+
" default_data_prefix,\n",
1516
" display_flamegraph,\n",
1617
" ensure_compiled_protos_are_importable,\n",
1718
" generate_compilation_statistics,\n",
@@ -23,6 +24,18 @@
2324
"import numpy as np"
2425
]
2526
},
27+
{
28+
"cell_type": "code",
29+
"execution_count": null,
30+
"id": "7a91f0e7-17da-4534-8ea9-29bcf3742567",
31+
"metadata": {},
32+
"outputs": [],
33+
"source": [
34+
"# Set the input data to use. default_data_prefix() checks the NSYS_JAX_DEFAULT_PREFIX environment variable, and if that is\n",
35+
"# not set then the current working directory is used. Use pathlib.Path if setting this explicitly.\n",
36+
"prefix = default_data_prefix()"
37+
]
38+
},
2639
{
2740
"cell_type": "code",
2841
"execution_count": null,
@@ -32,7 +45,7 @@
3245
"source": [
3346
"# Make sure that the .proto files under protos/ have been compiled to .py, and\n",
3447
"# that those generated .py files are importable.]\n",
35-
"compiled_dir = ensure_compiled_protos_are_importable()"
48+
"compiled_dir = ensure_compiled_protos_are_importable(prefix=prefix)"
3649
]
3750
},
3851
{
@@ -43,7 +56,7 @@
4356
"outputs": [],
4457
"source": [
4558
"# Load the runtime profile data\n",
46-
"all_data = load_profiler_data()\n",
59+
"all_data = load_profiler_data(prefix)\n",
4760
"# Remove some detail from the autotuner\n",
4861
"all_data = remove_autotuning_detail(all_data)\n",
4962
"# Align GPU timestamps across profiles collected by different Nsight Systems processes\n",
@@ -82,16 +95,14 @@
8295
"source": [
8396
"This data frame has a three-level index:\n",
8497
"- `ProgramId` is an integer ID that uniquely identifies the XLA module\n",
85-
"- This is the `ProgramExecution`-th execution of the module within the profiles. You may see this starting from 1, not 0, because of the `warmup_removal_heuristics` option passed to `load_profiler_data`.\n",
98+
"- This is the `ProgramExecution`-th execution of the module within the profiles. You may see this starting from 2, not 0, because of the `warmup_removal_heuristics` option passed to `load_profiler_data`.\n",
8699
"- `Device` is the global (across multiple nodes and processes) index of the GPU on which the module execution took place\n",
87100
"\n",
88101
"The columns are as follows:\n",
89102
"- `Name`: the name of the XLA module; this should always be the same for a given `ProgramId`\n",
90103
"- `NumThunks`: the number of thunks executed inside this module execution\n",
91104
"- `ProjStartMs`: the timestamp of the start of the module execution on the GPU, in milliseconds\n",
92105
"- `ProjDurMs`: the duration of the module execution on the GPU, in milliseconds\n",
93-
"- `OrigStartMs`: the timestamp of the start of the module launch **on the host**, in milliseconds. *i.e.* `ProjStartMs-OrigStartMs` is something like the launch latency of the first kernel\n",
94-
"- `OrigDurMs`: the duration of the module launch **on the host**, in milliseconds\n",
95106
"- `LocalDevice`: the index within the node/slice of the GPU on which the module execution took place\n",
96107
"- `Process`: the global (across multiple nodes) index of the process\n",
97108
"- `Slice`: the global index of the node/slice; devices within the same node/slice should have faster interconnects than to devices in different slices\n",
@@ -117,13 +128,13 @@
117128
"id": "7727d800-13d3-4505-89e8-80a5fed63512",
118129
"metadata": {},
119130
"source": [
120-
"Here the index has four levels. `ProgramId`, `ProgramExecution` and `Device` have the same meanings as in `module_df`.\n",
131+
"Here the index has four levels. `ProgramId`, `ProgramExecution` and `Device` have the same meanings as in `steady_state.module`.\n",
121132
"The fourth level (in the 3rd position) shows that this row is the `ThunkIndex`-th thunk within the `ProgramExecution`-th execution of XLA module `ProgramId`.\n",
122133
"Note that a given thunk can be executed multiple times within the same module, so indexing on the thunk name would not be unique.\n",
123134
"\n",
124135
"The columns are as follows:\n",
125136
"- `Name`: the name of the thunk; this should be unique within a given `ProgramId` and can be used as a key to look up XLA metadata\n",
126-
"- `ProjStartMs`, `OrigStartMs`, `OrigDurMs`: see above, same meaning as in `module_df`.\n",
137+
"- `ProjStartMs`: see above, same meaning as in `steady_state.module`.\n",
127138
"- `Communication`: does this thunk represent communication between GPUs (*i.e.* a NCCL collective)? XLA overlaps communication and computation kernels, and `load_profiler_data` triggers an overlap calculation. `ProjDurMs` for a communication kernel shows only the duration that was **not** overlapped with computation kernels, while `ProjDurHiddenMs` shows the duration that **was** overlapped.\n",
128139
"- This is the `ThunkExecution`-th execution of this thunk for this `(ProgramId, ProgramExecution, Device)`\n",
129140
"\n",
@@ -299,7 +310,7 @@
299310
"# Print out the largest entries adding up to at least this fraction of the total\n",
300311
"threshold = 0.97\n",
301312
"compile_summary[\"FracNonChild\"] = compile_summary[\"DurNonChildMs\"] / total_compile_time\n",
302-
"print(f\"Top {threshold:.0%}+ of {total_compile_time * 1e-9:.2f}s compilation time\")\n",
313+
"print(f\"Top {threshold:.0%}+ of {total_compile_time * 1e-3:.2f}s compilation time\")\n",
303314
"for row in compile_summary[\n",
304315
" compile_summary[\"FracNonChild\"].cumsum() <= threshold\n",
305316
"].itertuples():\n",
@@ -378,7 +389,7 @@
378389
" program_id, thunk_name = thunk_row.Index\n",
379390
" # policy=\"all\" means we may get a set of HloProto instead of a single one, if\n",
380391
" # nsys-jax-combine was used and the dumped metadata were not bitwise identical\n",
381-
" hlo_modules = xla_module_metadata(program_id, policy=\"all\")\n",
392+
" hlo_modules = xla_module_metadata(program_id, policy=\"all\", prefix=prefix)\n",
382393
" thunk_opcode, inst_metadata, inst_frames = hlo_modules.unique_result(\n",
383394
" lambda proto: instructions_and_frames(proto, thunk_name)\n",
384395
" )\n",

.github/container/nsys_jax/nsys_jax/analyses/communication.py

100644100755
Lines changed: 140 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,21 @@
11
#!/usr/bin/env python
22
import argparse
3+
import csv
34
from collections import defaultdict
5+
46
from nsys_jax import (
57
align_profiler_data_timestamps,
68
apply_warmup_heuristics,
79
ensure_compiled_protos_are_importable,
810
load_profiler_data,
911
)
1012
from math import sqrt
13+
from prettytable import PrettyTable
1114
import pathlib
1215
from uncertainties import ufloat # type: ignore
1316

1417

15-
def main():
16-
parser = argparse.ArgumentParser(
17-
description="Summarise communication in an nsys-jax report"
18-
)
19-
parser.add_argument("prefix", type=pathlib.Path)
20-
args = parser.parse_args()
21-
# Make sure that the .proto files under protos/ have been compiled to .py, and
22-
# that those generated .py files are importable.
23-
ensure_compiled_protos_are_importable(prefix=args.prefix)
24-
# Load the profiler data; the compilation part is needed for the warmup heuristics
25-
all_data = load_profiler_data(args.prefix, frames={"communication", "compile"})
26-
# Align timestamps
27-
all_data, alignment_metadata = align_profiler_data_timestamps(all_data)
28-
# TODO: make this pretty
29-
# print(alignment_metadata)
30-
# Partition the profile data into initialisation and steady-state running
31-
_, steady_state = apply_warmup_heuristics(all_data)
32-
assert len(steady_state.communication), (
33-
"Communication summary was requested but no steady-state communication was "
34-
"identified."
35-
)
18+
def process_communication_data(steady_state):
3619
collective_types = set()
3720
summary_data = defaultdict(dict)
3821
for (collective, message_size), df in steady_state.communication.groupby(
@@ -52,7 +35,10 @@ def main():
5235
summary_data[message_size][collective] = ufloat(
5336
bandwidth.mean(), bandwidth.std() / sqrt(len(bandwidth))
5437
)
55-
collective_types = sorted(collective_types)
38+
return sorted(collective_types), summary_data
39+
40+
41+
def print_bandwidth_table(collective_types, summary_data):
5642
collective_widths = {
5743
collective: max(
5844
len(collective),
@@ -96,5 +82,137 @@ def format_bandwidth(data, collective):
9682
)
9783

9884

85+
def process_hidden_ms_to_total_ms(steady_state):
86+
if steady_state.communication["ProjDurHiddenMs"].sum() == 0:
87+
return None, None
88+
89+
collective_types = set()
90+
summary_data = defaultdict(dict)
91+
for collective, df in steady_state.communication.groupby(["Collective"]):
92+
collective_types.add(collective)
93+
mean_dur_hidden_ms_to_total_ms = (
94+
df["ProjDurHiddenMs"] / (df["ProjDurMs"] + df["ProjDurHiddenMs"])
95+
).mean()
96+
summary_data[collective] = mean_dur_hidden_ms_to_total_ms
97+
return collective_types, summary_data
98+
99+
100+
def print_hidden_ms_to_total_ms_table(
101+
collective_types, summary_data, overall_hidden_ms_to_total_ms
102+
):
103+
table = PrettyTable()
104+
table.field_names = ["Collective", "Mean HiddenToTotalMs"]
105+
106+
for collective in collective_types:
107+
mean_value = summary_data[collective]
108+
table.add_row([collective[0], mean_value])
109+
110+
print(table)
111+
print("Overall HiddenMs to TotalMs:", overall_hidden_ms_to_total_ms)
112+
113+
114+
def calculate_overall_hidden_ms_to_total_ms(steady_state):
115+
if steady_state.communication["ProjDurHiddenMs"].sum() == 0:
116+
return None
117+
118+
overall_hidden_ms_to_total_ms = (
119+
steady_state.communication["ProjDurHiddenMs"].sum()
120+
/ (
121+
steady_state.communication["ProjDurMs"]
122+
+ steady_state.communication["ProjDurHiddenMs"]
123+
).sum()
124+
)
125+
return overall_hidden_ms_to_total_ms
126+
127+
128+
def write_to_csv(
129+
collective_types,
130+
bandwidth_summary,
131+
hidden_to_total_summary,
132+
overall_hidden_ms_to_total_ms,
133+
output_file,
134+
):
135+
with open(output_file, "w", newline="") as csvfile:
136+
writer = csv.writer(csvfile)
137+
138+
# Write bandwidth table
139+
writer.writerow(["Bandwidth Table"])
140+
writer.writerow(["Size [B]"] + list(collective_types))
141+
for message_size in sorted(bandwidth_summary.keys()):
142+
row = [message_size]
143+
for collective in collective_types:
144+
if collective in bandwidth_summary[message_size]:
145+
row.append(f"{bandwidth_summary[message_size][collective]:S}")
146+
else:
147+
row.append("-")
148+
writer.writerow(row)
149+
150+
writer.writerow([]) # Empty row for separation
151+
152+
# Write hidden to total table if data is available
153+
if hidden_to_total_summary is not None:
154+
writer.writerow(["HiddenMs to TotalMs Table"])
155+
writer.writerow(["Collective", "Mean HiddenToTotalMs"])
156+
for collective in hidden_to_total_summary:
157+
writer.writerow([collective[0], hidden_to_total_summary[collective]])
158+
159+
writer.writerow([]) # Empty row for separation
160+
161+
if overall_hidden_ms_to_total_ms is not None:
162+
writer.writerow(
163+
["Overall HiddenMs to TotalMs", overall_hidden_ms_to_total_ms]
164+
)
165+
166+
167+
def main():
168+
parser = argparse.ArgumentParser(
169+
description="Summarise communication in an nsys-jax report"
170+
)
171+
parser.add_argument("prefix", type=pathlib.Path)
172+
args = parser.parse_args()
173+
174+
# Make sure that the .proto files under protos/ have been compiled to .py, and
175+
# that those generated .py files are importable.
176+
ensure_compiled_protos_are_importable(prefix=args.prefix)
177+
# Load the profiler data; the compilation part is needed for the warmup heuristics
178+
all_data = load_profiler_data(args.prefix, frames={"communication", "compile"})
179+
# Align timestamps
180+
all_data, alignment_metadata = align_profiler_data_timestamps(all_data)
181+
# TODO: make this pretty
182+
# print(alignment_metadata)
183+
# Partition the profile data into initialisation and steady-state running
184+
_, steady_state = apply_warmup_heuristics(all_data)
185+
186+
assert len(steady_state.communication), (
187+
"Communication summary was requested but no steady-state communication was "
188+
"identified."
189+
)
190+
191+
collective_types, bandwidth_summary = process_communication_data(steady_state)
192+
print_bandwidth_table(collective_types, bandwidth_summary)
193+
194+
hidden_to_total_collective_types, hidden_to_total_summary = (
195+
process_hidden_ms_to_total_ms(steady_state)
196+
)
197+
if hidden_to_total_summary is not None:
198+
overall_hidden_ms_to_total_ms = calculate_overall_hidden_ms_to_total_ms(
199+
steady_state
200+
)
201+
print_hidden_ms_to_total_ms_table(
202+
hidden_to_total_collective_types,
203+
hidden_to_total_summary,
204+
overall_hidden_ms_to_total_ms,
205+
)
206+
207+
# Write all tables to a single CSV file
208+
write_to_csv(
209+
collective_types,
210+
bandwidth_summary,
211+
hidden_to_total_summary,
212+
overall_hidden_ms_to_total_ms,
213+
"communication_summary.csv",
214+
)
215+
216+
99217
if __name__ == "__main__":
100218
main()

.github/container/nsys_jax/nsys_jax/analysis.py

100644100755
Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -286,17 +286,8 @@ def get_message_size(
286286
of the semantics. This implementation aims to follow the same conventions that NCCL
287287
uses in its NVTX payloads and tests.
288288
"""
289-
return pd.Series(
290-
xla_module_metadata(program_id, prefix=prefix, policy="all").unique_result(
291-
lambda proto: _get_message_size(proto, instruction)
292-
),
293-
index=[
294-
"MessageSize",
295-
"Collective",
296-
"CollectiveSize",
297-
"BandwidthCorrection",
298-
"BusBandwidthCorrection",
299-
],
289+
return xla_module_metadata(program_id, prefix=prefix, policy="all").unique_result(
290+
lambda proto: _get_message_size(proto, instruction)
300291
)
301292

302293

@@ -311,13 +302,26 @@ def calculate_collective_metrics(
311302
comm_df = thunk_df[thunk_df["Communication"]].drop(columns=["Communication"])
312303
if len(comm_df) == 0:
313304
return comm_df
305+
306+
def body(tup):
307+
idx, name = tup
308+
return get_message_size(idx[0], name, prefix=prefix)
309+
310+
metrics_df = pd.DataFrame.from_records(
311+
map(body, comm_df["Name"].items()),
312+
columns=[
313+
"MessageSize",
314+
"Collective",
315+
"CollectiveSize",
316+
"BandwidthCorrection",
317+
"BusBandwidthCorrection",
318+
],
319+
index=comm_df.index,
320+
)
314321
comm_df = pd.concat(
315322
[
316323
comm_df,
317-
comm_df.apply(
318-
lambda row: get_message_size(row.name[0], row.Name, prefix=prefix),
319-
axis=1,
320-
),
324+
metrics_df,
321325
],
322326
axis=1,
323327
)

0 commit comments

Comments
 (0)