Skip to content

Commit b3d3dc8

Browse files
authored
nsys-jax: re-map module IDs (#1536)
This adds support for multi-process profiling when the different processes do not agree on the global numbering of modules.
1 parent efb11b7 commit b3d3dc8

File tree

9 files changed

+619
-75
lines changed

9 files changed

+619
-75
lines changed

.github/container/nsys_jax/nsys_jax/analyses/Analysis.ipynb

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
")\n",
2323
"import matplotlib.pyplot as plt\n",
2424
"import numpy as np\n",
25+
"import os\n",
2526
"import pathlib"
2627
]
2728
},
@@ -96,7 +97,7 @@
9697
"metadata": {},
9798
"source": [
9899
"This data frame has a three-level index:\n",
99-
"- `ProgramId` is an integer ID that uniquely identifies the XLA module\n",
100+
"- `ProgramId` is a string hash that uniquely identifies the XLA module\n",
100101
"- This is the `ProgramExecution`-th execution of the module within the profiles. You may see this starting from 2, not 0, because of the `warmup_removal_heuristics` option passed to `load_profiler_data`.\n",
101102
"- `Device` is the global (across multiple nodes and processes) index of the GPU on which the module execution took place\n",
102103
"\n",
@@ -371,6 +372,7 @@
371372
"gpu_active_unknown = gpu_active + [\"[Unknown]\"]\n",
372373
"gpu_idle_inside_modules = [\"[GPU idle during module execution]\"]\n",
373374
"gpu_idle_between_modules = [\"[GPU idle between module executions]\"]\n",
375+
"inconsistent_metadata = [\"[inconsistent metadata]\"]\n",
374376
"\n",
375377
"\n",
376378
"@functools.cache\n",
@@ -382,17 +384,41 @@
382384
" for called_inst in hlo_module.find_computation(called_comp_id).instructions\n",
383385
" ]\n",
384386
" metadata = [inst.metadata for inst in instructions]\n",
387+
" names = [meta.op_name for meta in metadata]\n",
385388
" frames = [hlo_module.get_stack_frames(meta.stack_frame_id) for meta in metadata]\n",
386-
" return hlo_inst.proto().opcode, metadata, frames\n",
389+
" return hlo_inst.proto().opcode, names, frames\n",
390+
"\n",
391+
"\n",
392+
"def reduce_instructions_and_frames(tup1, tup2):\n",
393+
" op1, names1, frames1 = tup1\n",
394+
" op2, names2, frames2 = tup2\n",
395+
" assert op1 == op2, (op1, op2)\n",
396+
" assert names1 == names2, (names1, names2)\n",
397+
" # If the call sites leading to the first JIT of a function were different in\n",
398+
" # different processes, the recorded stacks will be different in different\n",
399+
" # metadata dumps. Fudge that by keeping the common prefix and suffix and replacing\n",
400+
" # the middle with an \"inconsistent\" message.\n",
401+
" common_frames = []\n",
402+
" for stack1, stack2 in zip(frames1, frames2):\n",
403+
" if stack1 != stack2:\n",
404+
" common_prefix = os.path.commonprefix([stack1, stack2])\n",
405+
" stack1.reverse()\n",
406+
" stack2.reverse()\n",
407+
" common_suffix = os.path.commonprefix([stack1, stack2])\n",
408+
" common_frames.append(common_prefix + inconsistent_metadata + common_suffix)\n",
409+
" else:\n",
410+
" common_frames.append(stack1)\n",
411+
" return op1, names1, common_frames\n",
387412
"\n",
388413
"\n",
389414
"for thunk_row in thunk_summary.itertuples():\n",
390415
" program_id, thunk_name = thunk_row.Index\n",
391416
" # policy=\"all\" means we may get a set of HloProto instead of a single one, if\n",
392417
" # nsys-jax-combine was used and the dumped metadata were not bitwise identical\n",
393418
" hlo_modules = xla_module_metadata(program_id, policy=\"all\", prefix=prefix)\n",
394-
" thunk_opcode, inst_metadata, inst_frames = hlo_modules.unique_result(\n",
395-
" lambda proto: instructions_and_frames(proto, thunk_name)\n",
419+
" thunk_opcode, inst_op_names, inst_frames = hlo_modules.reduce_result(\n",
420+
" lambda proto: instructions_and_frames(proto, thunk_name),\n",
421+
" reduce_instructions_and_frames,\n",
396422
" )\n",
397423
"\n",
398424
" # Summarise by opcode, i.e. fusion/custom-call/...\n",
@@ -418,8 +444,8 @@
418444
" # 2nd choice: gpu_active_unknown\n",
419445
" {tuple(gpu_active_unknown)},\n",
420446
" )\n",
421-
" for meta, frames in zip(inst_metadata, inst_frames):\n",
422-
" op_name = [meta.op_name] if len(meta.op_name) else []\n",
447+
" for op_name_str, frames in zip(inst_op_names, inst_frames):\n",
448+
" op_name = [op_name_str] if len(op_name_str) else []\n",
423449
" if len(frames):\n",
424450
" src_runtime_preferences[0].add(tuple(gpu_active + frames + op_name))\n",
425451
" if len(op_name):\n",

.github/container/nsys_jax/nsys_jax/analyses/summary.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,10 @@ def main():
4848
/ module_stats[("ProjDurMs", "sum")].sum()
4949
)
5050

51-
if steady_state.communication is not None and len(steady_state.communication):
51+
have_comms = steady_state.communication is not None and len(
52+
steady_state.communication
53+
)
54+
if have_comms:
5255
# Calculate the time spent waiting in collectives for each module.
5356
# Min/max over devices within individual communication thunk executions
5457
min_max_device_times = (
@@ -83,9 +86,10 @@ def dump(fname, df):
8386
"Thunks": lambda _, v: f"{v:S}" if v.s else f"{v.n:.0f}",
8487
"Duration [ms]": lambda _, v: f"{v:S}",
8588
"Duration [%]": lambda _, v: f"{v:.3f}",
86-
"Wait time [ms]": lambda _, v: "---" if math.isnan(v.n) else f"{v:S}",
87-
"Wait time [%]": lambda _, v: "---" if math.isnan(v) else f"{v:.3f}",
8889
}
90+
if have_comms:
91+
fields["Wait time [ms]"] = lambda _, v: "---" if math.isnan(v.n) else f"{v:S}"
92+
fields["Wait time [%]"] = lambda _, v: "---" if math.isnan(v) else f"{v:.3f}"
8993
table = PrettyTable(align="r", custom_format=fields, field_names=fields.keys())
9094
for id, row in module_stats.iterrows():
9195
table.add_row(
@@ -96,9 +100,15 @@ def dump(fname, df):
96100
ufloat(row[("NumThunks", "mean")], row[("NumThunks", "std")]),
97101
ufloat(row[("ProjDurMs", "mean")], row[("ProjDurMs", "std")]),
98102
row[("ProjDurMs", "percent")],
99-
ufloat(row[("WaitMs", "mean")], row[("WaitMs", "std")]),
100-
row[("WaitMs", "percent")],
101103
]
104+
+ (
105+
[
106+
ufloat(row[("WaitMs", "mean")], row[("WaitMs", "std")]),
107+
row[("WaitMs", "percent")],
108+
]
109+
if have_comms
110+
else []
111+
)
102112
)
103113
print(table)
104114

.github/container/nsys_jax/nsys_jax/data_loaders.py

Lines changed: 34 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
import re
1111

1212
from .analysis import calculate_collective_metrics
13-
from .protobuf import xla_module_metadata
13+
from .protobuf import _hlo_cache, _remap_program_id, xla_module_metadata
14+
from .protobuf_utils import ensure_compiled_protos_are_importable
1415
from .utils import default_data_prefix, make_child_mask, ProfilerData
1516

1617
pd.options.mode.copy_on_write = True
@@ -20,7 +21,7 @@
2021
def _is_communication(
2122
program_id: int, prefix: pathlib.Path, instruction_name: str
2223
) -> bool:
23-
if program_id == -1:
24+
if program_id == "unknown":
2425
# Assume this is an autotuning execution.
2526
return False
2627
try:
@@ -143,10 +144,11 @@ def _sort_thunk_frame(df: pd.DataFrame) -> pd.DataFrame:
143144

144145
def _load_nvtx_gpu_proj_trace_single(
145146
prefix: pathlib.Path,
147+
replica: str | None,
146148
file: pathlib.Path,
147149
meta_file: pathlib.Path,
148150
frames: set[str],
149-
) -> dict[str, pd.DataFrame]:
151+
) -> tuple[dict[str, pd.DataFrame], dict[tuple[pathlib.Path, str], set[pathlib.Path]]]:
150152
# Load the thread metadata used to map module/thunk executions to global device IDs
151153
meta_df = _load_parquet_file(meta_file)
152154
# Match XLA's launcher thread name. These threads launch work if >1 GPU is being
@@ -299,22 +301,25 @@ def _load_nvtx_gpu_proj_trace_single(
299301
# The classic example where it is not set is during autotuning, where ops
300302
# to be autotuned are extracted into new HloModule instances, which are not
301303
# propagated to the GpuExecutable that emits the XlaModule annotation.
302-
# Those are probably not interesting, so setting the ProgramId to -1 in
303-
# such cases is acceptable.
304+
# Those are probably not interesting, so setting the ProgramId to
305+
# "unknown" in such cases is acceptable.
304306
module_re = (
305307
"^"
306308
+ tsl_prefix
307309
+ r"XlaModule:#(?:prefix=(.*?),|)hlo_module=([a-z0-9._-]+)(?:,program_id=(\d+)|)#$"
308310
)
309-
mod_program_ids = (
310-
df.loc[mod_ids, "Name"]
311-
.str.replace(
312-
pat=module_re,
313-
repl=lambda m: "-1" if m.group(3) is None else m.group(3),
314-
n=1,
315-
regex=True,
316-
)
317-
.astype(np.int32)
311+
# Apply a transformation to the program IDs to handle the case where profiles are
312+
# being combined from multiple processes, but the distributed application was not
313+
# strictly SPMD - so the IDs collected from different processes do not match for
314+
# "the same" program. The multi_process_program.py test in the nsys_jax test suite
315+
# explicitly constructs this scenario.
316+
mod_program_ids = df.loc[mod_ids, "Name"].str.replace(
317+
pat=module_re,
318+
repl=lambda m: _remap_program_id(
319+
old_id_str=m.group(3), name=m.group(2), prefix=prefix, replica=replica
320+
),
321+
n=1,
322+
regex=True,
318323
)
319324
# Update each module and thunk row with the program ID it corresponds to
320325
df.loc[mod_ids, "ProgramId"] = mod_program_ids
@@ -385,7 +390,7 @@ def clean_data_frame(d):
385390
"RangeStack",
386391
"TID",
387392
]
388-
).astype({"ProgramExecution": np.int32, "ProgramId": np.int32})
393+
).astype({"ProgramExecution": np.int32})
389394

390395
output = {}
391396
if "thunk" in frames:
@@ -427,7 +432,7 @@ def clean_data_frame(d):
427432
["ProgramId", "ProgramExecution", "Device"]
428433
)
429434

430-
return output
435+
return output, _hlo_cache
431436

432437

433438
def _enough_processes(work_items: int) -> int:
@@ -440,33 +445,42 @@ def _load_nvtx_gpu_proj_trace(
440445
prefix: pathlib.Path,
441446
frames: set[str],
442447
):
448+
# _remap_program_id needs to load protos
449+
ensure_compiled_protos_are_importable(prefix=prefix)
443450
path = prefix / "nvtx_gpu_proj_trace" / "trace.parquet"
444451
meta_path = prefix / "thread-metadata.parquet"
452+
replica_slugs: list[str | None]
445453
if path.is_dir():
446454
# We're looking at the output of nsys-jax-combine
447455
assert meta_path.is_dir()
448456
filenames = sorted(path.iterdir())
457+
replica_slugs = [fname.name for fname in filenames]
449458
meta_filenames = sorted(meta_path.iterdir())
450459
else:
451460
# We're looking at the output of nsys-jax
452461
assert not meta_path.is_dir()
453462
filenames = [path]
463+
replica_slugs = [None]
454464
meta_filenames = [meta_path]
455465

456466
if len(filenames) > 1:
457467
tmp = defaultdict(list)
458468
with multiprocessing.Pool(processes=_enough_processes(len(filenames))) as pool:
459-
for single_trace in pool.starmap(
469+
for single_trace, hlo_cache in pool.starmap(
460470
_load_nvtx_gpu_proj_trace_single,
461471
zip(
462472
itertools.repeat(prefix),
473+
replica_slugs,
463474
filenames,
464475
meta_filenames,
465476
itertools.repeat(frames),
466477
),
467478
):
468479
for k, v in single_trace.items():
469480
tmp[k].append(v)
481+
# Merge the caches from the pool worker processes into the main one.
482+
for k2, v2 in hlo_cache.items():
483+
_hlo_cache[k2] |= v2
470484
output = {}
471485
for k, v in tmp.items():
472486
output[k] = pd.concat(v, verify_integrity=True)
@@ -477,8 +491,9 @@ def _load_nvtx_gpu_proj_trace(
477491
if "thunk" in output:
478492
output["thunk"] = _sort_thunk_frame(output["thunk"])
479493
else:
480-
output = _load_nvtx_gpu_proj_trace_single(
481-
prefix, filenames[0], meta_filenames[0], frames
494+
# No explicit handling of the HLO cache, everything is in one process
495+
output, _ = _load_nvtx_gpu_proj_trace_single(
496+
prefix, None, filenames[0], meta_filenames[0], frames
482497
)
483498
if "module" in output:
484499
output["module"] = output["module"].sort_index()

0 commit comments

Comments
 (0)