6565)
6666
6767
68- # Target grid size for occupancy
69- _TARGET_GRID = 132 * 4
68+ # Blocks per SM for occupancy target
69+ _BLOCKS_PER_SM = 4
70+
71+ # Maximum threads per block (all modern NVIDIA GPUs support 1024)
72+ _MAX_THREADS_PER_BLOCK = 1024
73+
74+
75+ def _get_target_grid (device : torch .device = None ) -> int :
76+ """
77+ Compute target grid size based on device SM count.
78+
79+ Args:
80+ device: CUDA device. If None, uses current device.
81+
82+ Returns:
83+ Target number of blocks for good occupancy (SM_count * _BLOCKS_PER_SM)
84+ """
85+ if device is None :
86+ device = torch .cuda .current_device ()
87+ sm_count = torch .cuda .get_device_properties (device ).multi_processor_count
88+ return sm_count * _BLOCKS_PER_SM
89+
7090
7191# Warp configuration bounds
7292_MIN_WARPS = 4 # Minimum for reasonable occupancy (128 threads)
@@ -144,11 +164,16 @@ def __init__(
144164 dtype : cutlass .Numeric ,
145165 K : int ,
146166 enable_pdl : bool = False ,
167+ target_grid : int = None ,
147168 ):
148169 self .dtype = dtype
149170 self .K = K
150171 self .is_bfloat16 = dtype == cutlass .BFloat16
151172 self .enable_pdl = enable_pdl
173+ # Use provided target_grid or compute from current device
174+ self .target_grid = (
175+ target_grid if target_grid is not None else _get_target_grid ()
176+ )
152177
153178 assert K % SF_VEC_SIZE == 0
154179 self .num_sf_blocks_per_row = K // SF_VEC_SIZE
@@ -165,16 +190,16 @@ def __call__(
165190 threads_per_block = self .WARPS_PER_BLOCK * WARP_SIZE
166191 sf_blocks_per_tb = self .WARPS_PER_BLOCK * SF_BLOCKS_PER_WARP
167192
168- # Compute grid size at runtime
193+ # Compute grid size at runtime (target_grid is device-specific)
169194 num_blocks = cutlass .min (
170- cute .ceil_div (total_sf_blocks , sf_blocks_per_tb ), _TARGET_GRID
195+ cute .ceil_div (total_sf_blocks , sf_blocks_per_tb ), self . target_grid
171196 )
172197
173198 self .kernel (mInput , mOutput , mScales , total_sf_blocks ).launch (
174199 grid = [num_blocks , 1 , 1 ],
175200 block = [threads_per_block , 1 , 1 ],
176- max_number_threads = [512 , 1 , 1 ],
177- min_blocks_per_mp = 4 ,
201+ max_number_threads = [_MAX_THREADS_PER_BLOCK , 1 , 1 ],
202+ min_blocks_per_mp = _BLOCKS_PER_SM ,
178203 stream = stream ,
179204 use_pdl = self .enable_pdl ,
180205 )
@@ -286,11 +311,16 @@ def __init__(
286311 dtype : cutlass .Numeric ,
287312 K : int ,
288313 enable_pdl : bool = False ,
314+ target_grid : int = None ,
289315 ):
290316 self .dtype = dtype
291317 self .K = K
292318 self .is_bfloat16 = dtype == cutlass .BFloat16
293319 self .enable_pdl = enable_pdl
320+ # Use provided target_grid or compute from current device
321+ self .target_grid = (
322+ target_grid if target_grid is not None else _get_target_grid ()
323+ )
294324
295325 assert K % SF_VEC_SIZE == 0
296326 self .num_sf_blocks_per_row = K // SF_VEC_SIZE
@@ -328,16 +358,16 @@ def __call__(
328358 threads_per_block = self .warps_per_block * WARP_SIZE
329359 rows_per_block = self .rows_per_block
330360
331- # Compute grid size at runtime
361+ # Compute grid size at runtime (target_grid is device-specific)
332362 # Grid covers row batches, not individual rows
333363 total_row_batches = cute .ceil_div (padded_M , rows_per_block )
334- num_blocks = cutlass .min (total_row_batches , _TARGET_GRID )
364+ num_blocks = cutlass .min (total_row_batches , self . target_grid )
335365
336366 self .kernel (mInput , mOutput , mScales , M , padded_M ).launch (
337367 grid = [num_blocks , 1 , 1 ],
338368 block = [threads_per_block , 1 , 1 ],
339- max_number_threads = [threads_per_block , 1 , 1 ],
340- min_blocks_per_mp = 4 ,
369+ max_number_threads = [_MAX_THREADS_PER_BLOCK , 1 , 1 ],
370+ min_blocks_per_mp = _BLOCKS_PER_SM ,
341371 stream = stream ,
342372 use_pdl = self .enable_pdl ,
343373 )
@@ -550,14 +580,17 @@ def _get_compiled_kernel_linear(
550580 is_bfloat16 : bool ,
551581 K : int ,
552582 enable_pdl : bool = False ,
583+ target_grid : int = None ,
553584) -> Callable :
554585 """
555586 Get or compile LINEAR layout kernel with TVM-FFI.
556587
557- Cached by (K, dtype, pdl) - M-agnostic compilation.
588+ Cached by (K, dtype, pdl, target_grid ) - M-agnostic compilation.
558589 """
559590 cutlass_dtype = cutlass .BFloat16 if is_bfloat16 else cutlass .Float16
560- kernel_obj = MXFP8QuantizeLinearKernel (cutlass_dtype , K , enable_pdl )
591+ kernel_obj = MXFP8QuantizeLinearKernel (
592+ cutlass_dtype , K , enable_pdl , target_grid = target_grid
593+ )
561594
562595 # Use symbolic M for dynamic batch sizes
563596 sym_m = cute .sym_int ()
@@ -594,14 +627,17 @@ def _get_compiled_kernel_swizzled(
594627 is_bfloat16 : bool ,
595628 K : int ,
596629 enable_pdl : bool = False ,
630+ target_grid : int = None ,
597631) -> Callable :
598632 """
599633 Get or compile SWIZZLED layout kernel with TVM-FFI.
600634
601- Cached by (K, dtype, pdl) - M-agnostic compilation.
635+ Cached by (K, dtype, pdl, target_grid ) - M-agnostic compilation.
602636 """
603637 cutlass_dtype = cutlass .BFloat16 if is_bfloat16 else cutlass .Float16
604- kernel_obj = MXFP8QuantizeSwizzledKernel (cutlass_dtype , K , enable_pdl )
638+ kernel_obj = MXFP8QuantizeSwizzledKernel (
639+ cutlass_dtype , K , enable_pdl , target_grid = target_grid
640+ )
605641
606642 # Use symbolic M for dynamic batch sizes
607643 sym_m = cute .sym_int ()
@@ -699,6 +735,9 @@ def mxfp8_quantize_cute_dsl(
699735
700736 is_bfloat16 = input .dtype == torch .bfloat16
701737
738+ # Compute device-specific target grid for kernel compilation
739+ target_grid = _get_target_grid (input .device )
740+
702741 # Compute M-dependent values outside the cached kernel
703742 num_sf_blocks_per_row = padded_k // SF_VEC_SIZE
704743
@@ -708,7 +747,9 @@ def mxfp8_quantize_cute_dsl(
708747 padded_sf_cols = ((num_sf_blocks_per_row + 3 ) // 4 ) * 4
709748 scale_output_size = padded_m * padded_sf_cols
710749
711- kernel_fn = _get_compiled_kernel_swizzled (is_bfloat16 , padded_k , enable_pdl )
750+ kernel_fn = _get_compiled_kernel_swizzled (
751+ is_bfloat16 , padded_k , enable_pdl , target_grid
752+ )
712753
713754 fp8_output = torch .empty (m , padded_k , dtype = torch .uint8 , device = input .device )
714755 scale_output = torch .empty (
@@ -721,7 +762,9 @@ def mxfp8_quantize_cute_dsl(
721762 total_sf_blocks = m * num_sf_blocks_per_row
722763 scale_output_size = total_sf_blocks
723764
724- kernel_fn = _get_compiled_kernel_linear (is_bfloat16 , padded_k , enable_pdl )
765+ kernel_fn = _get_compiled_kernel_linear (
766+ is_bfloat16 , padded_k , enable_pdl , target_grid
767+ )
725768
726769 fp8_output = torch .empty (m , padded_k , dtype = torch .uint8 , device = input .device )
727770 scale_output = torch .empty (
0 commit comments