Skip to content

Commit 724eb02

Browse files
bchetiouiGoogle-ML-Automation
authored andcommitted
[Pallas/Mosaic GPU] Add the reduction_scratch_bytes field to CompilerParams.
This field allows configuring the number of bytes to reserve in order to perform cross-warp reductions. The more bytes can be allocated to such a reduction, the more registers can be reduced in parallel---yielding faster reductions. PiperOrigin-RevId: 860115656
1 parent bd995da commit 724eb02

File tree

6 files changed

+85
-23
lines changed

6 files changed

+85
-23
lines changed

docs/pallas/CHANGELOG.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,14 @@ Remember to align the itemized text with the first line of an item within a list
1313

1414
## Unreleased
1515

16+
* New features:
17+
18+
* Added a `reduction_scratch_bytes` field to
19+
{class}`jax.experimental.pallas.mosaic_gpu.CompilerParams`. This gives user
20+
control over how much shared memory Pallas is allowed to reserve for
21+
cross-warp reductions on GPU. Increasing this value typically allows for
22+
faster reductions.
23+
1624
* Changes
1725

1826
* The default lowering path on GPU now goes through Mosaic GPU. To keep using

jax/_src/pallas/mosaic_gpu/core.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,11 @@ class CompilerParams(pallas_core.CompilerParams):
102102
thread ever calls commit_smem(), reads from the committed SMEM and then
103103
issues an async copy overwriting that region (this is a very artificial
104104
and highly unlikely scenario).
105+
reduction_scratch_bytes: The number of shared memory bytes to reserve as
106+
scratch space for cross-warp reductions. The higher this value, the more
107+
registers can be reduced in parallel. 2 * 128 * 6 * 4 = 6144 bytes is
108+
typically a good value in order to extract most of the potential gains on
109+
H100 and B200.
105110
profile_space: The number of profiler events that can be collected in a
106111
single invocation. It is undefined behavior if a thread collects more
107112
events than this.
@@ -112,6 +117,7 @@ class CompilerParams(pallas_core.CompilerParams):
112117
dimension_semantics: Sequence[DimensionSemantics] | None = None
113118
max_concurrent_steps: int = 1
114119
unsafe_no_auto_barriers: bool = False
120+
reduction_scratch_bytes: int = 128 * 4 * 4
115121
profile_space: int = 0
116122
profile_dir: str = ""
117123
lowering_semantics: mgpu.core.LoweringSemantics = mgpu.core.LoweringSemantics.Lane

jax/_src/pallas/mosaic_gpu/lowering.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@
9090

9191
@dataclasses.dataclass(frozen=True, kw_only=True)
9292
class ResourceEstimatorContext:
93+
reduction_scratch_bytes: int
9394
axis_names: _AxisNames
9495
lowering_semantics: mgpu.LoweringSemantics
9596

@@ -361,7 +362,6 @@ def _run_scoped_resource_estimator(
361362
f"Unsupported memory space: {aval.memory_space}")
362363
return rs + _estimate_resources(ctx, jaxpr)
363364

364-
REDUCE_SCRATCH_ELEMS = 128 * 4 # vector of 4 elements per lane in each WG
365365

366366
@_register_resource_estimator(lax.reduce_sum_p)
367367
@_register_resource_estimator(lax.reduce_max_p)
@@ -370,10 +370,10 @@ def _reduce_resource_estimator(
370370
ctx: ResourceEstimatorContext, x_aval: jax_core.ShapedArray, *, axes,
371371
**kwargs
372372
) -> Resources:
373-
del ctx, axes # Unused.
373+
del x_aval, axes, kwargs # Unused.
374374
# We don't need SMEM for some reductions, but it depends on the layout, so we
375375
# conservatively request the maximum scratch space we might need.
376-
return Resources(smem_scratch_bytes=REDUCE_SCRATCH_ELEMS * x_aval.dtype.itemsize)
376+
return Resources(smem_scratch_bytes=ctx.reduction_scratch_bytes)
377377

378378

379379
@dataclasses.dataclass(frozen=True)
@@ -420,6 +420,8 @@ class ModuleContext:
420420
mesh_info: pallas_utils.MeshInfo | None
421421
# See the documentation of unsafe_no_auto_barriers in CompilerParams.
422422
auto_barriers: bool
423+
# See the documentation of reduction_scratch_bytes in CompilerParams.
424+
reduction_scratch_bytes: int
423425
warp_axis_name: str | None = None
424426

425427
@property
@@ -596,8 +598,9 @@ def replace(self, **changes: Any) -> LoweringRuleContext:
596598
@property
597599
def estimator_ctx(self) -> ResourceEstimatorContext:
598600
return ResourceEstimatorContext(
601+
reduction_scratch_bytes=self.module_ctx.reduction_scratch_bytes,
599602
axis_names=self.module_ctx.axis_names,
600-
lowering_semantics=self.module_ctx.lowering_semantics,
603+
lowering_semantics=self.module_ctx.lowering_semantics
601604
)
602605

603606

@@ -886,6 +889,7 @@ def lower_jaxpr_to_module(
886889

887890
rs = _estimate_resources(
888891
ResourceEstimatorContext(
892+
reduction_scratch_bytes=params.reduction_scratch_bytes,
889893
axis_names=axis_names, lowering_semantics=lowering_semantics
890894
),
891895
jaxpr,
@@ -982,6 +986,7 @@ def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value):
982986
if jax_mesh is not None
983987
else None,
984988
auto_barriers=not params.unsafe_no_auto_barriers,
989+
reduction_scratch_bytes=params.reduction_scratch_bytes,
985990
)
986991
del runtime_smem, grouped_barriers, runtime_barriers
987992
_ = lower_jaxpr_to_mosaic_gpu(
@@ -2592,12 +2597,11 @@ def _reduce_lowering_rule(op, ctx: LoweringRuleContext, x, *, axes, **kwargs):
25922597
raise NotImplementedError("Multi-axis reductions not supported")
25932598
reduced_dim = x.layout.tiling.tile_dimension(axes[0])
25942599
if any(reduced_dim[d] for d in x.layout.partitioned_warp_dims):
2595-
size = x.layout.vector_length * 128 # a vector per lane in each WG.
2596-
if size > REDUCE_SCRATCH_ELEMS:
2597-
raise NotImplementedError(
2598-
f"Reduce scratch {size=} exceeds max={REDUCE_SCRATCH_ELEMS}"
2599-
)
2600-
scratch_ty = jax.ShapeDtypeStruct(shape=(size,), dtype=x_aval.dtype)
2600+
dtype_bitwidth = dtypes.itemsize_bits(x_aval.dtype)
2601+
if dtype_bitwidth % 8:
2602+
raise NotImplementedError("Sub-byte dtypes not supported")
2603+
scratch_elems = ctx.module_ctx.reduction_scratch_bytes * 8 // dtype_bitwidth
2604+
scratch_ty = jax.ShapeDtypeStruct(shape=(scratch_elems,), dtype=x_aval.dtype)
26012605
ctx = ctx.module_ctx.scratch_view(scratch_ty)
26022606
else:
26032607
ctx = contextlib.nullcontext(None)
@@ -2645,7 +2649,9 @@ def _reduce_lowering_rule_wg(
26452649
def i32_attr(value: int) -> ir.IntegerAttr:
26462650
return ir.IntegerAttr.get(ir.IntegerType.get_signless(32), value)
26472651
reduction.attributes["offset"] = i32_attr(ctx.module_ctx.smem_used_bytes)
2648-
reduction.attributes["scratch_size"] = i32_attr(REDUCE_SCRATCH_ELEMS)
2652+
# TODO(bchetioui): here, we could just donate all the remaining free SMEM that
2653+
# we have at this point in time.
2654+
reduction.attributes["scratch_size"] = i32_attr(ctx.module_ctx.reduction_scratch_bytes)
26492655
return reduction.result
26502656

26512657

jax/experimental/mosaic/gpu/dialect_lowering.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -758,15 +758,9 @@ def _vector_multi_dim_reduction_op_lowering_rule(
758758
if any(reduced_dim[d] for d in src.layout.partitioned_warp_dims):
759759
# cross-warp reductions require scratch space.
760760
dtype = op.source.type.element_type
761-
size = src.layout.vector_length * 128 # a vector per lane in each WG.
762-
scratch_size = ir.IntegerAttr(op.attributes["scratch_size"]).value
763-
if size > scratch_size:
764-
raise ValueError(
765-
f"Required scratch space ({size}) is larger than the available"
766-
f" scratch size ({scratch_size})"
767-
)
761+
allocation_size = ir.IntegerAttr(op.attributes["scratch_size"]).value * 8 // utils.bitwidth(dtype)
768762
scratch = _slice_smem(
769-
ir.MemRefType.get([size], dtype, memory_space=utils.smem()),
763+
ir.MemRefType.get([allocation_size], dtype, memory_space=utils.smem()),
770764
arith.constant(None, op.attributes["offset"]),
771765
ctx.smem_requested_bytes,
772766
)

jax/experimental/mosaic/gpu/fragmented_array.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2795,9 +2795,9 @@ def swizzle_warp_idx_fn(lane_idx: ir.Value, vec_len: int):
27952795
num_banks % num_banks_per_output != 0
27962796
):
27972797
raise NotImplementedError(
2798-
"Unoptimized configuration for cross-warp reduction: "
2799-
f"{self.mlir_dtype} with {vec_len=}"
2800-
)
2798+
"Unoptimized configuration for cross-warp reduction: "
2799+
f"{self.mlir_dtype} with {vec_len=}"
2800+
)
28012801
# Define one row to be 128 bytes (32 banks of 4 bytes). For a given lane
28022802
# index, we want to store the data coming from all 4 warps
28032803
# contiguously in order to enable vectorized loads later on. If we
@@ -2981,7 +2981,12 @@ def reduce_stored(
29812981
scratch_ty = ir.MemRefType(scratch.type)
29822982
scratch_elems_per_register = WARPS_IN_WARPGROUP * unique_lanes * vec_len
29832983
if scratch_ty.shape[0] < scratch_elems_per_register:
2984-
raise ValueError("Insufficient scratch space for cross-warp reduction")
2984+
available_bytes = scratch_ty.shape[0] * utils.bitwidth(scratch_ty.element_type) // 8
2985+
required_bytes = scratch_elems_per_register * utils.bitwidth(scratch_ty.element_type) // 8
2986+
raise ValueError(
2987+
f"Required reduction scratch size ({required_bytes} bytes) is "
2988+
f"larger than the available scratch size ({available_bytes} bytes)"
2989+
)
29852990
if scratch_ty.get_strides_and_offset()[0] != [1]:
29862991
raise ValueError("Expected scratch to be contiguous")
29872992
num_concurrent_cross_warp_reductions = scratch_ty.shape[0] // scratch_elems_per_register

tests/pallas/mosaic_gpu_test.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2474,6 +2474,49 @@ def kernel(x_ref, out_ref):
24742474
self.assertAllClose(expected, jnp.sum(row))
24752475
self.assertArraysEqual(result, jax.lax.broadcast_in_dim(expected, (128,), ()))
24762476

2477+
def test_reduction_fails_on_too_little_scratch_bytes_for_cross_warp_reduction(self):
2478+
@functools.partial(
2479+
self.pallas_call,
2480+
out_shape=jax.ShapeDtypeStruct((128,), jnp.float32),
2481+
compiler_params=plgpu.CompilerParams(reduction_scratch_bytes=0),
2482+
)
2483+
def kernel(x_ref, y_ref):
2484+
x_val = plgpu.load(x_ref, (), layout=plgpu.Layout.WGMMA, optimized=False)
2485+
y_ref[...] = jnp.sum(x_val, axis=0)
2486+
2487+
with self.assertRaisesRegex(
2488+
ValueError,
2489+
r"Required reduction scratch size \(1024 bytes\) is larger than the "
2490+
r"available scratch size \(0 bytes\)"
2491+
):
2492+
kernel(jnp.zeros((128, 128), dtype=jnp.float32))
2493+
2494+
@jtu.thread_unsafe_test() # Modifies ``os.environ``.
2495+
def test_reduction_with_more_scratch_uses_less_synchronization(self):
2496+
def run_kernel(x, scratch_bytes):
2497+
def kernel(x_ref, y_ref):
2498+
x_val = plgpu.load(x_ref, (), layout=plgpu.Layout.WGMMA, optimized=False)
2499+
y_ref[...] = jnp.sum(x_val, axis=0)
2500+
return self.pallas_call(
2501+
kernel,
2502+
out_shape=jax.ShapeDtypeStruct((128,), jnp.float32),
2503+
compiler_params=plgpu.CompilerParams(reduction_scratch_bytes=scratch_bytes)
2504+
)(x)
2505+
2506+
x = jax.random.uniform(jax.random.key(0), shape=(128, 128), dtype=jnp.float32)
2507+
with jtu.set_env(MOSAIC_GPU_DUMP_SASS="1"), jtu.capture_stdout() as sass0:
2508+
out0 = run_kernel(x, 1024).block_until_ready()
2509+
2510+
with jtu.set_env(MOSAIC_GPU_DUMP_SASS="1"), jtu.capture_stdout() as sass1:
2511+
out1 = run_kernel(x, 2 * 1024).block_until_ready()
2512+
2513+
self.assertAllClose(out0, jnp.sum(x, axis=0))
2514+
self.assertArraysEqual(out0, out1)
2515+
2516+
syncs0 = re.findall(r"BAR.SYNC", sass0())
2517+
syncs1 = re.findall(r"BAR.SYNC", sass1())
2518+
self.assertLess(len(syncs1), len(syncs0))
2519+
24772520
@parameterized.product(
24782521
layout=(
24792522
plgpu.Layout.WGMMA,

0 commit comments

Comments
 (0)