9090
9191@dataclasses .dataclass (frozen = True , kw_only = True )
9292class ResourceEstimatorContext :
93+ reduction_scratch_bytes : int
9394 axis_names : _AxisNames
9495 lowering_semantics : mgpu .LoweringSemantics
9596
@@ -361,7 +362,6 @@ def _run_scoped_resource_estimator(
361362 f"Unsupported memory space: { aval .memory_space } " )
362363 return rs + _estimate_resources (ctx , jaxpr )
363364
364- REDUCE_SCRATCH_ELEMS = 128 * 4 # vector of 4 elements per lane in each WG
365365
366366@_register_resource_estimator (lax .reduce_sum_p )
367367@_register_resource_estimator (lax .reduce_max_p )
@@ -370,10 +370,10 @@ def _reduce_resource_estimator(
370370 ctx : ResourceEstimatorContext , x_aval : jax_core .ShapedArray , * , axes ,
371371 ** kwargs
372372) -> Resources :
373- del ctx , axes # Unused.
373+ del x_aval , axes , kwargs # Unused.
374374 # We don't need SMEM for some reductions, but it depends on the layout, so we
375375 # conservatively request the maximum scratch space we might need.
376- return Resources (smem_scratch_bytes = REDUCE_SCRATCH_ELEMS * x_aval . dtype . itemsize )
376+ return Resources (smem_scratch_bytes = ctx . reduction_scratch_bytes )
377377
378378
379379@dataclasses .dataclass (frozen = True )
@@ -420,6 +420,8 @@ class ModuleContext:
420420 mesh_info : pallas_utils .MeshInfo | None
421421 # See the documentation of unsafe_no_auto_barriers in CompilerParams.
422422 auto_barriers : bool
423+ # See the documentation of reduction_scratch_bytes in CompilerParams.
424+ reduction_scratch_bytes : int
423425 warp_axis_name : str | None = None
424426
425427 @property
@@ -596,8 +598,9 @@ def replace(self, **changes: Any) -> LoweringRuleContext:
596598 @property
597599 def estimator_ctx (self ) -> ResourceEstimatorContext :
598600 return ResourceEstimatorContext (
601+ reduction_scratch_bytes = self .module_ctx .reduction_scratch_bytes ,
599602 axis_names = self .module_ctx .axis_names ,
600- lowering_semantics = self .module_ctx .lowering_semantics ,
603+ lowering_semantics = self .module_ctx .lowering_semantics
601604 )
602605
603606
@@ -886,6 +889,7 @@ def lower_jaxpr_to_module(
886889
887890 rs = _estimate_resources (
888891 ResourceEstimatorContext (
892+ reduction_scratch_bytes = params .reduction_scratch_bytes ,
889893 axis_names = axis_names , lowering_semantics = lowering_semantics
890894 ),
891895 jaxpr ,
@@ -982,6 +986,7 @@ def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value):
982986 if jax_mesh is not None
983987 else None ,
984988 auto_barriers = not params .unsafe_no_auto_barriers ,
989+ reduction_scratch_bytes = params .reduction_scratch_bytes ,
985990 )
986991 del runtime_smem , grouped_barriers , runtime_barriers
987992 _ = lower_jaxpr_to_mosaic_gpu (
@@ -2592,12 +2597,11 @@ def _reduce_lowering_rule(op, ctx: LoweringRuleContext, x, *, axes, **kwargs):
25922597 raise NotImplementedError ("Multi-axis reductions not supported" )
25932598 reduced_dim = x .layout .tiling .tile_dimension (axes [0 ])
25942599 if any (reduced_dim [d ] for d in x .layout .partitioned_warp_dims ):
2595- size = x .layout .vector_length * 128 # a vector per lane in each WG.
2596- if size > REDUCE_SCRATCH_ELEMS :
2597- raise NotImplementedError (
2598- f"Reduce scratch { size = } exceeds max={ REDUCE_SCRATCH_ELEMS } "
2599- )
2600- scratch_ty = jax .ShapeDtypeStruct (shape = (size ,), dtype = x_aval .dtype )
2600+ dtype_bitwidth = dtypes .itemsize_bits (x_aval .dtype )
2601+ if dtype_bitwidth % 8 :
2602+ raise NotImplementedError ("Sub-byte dtypes not supported" )
2603+ scratch_elems = ctx .module_ctx .reduction_scratch_bytes * 8 // dtype_bitwidth
2604+ scratch_ty = jax .ShapeDtypeStruct (shape = (scratch_elems ,), dtype = x_aval .dtype )
26012605 ctx = ctx .module_ctx .scratch_view (scratch_ty )
26022606 else :
26032607 ctx = contextlib .nullcontext (None )
@@ -2645,7 +2649,9 @@ def _reduce_lowering_rule_wg(
26452649 def i32_attr (value : int ) -> ir .IntegerAttr :
26462650 return ir .IntegerAttr .get (ir .IntegerType .get_signless (32 ), value )
26472651 reduction .attributes ["offset" ] = i32_attr (ctx .module_ctx .smem_used_bytes )
2648- reduction .attributes ["scratch_size" ] = i32_attr (REDUCE_SCRATCH_ELEMS )
2652+ # TODO(bchetioui): here, we could just donate all the remaining free SMEM that
2653+ # we have at this point in time.
2654+ reduction .attributes ["scratch_size" ] = i32_attr (ctx .module_ctx .reduction_scratch_bytes )
26492655 return reduction .result
26502656
26512657
0 commit comments