|
21 | 21 | " xla_module_metadata,\n", |
22 | 22 | ")\n", |
23 | 23 | "import matplotlib.pyplot as plt\n", |
24 | | - "import numpy as np" |
| 24 | + "import numpy as np\n", |
| 25 | + "import pathlib" |
25 | 26 | ] |
26 | 27 | }, |
27 | 28 | { |
|
33 | 34 | "source": [ |
34 | 35 | "# Set the input data to use. default_data_prefix() checks the NSYS_JAX_DEFAULT_PREFIX environment variable, and if that is\n", |
35 | 36 | "# not set then the current working directory is used. Use pathlib.Path if setting this explicitly.\n", |
| 37 | + "prefix = pathlib.Path(\".\") # modify this and comment out the next line\n", |
36 | 38 | "prefix = default_data_prefix()" |
37 | 39 | ] |
38 | 40 | }, |
|
128 | 130 | "id": "7727d800-13d3-4505-89e8-80a5fed63512", |
129 | 131 | "metadata": {}, |
130 | 132 | "source": [ |
131 | | - "Here the index has four levels. `ProgramId`, `ProgramExecution` and `Device` have the same meanings as in `steady_state.module`.\n", |
132 | | - "The fourth level (in the 3rd position) shows that this row is the `ThunkIndex`-th thunk within the `ProgramExecution`-th execution of XLA module `ProgramId`.\n", |
133 | | - "Note that a given thunk can be executed multiple times within the same module, so indexing on the thunk name would not be unique.\n", |
| 133 | + "Here the index has five levels. `ProgramId`, `ProgramExecution` and `Device` have the same meanings as in `steady_state.module`.\n", |
| 134 | + "The two new levels, `Name` and `ThunkExecution`, show that a given row is the `ThunkExecution`-th execution within the `ProgramExecution`-th execution of XLA module `ProgramId` of thunk `Name`.\n", |
| 135 | + "The `ThunkExecution` value is needed because a given thunk can be executed multiple times within the same module.\n", |
| 136 | + "The `Name` of a thunk can be used, along with a `ProgramId`, to look up XLA metadata.\n", |
134 | 137 | "\n", |
135 | 138 | "The columns are as follows:\n", |
136 | | - "- `Name`: the name of the thunk; this should be unique within a given `ProgramId` and can be used as a key to look up XLA metadata\n", |
137 | 139 | "- `ProjStartMs`: see above, same meaning as in `steady_state.module`.\n", |
138 | 140 | "- `Communication`: does this thunk represent communication between GPUs (*i.e.* a NCCL collective)? XLA overlaps communication and computation kernels, and `load_profiler_data` triggers an overlap calculation. `ProjDurMs` for a communication kernel shows only the duration that was **not** overlapped with computation kernels, while `ProjDurHiddenMs` shows the duration that **was** overlapped.\n", |
139 | | - "- This is the `ThunkExecution`-th execution of this thunk for this `(ProgramId, ProgramExecution, Device)`\n", |
140 | 141 | "\n", |
141 | 142 | "The third data frame does not show any GPU execution, but is rather a host-side trace:" |
142 | 143 | ] |
|
178 | 179 | "id": "2e82c357-4e9d-48e4-b758-fa5357b2c8bd", |
179 | 180 | "metadata": {}, |
180 | 181 | "source": [ |
181 | | - "The index structure, and many of the columns, are equivalent to `thunk_df`. Additional columns are:\n", |
| 182 | + "The index structure, and many of the columns, are equivalent to the `.thunk` data frame. Additional columns are:\n", |
182 | 183 | "\n", |
183 | 184 | "- `MessageSize`: the message size of the collective in bytes; this aims to follow the same conventions as the NCCL tests\n", |
184 | 185 | "- `Collective`: the type of collective communication\n", |
|
524 | 525 | " # program, there may be different sub-groupings that are participating in smaller\n", |
525 | 526 | " # collectives in the strict/NCCL sense. TODO: it would be better to identify those\n", |
526 | 527 | " # sub-groupings and group them, but we currently lack the relevant information.\n", |
527 | | - " collective_df = df.groupby([\"ProgramId\", \"ProgramExecution\", \"ThunkIndex\"])\n", |
| 528 | + " collective_df = df.groupby(\n", |
| 529 | + " [\"ProgramId\", \"ProgramExecution\", \"Name\", \"ThunkExecution\"]\n", |
| 530 | + " )\n", |
528 | 531 | " # Take the fastest device kernel as a proxy for the actual bandwidth of the\n", |
529 | 532 | " # collective.\n", |
530 | 533 | " bandwidth_df = collective_df.agg(\n", |
|
534 | 537 | " \"ProjStartMs\": \"min\",\n", |
535 | 538 | " \"ProjDurFullMs\": \"min\",\n", |
536 | 539 | " \"ProjEndMs\": \"max\",\n", |
537 | | - " \"Name\": \"count\",\n", |
538 | 540 | " }\n", |
539 | 541 | " )\n", |
540 | 542 | " axs[0].plot(\n", |
|
582 | 584 | "\n", |
583 | 585 | "# Calculate statistics over different devices and different executions of each thunk, including multiple executions of the same thunk within the same module\n", |
584 | 586 | "compute_durations = steady_state.thunk.loc[\n", |
585 | | - " ~steady_state.thunk[\"Communication\"], (\"Name\", \"ProjDurMs\")\n", |
| 587 | + " ~steady_state.thunk[\"Communication\"], \"ProjDurMs\"\n", |
586 | 588 | "].groupby([\"ProgramId\", \"Name\"])\n", |
587 | | - "compute_duration_stats = compute_durations[\"ProjDurMs\"].agg((\"mean\", \"std\"))\n", |
| 589 | + "compute_duration_stats = compute_durations.agg((\"mean\", \"std\"))\n", |
588 | 590 | "compute_duration_means = compute_duration_stats[\"mean\"]\n", |
589 | 591 | "compute_duration_rel_stds = compute_duration_stats[\"std\"] / compute_duration_means\n", |
590 | 592 | "\n", |
|
634 | 636 | "\n", |
635 | 637 | "def durations_ms(idx):\n", |
636 | 638 | " program_id, thunk_name = idx\n", |
637 | | - " tmp = steady_state.thunk.loc[program_id, (\"Name\", \"ProjDurMs\")]\n", |
638 | | - " return tmp.loc[tmp[\"Name\"] == thunk_name, \"ProjDurMs\"]\n", |
| 639 | + " return steady_state.thunk.loc[(program_id, slice(None), thunk_name), \"ProjDurMs\"]\n", |
639 | 640 | "\n", |
640 | 641 | "\n", |
641 | 642 | "detailed_index = high_variance_means[high_variance_means > mean_threshold].index\n", |
|
666 | 667 | " squeeze=False,\n", |
667 | 668 | " tight_layout=True,\n", |
668 | 669 | " )\n", |
| 670 | + " # Compute (non-comm) kernel timings\n", |
669 | 671 | " time_df = steady_state.thunk.loc[\n", |
670 | 672 | " ~steady_state.thunk[\"Communication\"], (\"ProjStartMs\", \"ProjDurMs\")\n", |
671 | 673 | " ]\n", |
|
688 | 690 | " ):\n", |
689 | 691 | " # Mean over devices to get a single [thunk0_start, thunk0_end, thunk1_start, ...]\n", |
690 | 692 | " # array for this execution of this module\n", |
691 | | - " mean_times = interleave(exec_df.groupby(\"ThunkIndex\").agg(\"mean\"))\n", |
| 693 | + " mean_times = interleave(\n", |
| 694 | + " exec_df.groupby([\"Name\", \"ThunkExecution\"], sort=False).agg(\"mean\")\n", |
| 695 | + " )\n", |
692 | 696 | " # x axis of the plot will be the average over executions of the module\n", |
693 | 697 | " x_values.append(mean_times - mean_times[0])\n", |
694 | 698 | " for device, device_values in exec_df.groupby(\"Device\"):\n", |
695 | 699 | " # [thunk0_start, thunk0_end, ...] array for one device within one module exec\n", |
696 | 700 | " # with the average over devices subtracted\n", |
697 | 701 | " y_values[device].append(interleave(device_values) - mean_times)\n", |
698 | 702 | " mean_start_time_ms = np.mean(x_values, axis=0)\n", |
| 703 | + " # all_values: (num_devices, num_module_executions, thunks_per_module)\n", |
699 | 704 | " all_values = np.array(list(y_values.values()))\n", |
700 | 705 | " ax.plot(\n", |
701 | 706 | " mean_start_time_ms,\n", |
|
728 | 733 | " exec_df[\"ProjEndMs\"]\n", |
729 | 734 | " - steady_state.module.loc[(program_id, module_execution), \"ProjStartMs\"]\n", |
730 | 735 | " )\n", |
731 | | - " tmp = exec_df.groupby(\"ThunkIndex\").agg(\n", |
| 736 | + " tmp = exec_df.groupby([\"Name\", \"ThunkExecution\"]).agg(\n", |
732 | 737 | " {\n", |
733 | | - " \"Name\": \"first\",\n", |
734 | 738 | " \"Collective\": \"first\",\n", |
735 | 739 | " \"CollectiveSize\": \"first\",\n", |
736 | 740 | " \"EndInModuleMs\": \"mean\",\n", |
737 | 741 | " }\n", |
738 | 742 | " )\n", |
739 | 743 | " for coll_size, values in tmp.groupby(\"CollectiveSize\"):\n", |
740 | 744 | " comm_x_values[coll_size].append(values[\"EndInModuleMs\"])\n", |
741 | | - " (_, xmax), (ymin, ymax) = ax.get_xlim(), ax.get_ylim()\n", |
742 | | - " ax.set_xlim(0, xmax)\n", |
| 745 | + " ymin, ymax = ax.get_ylim()\n", |
| 746 | + " ax.set_xlim(mean_start_time_ms[0], mean_start_time_ms[-1])\n", |
743 | 747 | " ax.set_ylim(ymin, ymax)\n", |
744 | 748 | " largest_collective = max(comm_x_values.keys())\n", |
745 | 749 | " for n_color, (coll_size, values) in enumerate(comm_x_values.items()):\n", |
|
748 | 752 | " collective_times,\n", |
749 | 753 | " ymin,\n", |
750 | 754 | " # Draw taller vertical lines for collectives involving more devices\n", |
751 | | - " ymin * (1 - coll_size / largest_collective),\n", |
| 755 | + " ymin * (1 - 0.75 * coll_size / largest_collective),\n", |
752 | 756 | " color=f\"C{n_color}\",\n", |
753 | 757 | " label=f\"{coll_size}-device collective\",\n", |
754 | | - " linestyle=\"--\",\n", |
| 758 | + " linestyle=\"-\",\n", |
755 | 759 | " )\n", |
756 | 760 | "\n", |
757 | 761 | " ax.set_title(\n", |
|
836 | 840 | "outputs": [], |
837 | 841 | "source": [ |
838 | 842 | "num_traces = {\n", |
839 | | - " module_id: xla_module_metadata(module_id, policy=\"all\").unique_result(\n", |
| 843 | + " module_id: xla_module_metadata(\n", |
| 844 | + " module_id, policy=\"all\", prefix=prefix\n", |
| 845 | + " ).unique_result(\n", |
840 | 846 | " lambda hlo_module: len(\n", |
841 | 847 | " hlo_module.proto().buffer_assignment.heap_simulator_traces\n", |
842 | 848 | " )\n", |
|
855 | 861 | " squeeze=False,\n", |
856 | 862 | ")\n", |
857 | 863 | "for n_module, module_id in enumerate(module_ids_with_traces):\n", |
858 | | - " protos = xla_module_metadata(module_id, policy=\"all\")\n", |
| 864 | + " protos = xla_module_metadata(module_id, policy=\"all\", prefix=prefix)\n", |
859 | 865 | " sizes_by_logical_id = protos.unique_result(\n", |
860 | 866 | " lambda proto: {\n", |
861 | 867 | " buffer.id: buffer.size\n", |
|
0 commit comments