Skip to content

Commit cb09ffe

Browse files
committed
Perf tuning for B200
1 parent 3442add commit cb09ffe

1 file changed

Lines changed: 59 additions & 16 deletions

File tree

flashinfer/cute_dsl/mxfp8_quantize.py

Lines changed: 59 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,28 @@
6565
)
6666

6767

68-
# Target grid size for occupancy
69-
_TARGET_GRID = 132 * 4
68+
# Blocks per SM for occupancy target
69+
_BLOCKS_PER_SM = 4
70+
71+
# Maximum threads per block (all modern NVIDIA GPUs support 1024)
72+
_MAX_THREADS_PER_BLOCK = 1024
73+
74+
75+
def _get_target_grid(device: torch.device = None) -> int:
76+
"""
77+
Compute target grid size based on device SM count.
78+
79+
Args:
80+
device: CUDA device. If None, uses current device.
81+
82+
Returns:
83+
Target number of blocks for good occupancy (SM_count * _BLOCKS_PER_SM)
84+
"""
85+
if device is None:
86+
device = torch.cuda.current_device()
87+
sm_count = torch.cuda.get_device_properties(device).multi_processor_count
88+
return sm_count * _BLOCKS_PER_SM
89+
7090

7191
# Warp configuration bounds
7292
_MIN_WARPS = 4 # Minimum for reasonable occupancy (128 threads)
@@ -144,11 +164,16 @@ def __init__(
144164
dtype: cutlass.Numeric,
145165
K: int,
146166
enable_pdl: bool = False,
167+
target_grid: int = None,
147168
):
148169
self.dtype = dtype
149170
self.K = K
150171
self.is_bfloat16 = dtype == cutlass.BFloat16
151172
self.enable_pdl = enable_pdl
173+
# Use provided target_grid or compute from current device
174+
self.target_grid = (
175+
target_grid if target_grid is not None else _get_target_grid()
176+
)
152177

153178
assert K % SF_VEC_SIZE == 0
154179
self.num_sf_blocks_per_row = K // SF_VEC_SIZE
@@ -165,16 +190,16 @@ def __call__(
165190
threads_per_block = self.WARPS_PER_BLOCK * WARP_SIZE
166191
sf_blocks_per_tb = self.WARPS_PER_BLOCK * SF_BLOCKS_PER_WARP
167192

168-
# Compute grid size at runtime
193+
# Compute grid size at runtime (target_grid is device-specific)
169194
num_blocks = cutlass.min(
170-
cute.ceil_div(total_sf_blocks, sf_blocks_per_tb), _TARGET_GRID
195+
cute.ceil_div(total_sf_blocks, sf_blocks_per_tb), self.target_grid
171196
)
172197

173198
self.kernel(mInput, mOutput, mScales, total_sf_blocks).launch(
174199
grid=[num_blocks, 1, 1],
175200
block=[threads_per_block, 1, 1],
176-
max_number_threads=[512, 1, 1],
177-
min_blocks_per_mp=4,
201+
max_number_threads=[_MAX_THREADS_PER_BLOCK, 1, 1],
202+
min_blocks_per_mp=_BLOCKS_PER_SM,
178203
stream=stream,
179204
use_pdl=self.enable_pdl,
180205
)
@@ -286,11 +311,16 @@ def __init__(
286311
dtype: cutlass.Numeric,
287312
K: int,
288313
enable_pdl: bool = False,
314+
target_grid: int = None,
289315
):
290316
self.dtype = dtype
291317
self.K = K
292318
self.is_bfloat16 = dtype == cutlass.BFloat16
293319
self.enable_pdl = enable_pdl
320+
# Use provided target_grid or compute from current device
321+
self.target_grid = (
322+
target_grid if target_grid is not None else _get_target_grid()
323+
)
294324

295325
assert K % SF_VEC_SIZE == 0
296326
self.num_sf_blocks_per_row = K // SF_VEC_SIZE
@@ -328,16 +358,16 @@ def __call__(
328358
threads_per_block = self.warps_per_block * WARP_SIZE
329359
rows_per_block = self.rows_per_block
330360

331-
# Compute grid size at runtime
361+
# Compute grid size at runtime (target_grid is device-specific)
332362
# Grid covers row batches, not individual rows
333363
total_row_batches = cute.ceil_div(padded_M, rows_per_block)
334-
num_blocks = cutlass.min(total_row_batches, _TARGET_GRID)
364+
num_blocks = cutlass.min(total_row_batches, self.target_grid)
335365

336366
self.kernel(mInput, mOutput, mScales, M, padded_M).launch(
337367
grid=[num_blocks, 1, 1],
338368
block=[threads_per_block, 1, 1],
339-
max_number_threads=[threads_per_block, 1, 1],
340-
min_blocks_per_mp=4,
369+
max_number_threads=[_MAX_THREADS_PER_BLOCK, 1, 1],
370+
min_blocks_per_mp=_BLOCKS_PER_SM,
341371
stream=stream,
342372
use_pdl=self.enable_pdl,
343373
)
@@ -550,14 +580,17 @@ def _get_compiled_kernel_linear(
550580
is_bfloat16: bool,
551581
K: int,
552582
enable_pdl: bool = False,
583+
target_grid: int = None,
553584
) -> Callable:
554585
"""
555586
Get or compile LINEAR layout kernel with TVM-FFI.
556587
557-
Cached by (K, dtype, pdl) - M-agnostic compilation.
588+
Cached by (K, dtype, pdl, target_grid) - M-agnostic compilation.
558589
"""
559590
cutlass_dtype = cutlass.BFloat16 if is_bfloat16 else cutlass.Float16
560-
kernel_obj = MXFP8QuantizeLinearKernel(cutlass_dtype, K, enable_pdl)
591+
kernel_obj = MXFP8QuantizeLinearKernel(
592+
cutlass_dtype, K, enable_pdl, target_grid=target_grid
593+
)
561594

562595
# Use symbolic M for dynamic batch sizes
563596
sym_m = cute.sym_int()
@@ -594,14 +627,17 @@ def _get_compiled_kernel_swizzled(
594627
is_bfloat16: bool,
595628
K: int,
596629
enable_pdl: bool = False,
630+
target_grid: int = None,
597631
) -> Callable:
598632
"""
599633
Get or compile SWIZZLED layout kernel with TVM-FFI.
600634
601-
Cached by (K, dtype, pdl) - M-agnostic compilation.
635+
Cached by (K, dtype, pdl, target_grid) - M-agnostic compilation.
602636
"""
603637
cutlass_dtype = cutlass.BFloat16 if is_bfloat16 else cutlass.Float16
604-
kernel_obj = MXFP8QuantizeSwizzledKernel(cutlass_dtype, K, enable_pdl)
638+
kernel_obj = MXFP8QuantizeSwizzledKernel(
639+
cutlass_dtype, K, enable_pdl, target_grid=target_grid
640+
)
605641

606642
# Use symbolic M for dynamic batch sizes
607643
sym_m = cute.sym_int()
@@ -699,6 +735,9 @@ def mxfp8_quantize_cute_dsl(
699735

700736
is_bfloat16 = input.dtype == torch.bfloat16
701737

738+
# Compute device-specific target grid for kernel compilation
739+
target_grid = _get_target_grid(input.device)
740+
702741
# Compute M-dependent values outside the cached kernel
703742
num_sf_blocks_per_row = padded_k // SF_VEC_SIZE
704743

@@ -708,7 +747,9 @@ def mxfp8_quantize_cute_dsl(
708747
padded_sf_cols = ((num_sf_blocks_per_row + 3) // 4) * 4
709748
scale_output_size = padded_m * padded_sf_cols
710749

711-
kernel_fn = _get_compiled_kernel_swizzled(is_bfloat16, padded_k, enable_pdl)
750+
kernel_fn = _get_compiled_kernel_swizzled(
751+
is_bfloat16, padded_k, enable_pdl, target_grid
752+
)
712753

713754
fp8_output = torch.empty(m, padded_k, dtype=torch.uint8, device=input.device)
714755
scale_output = torch.empty(
@@ -721,7 +762,9 @@ def mxfp8_quantize_cute_dsl(
721762
total_sf_blocks = m * num_sf_blocks_per_row
722763
scale_output_size = total_sf_blocks
723764

724-
kernel_fn = _get_compiled_kernel_linear(is_bfloat16, padded_k, enable_pdl)
765+
kernel_fn = _get_compiled_kernel_linear(
766+
is_bfloat16, padded_k, enable_pdl, target_grid
767+
)
725768

726769
fp8_output = torch.empty(m, padded_k, dtype=torch.uint8, device=input.device)
727770
scale_output = torch.empty(

0 commit comments

Comments
 (0)