Skip to content

Commit d62000e

Browse files
authored
nsys-jax: make analysis script more tolerant (#1616)
- Tolerate multi-process HLO dumps that are not strictly identical. - Follow upstream renaming of thread names.
1 parent 668d58b commit d62000e

File tree

2 files changed

+20
-9
lines changed

2 files changed

+20
-9
lines changed

.github/container/nsys_jax/nsys_jax/analyses/pgle_costs.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,29 @@
66
load_profiler_data,
77
xla_module_metadata,
88
)
9+
from nsys_jax.protobuf import HloProto, HloProtoSet
910
import pathlib
1011

1112

12-
def write_pbtxt(outdir: pathlib.Path, series_ms, hlo_module):
13-
mod_proto = hlo_module.proto().hlo_module
14-
fingerprint = mod_proto.frontend_attributes.map["fingerprint_before_lhs"]
13+
def get_scheduling_name(module: HloProto, name: str) -> str:
14+
_, inst = module.find_instruction(name)
15+
return inst.proto().metadata.scheduling_name
16+
17+
18+
def write_pbtxt(outdir: pathlib.Path, series_ms, hlo_module_set: HloProtoSet):
19+
fingerprint = hlo_module_set.unique_result(
20+
lambda mod: mod.proto().hlo_module.frontend_attributes.map[
21+
"fingerprint_before_lhs"
22+
]
23+
)
1524
outdir.mkdir(exist_ok=True)
1625
fp_fname = f"{fingerprint}.pbtxt"
1726
null_names = 0
1827
with open(outdir / fp_fname, "w") as ofile:
1928
for name, cost_ms in series_ms.items():
20-
comp, inst = hlo_module.find_instruction(name)
21-
scheduling_name = inst.proto().metadata.scheduling_name
29+
scheduling_name = hlo_module_set.unique_result(
30+
lambda mod: get_scheduling_name(mod, name)
31+
)
2232
null_names += len(scheduling_name) == 0
2333
ofile.write(
2434
f'costs {{ name: "{scheduling_name}" cost_us: {cost_ms * 1000:.1f} }}\n'
@@ -61,7 +71,7 @@ def main():
6171
for row in module_ranking.itertuples():
6272
print(f"Processing module {row.Name} ({row.Index})")
6373
try:
64-
hlo_module = xla_module_metadata(row.Index, prefix=args.prefix)
74+
hlo_set = xla_module_metadata(row.Index, policy="all", prefix=args.prefix)
6575
except Exception as e:
6676
print(f"Skipping due to: {e}")
6777
continue
@@ -73,7 +83,7 @@ def main():
7383
write_pbtxt(
7484
pathlib.Path("./maxcomm_mincompute"),
7585
min_compute_max_comm(thunk_df.groupby("Name")),
76-
hlo_module,
86+
hlo_set,
7787
)
7888

7989

.github/container/nsys_jax/nsys_jax/data_loaders.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,11 +152,12 @@ def _load_nvtx_gpu_proj_trace_single(
152152
# Load the thread metadata used to map module/thunk executions to global device IDs
153153
meta_df = _load_parquet_file(meta_file)
154154
# Match XLA's launcher thread name. These threads launch work if >1 GPU is being
155-
# driven by the process.
155+
# driven by the process. xla#29725 renamed slice -> partition.
156156
device_by_pid_tid = (
157157
meta_df["Name"]
158158
.str.extract(
159-
r"^XlaLauncher:#global=(?P<Device>\d+),local=(?P<LocalDevice>\d+),process=(?P<Process>\d+),slice=(?P<Slice>\d+)#$"
159+
r"^XlaLauncher:#global=(?P<Device>\d+),local=(?P<LocalDevice>\d+),"
160+
r"process=(?P<Process>\d+),(?:partition|slice)=(?P<Slice>\d+)#$"
160161
)
161162
.dropna()
162163
.astype(np.int32)

0 commit comments

Comments
 (0)