Skip to content

Commit 9c56b12

Browse files
committed
Minor bug fixes in example notebook
Plumb through `prefix` so it's more convenient to explicitly set the input data path.
1 parent 66715ec commit 9c56b12

File tree

2 files changed

+22
-10
lines changed

2 files changed

+22
-10
lines changed

.github/container/nsys_jax/nsys_jax/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from .data_loaders import load_profiler_data
88
from .protobuf import xla_module_metadata
99
from .protobuf_utils import compile_protos, ensure_compiled_protos_are_importable
10-
from .utils import remove_autotuning_detail, remove_child_ranges
10+
from .utils import default_data_prefix, remove_autotuning_detail, remove_child_ranges
1111
from .visualization import create_flamegraph, display_flamegraph
1212

1313
__all__ = [
@@ -16,6 +16,7 @@
1616
"calculate_collective_metrics",
1717
"compile_protos",
1818
"create_flamegraph",
19+
"default_data_prefix",
1920
"display_flamegraph",
2021
"ensure_compiled_protos_are_importable",
2122
"generate_compilation_statistics",

.github/container/nsys_jax/nsys_jax/analyses/Analysis.ipynb

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
"from nsys_jax import (\n",
1313
" align_profiler_data_timestamps,\n",
1414
" apply_warmup_heuristics,\n",
15+
" default_data_prefix,\n",
1516
" display_flamegraph,\n",
1617
" ensure_compiled_protos_are_importable,\n",
1718
" generate_compilation_statistics,\n",
@@ -23,6 +24,18 @@
2324
"import numpy as np"
2425
]
2526
},
27+
{
28+
"cell_type": "code",
29+
"execution_count": null,
30+
"id": "7a91f0e7-17da-4534-8ea9-29bcf3742567",
31+
"metadata": {},
32+
"outputs": [],
33+
"source": [
34+
"# Set the input data to use. default_data_prefix() checks the NSYS_JAX_DEFAULT_PREFIX environment variable, and if that is\n",
35+
"# not set then the current working directory is used. Use pathlib.Path if setting this explicitly.\n",
36+
"prefix = default_data_prefix()"
37+
]
38+
},
2639
{
2740
"cell_type": "code",
2841
"execution_count": null,
@@ -32,7 +45,7 @@
3245
"source": [
3346
"# Make sure that the .proto files under protos/ have been compiled to .py, and\n",
3447
"# that those generated .py files are importable.]\n",
35-
"compiled_dir = ensure_compiled_protos_are_importable()"
48+
"compiled_dir = ensure_compiled_protos_are_importable(prefix=prefix)"
3649
]
3750
},
3851
{
@@ -43,7 +56,7 @@
4356
"outputs": [],
4457
"source": [
4558
"# Load the runtime profile data\n",
46-
"all_data = load_profiler_data()\n",
59+
"all_data = load_profiler_data(prefix)\n",
4760
"# Remove some detail from the autotuner\n",
4861
"all_data = remove_autotuning_detail(all_data)\n",
4962
"# Align GPU timestamps across profiles collected by different Nsight Systems processes\n",
@@ -82,16 +95,14 @@
8295
"source": [
8396
"This data frame has a three-level index:\n",
8497
"- `ProgramId` is an integer ID that uniquely identifies the XLA module\n",
85-
"- This is the `ProgramExecution`-th execution of the module within the profiles. You may see this starting from 1, not 0, because of the `warmup_removal_heuristics` option passed to `load_profiler_data`.\n",
98+
"- This is the `ProgramExecution`-th execution of the module within the profiles. You may see this starting from 2, not 0, because of the `warmup_removal_heuristics` option passed to `load_profiler_data`.\n",
8699
"- `Device` is the global (across multiple nodes and processes) index of the GPU on which the module execution took place\n",
87100
"\n",
88101
"The columns are as follows:\n",
89102
"- `Name`: the name of the XLA module; this should always be the same for a given `ProgramId`\n",
90103
"- `NumThunks`: the number of thunks executed inside this module execution\n",
91104
"- `ProjStartMs`: the timestamp of the start of the module execution on the GPU, in milliseconds\n",
92105
"- `ProjDurMs`: the duration of the module execution on the GPU, in milliseconds\n",
93-
"- `OrigStartMs`: the timestamp of the start of the module launch **on the host**, in milliseconds. *i.e.* `ProjStartMs-OrigStartMs` is something like the launch latency of the first kernel\n",
94-
"- `OrigDurMs`: the duration of the module launch **on the host**, in milliseconds\n",
95106
"- `LocalDevice`: the index within the node/slice of the GPU on which the module execution took place\n",
96107
"- `Process`: the global (across multiple nodes) index of the process\n",
97108
"- `Slice`: the global index of the node/slice; devices within the same node/slice should have faster interconnects than to devices in different slices\n",
@@ -117,13 +128,13 @@
117128
"id": "7727d800-13d3-4505-89e8-80a5fed63512",
118129
"metadata": {},
119130
"source": [
120-
"Here the index has four levels. `ProgramId`, `ProgramExecution` and `Device` have the same meanings as in `module_df`.\n",
131+
"Here the index has four levels. `ProgramId`, `ProgramExecution` and `Device` have the same meanings as in `steady_state.module`.\n",
121132
"The fourth level (in the 3rd position) shows that this row is the `ThunkIndex`-th thunk within the `ProgramExecution`-th execution of XLA module `ProgramId`.\n",
122133
"Note that a given thunk can be executed multiple times within the same module, so indexing on the thunk name would not be unique.\n",
123134
"\n",
124135
"The columns are as follows:\n",
125136
"- `Name`: the name of the thunk; this should be unique within a given `ProgramId` and can be used as a key to look up XLA metadata\n",
126-
"- `ProjStartMs`, `OrigStartMs`, `OrigDurMs`: see above, same meaning as in `module_df`.\n",
137+
"- `ProjStartMs`: see above, same meaning as in `steady_state.module`.\n",
127138
"- `Communication`: does this thunk represent communication between GPUs (*i.e.* a NCCL collective)? XLA overlaps communication and computation kernels, and `load_profiler_data` triggers an overlap calculation. `ProjDurMs` for a communication kernel shows only the duration that was **not** overlapped with computation kernels, while `ProjDurHiddenMs` shows the duration that **was** overlapped.\n",
128139
"- This is the `ThunkExecution`-th execution of this thunk for this `(ProgramId, ProgramExecution, Device)`\n",
129140
"\n",
@@ -299,7 +310,7 @@
299310
"# Print out the largest entries adding up to at least this fraction of the total\n",
300311
"threshold = 0.97\n",
301312
"compile_summary[\"FracNonChild\"] = compile_summary[\"DurNonChildMs\"] / total_compile_time\n",
302-
"print(f\"Top {threshold:.0%}+ of {total_compile_time*1e-9:.2f}s compilation time\")\n",
313+
"print(f\"Top {threshold:.0%}+ of {total_compile_time*1e-3:.2f}s compilation time\")\n",
303314
"for row in compile_summary[\n",
304315
" compile_summary[\"FracNonChild\"].cumsum() <= threshold\n",
305316
"].itertuples():\n",
@@ -378,7 +389,7 @@
378389
" program_id, thunk_name = thunk_row.Index\n",
379390
" # policy=\"all\" means we may get a set of HloProto instead of a single one, if\n",
380391
" # nsys-jax-combine was used and the dumped metadata were not bitwise identical\n",
381-
" hlo_modules = xla_module_metadata(program_id, policy=\"all\")\n",
392+
" hlo_modules = xla_module_metadata(program_id, policy=\"all\", prefix=prefix)\n",
382393
" thunk_opcode, inst_metadata, inst_frames = hlo_modules.unique_result(\n",
383394
" lambda proto: instructions_and_frames(proto, thunk_name)\n",
384395
" )\n",

0 commit comments

Comments
 (0)