1- from collections import defaultdict
21import functools
32import lzma
43import pathlib
54import typing
65
76
7+ def _host_memory_space (inst ):
8+ return inst .shape .layout .memory_space == 5
9+
10+
811class StackFrame (typing .NamedTuple ):
912 column : int
1013 file : str
@@ -25,6 +28,35 @@ def __init__(self, wrapped_hlo_proto, proto):
2528 # proto representing the actual collective, which will be different if the
2629 # async launch is handled by an async-start op
2730 # TODO: can any of copy-start, custom-call, recv, send represent communication?
31+ # This also aims to identify, and (for now) flag as communication, kernels that
32+ # implement device-to-host and host-to-device copies for memory offloading.
33+ # For example, a device-to-host offload might look like
34+ # computation {
35+ # ...
36+ # ROOT r1 = bf16[2,8,128,2048]{3,2,1,0:S(5)} dynamic-update-slice(...)
37+ # }
38+ # async_computation {
39+ # ...
40+ # ROOT r2 = bf16[2,8,128,2048]{3,2,1,0:S(5)} fusion(...), calls=computation
41+ # }
42+ # start = (...) async-start(...), calls=async_computation
43+ # where the :S(5) annotation shows that a buffer is in host memory.
44+ # A host-to-device load might look like
45+ # computation {
46+ # param_0 = bf16[2,8,128,2048]{3,2,1,0:S(5)} parameter(0)
47+ # ...
48+ # ROOT r1 = bf16[2,8,128,2048]{3,2,1,0} dynamic-slice(param_0, ...)
49+ # }
50+ # async_computation {
51+ # param_0 = bf16[2,8,128,2048]{3,2,1,0:S(5)} parameter(0)
52+ # ...
53+ # ROOT r2 = bf16[2,8,128,2048]{3,2,1,0} fusion(param_0, ...), calls=computation
54+ # }
55+ # start = (...) async-start(...), calls=async_computation
56+ # where the :S(5) memory space annotation is in a parameter instead of in the
57+ # return value.
58+ # For now, handling host-device kernels as single-device "collective"
59+ # communication should be sufficient.
2860 self ._comm_proto = None
2961 comm_opcodes = {
3062 "all-gather" ,
@@ -39,25 +71,50 @@ def __init__(self, wrapped_hlo_proto, proto):
3971 "all-reduce-start" ,
4072 "collective-permute-start" ,
4173 }
74+
75+ def _is_offloading_instruction (inst ):
76+ host_dest = _host_memory_space (inst )
77+
78+ def _host_operand (i ):
79+ _ , op = wrapped_hlo_proto .find_instruction_by_id (inst .operand_ids [i ])
80+ return _host_memory_space (op .proto ())
81+
82+ if inst .opcode == "dynamic-slice" and host_dest != _host_operand (0 ):
83+ return True
84+ elif (
85+ inst .opcode == "dynamic-update-slice"
86+ and host_dest == _host_operand (0 )
87+ and host_dest != _host_operand (1 )
88+ ):
89+ return True
90+ return False
91+
4292 if self ._proto .opcode in comm_opcodes | comm_start_opcodes :
4393 self ._comm_proto = self ._proto
44- elif self ._proto .opcode == "async-start" :
94+ elif self ._proto .opcode in {"async-start" , "fusion" }:
95+ # fusion example:
96+ # computation {
97+ # param_0 = f32[...]{...:S(5)} parameter(0)
98+ # ...
99+ # ROOT dus = f32[...]{...:S(5)} dynamic-update-slice(param_0, ...)
100+ # }
101+ # inst = f32[256,128,128]{2,1,0:S(5)} fusion(...), calls=computation
45102 # This might be thinly wrapping an opcode in `comm_opcodes`
46- other_opcodes = defaultdict (int )
47- for called_id in self ._proto .called_computation_ids :
48- for called_inst in wrapped_hlo_proto .find_computation (
49- called_id
50- ).instructions :
51- if called_inst .opcode in comm_opcodes :
103+ def _visit_computation (computation_id ):
104+ computation = wrapped_hlo_proto .find_computation (computation_id )
105+ for called_inst in computation .instructions :
106+ for called_id in called_inst .called_computation_ids :
107+ _visit_computation (called_id )
108+ if called_inst .opcode in comm_opcodes or _is_offloading_instruction (
109+ called_inst
110+ ):
52111 assert (
53112 self ._comm_proto is None
54113 ), f"Found { called_inst .opcode } child having already found { self ._comm_proto .opcode } "
55114 self ._comm_proto = called_inst
56- else :
57- other_opcodes [called_inst .opcode ] += 1
58- assert (
59- other_opcodes .keys () == {"parameter" }
60- ), f"async-start op { self ._proto .name } wrapped too many opcode types ({ dict (other_opcodes )} ) in addition to { self ._comm_proto } "
115+
116+ for called_id in self ._proto .called_computation_ids :
117+ _visit_computation (called_id )
61118
62119 def communication_proto (self ):
63120 return self ._comm_proto
@@ -68,12 +125,7 @@ def is_communication(self) -> bool:
68125 a little more complicated than you might hope, because async communications are
69126 not handled uniformly.
70127 """
71- if self ._comm_proto is None :
72- return False
73- assert (
74- self ._comm_proto .channel_id != 0
75- ), f"Got channel_id={ self ._comm_proto .channel_id } for { self ._comm_proto .name } "
76- return True
128+ return self ._comm_proto is not None
77129
78130 def proto (self ):
79131 """
0 commit comments