|
12 | 12 | "from nsys_jax import (\n", |
13 | 13 | " align_profiler_data_timestamps,\n", |
14 | 14 | " apply_warmup_heuristics,\n", |
| 15 | + " default_data_prefix,\n", |
15 | 16 | " display_flamegraph,\n", |
16 | 17 | " ensure_compiled_protos_are_importable,\n", |
17 | 18 | " generate_compilation_statistics,\n", |
|
23 | 24 | "import numpy as np" |
24 | 25 | ] |
25 | 26 | }, |
| 27 | + { |
| 28 | + "cell_type": "code", |
| 29 | + "execution_count": null, |
| 30 | + "id": "7a91f0e7-17da-4534-8ea9-29bcf3742567", |
| 31 | + "metadata": {}, |
| 32 | + "outputs": [], |
| 33 | + "source": [ |
| 34 | + "# Set the input data to use. default_data_prefix() checks the NSYS_JAX_DEFAULT_PREFIX environment variable, and if that is\n", |
| 35 | + "# not set then the current working directory is used. Use pathlib.Path if setting this explicitly.\n", |
| 36 | + "prefix = default_data_prefix()" |
| 37 | + ] |
| 38 | + }, |
26 | 39 | { |
27 | 40 | "cell_type": "code", |
28 | 41 | "execution_count": null, |
|
32 | 45 | "source": [ |
33 | 46 | "# Make sure that the .proto files under protos/ have been compiled to .py, and\n", |
34 | 47 | "# that those generated .py files are importable.]\n", |
35 | | - "compiled_dir = ensure_compiled_protos_are_importable()" |
| 48 | + "compiled_dir = ensure_compiled_protos_are_importable(prefix=prefix)" |
36 | 49 | ] |
37 | 50 | }, |
38 | 51 | { |
|
43 | 56 | "outputs": [], |
44 | 57 | "source": [ |
45 | 58 | "# Load the runtime profile data\n", |
46 | | - "all_data = load_profiler_data()\n", |
| 59 | + "all_data = load_profiler_data(prefix)\n", |
47 | 60 | "# Remove some detail from the autotuner\n", |
48 | 61 | "all_data = remove_autotuning_detail(all_data)\n", |
49 | 62 | "# Align GPU timestamps across profiles collected by different Nsight Systems processes\n", |
|
82 | 95 | "source": [ |
83 | 96 | "This data frame has a three-level index:\n", |
84 | 97 | "- `ProgramId` is an integer ID that uniquely identifies the XLA module\n", |
85 | | - "- This is the `ProgramExecution`-th execution of the module within the profiles. You may see this starting from 1, not 0, because of the `warmup_removal_heuristics` option passed to `load_profiler_data`.\n", |
| 98 | + "- This is the `ProgramExecution`-th execution of the module within the profiles. You may see this starting from 2, not 0, because of the `warmup_removal_heuristics` option passed to `load_profiler_data`.\n", |
86 | 99 | "- `Device` is the global (across multiple nodes and processes) index of the GPU on which the module execution took place\n", |
87 | 100 | "\n", |
88 | 101 | "The columns are as follows:\n", |
89 | 102 | "- `Name`: the name of the XLA module; this should always be the same for a given `ProgramId`\n", |
90 | 103 | "- `NumThunks`: the number of thunks executed inside this module execution\n", |
91 | 104 | "- `ProjStartMs`: the timestamp of the start of the module execution on the GPU, in milliseconds\n", |
92 | 105 | "- `ProjDurMs`: the duration of the module execution on the GPU, in milliseconds\n", |
93 | | - "- `OrigStartMs`: the timestamp of the start of the module launch **on the host**, in milliseconds. *i.e.* `ProjStartMs-OrigStartMs` is something like the launch latency of the first kernel\n", |
94 | | - "- `OrigDurMs`: the duration of the module launch **on the host**, in milliseconds\n", |
95 | 106 | "- `LocalDevice`: the index within the node/slice of the GPU on which the module execution took place\n", |
96 | 107 | "- `Process`: the global (across multiple nodes) index of the process\n", |
97 | 108 | "- `Slice`: the global index of the node/slice; devices within the same node/slice should have faster interconnects than to devices in different slices\n", |
|
117 | 128 | "id": "7727d800-13d3-4505-89e8-80a5fed63512", |
118 | 129 | "metadata": {}, |
119 | 130 | "source": [ |
120 | | - "Here the index has four levels. `ProgramId`, `ProgramExecution` and `Device` have the same meanings as in `module_df`.\n", |
| 131 | + "Here the index has four levels. `ProgramId`, `ProgramExecution` and `Device` have the same meanings as in `steady_state.module`.\n", |
121 | 132 | "The fourth level (in the 3rd position) shows that this row is the `ThunkIndex`-th thunk within the `ProgramExecution`-th execution of XLA module `ProgramId`.\n", |
122 | 133 | "Note that a given thunk can be executed multiple times within the same module, so indexing on the thunk name would not be unique.\n", |
123 | 134 | "\n", |
124 | 135 | "The columns are as follows:\n", |
125 | 136 | "- `Name`: the name of the thunk; this should be unique within a given `ProgramId` and can be used as a key to look up XLA metadata\n", |
126 | | - "- `ProjStartMs`, `OrigStartMs`, `OrigDurMs`: see above, same meaning as in `module_df`.\n", |
| 137 | + "- `ProjStartMs`: see above, same meaning as in `steady_state.module`.\n", |
127 | 138 | "- `Communication`: does this thunk represent communication between GPUs (*i.e.* a NCCL collective)? XLA overlaps communication and computation kernels, and `load_profiler_data` triggers an overlap calculation. `ProjDurMs` for a communication kernel shows only the duration that was **not** overlapped with computation kernels, while `ProjDurHiddenMs` shows the duration that **was** overlapped.\n", |
128 | 139 | "- This is the `ThunkExecution`-th execution of this thunk for this `(ProgramId, ProgramExecution, Device)`\n", |
129 | 140 | "\n", |
|
299 | 310 | "# Print out the largest entries adding up to at least this fraction of the total\n", |
300 | 311 | "threshold = 0.97\n", |
301 | 312 | "compile_summary[\"FracNonChild\"] = compile_summary[\"DurNonChildMs\"] / total_compile_time\n", |
302 | | - "print(f\"Top {threshold:.0%}+ of {total_compile_time * 1e-9:.2f}s compilation time\")\n", |
| 313 | + "print(f\"Top {threshold:.0%}+ of {total_compile_time*1e-3:.2f}s compilation time\")\n", |
303 | 314 | "for row in compile_summary[\n", |
304 | 315 | " compile_summary[\"FracNonChild\"].cumsum() <= threshold\n", |
305 | 316 | "].itertuples():\n", |
|
378 | 389 | " program_id, thunk_name = thunk_row.Index\n", |
379 | 390 | " # policy=\"all\" means we may get a set of HloProto instead of a single one, if\n", |
380 | 391 | " # nsys-jax-combine was used and the dumped metadata were not bitwise identical\n", |
381 | | - " hlo_modules = xla_module_metadata(program_id, policy=\"all\")\n", |
| 392 | + " hlo_modules = xla_module_metadata(program_id, policy=\"all\", prefix=prefix)\n", |
382 | 393 | " thunk_opcode, inst_metadata, inst_frames = hlo_modules.unique_result(\n", |
383 | 394 | " lambda proto: instructions_and_frames(proto, thunk_name)\n", |
384 | 395 | " )\n", |
|
0 commit comments