@@ -410,8 +410,8 @@ def __init__(
410410 vectorized_f32 : bool ,
411411 topk : cutlass .Int64 ,
412412 raster_along_m : bool = False ,
413- enable_pdl : bool = True ,
414413 b_tensor_l_sizes : Optional [Tuple [int , ...]] = None ,
414+ enable_pdl : bool = True ,
415415 ):
416416 """Initializes the configuration for a Blackwell blockscaled dense GEMM kernel with
417417 gather operation and SwiGLU fusion.
@@ -533,25 +533,24 @@ def __init__(
533533 self .vectorized_f32 = vectorized_f32
534534
535535 # Multi-B tensor configuration
536- # b_tensor_l_sizes is required — the Python wrapper layer always provides it
537- # as a tuple (even for single-B, e.g. (256,)).
538536 if b_tensor_l_sizes is None :
539- raise ValueError (
540- "b_tensor_l_sizes is required. Pass a tuple with the number of "
541- "experts per tensor, e.g. (num_experts,) for single-B."
537+ self .num_b_tensors = 1
538+ self .b_tensor_l_sizes = None
539+ # Offsets padded for safe indexing in kernel
540+ self .b_tensor_l_offsets = (0 ,) + (2 ** 30 ,) * self .MAX_B_TENSORS
541+ else :
542+ assert len (b_tensor_l_sizes ) <= self .MAX_B_TENSORS , (
543+ f"Max { self .MAX_B_TENSORS } B tensors, got { len (b_tensor_l_sizes )} "
542544 )
543- assert len (b_tensor_l_sizes ) <= self .MAX_B_TENSORS , (
544- f"Max { self .MAX_B_TENSORS } B tensors, got { len (b_tensor_l_sizes )} "
545- )
546- self .num_b_tensors = len (b_tensor_l_sizes )
547- self .b_tensor_l_sizes = b_tensor_l_sizes
548- offsets = [0 ]
549- for l_size in b_tensor_l_sizes :
550- offsets .append (offsets [- 1 ] + l_size )
551- # Pad to MAX_B_TENSORS + 1 for safe indexing
552- while len (offsets ) < self .MAX_B_TENSORS + 1 :
553- offsets .append (2 ** 30 )
554- self .b_tensor_l_offsets = tuple (offsets )
545+ self .num_b_tensors = len (b_tensor_l_sizes )
546+ self .b_tensor_l_sizes = b_tensor_l_sizes
547+ offsets = [0 ]
548+ for l_size in b_tensor_l_sizes :
549+ offsets .append (offsets [- 1 ] + l_size )
550+ # Pad to MAX_B_TENSORS + 1 for safe indexing
551+ while len (offsets ) < self .MAX_B_TENSORS + 1 :
552+ offsets .append (2 ** 30 )
553+ self .b_tensor_l_offsets = tuple (offsets )
555554
556555 def _setup_attributes (self ):
557556 """Set up configurations that are dependent on GEMM inputs
@@ -4034,6 +4033,7 @@ def wrapper(
40344033 m : cutlass .Int64 ,
40354034 n : cutlass .Int64 ,
40364035 k : cutlass .Int64 ,
4036+ l : cutlass .Int64 , # noqa: E741
40374037 tile_size : cutlass .Constexpr ,
40384038 scaling_vector_size : cutlass .Constexpr ,
40394039 max_active_clusters : cutlass .Constexpr ,
@@ -4043,12 +4043,19 @@ def wrapper(
40434043 """Unified wrapper supporting both single-B and multi-B tensors.
40444044
40454045 B tensors are always passed as tuples (length 1 for single-B).
4046- L sizes are configured via b_tensor_l_sizes in __init__.
4046+ When b_tensor_l_sizes is provided, L sizes come from b_tensor_l_sizes;
4047+ otherwise falls back to the l parameter (backward compatible single-B).
40474048 """
40484049 scale_k = k // scaling_vector_size
40494050 interm_size = n // 2
40504051 num_tiles = m // tile_size
4051- total_l = self .b_tensor_l_offsets [self .num_b_tensors ]
4052+ # When b_tensor_l_sizes is provided, total_l comes from the precomputed offsets
4053+ # and l is ignored. Callers must ensure l == sum(b_tensor_l_sizes).
4054+ # When b_tensor_l_sizes is None (single-B backward compat), l is used directly.
4055+ if cutlass .const_expr (self .b_tensor_l_sizes is not None ):
4056+ total_l = self .b_tensor_l_offsets [self .num_b_tensors ]
4057+ else :
4058+ total_l = l
40524059
40534060 a = cute .make_tensor (
40544061 a_ptr , layout = cute .make_ordered_layout ((orig_m , k , 1 ), order = (1 , 0 , 2 ))
@@ -4069,7 +4076,10 @@ def wrapper(
40694076 )
40704077
40714078 # Create B and alpha tensors using const_expr conditions
4072- l_0 = self .b_tensor_l_sizes [0 ]
4079+ if cutlass .const_expr (self .b_tensor_l_sizes is not None ):
4080+ l_0 = self .b_tensor_l_sizes [0 ]
4081+ else :
4082+ l_0 = l
40734083 alpha_0 = cute .make_tensor (alpha_ptr_tuple [0 ], layout = cute .make_layout ((l_0 ,)))
40744084 b_0 = cute .make_tensor (
40754085 b_ptr_tuple [0 ],
0 commit comments