Skip to content

Commit 5054cb1

Browse files
yyhclaude
authored andcommitted
fix: restore l parameter in wrapper for backward compat when b_tensor_l_sizes is None
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 9edbe9b commit 5054cb1

2 files changed

Lines changed: 36 additions & 22 deletions

File tree

flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py

Lines changed: 31 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -410,8 +410,8 @@ def __init__(
410410
vectorized_f32: bool,
411411
topk: cutlass.Int64,
412412
raster_along_m: bool = False,
413-
enable_pdl: bool = True,
414413
b_tensor_l_sizes: Optional[Tuple[int, ...]] = None,
414+
enable_pdl: bool = True,
415415
):
416416
"""Initializes the configuration for a Blackwell blockscaled dense GEMM kernel with
417417
gather operation and SwiGLU fusion.
@@ -533,25 +533,24 @@ def __init__(
533533
self.vectorized_f32 = vectorized_f32
534534

535535
# Multi-B tensor configuration
536-
# b_tensor_l_sizes is required — the Python wrapper layer always provides it
537-
# as a tuple (even for single-B, e.g. (256,)).
538536
if b_tensor_l_sizes is None:
539-
raise ValueError(
540-
"b_tensor_l_sizes is required. Pass a tuple with the number of "
541-
"experts per tensor, e.g. (num_experts,) for single-B."
537+
self.num_b_tensors = 1
538+
self.b_tensor_l_sizes = None
539+
# Offsets padded for safe indexing in kernel
540+
self.b_tensor_l_offsets = (0,) + (2**30,) * self.MAX_B_TENSORS
541+
else:
542+
assert len(b_tensor_l_sizes) <= self.MAX_B_TENSORS, (
543+
f"Max {self.MAX_B_TENSORS} B tensors, got {len(b_tensor_l_sizes)}"
542544
)
543-
assert len(b_tensor_l_sizes) <= self.MAX_B_TENSORS, (
544-
f"Max {self.MAX_B_TENSORS} B tensors, got {len(b_tensor_l_sizes)}"
545-
)
546-
self.num_b_tensors = len(b_tensor_l_sizes)
547-
self.b_tensor_l_sizes = b_tensor_l_sizes
548-
offsets = [0]
549-
for l_size in b_tensor_l_sizes:
550-
offsets.append(offsets[-1] + l_size)
551-
# Pad to MAX_B_TENSORS + 1 for safe indexing
552-
while len(offsets) < self.MAX_B_TENSORS + 1:
553-
offsets.append(2**30)
554-
self.b_tensor_l_offsets = tuple(offsets)
545+
self.num_b_tensors = len(b_tensor_l_sizes)
546+
self.b_tensor_l_sizes = b_tensor_l_sizes
547+
offsets = [0]
548+
for l_size in b_tensor_l_sizes:
549+
offsets.append(offsets[-1] + l_size)
550+
# Pad to MAX_B_TENSORS + 1 for safe indexing
551+
while len(offsets) < self.MAX_B_TENSORS + 1:
552+
offsets.append(2**30)
553+
self.b_tensor_l_offsets = tuple(offsets)
555554

556555
def _setup_attributes(self):
557556
"""Set up configurations that are dependent on GEMM inputs
@@ -4034,6 +4033,7 @@ def wrapper(
40344033
m: cutlass.Int64,
40354034
n: cutlass.Int64,
40364035
k: cutlass.Int64,
4036+
l: cutlass.Int64, # noqa: E741
40374037
tile_size: cutlass.Constexpr,
40384038
scaling_vector_size: cutlass.Constexpr,
40394039
max_active_clusters: cutlass.Constexpr,
@@ -4043,12 +4043,19 @@ def wrapper(
40434043
"""Unified wrapper supporting both single-B and multi-B tensors.
40444044
40454045
B tensors are always passed as tuples (length 1 for single-B).
4046-
L sizes are configured via b_tensor_l_sizes in __init__.
4046+
When b_tensor_l_sizes is provided, L sizes come from b_tensor_l_sizes;
4047+
otherwise falls back to the l parameter (backward compatible single-B).
40474048
"""
40484049
scale_k = k // scaling_vector_size
40494050
interm_size = n // 2
40504051
num_tiles = m // tile_size
4051-
total_l = self.b_tensor_l_offsets[self.num_b_tensors]
4052+
# When b_tensor_l_sizes is provided, total_l comes from the precomputed offsets
4053+
# and l is ignored. Callers must ensure l == sum(b_tensor_l_sizes).
4054+
# When b_tensor_l_sizes is None (single-B backward compat), l is used directly.
4055+
if cutlass.const_expr(self.b_tensor_l_sizes is not None):
4056+
total_l = self.b_tensor_l_offsets[self.num_b_tensors]
4057+
else:
4058+
total_l = l
40524059

40534060
a = cute.make_tensor(
40544061
a_ptr, layout=cute.make_ordered_layout((orig_m, k, 1), order=(1, 0, 2))
@@ -4069,7 +4076,10 @@ def wrapper(
40694076
)
40704077

40714078
# Create B and alpha tensors using const_expr conditions
4072-
l_0 = self.b_tensor_l_sizes[0]
4079+
if cutlass.const_expr(self.b_tensor_l_sizes is not None):
4080+
l_0 = self.b_tensor_l_sizes[0]
4081+
else:
4082+
l_0 = l
40734083
alpha_0 = cute.make_tensor(alpha_ptr_tuple[0], layout=cute.make_layout((l_0,)))
40744084
b_0 = cute.make_tensor(
40754085
b_ptr_tuple[0],

flashinfer/fused_moe/cute_dsl/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -274,8 +274,9 @@ def _get_compiled_gather_kernel(
274274
# Order must match wrapper signature:
275275
# (a_ptr, b_ptr_tuple, a_sf_ptr, b_sf_ptr_tuple, c_ptr, c_sf_ptr, alpha_ptr_tuple,
276276
# tile_idx_to_group_idx_ptr, tile_idx_to_mn_limit_ptr, token_id_mapping_ptr,
277-
# num_non_exiting_tiles_ptr, norm_const_ptr, orig_m, m, n, k,
277+
# num_non_exiting_tiles_ptr, norm_const_ptr, orig_m, m, n, k, l,
278278
# tile_size, scaling_vector_size, max_active_clusters, stream)
279+
num_experts = sum(b_tensor_l_sizes)
279280
compile_args = [
280281
a_ptr,
281282
b_ptr,
@@ -293,6 +294,7 @@ def _get_compiled_gather_kernel(
293294
permuted_m,
294295
n,
295296
k,
297+
num_experts,
296298
]
297299

298300
compiled_gemm = cute.compile(
@@ -620,6 +622,7 @@ def blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion_nvfp4(
620622
)
621623

622624
# Execute kernel
625+
num_experts = sum(b_tensor_l_sizes)
623626
exec_args = [
624627
a_ptr,
625628
b_ptr,
@@ -637,6 +640,7 @@ def blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion_nvfp4(
637640
permuted_m,
638641
n,
639642
k,
643+
num_experts, # l
640644
]
641645
compiled_gemm(*exec_args, stream=stream)
642646

0 commit comments

Comments
 (0)