Skip to content

Commit 039e00d

Browse files
[Pallas:MGPU] Lower multi-device operation with collective metadata.
In single-process multi-device configurations, the Mosaic custom call can collect pointers for each parameter across all local devices into a structure called 'collective metadata'. This change constructs collective metadata in the custom call and uses it during Pallas:MGPU lowering to calculate device rank and peer memory addresses. This enables lowering multi-device operations that require direct inter-device communication in a single-process mode. PiperOrigin-RevId: 860126391
1 parent 6b3c32f commit 039e00d

File tree

8 files changed

+666
-90
lines changed

8 files changed

+666
-90
lines changed

jax/_src/pallas/mosaic_gpu/lowering.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1043,21 +1043,20 @@ def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value):
10431043

10441044
# NOTE: new_out_shapes has out_shapes, then semaphores_shape and
10451045
# optionally the profiler buffer.
1046-
module, new_out_shapes, _, launch_ctx = (
1047-
mgpu_core._lower_as_gpu_kernel(
1048-
body,
1049-
grid=cuda_grid,
1050-
cluster=cluster,
1051-
block=block,
1052-
in_shapes=(*in_shapes, *scoped_semaphores_shape),
1053-
out_shape=(*out_shapes, *scoped_semaphores_shape),
1054-
inout_shape=(),
1055-
smem_scratch_shape=scratch_buffers,
1056-
lowering_semantics=lowering_semantics,
1057-
module_name=mlir.sanitize_name(debug_info.func_name),
1058-
kernel_name=mlir.sanitize_name(debug_info.func_name),
1059-
prof_spec=prof_spec,
1060-
)
1046+
module, new_out_shapes, _, launch_ctx = mgpu_core._lower_as_gpu_kernel(
1047+
body,
1048+
grid=cuda_grid,
1049+
cluster=cluster,
1050+
block=block,
1051+
in_shapes=(*in_shapes, *scoped_semaphores_shape),
1052+
out_shape=(*out_shapes, *scoped_semaphores_shape),
1053+
inout_shape=(),
1054+
smem_scratch_shape=scratch_buffers,
1055+
lowering_semantics=lowering_semantics,
1056+
module_name=mlir.sanitize_name(debug_info.func_name),
1057+
kernel_name=mlir.sanitize_name(debug_info.func_name),
1058+
prof_spec=prof_spec,
1059+
jax_mesh=jax_mesh,
10611060
)
10621061

10631062
if lowering_semantics == mgpu.LoweringSemantics.Warpgroup:
@@ -3805,7 +3804,6 @@ def _semaphore_signal_lowering_rule(
38053804
sem, transforms = _handle_transforms(ctx, sem, transforms)
38063805
if transforms:
38073806
raise NotImplementedError(f"Unhandled transforms for semaphore_signal: {transforms}")
3808-
sem_ptr = mgpu.utils.memref_ptr(sem)
38093807
if device_id is not None:
38103808
device_id, other_axes = primitives.device_id_to_logical(
38113809
ctx.module_ctx.mesh_info,
@@ -3817,7 +3815,9 @@ def _semaphore_signal_lowering_rule(
38173815
raise NotImplementedError(
38183816
f"Only JAX mesh axes can be used in device_id, but found {other_axes}"
38193817
)
3820-
sem_ptr = ctx.launch_ctx.to_remote(sem_ptr, device_id)
3818+
sem = ctx.launch_ctx.to_remote(sem, device_id)
3819+
sem_ptr = mgpu.utils.memref_ptr(sem)
3820+
38213821
# TODO(apaszke): Narrow the scope from .sys to .gpu when the semaphore is local.
38223822
val = _ir_constant(value, i32)
38233823
# We only signal the semaphore from a single lane, which does not guarantee

jax/_src/pallas/mosaic_gpu/primitives.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3443,7 +3443,6 @@ def _semaphore_signal_lowering_rule(
34433443
transformed_sems.append(sem)
34443444
del sems, transforms # Use transformed_sems instead.
34453445
for sem, value, device_id in zip(transformed_sems, values, device_ids, strict=True):
3446-
sem_ptr = mgpu.utils.memref_ptr(sem)
34473446
if device_id is not None:
34483447
device_id, other_axes = pallas_primitives.device_id_to_logical(
34493448
ctx.module_ctx.mesh_info,
@@ -3456,7 +3455,8 @@ def _semaphore_signal_lowering_rule(
34563455
f"Only JAX mesh axes can be used in device_id, but found {other_axes}"
34573456
)
34583457
device_id = lowering._ensure_ir_value(device_id, jnp.int32)
3459-
sem_ptr = ctx.launch_ctx.to_remote(sem_ptr, device_id)
3458+
sem = ctx.launch_ctx.to_remote(sem, device_id)
3459+
sem_ptr = mgpu.utils.memref_ptr(sem)
34603460
# TODO(apaszke): Narrow the scope from .sys to .gpu when the semaphore is local.
34613461
# We only signal the semaphore from a single lane, which does not guarantee
34623462
# anything about the state of the other three warps in the warpgroup (they

jax/experimental/mosaic/gpu/core.py

Lines changed: 112 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
import os
2727
import pathlib
2828
import time
29-
from typing import Any, Generic, TypeVar
29+
from typing import Any, Generic, TypeVar, TypedDict
3030
import weakref
3131

3232
import jax
@@ -127,7 +127,7 @@ def artificial_shared_memory_limit(limit):
127127
)
128128

129129

130-
def supports_cross_device_collectives():
130+
def is_nvshmem_available():
131131
try:
132132
nvshmem_bc_path = os.environ["MOSAIC_GPU_NVSHMEM_BC_PATH"]
133133
except KeyError:
@@ -146,6 +146,16 @@ def supports_cross_device_collectives():
146146
)
147147

148148

149+
def is_single_process_multi_device_topology():
150+
return (jax.device_count() > 1
151+
and jax.device_count() == jax.local_device_count())
152+
153+
154+
def supports_cross_device_collectives():
155+
return ((is_nvshmem_available() and jax.local_device_count() == 1)
156+
or is_single_process_multi_device_topology())
157+
158+
149159
mosaic_gpu_p = jax_core.Primitive("mosaic_gpu_p")
150160
mosaic_gpu_p.multiple_results = True
151161

@@ -160,6 +170,8 @@ def _mosaic_gpu_abstract_eval(*_, module, out_types, inout_types):
160170

161171

162172
def _has_communication(module, **_):
173+
if launch_context.uses_collective_metadata(module):
174+
return True
163175
empty_str_attr = ir.StringAttr.get("")
164176
for op in module.body:
165177
if "nvshmem" in getattr(op, "sym_name", empty_str_attr).value:
@@ -182,7 +194,8 @@ def _mosaic_gpu_lowering_rule(
182194
use_custom_barrier: bool = False,
183195
):
184196
axis_context = ctx.module_context.axis_context
185-
if _has_communication(module):
197+
is_multi_device_module = _has_communication(module)
198+
if is_multi_device_module:
186199
# Those checks are trying to ensure that the logical device ids are
187200
# consistent with the NVSHMEM PE ids that Mosaic will be using for
188201
# communication. Any divergence here would require us to implement a logical
@@ -236,22 +249,53 @@ def _mosaic_gpu_lowering_rule(
236249
else:
237250
KNOWN_KERNELS[kernel_id] = module_asm
238251

239-
op = mlir.custom_call(
240-
"mosaic_gpu_v2",
241-
result_types=[mlir.aval_to_ir_type(aval) for aval in ctx.avals_out],
242-
operands=args,
243-
operand_layouts=[list(reversed(range(a.ndim))) for a in ctx.avals_in],
244-
result_layouts=[list(reversed(range(a.ndim))) for a in ctx.avals_out],
245-
backend_config=dict(
252+
class CustomCallArgs(TypedDict):
253+
call_target_name: str
254+
result_types: list[ir.Type]
255+
operands: tuple[ir.Value]
256+
operand_layouts: list[list[int]]
257+
result_layouts: list[list[int]]
258+
backend_config: dict[str, ir.Attribute]
259+
operand_output_aliases: dict[int, int]
260+
api_version: int
261+
262+
custom_call_kwargs : CustomCallArgs = {
263+
"call_target_name": "mosaic_gpu_v2",
264+
"result_types": [mlir.aval_to_ir_type(aval) for aval in ctx.avals_out],
265+
"operands": args,
266+
"operand_layouts": [list(reversed(range(a.ndim))) for a in ctx.avals_in],
267+
"result_layouts": [list(reversed(range(a.ndim))) for a in ctx.avals_out],
268+
"backend_config": dict(
246269
kernel_hash=ir.StringAttr.get(kernel_id),
247270
module=ir.StringAttr.get(module_asm),
248271
use_custom_barrier=ir.BoolAttr.get(use_custom_barrier),
272+
uses_xla_collective_metadata=ir.BoolAttr.get(
273+
launch_context.uses_collective_metadata(module)
274+
),
249275
),
250-
operand_output_aliases=dict(input_output_aliases),
251-
api_version=4,
276+
"operand_output_aliases": dict(input_output_aliases),
277+
"api_version": 4
278+
}
279+
280+
if not is_multi_device_module or not is_single_process_multi_device_topology():
281+
op = mlir.custom_call(**custom_call_kwargs)
282+
return op.results
283+
284+
# Add collective metadata as additional output buffer.
285+
# Collective metadata stores pointers to both input and output parameters.
286+
num_params = len(ctx.avals_out) + len(ctx.avals_in)
287+
num_peers = axis_context.mesh.size
288+
collective_metadata_size = (
289+
launch_context.COLLECTIVE_METADATA_SIZE + num_peers * num_params
252290
)
253-
return op.results
254291

292+
custom_call_kwargs["result_layouts"] += [[0]]
293+
custom_call_kwargs["result_types"] += [
294+
ir.RankedTensorType.get((collective_metadata_size,),
295+
ir.IntegerType.get_signless(64))]
296+
op = mlir.custom_call(**custom_call_kwargs)
297+
# Drop the collective metadata buffer from the results.
298+
return op.results[:-1]
255299

256300
mlir.register_lowering(mosaic_gpu_p, _mosaic_gpu_lowering_rule, "cuda")
257301

@@ -554,6 +598,8 @@ def _launch(
554598
module: ir.Module,
555599
profiler_spec: profiler.ProfilerSpec | None = None,
556600
maybe_prof_buffer: ir.Value | None = None,
601+
collective_metadata: ir.Value | None = None,
602+
num_peers: int = 0,
557603
):
558604
if (profiler_spec is None) != (maybe_prof_buffer is None):
559605
raise ValueError(
@@ -647,7 +693,12 @@ def _launch(
647693
prof = None
648694

649695
ctx = launch_context.LaunchContext(
650-
module, launch_context.Scratch(launch_op), cluster, prof
696+
module,
697+
launch_context.Scratch(launch_op),
698+
cluster,
699+
prof,
700+
collective_metadata=collective_metadata,
701+
num_peers=num_peers,
651702
)
652703
with ctx.named_region("Init"):
653704
tmem_allocs: list[_TMEMAlloc | _TMEMDialectAlloc] = []
@@ -739,6 +790,7 @@ def _lower_as_gpu_kernel(
739790
module_name: str,
740791
kernel_name: str,
741792
prof_spec: profiler.ProfilerSpec | None = None,
793+
jax_mesh: mesh_lib.Mesh | None = None,
742794
):
743795
ptr_ty = ir.Type.parse("!llvm.ptr")
744796
token_ty = ir.Type.parse("!gpu.async.token")
@@ -786,12 +838,54 @@ def main(token_ptr, buffers):
786838
arg_refs = []
787839
# XLA will pass in inout refs again as outputs, but we ignore them.
788840
for i, ref_ty in enumerate([*in_ref_tys, *inout_ref_tys, *out_ref_tys]):
789-
ptr = llvm.load(ptr_ty, utils.getelementptr(buffers, [i], ptr_ty))
790-
arg_refs.append(utils.ptr_as_memref(ptr, ir.MemRefType(ref_ty)))
841+
gep_op = utils.getelementptr(buffers, [i], ptr_ty)
842+
ptr = llvm.load(ptr_ty, gep_op)
843+
arg_memref = utils.ptr_as_memref(ptr, ir.MemRefType(ref_ty))
844+
# Annotate so we can find the corresponding kernel argument during the
845+
# lowering.
846+
arg_memref.owner.attributes[launch_context.KERNEL_ARG_ID_ATTR] = ir.IntegerAttr.get(i32, i)
847+
arg_refs.append(arg_memref)
848+
849+
collective_metadata = None
850+
num_peers = 0
851+
852+
# Collective metadata parameter is used to lower collective operations
853+
# in a single-process setup.
854+
if (
855+
jax_mesh is not None
856+
and jax_mesh.size > 1
857+
and is_single_process_multi_device_topology()
858+
):
859+
num_args = len(arg_refs)
860+
num_peers = jax_mesh.size
861+
collective_metadata_size = (
862+
launch_context.COLLECTIVE_METADATA_SIZE + num_args * num_peers
863+
)
864+
metadata_shape = jax.ShapeDtypeStruct(
865+
shape=(collective_metadata_size,), dtype=np.int64
866+
)
867+
metadata_ptr = llvm.load(
868+
ptr_ty,
869+
utils.getelementptr(buffers, [num_args], ptr_ty),
870+
)
871+
collective_metadata = utils.ptr_as_memref(
872+
metadata_ptr, _shape_to_ref_ty(metadata_shape)
873+
)
874+
791875
prof_buffer = arg_refs.pop() if prof_spec is not None else None
876+
792877
with _launch(
793-
token, grid, cluster, block, smem_scratch_shape,
794-
lowering_semantics, module, prof_spec, prof_buffer
878+
token,
879+
grid,
880+
cluster,
881+
block,
882+
smem_scratch_shape,
883+
lowering_semantics,
884+
module,
885+
prof_spec,
886+
prof_buffer,
887+
collective_metadata,
888+
num_peers,
795889
) as (_launch_ctx, smem_refs):
796890
nonlocal launch_ctx
797891
launch_ctx = _launch_ctx

0 commit comments

Comments
 (0)