Skip to content

Commit 47a34ad

Browse files
authored
nsys-jax: improve alignment for multi-GPU runs (#1353)
This removes assumptions about the ordering of thunks being identical across different GPUs, which are not correct in the presence of overlapped communications and enough jitter. The summary analysis script gains a wait time metric and prettier formatting. Also fix a race condition in the EKS-based CI job for nsys-jax, and make CI unit test jobs go red if some tests fail.
1 parent 29fce40 commit 47a34ad

File tree

10 files changed

+158
-72
lines changed

10 files changed

+158
-72
lines changed

.github/container/nsys_jax/nsys_jax/analyses/Analysis.ipynb

Lines changed: 28 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@
2121
" xla_module_metadata,\n",
2222
")\n",
2323
"import matplotlib.pyplot as plt\n",
24-
"import numpy as np"
24+
"import numpy as np\n",
25+
"import pathlib"
2526
]
2627
},
2728
{
@@ -33,6 +34,7 @@
3334
"source": [
3435
"# Set the input data to use. default_data_prefix() checks the NSYS_JAX_DEFAULT_PREFIX environment variable, and if that is\n",
3536
"# not set then the current working directory is used. Use pathlib.Path if setting this explicitly.\n",
37+
"prefix = pathlib.Path(\".\") # modify this and comment out the next line\n",
3638
"prefix = default_data_prefix()"
3739
]
3840
},
@@ -128,15 +130,14 @@
128130
"id": "7727d800-13d3-4505-89e8-80a5fed63512",
129131
"metadata": {},
130132
"source": [
131-
"Here the index has four levels. `ProgramId`, `ProgramExecution` and `Device` have the same meanings as in `steady_state.module`.\n",
132-
"The fourth level (in the 3rd position) shows that this row is the `ThunkIndex`-th thunk within the `ProgramExecution`-th execution of XLA module `ProgramId`.\n",
133-
"Note that a given thunk can be executed multiple times within the same module, so indexing on the thunk name would not be unique.\n",
133+
"Here the index has five levels. `ProgramId`, `ProgramExecution` and `Device` have the same meanings as in `steady_state.module`.\n",
134+
"The two new levels, `Name` and `ThunkExecution`, show that a given row is the `ThunkExecution`-th execution within the `ProgramExecution`-th execution of XLA module `ProgramId` of thunk `Name`.\n",
135+
"The `ThunkExecution` value is needed because a given thunk can be executed multiple times within the same module.\n",
136+
"The `Name` of a thunk can be used, along with a `ProgramId`, to look up XLA metadata.\n",
134137
"\n",
135138
"The columns are as follows:\n",
136-
"- `Name`: the name of the thunk; this should be unique within a given `ProgramId` and can be used as a key to look up XLA metadata\n",
137139
"- `ProjStartMs`: see above, same meaning as in `steady_state.module`.\n",
138140
"- `Communication`: does this thunk represent communication between GPUs (*i.e.* a NCCL collective)? XLA overlaps communication and computation kernels, and `load_profiler_data` triggers an overlap calculation. `ProjDurMs` for a communication kernel shows only the duration that was **not** overlapped with computation kernels, while `ProjDurHiddenMs` shows the duration that **was** overlapped.\n",
139-
"- This is the `ThunkExecution`-th execution of this thunk for this `(ProgramId, ProgramExecution, Device)`\n",
140141
"\n",
141142
"The third data frame does not show any GPU execution, but is rather a host-side trace:"
142143
]
@@ -178,7 +179,7 @@
178179
"id": "2e82c357-4e9d-48e4-b758-fa5357b2c8bd",
179180
"metadata": {},
180181
"source": [
181-
"The index structure, and many of the columns, are equivalent to `thunk_df`. Additional columns are:\n",
182+
"The index structure, and many of the columns, are equivalent to the `.thunk` data frame. Additional columns are:\n",
182183
"\n",
183184
"- `MessageSize`: the message size of the collective in bytes; this aims to follow the same conventions as the NCCL tests\n",
184185
"- `Collective`: the type of collective communication\n",
@@ -524,7 +525,9 @@
524525
" # program, there may be different sub-groupings that are participating in smaller\n",
525526
" # collectives in the strict/NCCL sense. TODO: it would be better to identify those\n",
526527
" # sub-groupings and group them, but we currently lack the relevant information.\n",
527-
" collective_df = df.groupby([\"ProgramId\", \"ProgramExecution\", \"ThunkIndex\"])\n",
528+
" collective_df = df.groupby(\n",
529+
" [\"ProgramId\", \"ProgramExecution\", \"Name\", \"ThunkExecution\"]\n",
530+
" )\n",
528531
" # Take the fastest device kernel as a proxy for the actual bandwidth of the\n",
529532
" # collective.\n",
530533
" bandwidth_df = collective_df.agg(\n",
@@ -534,7 +537,6 @@
534537
" \"ProjStartMs\": \"min\",\n",
535538
" \"ProjDurFullMs\": \"min\",\n",
536539
" \"ProjEndMs\": \"max\",\n",
537-
" \"Name\": \"count\",\n",
538540
" }\n",
539541
" )\n",
540542
" axs[0].plot(\n",
@@ -582,9 +584,9 @@
582584
"\n",
583585
"# Calculate statistics over different devices and different executions of each thunk, including multiple executions of the same thunk within the same module\n",
584586
"compute_durations = steady_state.thunk.loc[\n",
585-
" ~steady_state.thunk[\"Communication\"], (\"Name\", \"ProjDurMs\")\n",
587+
" ~steady_state.thunk[\"Communication\"], \"ProjDurMs\"\n",
586588
"].groupby([\"ProgramId\", \"Name\"])\n",
587-
"compute_duration_stats = compute_durations[\"ProjDurMs\"].agg((\"mean\", \"std\"))\n",
589+
"compute_duration_stats = compute_durations.agg((\"mean\", \"std\"))\n",
588590
"compute_duration_means = compute_duration_stats[\"mean\"]\n",
589591
"compute_duration_rel_stds = compute_duration_stats[\"std\"] / compute_duration_means\n",
590592
"\n",
@@ -634,8 +636,7 @@
634636
"\n",
635637
"def durations_ms(idx):\n",
636638
" program_id, thunk_name = idx\n",
637-
" tmp = steady_state.thunk.loc[program_id, (\"Name\", \"ProjDurMs\")]\n",
638-
" return tmp.loc[tmp[\"Name\"] == thunk_name, \"ProjDurMs\"]\n",
639+
" return steady_state.thunk.loc[(program_id, slice(None), thunk_name), \"ProjDurMs\"]\n",
639640
"\n",
640641
"\n",
641642
"detailed_index = high_variance_means[high_variance_means > mean_threshold].index\n",
@@ -666,6 +667,7 @@
666667
" squeeze=False,\n",
667668
" tight_layout=True,\n",
668669
" )\n",
670+
" # Compute (non-comm) kernel timings\n",
669671
" time_df = steady_state.thunk.loc[\n",
670672
" ~steady_state.thunk[\"Communication\"], (\"ProjStartMs\", \"ProjDurMs\")\n",
671673
" ]\n",
@@ -688,14 +690,17 @@
688690
" ):\n",
689691
" # Mean over devices to get a single [thunk0_start, thunk0_end, thunk1_start, ...]\n",
690692
" # array for this execution of this module\n",
691-
" mean_times = interleave(exec_df.groupby(\"ThunkIndex\").agg(\"mean\"))\n",
693+
" mean_times = interleave(\n",
694+
" exec_df.groupby([\"Name\", \"ThunkExecution\"], sort=False).agg(\"mean\")\n",
695+
" )\n",
692696
" # x axis of the plot will be the average over executions of the module\n",
693697
" x_values.append(mean_times - mean_times[0])\n",
694698
" for device, device_values in exec_df.groupby(\"Device\"):\n",
695699
" # [thunk0_start, thunk0_end, ...] array for one device within one module exec\n",
696700
" # with the average over devices subtracted\n",
697701
" y_values[device].append(interleave(device_values) - mean_times)\n",
698702
" mean_start_time_ms = np.mean(x_values, axis=0)\n",
703+
" # all_values: (num_devices, num_module_executions, thunks_per_module)\n",
699704
" all_values = np.array(list(y_values.values()))\n",
700705
" ax.plot(\n",
701706
" mean_start_time_ms,\n",
@@ -728,18 +733,17 @@
728733
" exec_df[\"ProjEndMs\"]\n",
729734
" - steady_state.module.loc[(program_id, module_execution), \"ProjStartMs\"]\n",
730735
" )\n",
731-
" tmp = exec_df.groupby(\"ThunkIndex\").agg(\n",
736+
" tmp = exec_df.groupby([\"Name\", \"ThunkExecution\"]).agg(\n",
732737
" {\n",
733-
" \"Name\": \"first\",\n",
734738
" \"Collective\": \"first\",\n",
735739
" \"CollectiveSize\": \"first\",\n",
736740
" \"EndInModuleMs\": \"mean\",\n",
737741
" }\n",
738742
" )\n",
739743
" for coll_size, values in tmp.groupby(\"CollectiveSize\"):\n",
740744
" comm_x_values[coll_size].append(values[\"EndInModuleMs\"])\n",
741-
" (_, xmax), (ymin, ymax) = ax.get_xlim(), ax.get_ylim()\n",
742-
" ax.set_xlim(0, xmax)\n",
745+
" ymin, ymax = ax.get_ylim()\n",
746+
" ax.set_xlim(mean_start_time_ms[0], mean_start_time_ms[-1])\n",
743747
" ax.set_ylim(ymin, ymax)\n",
744748
" largest_collective = max(comm_x_values.keys())\n",
745749
" for n_color, (coll_size, values) in enumerate(comm_x_values.items()):\n",
@@ -748,10 +752,10 @@
748752
" collective_times,\n",
749753
" ymin,\n",
750754
" # Draw taller vertical lines for collectives involving more devices\n",
751-
" ymin * (1 - coll_size / largest_collective),\n",
755+
" ymin * (1 - 0.75 * coll_size / largest_collective),\n",
752756
" color=f\"C{n_color}\",\n",
753757
" label=f\"{coll_size}-device collective\",\n",
754-
" linestyle=\"--\",\n",
758+
" linestyle=\"-\",\n",
755759
" )\n",
756760
"\n",
757761
" ax.set_title(\n",
@@ -836,7 +840,9 @@
836840
"outputs": [],
837841
"source": [
838842
"num_traces = {\n",
839-
" module_id: xla_module_metadata(module_id, policy=\"all\").unique_result(\n",
843+
" module_id: xla_module_metadata(\n",
844+
" module_id, policy=\"all\", prefix=prefix\n",
845+
" ).unique_result(\n",
840846
" lambda hlo_module: len(\n",
841847
" hlo_module.proto().buffer_assignment.heap_simulator_traces\n",
842848
" )\n",
@@ -855,7 +861,7 @@
855861
" squeeze=False,\n",
856862
")\n",
857863
"for n_module, module_id in enumerate(module_ids_with_traces):\n",
858-
" protos = xla_module_metadata(module_id, policy=\"all\")\n",
864+
" protos = xla_module_metadata(module_id, policy=\"all\", prefix=prefix)\n",
859865
" sizes_by_logical_id = protos.unique_result(\n",
860866
" lambda proto: {\n",
861867
" buffer.id: buffer.size\n",

.github/container/nsys_jax/nsys_jax/analyses/communication.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,9 @@ def process_communication_data(steady_state):
3838
collective_types.add(collective)
3939
# This grouped data frame will have a row for each device that is participating
4040
# in this instance of the collective.
41-
devices = df.groupby(["ProgramId", "ProgramExecution", "ThunkIndex"])
41+
devices = df.groupby(
42+
["ProgramId", "ProgramExecution", "Name", "ThunkExecution"]
43+
)
4244
# Take the fastest device bandwidth. Rationale: the slower devices appear
4345
# slower because they spend some time waiting for the last device, and then all
4446
# devices complete the collective at the same time. The fastest device is
@@ -134,8 +136,7 @@ def process_hidden_ms_to_total_ms(steady_state):
134136
for collective, df in grouped_data:
135137
collective_types.add(collective)
136138
total_ms = df["ProjDurMs"] + df["ProjDurHiddenMs"]
137-
mean_dur_hidden_ms_to_total_ms = (df["ProjDurHiddenMs"] / total_ms).mean()
138-
summary_data[collective] = mean_dur_hidden_ms_to_total_ms
139+
summary_data[collective] = df["ProjDurHiddenMs"].sum() / total_ms.sum()
139140

140141
return collective_types, summary_data
141142

@@ -253,8 +254,7 @@ def main():
253254
# Load the profiler data; the compilation part is needed for the warmup heuristics
254255
all_data = load_profiler_data(args.prefix, frames={"communication", "compile"})
255256
# Align timestamps
256-
all_data, alignment_metadata = align_profiler_data_timestamps(all_data)
257-
print(f"Alignment metadata: {alignment_metadata}")
257+
all_data, _ = align_profiler_data_timestamps(all_data)
258258
# Partition the profile data into initialisation and steady-state running
259259
_, steady_state = apply_warmup_heuristics(all_data)
260260

.github/container/nsys_jax/nsys_jax/analyses/summary.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#!/usr/bin/env python
22
import argparse
3+
import math
34
from nsys_jax import (
45
apply_warmup_heuristics,
56
ensure_compiled_protos_are_importable,
@@ -8,6 +9,8 @@
89
remove_autotuning_detail,
910
)
1011
import pathlib
12+
from prettytable import PrettyTable
13+
from uncertainties import ufloat # type: ignore
1114

1215

1316
def main():
@@ -45,12 +48,59 @@ def main():
4548
/ module_stats[("ProjDurMs", "sum")].sum()
4649
)
4750

51+
if steady_state.communication is not None and len(steady_state.communication):
52+
# Calculate the time spent waiting in collectives for each module.
53+
# Min/max over devices within individual communication thunk executions
54+
min_max_device_times = (
55+
steady_state.communication["ProjDurMs"]
56+
.groupby(["ProgramId", "ProgramExecution", "Name", "ThunkExecution"])
57+
.agg(("min", "max"))
58+
)
59+
# Define wait time as max-min *exposed* communication thunk times
60+
thunk_wait_times = min_max_device_times["max"] - min_max_device_times["min"]
61+
# Sum over thunks within each module
62+
module_wait_times = thunk_wait_times.groupby(
63+
["ProgramId", "ProgramExecution"]
64+
).agg("sum")
65+
# Stats over different executions of the module
66+
wait_averages = module_wait_times.groupby("ProgramId").agg(("mean", "std"))
67+
module_stats[("WaitMs", "mean")] = wait_averages["mean"]
68+
module_stats[("WaitMs", "std")] = wait_averages["std"]
69+
module_stats[("WaitMs", "percent")] = (
70+
100 * wait_averages["mean"] / module_stats[("ProjDurMs", "mean")]
71+
)
72+
4873
def dump(fname, df):
4974
with open(fname + ".json", "w") as ofile:
5075
df.to_json(ofile, orient="split")
5176

5277
dump("module-stats", module_stats)
53-
print(f" === MODULE EXECUTION SUMMARY ===\n{module_stats}")
78+
print(" === MODULE EXECUTION SUMMARY ===")
79+
fields = {
80+
"ID": lambda _, v: str(v),
81+
"Name": lambda _, v: v,
82+
"#execs": lambda _, v: str(v),
83+
"Thunks": lambda _, v: f"{v:S}" if v.s else f"{v.n:.0f}",
84+
"Duration [ms]": lambda _, v: f"{v:S}",
85+
"Duration [%]": lambda _, v: f"{v:.3f}",
86+
"Wait time [ms]": lambda _, v: "---" if math.isnan(v.n) else f"{v:S}",
87+
"Wait time [%]": lambda _, v: "---" if math.isnan(v) else f"{v:.3f}",
88+
}
89+
table = PrettyTable(align="r", custom_format=fields, field_names=fields.keys())
90+
for id, row in module_stats.iterrows():
91+
table.add_row(
92+
[
93+
id,
94+
row[("Name", "first")],
95+
row[("Name", "count")],
96+
ufloat(row[("NumThunks", "mean")], row[("NumThunks", "std")]),
97+
ufloat(row[("ProjDurMs", "mean")], row[("ProjDurMs", "std")]),
98+
row[("ProjDurMs", "percent")],
99+
ufloat(row[("WaitMs", "mean")], row[("WaitMs", "std")]),
100+
row[("WaitMs", "percent")],
101+
]
102+
)
103+
print(table)
54104

55105
compilation_stats = generate_compilation_statistics(init.compile)
56106
if len(compilation_stats):

.github/container/nsys_jax/nsys_jax/analysis.py

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -54,12 +54,11 @@ def align_profiler_data_timestamps(
5454
)
5555
# For each collective, calculate the mean end time of each collective across devices
5656
mean_end_times = end_times.groupby(
57-
["ProgramId", "ProgramExecution", "ThunkIndex"]
57+
["ProgramId", "ProgramExecution", "Name", "ThunkExecution"], sort=False
5858
).agg("mean")
5959
# For each collective + device, calculate the delta of the end time from the mean
6060
end_time_skews = end_times - mean_end_times
61-
device_skews = end_time_skews.groupby("Device")
62-
median_device_skews = device_skews.agg("median")
61+
median_device_skews = end_time_skews.groupby("Device").agg("median")
6362
# Apply these corrections to the device-side timestamps
6463
for k in ["communication", "module", "thunk"]:
6564
df = getattr(frames, k)
@@ -78,11 +77,10 @@ def apply_warmup_heuristics(frames: ProfilerData) -> tuple[ProfilerData, Profile
7877
"""
7978
Given a ProfilerData dataclass, as returned by `load_profiler_data`, use heuristics
8079
to split the profile data into initialisation and steady state running. The current
81-
approach is to assume everything is steady state if compilation was not profiled,
82-
and if compilation *was* profiled then label the 0th execution as initialisation
83-
and the 2nd and later ones as steady state operation, discarding one execution in
84-
between. If there is no communication in the profile, that one in between is not
85-
discarded.
80+
approach is to check whether compilation of each module was profiled, and if so
81+
classify the first execution as initialization, and if the profile data includes
82+
communication thunks to classify an additional execution of each module as being
83+
initialization.
8684
8785
Returns a tuple of:
8886
ProfilerData dataclass, with only initialisation (and compile)
@@ -104,7 +102,9 @@ def apply_warmup_heuristics(frames: ProfilerData) -> tuple[ProfilerData, Profile
104102
#
105103
# then one-time costs (e.g. JIT compilation) of postamble(0) will affect when
106104
# step_function(1) is actually launched, whereas step_function(2) and later are
107-
# expected to launch closer to in lockstep across processes.
105+
# expected to launch closer to in lockstep across processes. Even if compilation is
106+
# not profiled, profiler initialisation can take variable time across processes and
107+
# induce skews between the first profiled executions.
108108
init = ProfilerData(compile=frames.compile)
109109
steady = ProfilerData()
110110
steady_state_threshold = (
@@ -115,20 +115,19 @@ def apply_warmup_heuristics(frames: ProfilerData) -> tuple[ProfilerData, Profile
115115
if df is None:
116116
continue
117117
compile_mask = df.index.get_level_values("ProgramId").isin(compilation_ids_seen)
118+
threshold = compile_mask + steady_state_threshold
118119
prog_exec_values = df.index.get_level_values("ProgramExecution")
119-
init_mask = compile_mask & (prog_exec_values == 0)
120-
steady_mask = ~compile_mask | (prog_exec_values > steady_state_threshold)
120+
init_mask = prog_exec_values < threshold
121+
steady_mask = ~init_mask
121122
if len(df) != 0 and not steady_mask.any():
122123
print(
123-
f"WARNING: heuristics could not identify steady-state execution in {k} frame, assuming EVERYTHING is steady-state. You may want to increase the number of profiled executions."
124+
f"WARNING: heuristics could not identify steady-state execution in {k} "
125+
"frame, assuming EVERYTHING is steady-state. You may want to increase "
126+
"the number of profiled executions."
124127
)
125128
setattr(init, k, df[steady_mask])
126129
setattr(steady, k, df[~steady_mask])
127130
else:
128-
assert (
129-
steady_state_threshold == 0
130-
or (prog_exec_values[~init_mask & ~steady_mask] == 1).all()
131-
)
132131
setattr(init, k, df[init_mask])
133132
setattr(steady, k, df[steady_mask])
134133
return init, steady
@@ -303,12 +302,14 @@ def calculate_collective_metrics(
303302
if len(comm_df) == 0:
304303
return comm_df
305304

306-
def body(tup):
307-
idx, name = tup
308-
return get_message_size(idx[0], name, prefix=prefix)
305+
assert comm_df.index.names[0] == "ProgramId"
306+
assert comm_df.index.names[2] == "Name"
307+
308+
def body(idx):
309+
return get_message_size(idx[0], idx[2], prefix=prefix)
309310

310311
metrics_df = pd.DataFrame.from_records(
311-
map(body, comm_df["Name"].items()),
312+
map(body, comm_df.index),
312313
columns=[
313314
"MessageSize",
314315
"Collective",

0 commit comments

Comments
 (0)