Skip to content

Commit ef3fd66

Browse files
authored
nsys-jax post-processing: treat host-device copies as 1-device collectives (#1073)
This adds logic to treat `dynamic[-update]-slice` operations that have a source/destination operand in the host memory space as being communication operations, labelling them as single-device "collectives". The goal is to improve support for analysing profiles of execution including offloading to host memory. Also fix using nsys 2024.6 by applying the same patch as 2024.5 that adds the thread ID.
1 parent 3638a66 commit ef3fd66

File tree

4 files changed

+138
-36
lines changed

4 files changed

+138
-36
lines changed

.github/container/install-nsight.sh

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,12 @@ apt-get clean
1717

1818
rm -rf /var/lib/apt/lists/*
1919

20-
NSYS202451=/opt/nvidia/nsight-systems-cli/2024.5.1
21-
if [[ -d "${NSYS202451}" ]]; then
22-
# * can match at least sbsa-armv8 and x86
23-
(cd ${NSYS202451}/target-linux-*/python/packages && git apply < /opt/nvidia/nsys-2024.5-tid-export.patch)
24-
fi
20+
for NSYS in /opt/nvidia/nsight-systems-cli/2024.5.1 /opt/nvidia/nsight-systems-cli/2024.6.1; do
21+
if [[ -d "${NSYS}" ]]; then
22+
# * can match at least sbsa-armv8 and x86
23+
(cd ${NSYS}/target-linux-*/python/packages && git apply < /opt/nvidia/nsys-2024.5-tid-export.patch)
24+
fi
25+
done
2526

2627
# Install extra dependencies needed for `nsys recipe ...` commands. These are
2728
# used by the nsys-jax wrapper script.

.github/container/jax_nsys/python/jax_nsys/jax_nsys/analysis.py

Lines changed: 46 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import pathlib
77
from typing import Any
88

9-
from .protobuf import HloProto, xla_module_metadata
9+
from .protobuf import HloProto, _host_memory_space, xla_module_metadata
1010
from .utils import make_child_mask, ProfilerData
1111

1212
pd.options.mode.copy_on_write = True
@@ -38,6 +38,11 @@ def align_profiler_data_timestamps(
3838
# Determine which collective size will be used for the alignment
3939
num_profiled_devices = len(comm_df.index.get_level_values("Device").unique())
4040
max_collective_size = comm_df["CollectiveSize"].max()
41+
if max_collective_size == 1:
42+
print(
43+
f"WARNING: cannot align {num_profiled_devices} devices because max collective size is 1"
44+
)
45+
return frames, {}
4146
assert (
4247
num_profiled_devices == max_collective_size
4348
), f"Aligning {num_profiled_devices} using collectives of size {max_collective_size} is not implemented"
@@ -193,13 +198,51 @@ def _get_message_size(
193198
"all-to-all",
194199
"collective-broadcast",
195200
"collective-permute-start",
201+
"dynamic-slice",
202+
"dynamic-update-slice",
196203
"reduce-scatter",
197204
}
198205
), f"{instruction}: message size calculation for {comm_inst.opcode} has not yet been validated"
206+
207+
def _byte_size(inst) -> int:
208+
size_bits = math.prod(
209+
inst.shape.dimensions,
210+
start=element_type_width(inst.shape.element_type),
211+
)
212+
size_bytes, rem = divmod(size_bits, 8)
213+
assert rem == 0
214+
return size_bytes
215+
199216
if comm_inst.opcode == "collective-permute-start":
200217
# See https://openxla.org/xla/operation_semantics#collectivepermute, which
201218
# generates pair-wise send+recv between devices
202219
collective_size = 2
220+
elif comm_inst.opcode in {"dynamic-slice", "dynamic-update-slice"}:
221+
# Label host-device transfers orchestrated by dynamic[-update]-slice as single
222+
# device collectives.
223+
collective_size = 1
224+
if comm_inst.opcode == "dynamic-update-slice":
225+
# For dynamic-update-slice the second operand is the one being copied
226+
_, src_inst = module_proto.find_instruction_by_id(comm_inst.operand_ids[1])
227+
transfer_size = _byte_size(src_inst.proto())
228+
else:
229+
# For dynamic-slice the return type size is the transfer size
230+
assert comm_inst.opcode == "dynamic-slice"
231+
_, src_inst = module_proto.find_instruction_by_id(comm_inst.operand_ids[0])
232+
transfer_size = _byte_size(comm_inst)
233+
dest_on_host = _host_memory_space(comm_inst)
234+
src_on_host = _host_memory_space(src_inst.proto())
235+
assert src_on_host != dest_on_host, (
236+
'dynamic[-update]-slice is only considered is only "communication" if it '
237+
"represents a host-device transfer"
238+
)
239+
return (
240+
transfer_size,
241+
"device-to-host" if dest_on_host else "host-to-device",
242+
1, # collective size
243+
1.0, # bw_correction
244+
1.0, # bus_correction
245+
)
203246
else:
204247
# replica_groups is something like {{0,1},{4,5},{2,3},{6,7}}, if there are 8
205248
# devices that are doing pair-wise collectives
@@ -220,17 +263,12 @@ def _get_message_size(
220263
total_msg_size = 0
221264
for operand_id in comm_inst.operand_ids:
222265
_, operand = module_proto.find_instruction_by_id(operand_id)
223-
msg_size_bits = math.prod(
224-
operand.proto().shape.dimensions,
225-
start=element_type_width(operand.proto().shape.element_type),
226-
)
266+
msg_size_bytes = _byte_size(operand.proto())
227267
if comm_inst.opcode == "reduce-scatter":
228268
# NCCL's convention is that the message size of a reduce-scatter is the size of output buffer:
229269
# https://github.com/NVIDIA/nccl/blob/ab2b89c4c339bd7f816fbc114a4b05d386b66290/src/collectives.cc#L122
230-
msg_size_bits, rem = divmod(msg_size_bits, collective_size)
270+
msg_size_bytes, rem = divmod(msg_size_bytes, collective_size)
231271
assert rem == 0
232-
msg_size_bytes, rem = divmod(msg_size_bits, 8)
233-
assert rem == 0
234272
total_msg_size += msg_size_bytes
235273

236274
collective = comm_inst.opcode.removesuffix("-start")

.github/container/jax_nsys/python/jax_nsys/jax_nsys/data_loaders.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,9 @@ def is_communication(row):
103103
return _calculate_overlap(thunk_df)
104104

105105

106+
compile_prefix = "XlaCompile:#module="
107+
108+
106109
def _load_nvtx_gpu_proj_trace_single(
107110
prefix: pathlib.Path,
108111
file: pathlib.Path,
@@ -305,10 +308,21 @@ def _load_nvtx_gpu_proj_trace_single(
305308
unique_pid_tid_pairs = module_df.loc[:, ("PID", "TID")].drop_duplicates()
306309
if len(unique_pid_tid_pairs) == 1:
307310
main_pid_tid_candidates.add(tuple(unique_pid_tid_pairs.iloc[0]))
311+
# If the profile only includes N>1 modules, we may still be able to identify the
312+
# main thread as the one responsible for XlaCompile ranges projected onto the GPU
313+
# timeline
314+
compile_ranges = df.loc[~all_thunks, "Name"].str.startswith(
315+
tsl_prefix + compile_prefix
316+
)
317+
compile_range_ids = compile_ranges[compile_ranges].index
318+
unique_pid_tid_pairs = df.loc[compile_range_ids, ("PID", "TID")].drop_duplicates()
319+
if len(unique_pid_tid_pairs) == 1:
320+
main_pid_tid_candidates.add(tuple(unique_pid_tid_pairs.iloc[0]))
308321
assert len(main_pid_tid_candidates) < 2
309322
if len(main_pid_tid_candidates) == 1:
310323
# Possibly not correct if len(device_by_pid_tid) > 1
311324
assert len(device_by_pid_tid) > 0
325+
# Associate the main thread with the 0th device in device_by_pid_tid
312326
main_thread_df = device_by_pid_tid.iloc[:1]
313327
main_thread_df.index = pd.MultiIndex.from_tuples(
314328
main_pid_tid_candidates, names=["PID", "TID"]
@@ -425,16 +439,13 @@ def _load_nvtx_gpu_proj_trace(
425439
return output
426440

427441

428-
compile_prefix = "TSL:XlaCompile:#module="
429-
430-
431442
def _splice_parallel_ranges(compile_df: pd.DataFrame) -> pd.DataFrame:
432443
# When parallel compilation is enabled, we end up with worker threads that
433444
# emit NVTX ranges but which are not accounted for in the RangeStack tree.
434445
# Splice these in under the relevant XlaCompile ranges in the RangeStack tree and
435446
# drop everything else.
436447
retain_mask = pd.Series(False, index=compile_df.index)
437-
compile_mask = compile_df["Name"].str.startswith(compile_prefix)
448+
compile_mask = compile_df["Name"].str.startswith("TSL:" + compile_prefix)
438449
for compile_range in compile_df[compile_mask].itertuples():
439450
# Identify the slice of `compile_df` that overlaps in time with this XlaCompile
440451
# range

.github/container/jax_nsys/python/jax_nsys/jax_nsys/protobuf.py

Lines changed: 71 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
1-
from collections import defaultdict
21
import functools
32
import lzma
43
import pathlib
54
import typing
65

76

7+
def _host_memory_space(inst):
8+
return inst.shape.layout.memory_space == 5
9+
10+
811
class StackFrame(typing.NamedTuple):
912
column: int
1013
file: str
@@ -25,6 +28,35 @@ def __init__(self, wrapped_hlo_proto, proto):
2528
# proto representing the actual collective, which will be different if the
2629
# async launch is handled by an async-start op
2730
# TODO: can any of copy-start, custom-call, recv, send represent communication?
31+
# This also aims to identify, and (for now) flag as communication, kernels that
32+
# implement device-to-host and host-to-device copies for memory offloading.
33+
# For example, a device-to-host offload might look like
34+
# computation {
35+
# ...
36+
# ROOT r1 = bf16[2,8,128,2048]{3,2,1,0:S(5)} dynamic-update-slice(...)
37+
# }
38+
# async_computation {
39+
# ...
40+
# ROOT r2 = bf16[2,8,128,2048]{3,2,1,0:S(5)} fusion(...), calls=computation
41+
# }
42+
# start = (...) async-start(...), calls=async_computation
43+
# where the :S(5) annotation shows that a buffer is in host memory.
44+
# A host-to-device load might look like
45+
# computation {
46+
# param_0 = bf16[2,8,128,2048]{3,2,1,0:S(5)} parameter(0)
47+
# ...
48+
# ROOT r1 = bf16[2,8,128,2048]{3,2,1,0} dynamic-slice(param_0, ...)
49+
# }
50+
# async_computation {
51+
# param_0 = bf16[2,8,128,2048]{3,2,1,0:S(5)} parameter(0)
52+
# ...
53+
# ROOT r2 = bf16[2,8,128,2048]{3,2,1,0} fusion(param_0, ...), calls=computation
54+
# }
55+
# start = (...) async-start(...), calls=async_computation
56+
# where the :S(5) memory space annotation is in a parameter instead of in the
57+
# return value.
58+
# For now, handling host-device kernels as single-device "collective"
59+
# communication should be sufficient.
2860
self._comm_proto = None
2961
comm_opcodes = {
3062
"all-gather",
@@ -39,25 +71,50 @@ def __init__(self, wrapped_hlo_proto, proto):
3971
"all-reduce-start",
4072
"collective-permute-start",
4173
}
74+
75+
def _is_offloading_instruction(inst):
76+
host_dest = _host_memory_space(inst)
77+
78+
def _host_operand(i):
79+
_, op = wrapped_hlo_proto.find_instruction_by_id(inst.operand_ids[i])
80+
return _host_memory_space(op.proto())
81+
82+
if inst.opcode == "dynamic-slice" and host_dest != _host_operand(0):
83+
return True
84+
elif (
85+
inst.opcode == "dynamic-update-slice"
86+
and host_dest == _host_operand(0)
87+
and host_dest != _host_operand(1)
88+
):
89+
return True
90+
return False
91+
4292
if self._proto.opcode in comm_opcodes | comm_start_opcodes:
4393
self._comm_proto = self._proto
44-
elif self._proto.opcode == "async-start":
94+
elif self._proto.opcode in {"async-start", "fusion"}:
95+
# fusion example:
96+
# computation {
97+
# param_0 = f32[...]{...:S(5)} parameter(0)
98+
# ...
99+
# ROOT dus = f32[...]{...:S(5)} dynamic-update-slice(param_0, ...)
100+
# }
101+
# inst = f32[256,128,128]{2,1,0:S(5)} fusion(...), calls=computation
45102
# This might be thinly wrapping an opcode in `comm_opcodes`
46-
other_opcodes = defaultdict(int)
47-
for called_id in self._proto.called_computation_ids:
48-
for called_inst in wrapped_hlo_proto.find_computation(
49-
called_id
50-
).instructions:
51-
if called_inst.opcode in comm_opcodes:
103+
def _visit_computation(computation_id):
104+
computation = wrapped_hlo_proto.find_computation(computation_id)
105+
for called_inst in computation.instructions:
106+
for called_id in called_inst.called_computation_ids:
107+
_visit_computation(called_id)
108+
if called_inst.opcode in comm_opcodes or _is_offloading_instruction(
109+
called_inst
110+
):
52111
assert (
53112
self._comm_proto is None
54113
), f"Found {called_inst.opcode} child having already found {self._comm_proto.opcode}"
55114
self._comm_proto = called_inst
56-
else:
57-
other_opcodes[called_inst.opcode] += 1
58-
assert (
59-
other_opcodes.keys() == {"parameter"}
60-
), f"async-start op {self._proto.name} wrapped too many opcode types ({dict(other_opcodes)}) in addition to {self._comm_proto}"
115+
116+
for called_id in self._proto.called_computation_ids:
117+
_visit_computation(called_id)
61118

62119
def communication_proto(self):
63120
return self._comm_proto
@@ -68,12 +125,7 @@ def is_communication(self) -> bool:
68125
a little more complicated than you might hope, because async communications are
69126
not handled uniformly.
70127
"""
71-
if self._comm_proto is None:
72-
return False
73-
assert (
74-
self._comm_proto.channel_id != 0
75-
), f"Got channel_id={self._comm_proto.channel_id} for {self._comm_proto.name}"
76-
return True
128+
return self._comm_proto is not None
77129

78130
def proto(self):
79131
"""

0 commit comments

Comments
 (0)