2626import os
2727import pathlib
2828import time
29- from typing import Any , Generic , TypeVar
29+ from typing import Any , Generic , TypeVar , TypedDict
3030import weakref
3131
3232import jax
@@ -127,7 +127,7 @@ def artificial_shared_memory_limit(limit):
127127 )
128128
129129
130- def supports_cross_device_collectives ():
130+ def is_nvshmem_available ():
131131 try :
132132 nvshmem_bc_path = os .environ ["MOSAIC_GPU_NVSHMEM_BC_PATH" ]
133133 except KeyError :
@@ -146,6 +146,16 @@ def supports_cross_device_collectives():
146146 )
147147
148148
149+ def is_single_process_multi_device_topology ():
150+ return (jax .device_count () > 1
151+ and jax .device_count () == jax .local_device_count ())
152+
153+
154+ def supports_cross_device_collectives ():
155+ return ((is_nvshmem_available () and jax .local_device_count () == 1 )
156+ or is_single_process_multi_device_topology ())
157+
158+
149159mosaic_gpu_p = jax_core .Primitive ("mosaic_gpu_p" )
150160mosaic_gpu_p .multiple_results = True
151161
@@ -160,6 +170,8 @@ def _mosaic_gpu_abstract_eval(*_, module, out_types, inout_types):
160170
161171
162172def _has_communication (module , ** _ ):
173+ if launch_context .uses_collective_metadata (module ):
174+ return True
163175 empty_str_attr = ir .StringAttr .get ("" )
164176 for op in module .body :
165177 if "nvshmem" in getattr (op , "sym_name" , empty_str_attr ).value :
@@ -182,7 +194,8 @@ def _mosaic_gpu_lowering_rule(
182194 use_custom_barrier : bool = False ,
183195):
184196 axis_context = ctx .module_context .axis_context
185- if _has_communication (module ):
197+ is_multi_device_module = _has_communication (module )
198+ if is_multi_device_module :
186199 # Those checks are trying to ensure that the logical device ids are
187200 # consistent with the NVSHMEM PE ids that Mosaic will be using for
188201 # communication. Any divergence here would require us to implement a logical
@@ -236,22 +249,53 @@ def _mosaic_gpu_lowering_rule(
236249 else :
237250 KNOWN_KERNELS [kernel_id ] = module_asm
238251
239- op = mlir .custom_call (
240- "mosaic_gpu_v2" ,
241- result_types = [mlir .aval_to_ir_type (aval ) for aval in ctx .avals_out ],
242- operands = args ,
243- operand_layouts = [list (reversed (range (a .ndim ))) for a in ctx .avals_in ],
244- result_layouts = [list (reversed (range (a .ndim ))) for a in ctx .avals_out ],
245- backend_config = dict (
252+ class CustomCallArgs (TypedDict ):
253+ call_target_name : str
254+ result_types : list [ir .Type ]
255+ operands : tuple [ir .Value ]
256+ operand_layouts : list [list [int ]]
257+ result_layouts : list [list [int ]]
258+ backend_config : dict [str , ir .Attribute ]
259+ operand_output_aliases : dict [int , int ]
260+ api_version : int
261+
262+ custom_call_kwargs : CustomCallArgs = {
263+ "call_target_name" : "mosaic_gpu_v2" ,
264+ "result_types" : [mlir .aval_to_ir_type (aval ) for aval in ctx .avals_out ],
265+ "operands" : args ,
266+ "operand_layouts" : [list (reversed (range (a .ndim ))) for a in ctx .avals_in ],
267+ "result_layouts" : [list (reversed (range (a .ndim ))) for a in ctx .avals_out ],
268+ "backend_config" : dict (
246269 kernel_hash = ir .StringAttr .get (kernel_id ),
247270 module = ir .StringAttr .get (module_asm ),
248271 use_custom_barrier = ir .BoolAttr .get (use_custom_barrier ),
272+ uses_xla_collective_metadata = ir .BoolAttr .get (
273+ launch_context .uses_collective_metadata (module )
274+ ),
249275 ),
250- operand_output_aliases = dict (input_output_aliases ),
251- api_version = 4 ,
276+ "operand_output_aliases" : dict (input_output_aliases ),
277+ "api_version" : 4
278+ }
279+
280+ if not is_multi_device_module or not is_single_process_multi_device_topology ():
281+ op = mlir .custom_call (** custom_call_kwargs )
282+ return op .results
283+
284+ # Add collective metadata as additional output buffer.
285+ # Collective metadata stores pointers to both input and output parameters.
286+ num_params = len (ctx .avals_out ) + len (ctx .avals_in )
287+ num_peers = axis_context .mesh .size
288+ collective_metadata_size = (
289+ launch_context .COLLECTIVE_METADATA_SIZE + num_peers * num_params
252290 )
253- return op .results
254291
292+ custom_call_kwargs ["result_layouts" ] += [[0 ]]
293+ custom_call_kwargs ["result_types" ] += [
294+ ir .RankedTensorType .get ((collective_metadata_size ,),
295+ ir .IntegerType .get_signless (64 ))]
296+ op = mlir .custom_call (** custom_call_kwargs )
297+ # Drop the collective metadata buffer from the results.
298+ return op .results [:- 1 ]
255299
256300mlir .register_lowering (mosaic_gpu_p , _mosaic_gpu_lowering_rule , "cuda" )
257301
@@ -554,6 +598,8 @@ def _launch(
554598 module : ir .Module ,
555599 profiler_spec : profiler .ProfilerSpec | None = None ,
556600 maybe_prof_buffer : ir .Value | None = None ,
601+ collective_metadata : ir .Value | None = None ,
602+ num_peers : int = 0 ,
557603):
558604 if (profiler_spec is None ) != (maybe_prof_buffer is None ):
559605 raise ValueError (
@@ -647,7 +693,12 @@ def _launch(
647693 prof = None
648694
649695 ctx = launch_context .LaunchContext (
650- module , launch_context .Scratch (launch_op ), cluster , prof
696+ module ,
697+ launch_context .Scratch (launch_op ),
698+ cluster ,
699+ prof ,
700+ collective_metadata = collective_metadata ,
701+ num_peers = num_peers ,
651702 )
652703 with ctx .named_region ("Init" ):
653704 tmem_allocs : list [_TMEMAlloc | _TMEMDialectAlloc ] = []
@@ -739,6 +790,7 @@ def _lower_as_gpu_kernel(
739790 module_name : str ,
740791 kernel_name : str ,
741792 prof_spec : profiler .ProfilerSpec | None = None ,
793+ jax_mesh : mesh_lib .Mesh | None = None ,
742794):
743795 ptr_ty = ir .Type .parse ("!llvm.ptr" )
744796 token_ty = ir .Type .parse ("!gpu.async.token" )
@@ -786,12 +838,54 @@ def main(token_ptr, buffers):
786838 arg_refs = []
787839 # XLA will pass in inout refs again as outputs, but we ignore them.
788840 for i , ref_ty in enumerate ([* in_ref_tys , * inout_ref_tys , * out_ref_tys ]):
789- ptr = llvm .load (ptr_ty , utils .getelementptr (buffers , [i ], ptr_ty ))
790- arg_refs .append (utils .ptr_as_memref (ptr , ir .MemRefType (ref_ty )))
841+ gep_op = utils .getelementptr (buffers , [i ], ptr_ty )
842+ ptr = llvm .load (ptr_ty , gep_op )
843+ arg_memref = utils .ptr_as_memref (ptr , ir .MemRefType (ref_ty ))
844+ # Annotate so we can find the corresponding kernel argument during the
845+ # lowering.
846+ arg_memref .owner .attributes [launch_context .KERNEL_ARG_ID_ATTR ] = ir .IntegerAttr .get (i32 , i )
847+ arg_refs .append (arg_memref )
848+
849+ collective_metadata = None
850+ num_peers = 0
851+
852+ # Collective metadata parameter is used to lower collective operations
853+ # in a single-process setup.
854+ if (
855+ jax_mesh is not None
856+ and jax_mesh .size > 1
857+ and is_single_process_multi_device_topology ()
858+ ):
859+ num_args = len (arg_refs )
860+ num_peers = jax_mesh .size
861+ collective_metadata_size = (
862+ launch_context .COLLECTIVE_METADATA_SIZE + num_args * num_peers
863+ )
864+ metadata_shape = jax .ShapeDtypeStruct (
865+ shape = (collective_metadata_size ,), dtype = np .int64
866+ )
867+ metadata_ptr = llvm .load (
868+ ptr_ty ,
869+ utils .getelementptr (buffers , [num_args ], ptr_ty ),
870+ )
871+ collective_metadata = utils .ptr_as_memref (
872+ metadata_ptr , _shape_to_ref_ty (metadata_shape )
873+ )
874+
791875 prof_buffer = arg_refs .pop () if prof_spec is not None else None
876+
792877 with _launch (
793- token , grid , cluster , block , smem_scratch_shape ,
794- lowering_semantics , module , prof_spec , prof_buffer
878+ token ,
879+ grid ,
880+ cluster ,
881+ block ,
882+ smem_scratch_shape ,
883+ lowering_semantics ,
884+ module ,
885+ prof_spec ,
886+ prof_buffer ,
887+ collective_metadata ,
888+ num_peers ,
795889 ) as (_launch_ctx , smem_refs ):
796890 nonlocal launch_ctx
797891 launch_ctx = _launch_ctx
0 commit comments