Skip to content

Commit 6a0b546

Browse files
authored
[KERNELS] fix hopper smem heuristic and blackwell shuffle transform (#9999)
two commits forgotten in #9986
1 parent 028e5da commit 6a0b546

2 files changed

Lines changed: 12 additions & 7 deletions

File tree

python/triton_kernels/triton_kernels/matmul_details/opt_flags_details/opt_flags_nvidia.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,10 @@ def compute_num_stages(
180180
# that is not fully captured by the simple stage_size model above.
181181
if is_persistent and (lhs_dtype == FP32 or rhs_dtype == FP32):
182182
smem_capacity -= 32 * 1024
183+
if is_persistent and not has_native_mxfp and epilogue_reduction_n > 1:
184+
# Hopper fused reductions materialize an additional reduced-N output
185+
# tile in smem.
186+
smem_capacity -= int(block_m * acc_block_n * out_itemsize)
183187
smem_capacity = max(smem_capacity, 0)
184188
max_stages = 5 if rhs_dtype == FP4 else 4 # maybe 5 everywhere; just haven't tested
185189
num_stages = min(smem_capacity // int(stage_size), max_stages)

python/triton_kernels/triton_kernels/tensor_details/layout_details/blackwell_value_shuffled.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -104,13 +104,11 @@ def swizzle_data(self, data: torch.Tensor) -> torch.Tensor:
104104
Target layout: [E, num_tiles_k, num_tiles_n, tile_n, tile_k_packed]
105105
This matches the baseline TMA block shape [block_n, packed_block_k] after swapping.
106106
"""
107-
if data.ndim == 2:
108-
data = data.unsqueeze(0)
109-
if data.ndim != 3:
110-
raise ValueError(f"Expected 2D or 3D canonical data, got {data.ndim}D")
111-
112107
data = self._canonical_to_physical(data)
113-
E, K_packed, N = data.shape
108+
leading_shape = data.shape[:-2]
109+
E = math.prod(leading_shape)
110+
K_packed, N = data.shape[-2:]
111+
data = data.reshape(E, K_packed, N)
114112
tile_k_packed, tile_n, padded_K_packed, padded_N, num_tiles_k, num_tiles_n = \
115113
self._compute_params(E, K_packed, N)
116114

@@ -139,6 +137,7 @@ def unswizzle_data(self, data: torch.Tensor) -> torch.Tensor:
139137
Input layout: [E, num_tiles_k, num_tiles_n, tile_n, tile_k_packed]
140138
"""
141139
E = data.shape[0]
140+
leading_shape = self.shape[:-2]
142141
# Recover original shape from self.shape (the logical shape passed to convert_layout)
143142
orig_K_packed = self.shape[-2] // 2 if self.is_fp4 else self.shape[-2]
144143
orig_N = self.shape[-1]
@@ -159,4 +158,6 @@ def unswizzle_data(self, data: torch.Tensor) -> torch.Tensor:
159158
# Trim padding back to original shape
160159
data = data[:, :orig_K_packed, :orig_N].contiguous()
161160
data = self._physical_to_canonical(data)
162-
return data if len(self.shape) == 3 else data.squeeze(0)
161+
if not leading_shape:
162+
return data.squeeze(0)
163+
return data.reshape(*leading_shape, data.shape[-2], data.shape[-1])

0 commit comments

Comments
 (0)