diff --git a/flashinfer/norm/kernels/fused_add_rmsnorm.py b/flashinfer/norm/kernels/fused_add_rmsnorm.py index 406b604ac5..31e3840d4a 100644 --- a/flashinfer/norm/kernels/fused_add_rmsnorm.py +++ b/flashinfer/norm/kernels/fused_add_rmsnorm.py @@ -22,6 +22,7 @@ """ import functools +import math import cutlass import cutlass.cute as cute @@ -34,17 +35,20 @@ rcp_approx_ftz, cvt_and_store_f32_to_e4m3_hw, cvt_and_store_f32_to_e4m3_sw, + cvt_and_store_8xf32_to_e4m3_hw, + cvt_and_store_4xf32_to_e4m3_hw, + cvt_and_store_2xf32_to_e4m3_hw, has_hw_fp8_cvt, get_ptr_as_int64, - row_reduce_sum, + get_sm_version, + row_reduce_sum_multirow, predicate_k, - compute_optimal_vec_size, - compute_threads_per_row, - make_tv_layout, _torch_dtype_to_str, get_cutlass_dtype, ) +from .rmsnorm import RMSNormKernel + # ============================================================================= # FusedAddRMSNormKernel @@ -65,29 +69,120 @@ def __init__( dtype: cutlass.Numeric, H: int, weight_bias: float = 0.0, + sm_version: int | None = None, ): self.dtype = dtype self.H = H self.weight_bias = weight_bias + self.sm_version = sm_version if sm_version is not None else get_sm_version() + + self.cluster_n = self._compute_cluster_n(H, dtype, self.sm_version) + self.H_per_cta = H // self.cluster_n + + elem_bytes = dtype.width // 8 + max_vec_size = COPY_BITS // 8 // elem_bytes - # Vectorization parameters: use optimal vec_size for warp utilization - elem_bits = dtype.width - max_vec_size = COPY_BITS // elem_bits - self.vec_size = compute_optimal_vec_size(H, max_vec_size) - self.copy_bits = self.vec_size * elem_bits + h_align = self.H_per_cta & (-self.H_per_cta) + self.vec_size = min(h_align, max_vec_size) + self.copy_bits = self.vec_size * dtype.width - self.threads_per_row = compute_threads_per_row(H, self.vec_size) - self.num_threads = self.threads_per_row - self.num_warps = max(self.threads_per_row // 32, 1) + self.threads_per_row = RMSNormKernel._compute_threads_per_row(self.H_per_cta) + self.num_threads = RMSNormKernel._compute_num_threads(self.H_per_cta) + self.rows_per_block = self.num_threads // self.threads_per_row + self.warps_per_row = max(self.threads_per_row // 32, 1) self.num_vec_blocks = max( - 1, (H // self.vec_size + self.threads_per_row - 1) // self.threads_per_row + 1, + (self.H_per_cta // self.vec_size + self.threads_per_row - 1) + // self.threads_per_row, ) self.cols_per_tile = self.vec_size * self.num_vec_blocks * self.threads_per_row + if self.copy_bits >= 32: + tile_bytes_2 = 2 * self.rows_per_block * self.cols_per_tile * elem_bytes + props = torch.cuda.get_device_properties(torch.cuda.current_device()) + self.use_async_copy = ( + tile_bytes_2 <= props.shared_memory_per_block_optin // 2 + ) + else: + self.use_async_copy = False + + @staticmethod + def _compute_cluster_n(H: int, dtype: cutlass.Numeric, sm_version: int) -> int: + """Compute optimal cluster size for fused-add kernel (2 shared tiles). + + Because fused-add needs 2 tiles (input + residual) in shared memory, + we target smem <= max_smem // 2 so that at least 2 blocks can + co-schedule per SM for good occupancy. If no cluster_n achieves + that, fall back to the first cluster_n that fits at all. + """ + if sm_version < 90: + return 1 + + props = torch.cuda.get_device_properties(torch.cuda.current_device()) + max_smem_bytes = props.shared_memory_per_block_optin + elem_size = dtype.width // 8 + occupancy_target = max_smem_bytes // 2 + + best_fit = 1 + for cluster_n in [1, 2, 4, 8, 16]: + if H % cluster_n != 0: + continue + smem_needed = FusedAddRMSNormKernel._estimate_smem_bytes( + H, cluster_n, elem_size + ) + if smem_needed <= occupancy_target: + return cluster_n + if smem_needed <= max_smem_bytes and best_fit == 1: + best_fit = cluster_n + + return best_fit + + @staticmethod + def _estimate_smem_bytes(H: int, cluster_n: int, elem_size: int) -> int: + """Estimate shared memory bytes (2 tiles for input + residual).""" + H_per_cta = H // cluster_n + threads_per_row = RMSNormKernel._compute_threads_per_row(H_per_cta) + num_threads = RMSNormKernel._compute_num_threads(H_per_cta) + rows_per_block = num_threads // threads_per_row + warps_per_row = max(threads_per_row // 32, 1) + + max_vec_size = COPY_BITS // 8 // elem_size + h_align = H_per_cta & (-H_per_cta) + vec_size = min(h_align, max_vec_size) + num_vec_blocks = max( + 1, (H_per_cta // vec_size + threads_per_row - 1) // threads_per_row + ) + cols_per_tile = vec_size * num_vec_blocks * threads_per_row + + tile_bytes = 2 * rows_per_block * cols_per_tile * elem_size + + if cluster_n == 1: + return tile_bytes + rows_per_block * warps_per_row * 4 + else: + return ( + tile_bytes + + rows_per_block * warps_per_row * cluster_n * 4 + + 8 # mbarrier + ) + def _smem_size_in_bytes(self) -> int: - # Only reduction buffer needed (register-based approach) - return self.num_warps * 4 + if self.use_async_copy: + tile_bytes = ( + 2 * self.rows_per_block * self.cols_per_tile * (self.dtype.width // 8) + ) + else: + tile_bytes = 0 + + if self.cluster_n == 1: + reduction_bytes = self.rows_per_block * self.warps_per_row * 4 + else: + reduction_bytes = ( + self.rows_per_block * self.warps_per_row * self.cluster_n * 4 + ) + + mbar_bytes = 8 if self.cluster_n > 1 else 0 + return tile_bytes + reduction_bytes + mbar_bytes @cute.jit def __call__( @@ -100,17 +195,21 @@ def __call__( enable_pdl: cutlass.Constexpr[bool], stream, ): - tv_shape, tv_stride = make_tv_layout( + tv_shape, tv_stride = RMSNormKernel._make_tv_layout( self.threads_per_row, + self.rows_per_block, self.vec_size, self.num_vec_blocks, ) tv_layout = cute.make_layout(tv_shape, stride=tv_stride) - tiler_mn = (1, self.cols_per_tile) + tiler_mn = (self.rows_per_block, self.cols_per_tile) + + cluster_n = self.cluster_n self.kernel(mX, mR, mW, M, eps, enable_pdl, tv_layout, tiler_mn).launch( - grid=[M, 1, 1], + grid=[cute.ceil_div(M, self.rows_per_block), cluster_n, 1], block=[self.num_threads, 1, 1], + cluster=[1, cluster_n, 1] if cutlass.const_expr(cluster_n > 1) else None, smem=self._smem_size_in_bytes(), stream=stream, use_pdl=enable_pdl, @@ -136,88 +235,180 @@ def kernel( cute.arch.griddepcontrol_wait() H = self.H + cluster_n = self.cluster_n weight_bias = self.weight_bias - threads_per_row = tv_layout.shape[0][0] - num_warps = self.num_warps copy_bits = self.copy_bits + threads_per_row = tv_layout.shape[0][0] + rows_per_block = tiler_mn[0] + warps_per_row = max(threads_per_row // 32, 1) + if cutlass.const_expr(cluster_n > 1): + cluster_y = cute.arch.block_idx()[1] + else: + cluster_y = cutlass.const_expr(0) + + # ===== Allocate shared memory ===== smem = cutlass.utils.SmemAllocator() - reduction_buffer = smem.allocate_tensor( - Float32, - cute.make_layout((num_warps,)), - byte_alignment=4, - ) + if cutlass.const_expr(self.use_async_copy): + sX = smem.allocate_tensor( + mX.element_type, + cute.make_ordered_layout(tiler_mn, order=(1, 0)), + byte_alignment=16, + ) + sR = smem.allocate_tensor( + mR.element_type, + cute.make_ordered_layout(tiler_mn, order=(1, 0)), + byte_alignment=16, + ) + + if cutlass.const_expr(cluster_n == 1): + reduction_buffer = smem.allocate_tensor( + Float32, + cute.make_layout((rows_per_block, warps_per_row)), + byte_alignment=4, + ) + mbar_ptr = None + else: + reduction_buffer = smem.allocate_tensor( + Float32, + cute.make_layout((rows_per_block, (warps_per_row, cluster_n))), + byte_alignment=4, + ) + mbar_ptr = smem.allocate_array(cutlass.Int64, num_elems=1) + + # ===== Initialize cluster ===== + if cutlass.const_expr(cluster_n > 1): + if tidx == 0: + cute.arch.mbarrier_init(mbar_ptr, 1) + cute.arch.mbarrier_init_fence() + cute.arch.cluster_arrive_relaxed() + cute.arch.cluster_wait() + + # ===== Coordinate tracking and tiling ===== idX = cute.make_identity_tensor(mX.shape) - gX = cute.local_tile(mX, tiler_mn, (bidx, 0)) - gR = cute.local_tile(mR, tiler_mn, (bidx, 0)) - cX = cute.local_tile(idX, tiler_mn, (bidx, 0)) + gX = cute.local_tile(mX, tiler_mn, (bidx, cluster_y)) + gR = cute.local_tile(mR, tiler_mn, (bidx, cluster_y)) + cX = cute.local_tile(idX, tiler_mn, (bidx, cluster_y)) - mW_2d = cute.prepend_ones(mW, up_to_rank=2) + mW_expanded_layout = cute.prepend( + mW.layout, cute.make_layout((tiler_mn[0],), stride=(0,)) + ) + mW_2d = cute.make_tensor(mW.iterator, mW_expanded_layout) + gW = cute.local_tile(mW_2d, tiler_mn, (0, cluster_y)) - copy_atom = cute.make_copy_atom( + # ===== Create TiledCopy atoms ===== + copy_atom_sync = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + mX.element_type, + num_bits_per_copy=copy_bits, + ) + copy_atom_store = cute.make_copy_atom( cute.nvgpu.CopyUniversalOp(), mX.element_type, num_bits_per_copy=copy_bits, ) - tiled_copy = cute.make_tiled_copy(copy_atom, tv_layout, tiler_mn) - thr_copy = tiled_copy.get_slice(tidx) + if cutlass.const_expr(self.use_async_copy): + copy_atom_async = cute.make_copy_atom( + cute.nvgpu.cpasync.CopyG2SOp(), + mX.element_type, + num_bits_per_copy=copy_bits, + ) + tiled_copy_load = cute.make_tiled_copy(copy_atom_async, tv_layout, tiler_mn) + else: + tiled_copy_load = cute.make_tiled_copy(copy_atom_sync, tv_layout, tiler_mn) + + tiled_copy_W = cute.make_tiled_copy(copy_atom_sync, tv_layout, tiler_mn) + tiled_copy_store = cute.make_tiled_copy(copy_atom_store, tv_layout, tiler_mn) + + thr_copy_X = tiled_copy_load.get_slice(tidx) + thr_copy_W = tiled_copy_W.get_slice(tidx) + thr_copy_O = tiled_copy_store.get_slice(tidx) + + # Partition input + tXgX = thr_copy_X.partition_S(gX) + tXcX = thr_copy_X.partition_S(cX) + tXrX = cute.make_fragment_like(tXgX) + + # Partition residual (same load tiled copy) + tRgR = thr_copy_X.partition_S(gR) + tRrR = cute.make_fragment_like(tRgR) + + if cutlass.const_expr(self.use_async_copy): + tXsX = thr_copy_X.partition_D(sX) + tRsR = thr_copy_X.partition_D(sR) + + # Partition weight (sync, separate tiled copy) + tWgW = thr_copy_W.partition_S(gW) + tWrW = cute.make_fragment_like(tWgW) + tXrW = thr_copy_X.retile(tWrW) + + # Partition output destinations + tXgO = thr_copy_O.partition_D(gX) + tRgO = thr_copy_O.partition_D(gR) + tXrO = cute.make_fragment_like(tXgO) + + # ===== Bounds checking ===== + tXpX = predicate_k(tXcX, limit=H) + tWpW = predicate_k(thr_copy_W.partition_S(cX), limit=H) + row_coord = tXcX[(0, 0), 0, 0] + row_in_bounds = row_coord[0] < M - tXgX = thr_copy.partition_S(gX) - tXgR = thr_copy.partition_S(gR) - tXgW = thr_copy.partition_S(mW_2d) - tXcX = thr_copy.partition_S(cX) - tYgX = thr_copy.partition_D(gX) - tYgR = thr_copy.partition_D(gR) + # ===== Pass 1: Load input + residual, compute h, reduce ===== + if cutlass.const_expr(self.use_async_copy): + if row_in_bounds: + cute.copy(copy_atom_async, tXgX, tXsX, pred=tXpX) + cute.copy(copy_atom_async, tRgR, tRsR, pred=tXpX) + cute.arch.cp_async_commit_group() - # Register fragments - initialize to zero for proper handling of out-of-bounds threads - tXrX = cute.make_rmem_tensor(tXgX.shape, mX.element_type) - tXrR = cute.make_rmem_tensor(tXgR.shape, mR.element_type) - tXrW = cute.make_rmem_tensor(tXgW.shape, mW.element_type) - tXrX.store(cute.zeros_like(tXrX, dtype=mX.element_type)) - tXrR.store(cute.zeros_like(tXrR, dtype=mR.element_type)) - tXrW.store(cute.zeros_like(tXrW, dtype=mW.element_type)) + cute.copy(copy_atom_sync, tWgW, tWrW, pred=tWpW) - tXpX = predicate_k(tXcX, limit=H) + cute.arch.cp_async_wait_group(0) - # Phase 1: Load input and residual from global to register - cute.copy(copy_atom, tXgX, tXrX, pred=tXpX) - cute.copy(copy_atom, tXgR, tXrR, pred=tXpX) + cute.autovec_copy(tXsX, tXrX) + cute.autovec_copy(tRsR, tRrR) + else: + tXrX.store(cute.zeros_like(tXrX, dtype=mX.element_type)) + tRrR.store(cute.zeros_like(tRrR, dtype=mR.element_type)) + if row_in_bounds: + cute.copy(copy_atom_sync, tXgX, tXrX, pred=tXpX) + cute.copy(copy_atom_sync, tRgR, tRrR, pred=tXpX) - x_in = tXrX.load().to(Float32) - r_in = tXrR.load().to(Float32) - x = x_in + r_in + cute.copy(copy_atom_sync, tWgW, tWrW, pred=tWpW) - # Phase 2: Store x to residual (global) - tXrR_out = x.to(mR.element_type) - tXrR_store = cute.make_rmem_tensor(tYgR.shape, mR.element_type) - tXrR_store.store(tXrR_out) + x_in = tXrX.load().to(Float32) + r_in = tRrR.load().to(Float32) + h = x_in + r_in - cute.copy(copy_atom, tXrR_store, tYgR, pred=tXpX) + # Write h to residual (global) + tXrO.store(h.to(mR.element_type)) + if row_in_bounds: + cute.copy(copy_atom_store, tXrO, tRgO, pred=tXpX) - # Phase 3: Compute sum of squares (x is kept in registers) - x_sq = x * x - sum_sq = row_reduce_sum(x_sq, threads_per_row, reduction_buffer) + h_sq = h * h + sum_sq = row_reduce_sum_multirow( + h_sq, threads_per_row, reduction_buffer, mbar_ptr, cluster_n + ) mean_sq = sum_sq / Float32(H) rstd = cute.math.rsqrt(mean_sq + eps, fastmath=True) - # Phase 4: Load weight from global to register - cute.copy(copy_atom, tXgW, tXrW, pred=tXpX) + if cutlass.const_expr(cluster_n > 1): + cute.arch.cluster_arrive_relaxed() + cute.arch.cluster_wait() + else: + cute.arch.barrier() + # ===== Pass 2: Normalize and store output ===== w = tXrW.load().to(Float32) + y = h * rstd * (w + Float32(weight_bias)) - # output = x * rstd * (weight + weight_bias) - # x is still in registers from Phase 1 - y = x * rstd * (w + Float32(weight_bias)) + tXrO.store(y.to(mX.element_type)) - tYrY = y.to(mX.element_type) - tXrY = cute.make_rmem_tensor(tYgX.shape, mX.element_type) - tXrY.store(tYrY) - - cute.copy(copy_atom, tXrY, tYgX, pred=tXpX) + if row_in_bounds: + cute.copy(copy_atom_store, tXrO, tXgO, pred=tXpX) # PDL: Signal dependent kernels (SM90+ only) if enable_pdl: @@ -236,6 +427,7 @@ class FusedAddRMSNormQuantKernel: Computes: 1. residual = input + residual (in-place update) 2. output = clamp(residual / sqrt(mean(residual^2) + eps) * weight / scale, -448, 448) + """ def __init__( @@ -244,30 +436,66 @@ def __init__( H: int, weight_bias: float = 0.0, use_hw_fp8: bool = True, + sm_version: int | None = None, ): self.dtype = dtype self.H = H self.weight_bias = weight_bias self.use_hw_fp8 = use_hw_fp8 + self.sm_version = sm_version if sm_version is not None else get_sm_version() - # Vectorization parameters: use optimal vec_size for warp utilization - elem_bits = dtype.width - max_vec_size = COPY_BITS // elem_bits - self.vec_size = compute_optimal_vec_size(H, max_vec_size) - self.copy_bits = self.vec_size * elem_bits + self.cluster_n = FusedAddRMSNormKernel._compute_cluster_n( + H, dtype, self.sm_version + ) + self.H_per_cta = H // self.cluster_n + + elem_bytes = dtype.width // 8 + max_vec_size = COPY_BITS // 8 // elem_bytes - self.threads_per_row = compute_threads_per_row(H, self.vec_size) - self.num_threads = self.threads_per_row - self.num_warps = max(self.threads_per_row // 32, 1) + h_align = self.H_per_cta & (-self.H_per_cta) + self.vec_size = min(h_align, max_vec_size) + self.copy_bits = self.vec_size * dtype.width + + self.threads_per_row = RMSNormKernel._compute_threads_per_row(self.H_per_cta) + self.num_threads = RMSNormKernel._compute_num_threads(self.H_per_cta) + if self.H_per_cta > 8192 and self.num_threads < 256: + self.num_threads = 256 + self.rows_per_block = self.num_threads // self.threads_per_row + self.warps_per_row = max(self.threads_per_row // 32, 1) self.num_vec_blocks = max( - 1, (H // self.vec_size + self.threads_per_row - 1) // self.threads_per_row + 1, + (self.H_per_cta // self.vec_size + self.threads_per_row - 1) + // self.threads_per_row, ) self.cols_per_tile = self.vec_size * self.num_vec_blocks * self.threads_per_row + if self.copy_bits >= 32: + tile_bytes_2 = 2 * self.rows_per_block * self.cols_per_tile * elem_bytes + props = torch.cuda.get_device_properties(torch.cuda.current_device()) + self.use_async_copy = ( + tile_bytes_2 <= props.shared_memory_per_block_optin // 2 + ) + else: + self.use_async_copy = False + def _smem_size_in_bytes(self) -> int: - # Only reduction buffer needed (register-based approach) - return self.num_warps * 4 + if self.use_async_copy: + tile_bytes = ( + 2 * self.rows_per_block * self.cols_per_tile * (self.dtype.width // 8) + ) + else: + tile_bytes = 0 + + if self.cluster_n == 1: + reduction_bytes = self.rows_per_block * self.warps_per_row * 4 + else: + reduction_bytes = ( + self.rows_per_block * self.warps_per_row * self.cluster_n * 4 + ) + + mbar_bytes = 8 if self.cluster_n > 1 else 0 + return tile_bytes + reduction_bytes + mbar_bytes @cute.jit def __call__( @@ -282,17 +510,21 @@ def __call__( enable_pdl: cutlass.Constexpr[bool], stream, ): - tv_shape, tv_stride = make_tv_layout( + tv_shape, tv_stride = RMSNormKernel._make_tv_layout( self.threads_per_row, + self.rows_per_block, self.vec_size, self.num_vec_blocks, ) tv_layout = cute.make_layout(tv_shape, stride=tv_stride) - tiler_mn = (1, self.cols_per_tile) + tiler_mn = (self.rows_per_block, self.cols_per_tile) + + cluster_n = self.cluster_n self.kernel(mY, mX, mR, mW, M, mS, eps, enable_pdl, tv_layout, tiler_mn).launch( - grid=[M, 1, 1], + grid=[cute.ceil_div(M, self.rows_per_block), cluster_n, 1], block=[self.num_threads, 1, 1], + cluster=[1, cluster_n, 1] if cutlass.const_expr(cluster_n > 1) else None, smem=self._smem_size_in_bytes(), stream=stream, use_pdl=enable_pdl, @@ -320,106 +552,315 @@ def kernel( cute.arch.griddepcontrol_wait() H = self.H + cluster_n = self.cluster_n + cols_per_tile = self.cols_per_tile weight_bias = self.weight_bias - threads_per_row = tv_layout.shape[0][0] - num_warps = self.num_warps copy_bits = self.copy_bits vec_size = self.vec_size num_vec_blocks = self.num_vec_blocks + threads_per_row = tv_layout.shape[0][0] + rows_per_block = tiler_mn[0] + warps_per_row = max(threads_per_row // 32, 1) + + if cutlass.const_expr(cluster_n > 1): + cluster_y = cute.arch.block_idx()[1] + else: + cluster_y = cutlass.const_expr(0) inv_scale = rcp_approx_ftz(mS[0]) + # ===== Allocate shared memory ===== smem = cutlass.utils.SmemAllocator() - reduction_buffer = smem.allocate_tensor( - Float32, - cute.make_layout((num_warps,)), - byte_alignment=4, - ) + if cutlass.const_expr(self.use_async_copy): + sX = smem.allocate_tensor( + mX.element_type, + cute.make_ordered_layout(tiler_mn, order=(1, 0)), + byte_alignment=16, + ) + sR = smem.allocate_tensor( + mR.element_type, + cute.make_ordered_layout(tiler_mn, order=(1, 0)), + byte_alignment=16, + ) + + if cutlass.const_expr(cluster_n == 1): + reduction_buffer = smem.allocate_tensor( + Float32, + cute.make_layout((rows_per_block, warps_per_row)), + byte_alignment=4, + ) + mbar_ptr = None + else: + reduction_buffer = smem.allocate_tensor( + Float32, + cute.make_layout((rows_per_block, (warps_per_row, cluster_n))), + byte_alignment=4, + ) + mbar_ptr = smem.allocate_array(cutlass.Int64, num_elems=1) + + # ===== Initialize cluster ===== + if cutlass.const_expr(cluster_n > 1): + if tidx == 0: + cute.arch.mbarrier_init(mbar_ptr, 1) + cute.arch.mbarrier_init_fence() + cute.arch.cluster_arrive_relaxed() + cute.arch.cluster_wait() + + # ===== Coordinate tracking and tiling ===== idX = cute.make_identity_tensor(mX.shape) - gX = cute.local_tile(mX, tiler_mn, (bidx, 0)) - gR = cute.local_tile(mR, tiler_mn, (bidx, 0)) - cX = cute.local_tile(idX, tiler_mn, (bidx, 0)) + gX = cute.local_tile(mX, tiler_mn, (bidx, cluster_y)) + gR = cute.local_tile(mR, tiler_mn, (bidx, cluster_y)) + cX = cute.local_tile(idX, tiler_mn, (bidx, cluster_y)) - mW_2d = cute.prepend_ones(mW, up_to_rank=2) + mW_expanded_layout = cute.prepend( + mW.layout, cute.make_layout((tiler_mn[0],), stride=(0,)) + ) + mW_2d = cute.make_tensor(mW.iterator, mW_expanded_layout) + gW = cute.local_tile(mW_2d, tiler_mn, (0, cluster_y)) - copy_atom_load = cute.make_copy_atom( + # ===== Create TiledCopy atoms ===== + copy_atom_sync = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + mX.element_type, + num_bits_per_copy=copy_bits, + ) + copy_atom_store = cute.make_copy_atom( cute.nvgpu.CopyUniversalOp(), mX.element_type, num_bits_per_copy=copy_bits, ) - tiled_copy_load = cute.make_tiled_copy(copy_atom_load, tv_layout, tiler_mn) - thr_copy_load = tiled_copy_load.get_slice(tidx) + if cutlass.const_expr(self.use_async_copy): + copy_atom_async = cute.make_copy_atom( + cute.nvgpu.cpasync.CopyG2SOp(), + mX.element_type, + num_bits_per_copy=copy_bits, + ) + tiled_copy_load = cute.make_tiled_copy(copy_atom_async, tv_layout, tiler_mn) + else: + tiled_copy_load = cute.make_tiled_copy(copy_atom_sync, tv_layout, tiler_mn) + + tiled_copy_W = cute.make_tiled_copy(copy_atom_sync, tv_layout, tiler_mn) + tiled_copy_store = cute.make_tiled_copy(copy_atom_store, tv_layout, tiler_mn) + + thr_copy_X = tiled_copy_load.get_slice(tidx) + thr_copy_W = tiled_copy_W.get_slice(tidx) + thr_copy_O = tiled_copy_store.get_slice(tidx) + + # Partition input + tXgX = thr_copy_X.partition_S(gX) + tXcX = thr_copy_X.partition_S(cX) + tXrX = cute.make_fragment_like(tXgX) + + # Partition residual (same load tiled copy) + tRgR = thr_copy_X.partition_S(gR) + tRrR = cute.make_fragment_like(tRgR) + + if cutlass.const_expr(self.use_async_copy): + tXsX = thr_copy_X.partition_D(sX) + tRsR = thr_copy_X.partition_D(sR) + + # Partition weight (sync, separate tiled copy) + tWgW = thr_copy_W.partition_S(gW) + tWrW = cute.make_fragment_like(tWgW) + tXrW = thr_copy_X.retile(tWrW) + + # Partition residual store destination (match non-quant kernel pattern) + tRgO = thr_copy_O.partition_D(gR) + tXgO_r = thr_copy_O.partition_D(gX) + tRrO = cute.make_fragment_like(tXgO_r) + + # ===== Bounds checking ===== + tXpX = predicate_k(tXcX, limit=H) + tWpW = predicate_k(thr_copy_W.partition_S(cX), limit=H) + row_coord = tXcX[(0, 0), 0, 0] + row_in_bounds = row_coord[0] < M - tXgX = thr_copy_load.partition_S(gX) - tXgR = thr_copy_load.partition_S(gR) - tXgW = thr_copy_load.partition_S(mW_2d) - tXcX = thr_copy_load.partition_S(cX) - tYgR = thr_copy_load.partition_D(gR) + # ===== Pass 1: Load input + residual, compute h, reduce ===== + if cutlass.const_expr(self.use_async_copy): + if row_in_bounds: + cute.copy(copy_atom_async, tXgX, tXsX, pred=tXpX) + cute.copy(copy_atom_async, tRgR, tRsR, pred=tXpX) + cute.arch.cp_async_commit_group() - # Register fragments - initialize to zero for proper handling of out-of-bounds threads - tXrX = cute.make_rmem_tensor(tXgX.shape, mX.element_type) - tXrR = cute.make_rmem_tensor(tXgR.shape, mR.element_type) - tXrW = cute.make_rmem_tensor(tXgW.shape, mW.element_type) - tXrX.store(cute.zeros_like(tXrX, dtype=mX.element_type)) - tXrR.store(cute.zeros_like(tXrR, dtype=mR.element_type)) - tXrW.store(cute.zeros_like(tXrW, dtype=mW.element_type)) + cute.copy(copy_atom_sync, tWgW, tWrW, pred=tWpW) - tXpX = predicate_k(tXcX, limit=H) + cute.arch.cp_async_wait_group(0) - # Phase 1: Load input and residual from global to register - cute.copy(copy_atom_load, tXgX, tXrX, pred=tXpX) - cute.copy(copy_atom_load, tXgR, tXrR, pred=tXpX) + cute.autovec_copy(tXsX, tXrX) + cute.autovec_copy(tRsR, tRrR) + else: + tXrX.store(cute.zeros_like(tXrX, dtype=mX.element_type)) + tRrR.store(cute.zeros_like(tRrR, dtype=mR.element_type)) + if row_in_bounds: + cute.copy(copy_atom_sync, tXgX, tXrX, pred=tXpX) + cute.copy(copy_atom_sync, tRgR, tRrR, pred=tXpX) + + cute.copy(copy_atom_sync, tWgW, tWrW, pred=tWpW) x_in = tXrX.load().to(Float32) - r_in = tXrR.load().to(Float32) - x = x_in + r_in + r_in = tRrR.load().to(Float32) + h = x_in + r_in - # Store x to residual (global) - tXrR_out = x.to(mR.element_type) - tXrR_store = cute.make_rmem_tensor(tYgR.shape, mR.element_type) - tXrR_store.store(tXrR_out) - cute.copy(copy_atom_load, tXrR_store, tYgR, pred=tXpX) + # Write h to residual (global) + tRrO.store(h.to(mR.element_type)) + if row_in_bounds: + cute.copy(copy_atom_store, tRrO, tRgO, pred=tXpX) - # Phase 2: Compute sum of squares (x is kept in registers) - x_sq = x * x - sum_sq = row_reduce_sum(x_sq, threads_per_row, reduction_buffer) + h_sq = h * h + sum_sq = row_reduce_sum_multirow( + h_sq, threads_per_row, reduction_buffer, mbar_ptr, cluster_n + ) mean_sq = sum_sq / Float32(H) rstd = cute.math.rsqrt(mean_sq + eps, fastmath=True) - # Phase 3: Load weight from global to register - cute.copy(copy_atom_load, tXgW, tXrW, pred=tXpX) - w = tXrW.load().to(Float32) + if cutlass.const_expr(cluster_n > 1): + cute.arch.cluster_arrive_relaxed() + cute.arch.cluster_wait() + else: + cute.arch.barrier() - # output = x * rstd * (weight + weight_bias) * inv_scale - # x is still in registers from Phase 1 - y = x * rstd * (w + Float32(weight_bias)) * inv_scale + # ===== Pass 2: Normalize, quantize, and store FP8 output ===== + w = tXrW.load().to(Float32) + y = h * rstd * (w + Float32(weight_bias)) * inv_scale - # Phase 4: Clamp and store to FP8 output using PTX scalar stores - # (CuTe FP8 conversion requires vectorized ops, so we use PTX for scalar stores) - # Store y to register tensor for element-wise access - tYrY_f32 = cute.make_rmem_tensor(tXgX.shape, Float32) + tYrY_f32 = cute.make_rmem_tensor(tXrX.shape, Float32) tYrY_f32.store(y) - col_offset = tidx * vec_size - for v in cutlass.range_constexpr(num_vec_blocks): - for e in cutlass.range_constexpr(vec_size): - idx = col_offset + v * threads_per_row * vec_size + e - if idx < H: - # Clamp and convert - use flat index for register tensor - flat_idx = v * vec_size + e - clamped = max(tYrY_f32[flat_idx], Float32(-FLOAT8_E4M3_MAX)) - clamped = min(clamped, Float32(FLOAT8_E4M3_MAX)) - # Use PTX to convert and store FP8 byte - out_offset = bidx * H + idx - out_ptr = get_ptr_as_int64(mY, Int32(out_offset)) - if self.use_hw_fp8: - cvt_and_store_f32_to_e4m3_hw(clamped, out_ptr) - else: - cvt_and_store_f32_to_e4m3_sw(clamped, out_ptr) + lane_in_row = tidx % threads_per_row + row_in_block = tidx // threads_per_row + actual_row = bidx * rows_per_block + row_in_block + col_offset = lane_in_row * vec_size + + if cutlass.const_expr(self.use_hw_fp8 and vec_size == 8): + for v in cutlass.range_constexpr(num_vec_blocks): + local_col = col_offset + v * threads_per_row * vec_size + abs_col = cluster_y * cols_per_tile + local_col + if abs_col + 8 <= H and actual_row < M: + base = v * 8 + cvt_and_store_8xf32_to_e4m3_hw( + tYrY_f32[base], + tYrY_f32[base + 1], + tYrY_f32[base + 2], + tYrY_f32[base + 3], + tYrY_f32[base + 4], + tYrY_f32[base + 5], + tYrY_f32[base + 6], + tYrY_f32[base + 7], + get_ptr_as_int64( + mY, + cute.crd2idx( + (Int32(actual_row), Int32(abs_col)), mY.layout + ), + ), + ) + else: + for e in cutlass.range_constexpr(vec_size): + abs_col_e = cluster_y * cols_per_tile + local_col + e + if abs_col_e < H and actual_row < M: + flat_idx = v * vec_size + e + clamped = max(tYrY_f32[flat_idx], Float32(-FLOAT8_E4M3_MAX)) + clamped = min(clamped, Float32(FLOAT8_E4M3_MAX)) + cvt_and_store_f32_to_e4m3_hw( + clamped, + get_ptr_as_int64( + mY, + cute.crd2idx( + (Int32(actual_row), Int32(abs_col_e)), + mY.layout, + ), + ), + ) + elif cutlass.const_expr(self.use_hw_fp8 and vec_size == 4): + for v in cutlass.range_constexpr(num_vec_blocks): + local_col = col_offset + v * threads_per_row * vec_size + abs_col = cluster_y * cols_per_tile + local_col + if abs_col + 4 <= H and actual_row < M: + base = v * 4 + cvt_and_store_4xf32_to_e4m3_hw( + tYrY_f32[base], + tYrY_f32[base + 1], + tYrY_f32[base + 2], + tYrY_f32[base + 3], + get_ptr_as_int64( + mY, + cute.crd2idx( + (Int32(actual_row), Int32(abs_col)), mY.layout + ), + ), + ) + else: + for e in cutlass.range_constexpr(vec_size): + abs_col_e = cluster_y * cols_per_tile + local_col + e + if abs_col_e < H and actual_row < M: + flat_idx = v * vec_size + e + clamped = max(tYrY_f32[flat_idx], Float32(-FLOAT8_E4M3_MAX)) + clamped = min(clamped, Float32(FLOAT8_E4M3_MAX)) + cvt_and_store_f32_to_e4m3_hw( + clamped, + get_ptr_as_int64( + mY, + cute.crd2idx( + (Int32(actual_row), Int32(abs_col_e)), + mY.layout, + ), + ), + ) + elif cutlass.const_expr(self.use_hw_fp8 and vec_size == 2): + for v in cutlass.range_constexpr(num_vec_blocks): + local_col = col_offset + v * threads_per_row * vec_size + abs_col = cluster_y * cols_per_tile + local_col + if abs_col + 2 <= H and actual_row < M: + base = v * 2 + cvt_and_store_2xf32_to_e4m3_hw( + tYrY_f32[base], + tYrY_f32[base + 1], + get_ptr_as_int64( + mY, + cute.crd2idx( + (Int32(actual_row), Int32(abs_col)), mY.layout + ), + ), + ) + else: + for e in cutlass.range_constexpr(vec_size): + abs_col_e = cluster_y * cols_per_tile + local_col + e + if abs_col_e < H and actual_row < M: + flat_idx = v * vec_size + e + clamped = max(tYrY_f32[flat_idx], Float32(-FLOAT8_E4M3_MAX)) + clamped = min(clamped, Float32(FLOAT8_E4M3_MAX)) + cvt_and_store_f32_to_e4m3_hw( + clamped, + get_ptr_as_int64( + mY, + cute.crd2idx( + (Int32(actual_row), Int32(abs_col_e)), + mY.layout, + ), + ), + ) + else: + for v in cutlass.range_constexpr(num_vec_blocks): + for e in cutlass.range_constexpr(vec_size): + local_col = col_offset + v * threads_per_row * vec_size + e + abs_col = cluster_y * cols_per_tile + local_col + if abs_col < H and actual_row < M: + flat_idx = v * vec_size + e + clamped = max(tYrY_f32[flat_idx], Float32(-FLOAT8_E4M3_MAX)) + clamped = min(clamped, Float32(FLOAT8_E4M3_MAX)) + out_ptr = get_ptr_as_int64( + mY, + cute.crd2idx( + (Int32(actual_row), Int32(abs_col)), mY.layout + ), + ) + if self.use_hw_fp8: + cvt_and_store_f32_to_e4m3_hw(clamped, out_ptr) + else: + cvt_and_store_f32_to_e4m3_sw(clamped, out_ptr) # PDL: Signal dependent kernels (SM90+ only) if enable_pdl: @@ -433,22 +874,43 @@ def kernel( @functools.cache def _get_compiled_fused_add_rmsnorm_kernel( - dtype_str: str, H: int, weight_bias: float, enable_pdl: bool + dtype_str: str, + H: int, + weight_bias: float, + enable_pdl: bool, + sm_version: int, + contiguous: bool = True, ): - """Get a compiled Fused Add + RMSNorm kernel using TVM-FFI.""" + """Get a compiled Fused Add + RMSNorm kernel using TVM-FFI. + + When contiguous=True, tensors are compiled with compact (dense) layouts for + optimal codegen. When False, symbolic row strides are used to support + arbitrary row strides at the cost of some performance. + """ dtype = get_cutlass_dtype(dtype_str) - kernel_obj = FusedAddRMSNormKernel(dtype, H, weight_bias) + kernel_obj = FusedAddRMSNormKernel(dtype, H, weight_bias, sm_version=sm_version) sym_m = cute.sym_int() - sym_row_stride_x = cute.sym_int(divisibility=kernel_obj.vec_size) - sym_row_stride_r = cute.sym_int(divisibility=kernel_obj.vec_size) - x_fake = cute.runtime.make_fake_tensor( - dtype, (sym_m, H), (sym_row_stride_x, 1), assumed_align=16 - ) - r_fake = cute.runtime.make_fake_tensor( - dtype, (sym_m, H), (sym_row_stride_r, 1), assumed_align=16 - ) + if contiguous: + elem_bytes = dtype.width // 8 + tensor_align = math.gcd(128, H * elem_bytes) + x_fake = cute.runtime.make_fake_compact_tensor( + dtype, (sym_m, H), stride_order=(1, 0), assumed_align=tensor_align + ) + r_fake = cute.runtime.make_fake_compact_tensor( + dtype, (sym_m, H), stride_order=(1, 0), assumed_align=tensor_align + ) + else: + sym_row_stride_x = cute.sym_int(divisibility=kernel_obj.vec_size) + sym_row_stride_r = cute.sym_int(divisibility=kernel_obj.vec_size) + x_fake = cute.runtime.make_fake_tensor( + dtype, (sym_m, H), (sym_row_stride_x, 1), assumed_align=16 + ) + r_fake = cute.runtime.make_fake_tensor( + dtype, (sym_m, H), (sym_row_stride_r, 1), assumed_align=16 + ) + w_fake = cute.runtime.make_fake_compact_tensor(dtype, (H,), assumed_align=16) stream_fake = cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True) @@ -476,28 +938,47 @@ def _get_compiled_fused_add_rmsnorm_quant_kernel( weight_bias: float, enable_pdl: bool, use_hw_fp8: bool = True, + sm_version: int = 80, + contiguous: bool = True, ): - """Get a compiled Fused Add + RMSNorm + Quant kernel using TVM-FFI.""" + """Get a compiled Fused Add + RMSNorm + Quant kernel using TVM-FFI. + + See _get_compiled_fused_add_rmsnorm_kernel for contiguous parameter semantics. + """ dtype = get_cutlass_dtype(dtype_str) out_dtype = get_cutlass_dtype(out_dtype_str) kernel_obj = FusedAddRMSNormQuantKernel( - dtype, H, weight_bias, use_hw_fp8=use_hw_fp8 + dtype, H, weight_bias, use_hw_fp8=use_hw_fp8, sm_version=sm_version ) sym_m = cute.sym_int() - sym_row_stride_y = cute.sym_int(divisibility=kernel_obj.vec_size) - sym_row_stride_x = cute.sym_int(divisibility=kernel_obj.vec_size) - sym_row_stride_r = cute.sym_int(divisibility=kernel_obj.vec_size) - y_fake = cute.runtime.make_fake_tensor( - out_dtype, (sym_m, H), (sym_row_stride_y, 1), assumed_align=16 - ) - x_fake = cute.runtime.make_fake_tensor( - dtype, (sym_m, H), (sym_row_stride_x, 1), assumed_align=16 - ) - r_fake = cute.runtime.make_fake_tensor( - dtype, (sym_m, H), (sym_row_stride_r, 1), assumed_align=16 - ) + if contiguous: + in_align = math.gcd(128, H * (dtype.width // 8)) + out_align = math.gcd(128, H * (out_dtype.width // 8)) + y_fake = cute.runtime.make_fake_compact_tensor( + out_dtype, (sym_m, H), stride_order=(1, 0), assumed_align=out_align + ) + x_fake = cute.runtime.make_fake_compact_tensor( + dtype, (sym_m, H), stride_order=(1, 0), assumed_align=in_align + ) + r_fake = cute.runtime.make_fake_compact_tensor( + dtype, (sym_m, H), stride_order=(1, 0), assumed_align=in_align + ) + else: + sym_row_stride_y = cute.sym_int(divisibility=kernel_obj.vec_size) + sym_row_stride_x = cute.sym_int(divisibility=kernel_obj.vec_size) + sym_row_stride_r = cute.sym_int(divisibility=kernel_obj.vec_size) + y_fake = cute.runtime.make_fake_tensor( + out_dtype, (sym_m, H), (sym_row_stride_y, 1), assumed_align=16 + ) + x_fake = cute.runtime.make_fake_tensor( + dtype, (sym_m, H), (sym_row_stride_x, 1), assumed_align=16 + ) + r_fake = cute.runtime.make_fake_tensor( + dtype, (sym_m, H), (sym_row_stride_r, 1), assumed_align=16 + ) + w_fake = cute.runtime.make_fake_compact_tensor(dtype, (H,), assumed_align=16) s_fake = cute.runtime.make_fake_compact_tensor(Float32, (1,), assumed_align=4) @@ -535,17 +1016,24 @@ def fused_add_rmsnorm_cute( ) -> None: """CuTe DSL Fused Add + RMSNorm implementation. - Supports arbitrary stride - no need to call contiguous(). - Last dimension must be contiguous (stride[-1] == 1). + Supports non-contiguous tensors (stride[-1] must be 1). Uses an optimized + compact kernel for contiguous inputs and a general strided kernel otherwise. + Both input and residual are modified in-place. """ shape = input.shape H = shape[-1] M = shape[0] + is_contiguous = input.is_contiguous() and residual.is_contiguous() dtype_str = _torch_dtype_to_str(input.dtype) kernel = _get_compiled_fused_add_rmsnorm_kernel( - dtype_str, H, weight_bias, enable_pdl + dtype_str, + H, + weight_bias, + enable_pdl, + get_sm_version(input.device), + contiguous=is_contiguous, ) kernel(input, residual, weight, M, eps) @@ -562,14 +1050,18 @@ def fused_add_rmsnorm_quant_cute( ) -> None: """CuTe DSL Fused Add + RMSNorm + FP8 quantization implementation. - Supports arbitrary stride - no need to call contiguous(). - Last dimension must be contiguous (stride[-1] == 1). + Supports non-contiguous tensors (stride[-1] must be 1). Uses an optimized + compact kernel for contiguous inputs and a general strided kernel otherwise. + Residual is modified in-place with h = input + residual. """ shape = input.shape H = shape[-1] M = shape[0] + is_contiguous = ( + input.is_contiguous() and residual.is_contiguous() and out.is_contiguous() + ) dtype_str = _torch_dtype_to_str(input.dtype) out_dtype_str = _torch_dtype_to_str(out.dtype) kernel = _get_compiled_fused_add_rmsnorm_quant_kernel( @@ -579,6 +1071,8 @@ def fused_add_rmsnorm_quant_cute( weight_bias, enable_pdl, use_hw_fp8=has_hw_fp8_cvt(input.device), + sm_version=get_sm_version(input.device), + contiguous=is_contiguous, ) kernel( out, diff --git a/flashinfer/norm/kernels/rmsnorm.py b/flashinfer/norm/kernels/rmsnorm.py index 7d7388fc27..a47a980036 100644 --- a/flashinfer/norm/kernels/rmsnorm.py +++ b/flashinfer/norm/kernels/rmsnorm.py @@ -23,7 +23,7 @@ """ import functools -import operator +import math import cutlass import cutlass.cute as cute @@ -36,17 +36,17 @@ rcp_approx_ftz, cvt_and_store_f32_to_e4m3_hw, cvt_and_store_f32_to_e4m3_sw, + cvt_and_store_8xf32_to_e4m3_hw, + cvt_and_store_4xf32_to_e4m3_hw, + cvt_and_store_2xf32_to_e4m3_hw, has_hw_fp8_cvt, get_ptr_as_int64, - warp_reduce, - row_reduce_sum, + get_sm_version, + row_reduce_sum_multirow, predicate_k, - compute_optimal_vec_size, - compute_threads_per_row, make_tv_layout, _torch_dtype_to_str, get_cutlass_dtype, - get_num_sm, ) @@ -60,11 +60,6 @@ class RMSNormKernel: RMSNorm Kernel using CuTe-DSL. Computes: output = input / sqrt(mean(input^2) + eps) * (weight + weight_bias) - - Key optimizations: - 1. 128-bit vectorized loads for input and weight - 2. Two-stage reduction: warp shuffle + cross-warp shared memory - 3. All computations in FP32 for numerical stability """ def __init__( @@ -72,32 +67,138 @@ def __init__( dtype: cutlass.Numeric, H: int, weight_bias: float = 0.0, + sm_version: int | None = None, ): self.dtype = dtype self.H = H self.weight_bias = weight_bias + self.sm_version = sm_version if sm_version is not None else get_sm_version() + + self.cluster_n = self._compute_cluster_n(H, dtype, self.sm_version) + self.H_per_cta = H // self.cluster_n + + elem_bytes = dtype.width // 8 + max_vec_size = COPY_BITS // 8 // elem_bytes - # Vectorization parameters: use optimal vec_size for warp utilization - elem_bits = dtype.width - max_vec_size = COPY_BITS // elem_bits # 8 for float16/bfloat16, 4 for float32 - self.vec_size = compute_optimal_vec_size(H, max_vec_size) - self.copy_bits = self.vec_size * elem_bits # Actual bits per copy + h_align = self.H_per_cta & (-self.H_per_cta) + self.vec_size = min(h_align, max_vec_size) + self.copy_bits = self.vec_size * dtype.width - # Thread configuration - self.threads_per_row = compute_threads_per_row(H, self.vec_size) - self.num_threads = self.threads_per_row # One row per block - self.num_warps = max(self.threads_per_row // 32, 1) + self.threads_per_row = self._compute_threads_per_row(self.H_per_cta) + self.num_threads = self._compute_num_threads(self.H_per_cta) + self.rows_per_block = self.num_threads // self.threads_per_row + self.warps_per_row = max(self.threads_per_row // 32, 1) - # Vectorization blocks self.num_vec_blocks = max( - 1, (H // self.vec_size + self.threads_per_row - 1) // self.threads_per_row + 1, + (self.H_per_cta // self.vec_size + self.threads_per_row - 1) + // self.threads_per_row, ) self.cols_per_tile = self.vec_size * self.num_vec_blocks * self.threads_per_row + if self.copy_bits >= 32: + tile_bytes = self.rows_per_block * self.cols_per_tile * elem_bytes + props = torch.cuda.get_device_properties(torch.cuda.current_device()) + self.use_async_copy = tile_bytes <= props.shared_memory_per_block_optin // 2 + else: + self.use_async_copy = False + + @staticmethod + def _compute_cluster_n(H: int, dtype: cutlass.Numeric, sm_version: int) -> int: + """Compute optimal cluster size based on H and device shared memory.""" + if sm_version < 90: + return 1 + + props = torch.cuda.get_device_properties(torch.cuda.current_device()) + max_smem_bytes = props.shared_memory_per_block_optin + elem_size = dtype.width // 8 + + for cluster_n in [1, 2, 4, 8, 16]: + if H % cluster_n != 0: + continue + smem_needed = RMSNormKernel._estimate_smem_bytes(H, cluster_n, elem_size) + if smem_needed <= max_smem_bytes: + return cluster_n + + return 16 + + @staticmethod + def _estimate_smem_bytes(H: int, cluster_n: int, elem_size: int) -> int: + """Estimate shared memory bytes for a given cluster configuration.""" + H_per_cta = H // cluster_n + threads_per_row = RMSNormKernel._compute_threads_per_row(H_per_cta) + num_threads = RMSNormKernel._compute_num_threads(H_per_cta) + rows_per_block = num_threads // threads_per_row + warps_per_row = max(threads_per_row // 32, 1) + + max_vec_size = COPY_BITS // 8 // elem_size + h_align = H_per_cta & (-H_per_cta) + vec_size = min(h_align, max_vec_size) + num_vec_blocks = max( + 1, (H_per_cta // vec_size + threads_per_row - 1) // threads_per_row + ) + cols_per_tile = vec_size * num_vec_blocks * threads_per_row + + tile_bytes = rows_per_block * cols_per_tile * elem_size + + if cluster_n == 1: + return tile_bytes + rows_per_block * warps_per_row * 4 + else: + return ( + tile_bytes + + rows_per_block * warps_per_row * cluster_n * 4 + + 8 # mbarrier + ) + + @staticmethod + def _compute_threads_per_row(H: int) -> int: + if H <= 64: + return 8 + elif H <= 128: + return 16 + elif H <= 3072: + return 32 + elif H <= 6144: + return 64 + elif H <= 16384: + return 128 + else: + return 256 + + @staticmethod + def _compute_num_threads(H: int) -> int: + return 128 if H <= 16384 else 256 + + @staticmethod + def _make_tv_layout(threads_per_row, rows_per_block, vec_size, num_vec_blocks): + """Create Thread-Value layout for multi-row coalesced vectorized access.""" + shape = ( + (threads_per_row, rows_per_block), + (vec_size, num_vec_blocks), + ) + stride = ( + (vec_size * rows_per_block, 1), + (rows_per_block, rows_per_block * vec_size * threads_per_row), + ) + return shape, stride + def _smem_size_in_bytes(self) -> int: - """Calculate shared memory requirement.""" - # Only reduction buffer needed (no shared memory for input/weight) - return self.num_warps * 4 + if self.use_async_copy: + tile_bytes = ( + self.rows_per_block * self.cols_per_tile * (self.dtype.width // 8) + ) + else: + tile_bytes = 0 + + if self.cluster_n == 1: + reduction_bytes = self.rows_per_block * self.warps_per_row * 4 + else: + reduction_bytes = ( + self.rows_per_block * self.warps_per_row * self.cluster_n * 4 + ) + + mbar_bytes = 8 if self.cluster_n > 1 else 0 + return tile_bytes + reduction_bytes + mbar_bytes @cute.jit def __call__( @@ -110,18 +211,21 @@ def __call__( enable_pdl: cutlass.Constexpr[bool], stream, ): - """Launch the RMSNorm kernel.""" - tv_shape, tv_stride = make_tv_layout( + tv_shape, tv_stride = self._make_tv_layout( self.threads_per_row, + self.rows_per_block, self.vec_size, self.num_vec_blocks, ) tv_layout = cute.make_layout(tv_shape, stride=tv_stride) - tiler_mn = (1, self.cols_per_tile) + tiler_mn = (self.rows_per_block, self.cols_per_tile) + + cluster_n = self.cluster_n self.kernel(mX, mW, mY, M, eps, enable_pdl, tv_layout, tiler_mn).launch( - grid=[M, 1, 1], + grid=[cute.ceil_div(M, self.rows_per_block), cluster_n, 1], block=[self.num_threads, 1, 1], + cluster=[1, cluster_n, 1] if cutlass.const_expr(cluster_n > 1) else None, smem=self._smem_size_in_bytes(), stream=stream, use_pdl=enable_pdl, @@ -139,7 +243,6 @@ def kernel( tv_layout: cute.Layout, tiler_mn: cute.Shape, ): - """Device kernel for RMSNorm.""" tidx, _, _ = cute.arch.thread_idx() bidx, _, _ = cute.arch.block_idx() @@ -148,86 +251,165 @@ def kernel( cute.arch.griddepcontrol_wait() H = self.H + cluster_n = self.cluster_n weight_bias = self.weight_bias - threads_per_row = tv_layout.shape[0][0] - num_warps = self.num_warps copy_bits = self.copy_bits + threads_per_row = tv_layout.shape[0][0] + rows_per_block = tiler_mn[0] + warps_per_row = max(threads_per_row // 32, 1) + + if cutlass.const_expr(cluster_n > 1): + cluster_y = cute.arch.block_idx()[1] + else: + cluster_y = cutlass.const_expr(0) - # Allocate shared memory (only reduction buffer needed) + # ===== Allocate shared memory ===== smem = cutlass.utils.SmemAllocator() - reduction_buffer = smem.allocate_tensor( - Float32, - cute.make_layout((num_warps,)), - byte_alignment=4, - ) - # Create identity tensor for coordinate tracking + if cutlass.const_expr(self.use_async_copy): + sX = smem.allocate_tensor( + mX.element_type, + cute.make_ordered_layout(tiler_mn, order=(1, 0)), + byte_alignment=16, + ) + + if cutlass.const_expr(cluster_n == 1): + reduction_buffer = smem.allocate_tensor( + Float32, + cute.make_layout((rows_per_block, warps_per_row)), + byte_alignment=4, + ) + mbar_ptr = None + else: + reduction_buffer = smem.allocate_tensor( + Float32, + cute.make_layout((rows_per_block, (warps_per_row, cluster_n))), + byte_alignment=4, + ) + mbar_ptr = smem.allocate_array(cutlass.Int64, num_elems=1) + + # ===== Initialize cluster ===== + if cutlass.const_expr(cluster_n > 1): + if tidx == 0: + cute.arch.mbarrier_init(mbar_ptr, 1) + cute.arch.mbarrier_init_fence() + cute.arch.cluster_arrive_relaxed() + cute.arch.cluster_wait() + + # ===== Coordinate tracking and tiling ===== idX = cute.make_identity_tensor(mX.shape) - # Slice for this row - gX = cute.local_tile(mX, tiler_mn, (bidx, 0)) - gY = cute.local_tile(mY, tiler_mn, (bidx, 0)) - cX = cute.local_tile(idX, tiler_mn, (bidx, 0)) + gX = cute.local_tile(mX, tiler_mn, (bidx, cluster_y)) + gY = cute.local_tile(mY, tiler_mn, (bidx, cluster_y)) + cX = cute.local_tile(idX, tiler_mn, (bidx, cluster_y)) - # Expand weight to 2D for consistent tiling - mW_2d = cute.prepend_ones(mW, up_to_rank=2) + mW_expanded_layout = cute.prepend( + mW.layout, cute.make_layout((tiler_mn[0],), stride=(0,)) + ) + mW_2d = cute.make_tensor(mW.iterator, mW_expanded_layout) + gW = cute.local_tile(mW_2d, tiler_mn, (0, cluster_y)) - # Create TiledCopy for load and store (both use CopyUniversalOp for sync operations) - copy_atom = cute.make_copy_atom( + # ===== Create TiledCopy atoms ===== + copy_atom_sync = cute.make_copy_atom( cute.nvgpu.CopyUniversalOp(), mX.element_type, num_bits_per_copy=copy_bits, ) + copy_atom_store = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + mY.element_type, + num_bits_per_copy=copy_bits, + ) + + if cutlass.const_expr(self.use_async_copy): + copy_atom_async = cute.make_copy_atom( + cute.nvgpu.cpasync.CopyG2SOp(), + mX.element_type, + num_bits_per_copy=copy_bits, + ) + tiled_copy_load = cute.make_tiled_copy(copy_atom_async, tv_layout, tiler_mn) + else: + tiled_copy_load = cute.make_tiled_copy(copy_atom_sync, tv_layout, tiler_mn) + + tiled_copy_W = cute.make_tiled_copy(copy_atom_sync, tv_layout, tiler_mn) + tiled_copy_store = cute.make_tiled_copy(copy_atom_store, tv_layout, tiler_mn) + + thr_copy_X = tiled_copy_load.get_slice(tidx) + thr_copy_W = tiled_copy_W.get_slice(tidx) + thr_copy_O = tiled_copy_store.get_slice(tidx) + + # Partition input + tXgX = thr_copy_X.partition_S(gX) + tXcX = thr_copy_X.partition_S(cX) + tXrX = cute.make_fragment_like(tXgX) - tiled_copy = cute.make_tiled_copy(copy_atom, tv_layout, tiler_mn) - thr_copy = tiled_copy.get_slice(tidx) + if cutlass.const_expr(self.use_async_copy): + tXsX = thr_copy_X.partition_D(sX) - # Partition tensors - tXgX = thr_copy.partition_S(gX) - tXgW = thr_copy.partition_S(mW_2d) - tXgY = thr_copy.partition_D(gY) - tXcX = thr_copy.partition_S(cX) + # Partition weight (sync, separate tiled copy) + tWgW = thr_copy_W.partition_S(gW) + tWrW = cute.make_fragment_like(tWgW) + tXrW = thr_copy_X.retile(tWrW) - # Register fragments - initialize to zero for proper handling of out-of-bounds threads - tXrX = cute.make_rmem_tensor(tXgX.shape, mX.element_type) - tXrW = cute.make_rmem_tensor(tXgW.shape, mW.element_type) - tXrX.store(cute.zeros_like(tXrX, dtype=mX.element_type)) - tXrW.store(cute.zeros_like(tXrW, dtype=mW.element_type)) + # Partition output + tXgO = thr_copy_O.partition_D(gY) + tXrO = cute.make_fragment_like(tXgO) - # Bounds checking (column boundary only, row is always valid since grid=[M,1,1]) + # ===== Bounds checking ===== tXpX = predicate_k(tXcX, limit=H) + tWpW = predicate_k(thr_copy_W.partition_S(cX), limit=H) + row_coord = tXcX[(0, 0), 0, 0] + row_in_bounds = row_coord[0] < M + + # ===== Pass 1: Load input + compute sum of squares ===== + if cutlass.const_expr(self.use_async_copy): + if row_in_bounds: + cute.copy(copy_atom_async, tXgX, tXsX, pred=tXpX) + cute.arch.cp_async_commit_group() + + cute.copy(copy_atom_sync, tWgW, tWrW, pred=tWpW) + + cute.arch.cp_async_wait_group(0) + + cute.autovec_copy(tXsX, tXrX) + else: + tXrX.store(cute.zeros_like(tXrX, dtype=mX.element_type)) + if row_in_bounds: + cute.copy(copy_atom_sync, tXgX, tXrX, pred=tXpX) - # =================================================================== - # Phase 1: Load input from global to register - # =================================================================== - cute.copy(copy_atom, tXgX, tXrX, pred=tXpX) + cute.copy(copy_atom_sync, tWgW, tWrW, pred=tWpW) x = tXrX.load().to(Float32) x_sq = x * x - sum_sq = row_reduce_sum(x_sq, threads_per_row, reduction_buffer) + sum_sq = row_reduce_sum_multirow( + x_sq, threads_per_row, reduction_buffer, mbar_ptr, cluster_n + ) - # Compute rstd = 1 / sqrt(mean(x^2) + eps) mean_sq = sum_sq / Float32(H) rstd = cute.math.rsqrt(mean_sq + eps, fastmath=True) - # =================================================================== - # Phase 2: Load weight from global to register - # =================================================================== - cute.copy(copy_atom, tXgW, tXrW, pred=tXpX) + if cutlass.const_expr(cluster_n > 1): + cute.arch.cluster_arrive_relaxed() + cute.arch.cluster_wait() + else: + cute.arch.barrier() + + # ===== Pass 2: Normalize and store output ===== + # Re-load x from shared memory to relieve register pressure. + # Without this, x (up to 128 FP32 values/thread at large H) must + # survive across the reduction + barrier, causing spills to local mem. + if cutlass.const_expr(self.use_async_copy): + cute.autovec_copy(tXsX, tXrX) + x = tXrX.load().to(Float32) w = tXrW.load().to(Float32) - - # output = input * rstd * (weight + weight_bias) y = x * rstd * (w + Float32(weight_bias)) - # Store output using cute.copy with predicate - tYrY = y.to(mY.element_type) - tXrY = cute.make_rmem_tensor(tXgY.shape, mY.element_type) - tXrY.store(tYrY) + tXrO.store(y.to(mY.element_type)) - cute.copy(copy_atom, tXrY, tXgY, pred=tXpX) + if row_in_bounds: + cute.copy(copy_atom_store, tXrO, tXgO, pred=tXpX) - # PDL: Signal dependent kernels (SM90+ only) if enable_pdl: cute.arch.griddepcontrol_launch_dependents() @@ -241,11 +423,13 @@ class QKRMSNormKernel: """ QK RMSNorm Kernel using CuTe-DSL for 3D tensors [batch, heads, head_dim]. - Supports arbitrary stride - no need for contiguous tensors. - Each warp processes one (batch, head) pair independently. - Uses warp-only reduction (no cross-warp shared memory sync needed). + Supports arbitrary stride (only stride[-1] == 1 required). Each block + processes rows_per_block rows, where each row is a (batch, head) pair + handled by threads_per_row threads. - Computes: output[b,h,:] = input[b,h,:] / sqrt(mean(input[b,h,:]^2) + eps) * (weight + weight_bias) + Architecture mirrors RMSNormKernel but uses per-row 3D->2D tiles instead + of a single multi-row 2D tile, since 3D strides may be non-uniform across + the batch/head dimensions. """ def __init__( @@ -253,34 +437,46 @@ def __init__( dtype: cutlass.Numeric, head_dim: int, weight_bias: float = 0.0, - num_warps: int = 4, ): self.dtype = dtype self.head_dim = head_dim self.weight_bias = weight_bias - self.num_warps = num_warps - # Vectorization: each warp (32 threads) processes head_dim elements - elem_bits = dtype.width - max_vec_size = COPY_BITS // elem_bits # 8 for float16/bfloat16 - self.vec_size = compute_optimal_vec_size(head_dim, max_vec_size) - self.copy_bits = self.vec_size * elem_bits + elem_bytes = dtype.width // 8 + max_vec_size = COPY_BITS // 8 // elem_bytes + + h_align = head_dim & (-head_dim) + self.vec_size = min(h_align, max_vec_size) + self.copy_bits = self.vec_size * dtype.width - # Threads per warp is always 32 - self.threads_per_warp = 32 - self.num_threads = self.threads_per_warp * num_warps + self.threads_per_row = RMSNormKernel._compute_threads_per_row(head_dim) + self.num_threads = RMSNormKernel._compute_num_threads(head_dim) + self.rows_per_block = self.num_threads // self.threads_per_row + self.warps_per_row = max(self.threads_per_row // 32, 1) - # Number of vectorized blocks per warp self.num_vec_blocks = max( 1, - (head_dim // self.vec_size + self.threads_per_warp - 1) - // self.threads_per_warp, + (head_dim // self.vec_size + self.threads_per_row - 1) + // self.threads_per_row, ) - self.cols_per_tile = self.vec_size * self.num_vec_blocks * self.threads_per_warp + self.cols_per_tile = self.vec_size * self.num_vec_blocks * self.threads_per_row + + if self.copy_bits >= 32: + tile_bytes = self.rows_per_block * self.cols_per_tile * elem_bytes + props = torch.cuda.get_device_properties(torch.cuda.current_device()) + self.use_async_copy = tile_bytes <= props.shared_memory_per_block_optin // 2 + else: + self.use_async_copy = False def _smem_size_in_bytes(self) -> int: - # No shared memory needed - warp-only reduction - return 0 + if self.use_async_copy: + tile_bytes = ( + self.rows_per_block * self.cols_per_tile * (self.dtype.width // 8) + ) + else: + tile_bytes = 0 + reduction_bytes = self.rows_per_block * self.warps_per_row * 4 + return tile_bytes + reduction_bytes @cute.jit def __call__( @@ -292,28 +488,18 @@ def __call__( N: Int32, eps: Float32, enable_pdl: cutlass.Constexpr[bool], - num_blocks: Int32, stream, ): - """Launch the QKRMSNorm kernel. - - Args: - mX: Input tensor of shape [B, N, H] with arbitrary stride. - mW: Weight tensor of shape [H]. - mY: Output tensor of shape [B, N, H] with arbitrary stride. - B: Batch size. - N: Number of heads. - eps: Epsilon for numerical stability. - enable_pdl: Enable PDL for SM90+. - num_blocks: Number of blocks to launch. - stream: CUDA stream. - """ - # Use 32 threads per warp for warp-level layout - tv_shape, tv_stride = make_tv_layout(32, self.vec_size, self.num_vec_blocks) + tv_shape, tv_stride = make_tv_layout( + self.threads_per_row, self.vec_size, self.num_vec_blocks + ) tv_layout = cute.make_layout(tv_shape, stride=tv_stride) + tiler_1d = (1, self.cols_per_tile) - self.kernel(mX, mW, mY, B, N, eps, enable_pdl, tv_layout).launch( - grid=[num_blocks, 1, 1], + M = B * N + + self.kernel(mX, mW, mY, N, M, eps, enable_pdl, tv_layout, tiler_1d).launch( + grid=[cute.ceil_div(M, self.rows_per_block), 1, 1], block=[self.num_threads, 1, 1], smem=self._smem_size_in_bytes(), stream=stream, @@ -326,15 +512,15 @@ def kernel( mX: cute.Tensor, mW: cute.Tensor, mY: cute.Tensor, - B: Int32, N: Int32, + M: Int32, eps: Float32, enable_pdl: cutlass.Constexpr[bool], tv_layout: cute.Layout, + tiler_1d: cute.Shape, ): - """Device kernel for QKRMSNorm with 3D tensor support and arbitrary stride.""" - bidx, _, _ = cute.arch.block_idx() tidx, _, _ = cute.arch.thread_idx() + bidx, _, _ = cute.arch.block_idx() # PDL: Wait for previous kernel (SM90+ only) if enable_pdl: @@ -342,111 +528,148 @@ def kernel( head_dim = self.head_dim weight_bias = self.weight_bias - num_warps = self.num_warps copy_bits = self.copy_bits + threads_per_row = self.threads_per_row + rows_per_block = self.rows_per_block + warps_per_row = self.warps_per_row - # Thread indexing within block - lane_idx = tidx % 32 - warp_idx = tidx // 32 + # Each group of threads_per_row threads handles one row + lane_in_row = tidx % threads_per_row + row_in_block = tidx // threads_per_row + actual_row = bidx * rows_per_block + row_in_block - # Total workers and jobs - grid_dim_x, _, _ = cute.arch.grid_dim() - num_workers = grid_dim_x * num_warps - worker_idx = bidx * num_warps + warp_idx + batch_idx = actual_row // N + head_idx = actual_row % N + row_in_bounds = actual_row < M - # Total number of rows - M = B * N + # ===== Allocate shared memory ===== + smem = cutlass.utils.SmemAllocator() - # Create copy atom for vectorized loads/stores - copy_atom = cute.make_copy_atom( + if cutlass.const_expr(self.use_async_copy): + sX = smem.allocate_tensor( + mX.element_type, + cute.make_ordered_layout( + (rows_per_block, self.cols_per_tile), order=(1, 0) + ), + byte_alignment=16, + ) + + reduction_buffer = smem.allocate_tensor( + Float32, + cute.make_layout((rows_per_block, warps_per_row)), + byte_alignment=4, + ) + mbar_ptr = None + cluster_n = 1 + + # ===== Per-row 3D -> 2D tiles ===== + gX_3d = cute.local_tile( + mX, (1, 1, self.cols_per_tile), (batch_idx, head_idx, 0) + ) + gX = cute.group_modes(gX_3d, 0, 2) + + gY_3d = cute.local_tile( + mY, (1, 1, self.cols_per_tile), (batch_idx, head_idx, 0) + ) + gY = cute.group_modes(gY_3d, 0, 2) + + mW_2d = cute.prepend_ones(mW, up_to_rank=2) + + # ===== Create TiledCopy atoms ===== + copy_atom_sync = cute.make_copy_atom( cute.nvgpu.CopyUniversalOp(), mX.element_type, num_bits_per_copy=copy_bits, ) + copy_atom_store = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + mY.element_type, + num_bits_per_copy=copy_bits, + ) - # Expand weight to 2D for consistent tiling: [1, H] - mW_2d = cute.prepend_ones(mW, up_to_rank=2) + if cutlass.const_expr(self.use_async_copy): + copy_atom_async = cute.make_copy_atom( + cute.nvgpu.cpasync.CopyG2SOp(), + mX.element_type, + num_bits_per_copy=copy_bits, + ) + tiled_copy_load = cute.make_tiled_copy(copy_atom_async, tv_layout, tiler_1d) + else: + tiled_copy_load = cute.make_tiled_copy(copy_atom_sync, tv_layout, tiler_1d) - # Create tiled copy for warp-level access (32 threads) - tiler_2d = (1, self.cols_per_tile) - tiled_copy = cute.make_tiled_copy(copy_atom, tv_layout, tiler_2d) - thr_copy = tiled_copy.get_slice(lane_idx) + tiled_copy_W = cute.make_tiled_copy(copy_atom_sync, tv_layout, tiler_1d) + tiled_copy_store = cute.make_tiled_copy(copy_atom_store, tv_layout, tiler_1d) - # Create identity tensor matching tile shape for bounds checking - id2d = cute.make_identity_tensor(tiler_2d) + thr_copy_X = tiled_copy_load.get_slice(lane_in_row) + thr_copy_W = tiled_copy_W.get_slice(lane_in_row) + thr_copy_O = tiled_copy_store.get_slice(lane_in_row) - # Weight and predicate are the same for all rows - compute once - tXgW = thr_copy.partition_S(mW_2d) - tXcX = thr_copy.partition_S(id2d) - tXpX = predicate_k(tXcX, limit=head_dim) + # ===== Partition input ===== + tXgX = thr_copy_X.partition_S(gX) + tXrX = cute.make_fragment_like(tXgX) - # Load weight once (same for all rows) - tXrW = cute.make_rmem_tensor(tXgW.shape, mW.element_type) - tXrW.store(cute.zeros_like(tXrW, dtype=mW.element_type)) - cute.copy(copy_atom, tXgW, tXrW, pred=tXpX) - w = tXrW.load().to(Float32) + if cutlass.const_expr(self.use_async_copy): + sX_row = cute.local_tile(sX, tiler_1d, (row_in_block, 0)) + tXsX = thr_copy_X.partition_D(sX_row) - # Each warp processes multiple rows with grid-stride loop - row_idx = worker_idx - while row_idx < M: - batch_idx = row_idx // N - head_idx = row_idx % N + # ===== Partition weight (sync, separate tiled copy) ===== + tWgW = thr_copy_W.partition_S(mW_2d) + tWrW = cute.make_fragment_like(tWgW) + tXrW = thr_copy_X.retile(tWrW) - # Get 3D tile and collapse first two dims (both size 1) to 2D for tiled_copy - gX = cute.group_modes( - cute.local_tile( - mX, (1, 1, self.cols_per_tile), (batch_idx, head_idx, 0) - ), - 0, - 2, - ) - gY = cute.group_modes( - cute.local_tile( - mY, (1, 1, self.cols_per_tile), (batch_idx, head_idx, 0) - ), - 0, - 2, - ) + # ===== Partition output ===== + tXgO = thr_copy_O.partition_D(gY) + tXrO = cute.make_fragment_like(tXgO) - # Partition tensors for this thread - tXgX = thr_copy.partition_S(gX) - tXgY = thr_copy.partition_D(gY) + # ===== Bounds checking ===== + id1d = cute.make_identity_tensor(tiler_1d) + tXpX = predicate_k(thr_copy_X.partition_S(id1d), limit=head_dim) + tWpW = predicate_k(thr_copy_W.partition_S(id1d), limit=head_dim) - # Register fragment for input - initialize to zero - tXrX = cute.make_rmem_tensor(tXgX.shape, mX.element_type) - tXrX.store(cute.zeros_like(tXrX, dtype=mX.element_type)) + # ===== Pass 1: Load input + compute sum of squares ===== + if cutlass.const_expr(self.use_async_copy): + if row_in_bounds: + cute.copy(copy_atom_async, tXgX, tXsX, pred=tXpX) + cute.arch.cp_async_commit_group() - # Phase 1: Load input and compute sum of squares - cute.copy(copy_atom, tXgX, tXrX, pred=tXpX) + cute.copy(copy_atom_sync, tWgW, tWrW, pred=tWpW) - x = tXrX.load().to(Float32) - x_sq = x * x + cute.arch.cp_async_wait_group(0) - # Reduce within register tensor first - local_sum = x_sq.reduce( - cute.ReductionOp.ADD, init_val=Float32(0.0), reduction_profile=0 - ) + cute.autovec_copy(tXsX, tXrX) + else: + tXrX.store(cute.zeros_like(tXrX, dtype=mX.element_type)) + if row_in_bounds: + cute.copy(copy_atom_sync, tXgX, tXrX, pred=tXpX) + + cute.copy(copy_atom_sync, tWgW, tWrW, pred=tWpW) + + x = tXrX.load().to(Float32) + x_sq = x * x + sum_sq = row_reduce_sum_multirow( + x_sq, threads_per_row, reduction_buffer, mbar_ptr, cluster_n + ) - # Warp reduction for sum_sq - sum_sq = warp_reduce(local_sum, operator.add, width=32) + mean_sq = sum_sq / Float32(head_dim) + rstd = cute.math.rsqrt(mean_sq + eps, fastmath=True) - # Compute rstd - mean_sq = sum_sq / Float32(head_dim) - rstd = cute.math.rsqrt(mean_sq + eps, fastmath=True) + cute.arch.barrier() - # output = input * rstd * (weight + weight_bias) - # w is already loaded outside the loop - y = x * rstd * (w + Float32(weight_bias)) + # ===== Pass 2: Normalize and store output ===== + # Re-load x from shared memory to relieve register pressure. + # Without this, x (up to 128 FP32 values/thread at large H) must + # survive across the reduction + barrier, causing spills to local mem. + if cutlass.const_expr(self.use_async_copy): + cute.autovec_copy(tXsX, tXrX) + x = tXrX.load().to(Float32) - # Store output - tYrY = y.to(mY.element_type) - tXrY = cute.make_rmem_tensor(tXgY.shape, mY.element_type) - tXrY.store(tYrY) + w = tXrW.load().to(Float32) + y = x * rstd * (w + Float32(weight_bias)) - cute.copy(copy_atom, tXrY, tXgY, pred=tXpX) + tXrO.store(y.to(mY.element_type)) - # Next row for this warp - row_idx = row_idx + num_workers + if row_in_bounds: + cute.copy(copy_atom_store, tXrO, tXgO, pred=tXpX) # PDL: Signal dependent kernels (SM90+ only) if enable_pdl: @@ -472,35 +695,62 @@ def __init__( H: int, weight_bias: float = 0.0, use_hw_fp8: bool = True, + sm_version: int | None = None, ): self.dtype = dtype self.H = H self.weight_bias = weight_bias self.use_hw_fp8 = use_hw_fp8 + self.sm_version = sm_version if sm_version is not None else get_sm_version() + + self.cluster_n = RMSNormKernel._compute_cluster_n(H, dtype, self.sm_version) + self.H_per_cta = H // self.cluster_n - # Vectorization parameters: use optimal vec_size for warp utilization - elem_bits = dtype.width - max_vec_size_in = COPY_BITS // elem_bits # 8 for fp16/bf16 - self.vec_size = compute_optimal_vec_size(H, max_vec_size_in) - self.copy_bits = self.vec_size * elem_bits + elem_bytes = dtype.width // 8 + max_vec_size = COPY_BITS // 8 // elem_bytes - # For FP8 output: minimum 16 bits = 2 FP8 elements - # Use same vec_size to keep layouts aligned, but ensure copy_bits_out >= 16 - self.vec_size_out = self.vec_size - self.copy_bits_out = max(16, self.vec_size * 8) + h_align = self.H_per_cta & (-self.H_per_cta) + self.vec_size = min(h_align, max_vec_size) + self.copy_bits = self.vec_size * dtype.width - self.threads_per_row = compute_threads_per_row(H, self.vec_size) - self.num_threads = self.threads_per_row - self.num_warps = max(self.threads_per_row // 32, 1) + self.threads_per_row = RMSNormKernel._compute_threads_per_row(self.H_per_cta) + self.num_threads = RMSNormKernel._compute_num_threads(self.H_per_cta) + if self.H_per_cta > 8192 and self.num_threads < 256: + self.num_threads = 256 + self.rows_per_block = self.num_threads // self.threads_per_row + self.warps_per_row = max(self.threads_per_row // 32, 1) self.num_vec_blocks = max( - 1, (H // self.vec_size + self.threads_per_row - 1) // self.threads_per_row + 1, + (self.H_per_cta // self.vec_size + self.threads_per_row - 1) + // self.threads_per_row, ) self.cols_per_tile = self.vec_size * self.num_vec_blocks * self.threads_per_row + if self.copy_bits >= 32: + tile_bytes = self.rows_per_block * self.cols_per_tile * elem_bytes + props = torch.cuda.get_device_properties(torch.cuda.current_device()) + self.use_async_copy = tile_bytes <= props.shared_memory_per_block_optin // 2 + else: + self.use_async_copy = False + def _smem_size_in_bytes(self) -> int: - # Only reduction buffer needed - return self.num_warps * 4 + if self.use_async_copy: + tile_bytes = ( + self.rows_per_block * self.cols_per_tile * (self.dtype.width // 8) + ) + else: + tile_bytes = 0 + + if self.cluster_n == 1: + reduction_bytes = self.rows_per_block * self.warps_per_row * 4 + else: + reduction_bytes = ( + self.rows_per_block * self.warps_per_row * self.cluster_n * 4 + ) + + mbar_bytes = 8 if self.cluster_n > 1 else 0 + return tile_bytes + reduction_bytes + mbar_bytes @cute.jit def __call__( @@ -514,15 +764,21 @@ def __call__( enable_pdl: cutlass.Constexpr[bool], stream, ): - tv_shape, tv_stride = make_tv_layout( - self.threads_per_row, self.vec_size, self.num_vec_blocks + tv_shape, tv_stride = RMSNormKernel._make_tv_layout( + self.threads_per_row, + self.rows_per_block, + self.vec_size, + self.num_vec_blocks, ) tv_layout = cute.make_layout(tv_shape, stride=tv_stride) - tiler_mn = (1, self.cols_per_tile) + tiler_mn = (self.rows_per_block, self.cols_per_tile) + + cluster_n = self.cluster_n self.kernel(mX, mW, mY, M, mS, eps, enable_pdl, tv_layout, tiler_mn).launch( - grid=[M, 1, 1], + grid=[cute.ceil_div(M, self.rows_per_block), cluster_n, 1], block=[self.num_threads, 1, 1], + cluster=[1, cluster_n, 1] if cutlass.const_expr(cluster_n > 1) else None, smem=self._smem_size_in_bytes(), stream=stream, use_pdl=enable_pdl, @@ -549,83 +805,287 @@ def kernel( cute.arch.griddepcontrol_wait() H = self.H + cluster_n = self.cluster_n + cols_per_tile = self.cols_per_tile weight_bias = self.weight_bias - threads_per_row = tv_layout.shape[0][0] - num_warps = self.num_warps copy_bits = self.copy_bits vec_size = self.vec_size num_vec_blocks = self.num_vec_blocks + threads_per_row = tv_layout.shape[0][0] + rows_per_block = tiler_mn[0] + warps_per_row = max(threads_per_row // 32, 1) + + if cutlass.const_expr(cluster_n > 1): + cluster_y = cute.arch.block_idx()[1] + else: + cluster_y = cutlass.const_expr(0) inv_scale = rcp_approx_ftz(mS[0]) + # ===== Allocate shared memory ===== smem = cutlass.utils.SmemAllocator() - reduction_buffer = smem.allocate_tensor( - Float32, cute.make_layout((num_warps,)), byte_alignment=4 - ) + if cutlass.const_expr(self.use_async_copy): + sX = smem.allocate_tensor( + mX.element_type, + cute.make_ordered_layout(tiler_mn, order=(1, 0)), + byte_alignment=16, + ) + + if cutlass.const_expr(cluster_n == 1): + reduction_buffer = smem.allocate_tensor( + Float32, + cute.make_layout((rows_per_block, warps_per_row)), + byte_alignment=4, + ) + mbar_ptr = None + else: + reduction_buffer = smem.allocate_tensor( + Float32, + cute.make_layout((rows_per_block, (warps_per_row, cluster_n))), + byte_alignment=4, + ) + mbar_ptr = smem.allocate_array(cutlass.Int64, num_elems=1) + + # ===== Initialize cluster ===== + if cutlass.const_expr(cluster_n > 1): + if tidx == 0: + cute.arch.mbarrier_init(mbar_ptr, 1) + cute.arch.mbarrier_init_fence() + cute.arch.cluster_arrive_relaxed() + cute.arch.cluster_wait() + + # ===== Coordinate tracking and tiling ===== idX = cute.make_identity_tensor(mX.shape) - gX = cute.local_tile(mX, tiler_mn, (bidx, 0)) - cX = cute.local_tile(idX, tiler_mn, (bidx, 0)) - mW_2d = cute.prepend_ones(mW, up_to_rank=2) + gX = cute.local_tile(mX, tiler_mn, (bidx, cluster_y)) + cX = cute.local_tile(idX, tiler_mn, (bidx, cluster_y)) - copy_atom_load = cute.make_copy_atom( - cute.nvgpu.CopyUniversalOp(), mX.element_type, num_bits_per_copy=copy_bits + mW_expanded_layout = cute.prepend( + mW.layout, cute.make_layout((tiler_mn[0],), stride=(0,)) ) + mW_2d = cute.make_tensor(mW.iterator, mW_expanded_layout) + gW = cute.local_tile(mW_2d, tiler_mn, (0, cluster_y)) - tiled_copy_load = cute.make_tiled_copy(copy_atom_load, tv_layout, tiler_mn) - thr_copy_load = tiled_copy_load.get_slice(tidx) + # ===== Create TiledCopy atoms ===== + copy_atom_sync = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + mX.element_type, + num_bits_per_copy=copy_bits, + ) - tXgX = thr_copy_load.partition_S(gX) - tXgW = thr_copy_load.partition_S(mW_2d) - tXcX = thr_copy_load.partition_S(cX) + if cutlass.const_expr(self.use_async_copy): + copy_atom_async = cute.make_copy_atom( + cute.nvgpu.cpasync.CopyG2SOp(), + mX.element_type, + num_bits_per_copy=copy_bits, + ) + tiled_copy_load = cute.make_tiled_copy(copy_atom_async, tv_layout, tiler_mn) + else: + tiled_copy_load = cute.make_tiled_copy(copy_atom_sync, tv_layout, tiler_mn) - # Register fragments - initialize to zero for proper handling of out-of-bounds threads - tXrX = cute.make_rmem_tensor(tXgX.shape, mX.element_type) - tXrW = cute.make_rmem_tensor(tXgW.shape, mW.element_type) - tXrX.store(cute.zeros_like(tXrX, dtype=mX.element_type)) - tXrW.store(cute.zeros_like(tXrW, dtype=mW.element_type)) + tiled_copy_W = cute.make_tiled_copy(copy_atom_sync, tv_layout, tiler_mn) + thr_copy_X = tiled_copy_load.get_slice(tidx) + thr_copy_W = tiled_copy_W.get_slice(tidx) + + # Partition input + tXgX = thr_copy_X.partition_S(gX) + tXcX = thr_copy_X.partition_S(cX) + tXrX = cute.make_fragment_like(tXgX) + + if cutlass.const_expr(self.use_async_copy): + tXsX = thr_copy_X.partition_D(sX) + + # Partition weight + tWgW = thr_copy_W.partition_S(gW) + tWrW = cute.make_fragment_like(tWgW) + tXrW = thr_copy_X.retile(tWrW) + + # ===== Bounds checking ===== tXpX = predicate_k(tXcX, limit=H) + tWpW = predicate_k(thr_copy_W.partition_S(cX), limit=H) + row_coord = tXcX[(0, 0), 0, 0] + row_in_bounds = row_coord[0] < M + + # ===== Pass 1: Load input + compute sum of squares ===== + if cutlass.const_expr(self.use_async_copy): + if row_in_bounds: + cute.copy(copy_atom_async, tXgX, tXsX, pred=tXpX) + cute.arch.cp_async_commit_group() - # Phase 1: Load input from global to register - cute.copy(copy_atom_load, tXgX, tXrX, pred=tXpX) + cute.copy(copy_atom_sync, tWgW, tWrW, pred=tWpW) + + cute.arch.cp_async_wait_group(0) + + cute.autovec_copy(tXsX, tXrX) + else: + tXrX.store(cute.zeros_like(tXrX, dtype=mX.element_type)) + if row_in_bounds: + cute.copy(copy_atom_sync, tXgX, tXrX, pred=tXpX) + + cute.copy(copy_atom_sync, tWgW, tWrW, pred=tWpW) x = tXrX.load().to(Float32) x_sq = x * x - sum_sq = row_reduce_sum(x_sq, threads_per_row, reduction_buffer) + sum_sq = row_reduce_sum_multirow( + x_sq, threads_per_row, reduction_buffer, mbar_ptr, cluster_n + ) mean_sq = sum_sq / Float32(H) rstd = cute.math.rsqrt(mean_sq + eps, fastmath=True) - # Phase 2: Load weight from global to register - cute.copy(copy_atom_load, tXgW, tXrW, pred=tXpX) + if cutlass.const_expr(cluster_n > 1): + cute.arch.cluster_arrive_relaxed() + cute.arch.cluster_wait() + else: + cute.arch.barrier() + + # ===== Pass 2: Normalize, quantize, and store FP8 output ===== + # Re-load x from shared memory to relieve register pressure. + # Without this, x (up to 128 FP32 values/thread at large H) must + # survive across the reduction + barrier, causing spills to local mem. + if cutlass.const_expr(self.use_async_copy): + cute.autovec_copy(tXsX, tXrX) + x = tXrX.load().to(Float32) w = tXrW.load().to(Float32) y = x * rstd * (w + Float32(weight_bias)) * inv_scale - # Phase 3: Clamp and store to FP8 output using PTX scalar stores - # (CuTe FP8 conversion requires vectorized ops, so we use PTX for scalar stores) - # Store y to register tensor for element-wise access - tYrY_f32 = cute.make_rmem_tensor(tXgX.shape, Float32) + tYrY_f32 = cute.make_rmem_tensor(tXrX.shape, Float32) tYrY_f32.store(y) - col_offset = tidx * vec_size - for v in cutlass.range_constexpr(num_vec_blocks): - for e in cutlass.range_constexpr(vec_size): - idx = col_offset + v * threads_per_row * vec_size + e - if idx < H: - # Clamp and convert - use flat index for register tensor - flat_idx = v * vec_size + e - clamped = max(tYrY_f32[flat_idx], Float32(-FLOAT8_E4M3_MAX)) - clamped = min(clamped, Float32(FLOAT8_E4M3_MAX)) - # Use PTX to convert and store FP8 byte - out_offset = bidx * H + idx - out_ptr = get_ptr_as_int64(mY, Int32(out_offset)) - if self.use_hw_fp8: - cvt_and_store_f32_to_e4m3_hw(clamped, out_ptr) - else: - cvt_and_store_f32_to_e4m3_sw(clamped, out_ptr) + lane_in_row = tidx % threads_per_row + row_in_block = tidx // threads_per_row + actual_row = bidx * rows_per_block + row_in_block + col_offset = lane_in_row * vec_size + + if cutlass.const_expr(self.use_hw_fp8 and vec_size == 8): + for v in cutlass.range_constexpr(num_vec_blocks): + local_col = col_offset + v * threads_per_row * vec_size + abs_col = cluster_y * cols_per_tile + local_col + if abs_col + 8 <= H and actual_row < M: + base = v * 8 + cvt_and_store_8xf32_to_e4m3_hw( + tYrY_f32[base], + tYrY_f32[base + 1], + tYrY_f32[base + 2], + tYrY_f32[base + 3], + tYrY_f32[base + 4], + tYrY_f32[base + 5], + tYrY_f32[base + 6], + tYrY_f32[base + 7], + get_ptr_as_int64( + mY, + cute.crd2idx( + (Int32(actual_row), Int32(abs_col)), mY.layout + ), + ), + ) + else: + for e in cutlass.range_constexpr(vec_size): + abs_col_e = cluster_y * cols_per_tile + local_col + e + if abs_col_e < H and actual_row < M: + flat_idx = v * vec_size + e + clamped = max(tYrY_f32[flat_idx], Float32(-FLOAT8_E4M3_MAX)) + clamped = min(clamped, Float32(FLOAT8_E4M3_MAX)) + cvt_and_store_f32_to_e4m3_hw( + clamped, + get_ptr_as_int64( + mY, + cute.crd2idx( + (Int32(actual_row), Int32(abs_col_e)), + mY.layout, + ), + ), + ) + elif cutlass.const_expr(self.use_hw_fp8 and vec_size == 4): + for v in cutlass.range_constexpr(num_vec_blocks): + local_col = col_offset + v * threads_per_row * vec_size + abs_col = cluster_y * cols_per_tile + local_col + if abs_col + 4 <= H and actual_row < M: + base = v * 4 + cvt_and_store_4xf32_to_e4m3_hw( + tYrY_f32[base], + tYrY_f32[base + 1], + tYrY_f32[base + 2], + tYrY_f32[base + 3], + get_ptr_as_int64( + mY, + cute.crd2idx( + (Int32(actual_row), Int32(abs_col)), mY.layout + ), + ), + ) + else: + for e in cutlass.range_constexpr(vec_size): + abs_col_e = cluster_y * cols_per_tile + local_col + e + if abs_col_e < H and actual_row < M: + flat_idx = v * vec_size + e + clamped = max(tYrY_f32[flat_idx], Float32(-FLOAT8_E4M3_MAX)) + clamped = min(clamped, Float32(FLOAT8_E4M3_MAX)) + cvt_and_store_f32_to_e4m3_hw( + clamped, + get_ptr_as_int64( + mY, + cute.crd2idx( + (Int32(actual_row), Int32(abs_col_e)), + mY.layout, + ), + ), + ) + elif cutlass.const_expr(self.use_hw_fp8 and vec_size == 2): + for v in cutlass.range_constexpr(num_vec_blocks): + local_col = col_offset + v * threads_per_row * vec_size + abs_col = cluster_y * cols_per_tile + local_col + if abs_col + 2 <= H and actual_row < M: + base = v * 2 + cvt_and_store_2xf32_to_e4m3_hw( + tYrY_f32[base], + tYrY_f32[base + 1], + get_ptr_as_int64( + mY, + cute.crd2idx( + (Int32(actual_row), Int32(abs_col)), mY.layout + ), + ), + ) + else: + for e in cutlass.range_constexpr(vec_size): + abs_col_e = cluster_y * cols_per_tile + local_col + e + if abs_col_e < H and actual_row < M: + flat_idx = v * vec_size + e + clamped = max(tYrY_f32[flat_idx], Float32(-FLOAT8_E4M3_MAX)) + clamped = min(clamped, Float32(FLOAT8_E4M3_MAX)) + cvt_and_store_f32_to_e4m3_hw( + clamped, + get_ptr_as_int64( + mY, + cute.crd2idx( + (Int32(actual_row), Int32(abs_col_e)), + mY.layout, + ), + ), + ) + else: + for v in cutlass.range_constexpr(num_vec_blocks): + for e in cutlass.range_constexpr(vec_size): + local_col = col_offset + v * threads_per_row * vec_size + e + abs_col = cluster_y * cols_per_tile + local_col + if abs_col < H and actual_row < M: + flat_idx = v * vec_size + e + clamped = max(tYrY_f32[flat_idx], Float32(-FLOAT8_E4M3_MAX)) + clamped = min(clamped, Float32(FLOAT8_E4M3_MAX)) + out_ptr = get_ptr_as_int64( + mY, + cute.crd2idx( + (Int32(actual_row), Int32(abs_col)), mY.layout + ), + ) + if self.use_hw_fp8: + cvt_and_store_f32_to_e4m3_hw(clamped, out_ptr) + else: + cvt_and_store_f32_to_e4m3_sw(clamped, out_ptr) # PDL: Signal dependent kernels (SM90+ only) if enable_pdl: @@ -639,38 +1099,54 @@ def kernel( @functools.cache def _get_compiled_rmsnorm_kernel( - dtype_str: str, H: int, weight_bias: float, enable_pdl: bool + dtype_str: str, + H: int, + weight_bias: float, + enable_pdl: bool, + sm_version: int, + contiguous: bool = True, ): - """Get a compiled RMSNorm kernel using TVM-FFI.""" + """Get a compiled RMSNorm kernel using TVM-FFI. + + When contiguous=True, tensors are compiled with compact (dense) layouts for + optimal codegen. When False, symbolic row strides are used to support + arbitrary row strides at the cost of some performance. + """ dtype = get_cutlass_dtype(dtype_str) - kernel_obj = RMSNormKernel(dtype, H, weight_bias) + kernel_obj = RMSNormKernel(dtype, H, weight_bias, sm_version=sm_version) - # Use symbolic size for dynamic M dimension sym_m = cute.sym_int() - # Use symbolic stride for arbitrary row stride (last dim must be contiguous) - sym_row_stride_x = cute.sym_int(divisibility=kernel_obj.vec_size) - sym_row_stride_y = cute.sym_int(divisibility=kernel_obj.vec_size) - # Create fake tensors with symbolic stride for arbitrary stride support - x_fake = cute.runtime.make_fake_tensor( - dtype, (sym_m, H), (sym_row_stride_x, 1), assumed_align=16 - ) + if contiguous: + elem_bytes = dtype.width // 8 + tensor_align = math.gcd(128, H * elem_bytes) + x_fake = cute.runtime.make_fake_compact_tensor( + dtype, (sym_m, H), stride_order=(1, 0), assumed_align=tensor_align + ) + y_fake = cute.runtime.make_fake_compact_tensor( + dtype, (sym_m, H), stride_order=(1, 0), assumed_align=tensor_align + ) + else: + sym_row_stride_x = cute.sym_int(divisibility=kernel_obj.vec_size) + sym_row_stride_y = cute.sym_int(divisibility=kernel_obj.vec_size) + x_fake = cute.runtime.make_fake_tensor( + dtype, (sym_m, H), (sym_row_stride_x, 1), assumed_align=16 + ) + y_fake = cute.runtime.make_fake_tensor( + dtype, (sym_m, H), (sym_row_stride_y, 1), assumed_align=16 + ) + w_fake = cute.runtime.make_fake_compact_tensor(dtype, (H,), assumed_align=16) - y_fake = cute.runtime.make_fake_tensor( - dtype, (sym_m, H), (sym_row_stride_y, 1), assumed_align=16 - ) - # Create fake stream that uses environment stream at runtime stream_fake = cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True) - # Compile with TVM-FFI enabled compiled_kernel = cute.compile( kernel_obj, x_fake, w_fake, y_fake, - Int32(1), # Dummy M - Float32(1e-6), # Dummy eps + Int32(1), + Float32(1e-6), enable_pdl, stream_fake, options="--enable-tvm-ffi", @@ -681,24 +1157,22 @@ def _get_compiled_rmsnorm_kernel( @functools.cache def _get_compiled_qk_rmsnorm_kernel( - dtype_str: str, head_dim: int, weight_bias: float, num_warps: int, enable_pdl: bool + dtype_str: str, head_dim: int, weight_bias: float, enable_pdl: bool ): """Get a compiled QKRMSNorm kernel for 3D tensors with arbitrary stride.""" dtype = get_cutlass_dtype(dtype_str) - kernel_obj = QKRMSNormKernel(dtype, head_dim, weight_bias, num_warps) + kernel_obj = QKRMSNormKernel(dtype, head_dim, weight_bias) - # Use symbolic sizes for B, N dimensions sym_b = cute.sym_int() sym_n = cute.sym_int() - # Use symbolic strides for arbitrary stride support - # stride[-1] must be 1 (contiguous in head_dim), but batch/head strides can be anything + # Stride divisibility = vec_size guarantees each row start is aligned + # for the chosen copy_bits (e.g. vec_size=8 for fp16 → 16-byte aligned). sym_batch_stride_x = cute.sym_int(divisibility=kernel_obj.vec_size) sym_head_stride_x = cute.sym_int(divisibility=kernel_obj.vec_size) sym_batch_stride_y = cute.sym_int(divisibility=kernel_obj.vec_size) sym_head_stride_y = cute.sym_int(divisibility=kernel_obj.vec_size) - # Create 3D fake tensors with arbitrary stride x_fake = cute.runtime.make_fake_tensor( dtype, (sym_b, sym_n, head_dim), @@ -715,7 +1189,6 @@ def _get_compiled_qk_rmsnorm_kernel( stream_fake = cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True) - # Compile with TVM-FFI enabled compiled_kernel = cute.compile( kernel_obj, x_fake, @@ -725,7 +1198,6 @@ def _get_compiled_qk_rmsnorm_kernel( Int32(1), # Dummy N Float32(1e-6), # Dummy eps enable_pdl, - Int32(1), # Dummy num_blocks stream_fake, options="--enable-tvm-ffi", ) @@ -741,23 +1213,41 @@ def _get_compiled_rmsnorm_quant_kernel( weight_bias: float, enable_pdl: bool, use_hw_fp8: bool = True, + sm_version: int = 80, + contiguous: bool = True, ): - """Get a compiled RMSNorm + Quant kernel using TVM-FFI.""" + """Get a compiled RMSNorm + Quant kernel using TVM-FFI. + + See _get_compiled_rmsnorm_kernel for contiguous parameter semantics. + """ dtype = get_cutlass_dtype(dtype_str) out_dtype = get_cutlass_dtype(out_dtype_str) - kernel_obj = RMSNormQuantKernel(dtype, H, weight_bias, use_hw_fp8=use_hw_fp8) + kernel_obj = RMSNormQuantKernel( + dtype, H, weight_bias, use_hw_fp8=use_hw_fp8, sm_version=sm_version + ) sym_m = cute.sym_int() - sym_row_stride_x = cute.sym_int(divisibility=kernel_obj.vec_size) - sym_row_stride_y = cute.sym_int(divisibility=kernel_obj.vec_size_out) - x_fake = cute.runtime.make_fake_tensor( - dtype, (sym_m, H), (sym_row_stride_x, 1), assumed_align=16 - ) + if contiguous: + in_align = math.gcd(128, H * (dtype.width // 8)) + out_align = math.gcd(128, H * (out_dtype.width // 8)) + x_fake = cute.runtime.make_fake_compact_tensor( + dtype, (sym_m, H), stride_order=(1, 0), assumed_align=in_align + ) + y_fake = cute.runtime.make_fake_compact_tensor( + out_dtype, (sym_m, H), stride_order=(1, 0), assumed_align=out_align + ) + else: + sym_row_stride_x = cute.sym_int(divisibility=kernel_obj.vec_size) + sym_row_stride_y = cute.sym_int(divisibility=kernel_obj.vec_size) + x_fake = cute.runtime.make_fake_tensor( + dtype, (sym_m, H), (sym_row_stride_x, 1), assumed_align=16 + ) + y_fake = cute.runtime.make_fake_tensor( + out_dtype, (sym_m, H), (sym_row_stride_y, 1), assumed_align=16 + ) + w_fake = cute.runtime.make_fake_compact_tensor(dtype, (H,), assumed_align=16) - y_fake = cute.runtime.make_fake_tensor( - out_dtype, (sym_m, H), (sym_row_stride_y, 1), assumed_align=16 - ) s_fake = cute.runtime.make_fake_compact_tensor(Float32, (1,), assumed_align=4) stream_fake = cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True) @@ -769,7 +1259,7 @@ def _get_compiled_rmsnorm_quant_kernel( y_fake, Int32(1), s_fake, - Float32(1e-6), # eps + Float32(1e-6), enable_pdl, stream_fake, options="--enable-tvm-ffi", @@ -793,24 +1283,29 @@ def rmsnorm_cute( ) -> None: """CuTe DSL RMSNorm implementation. - Supports arbitrary stride - no need to call contiguous(). - Last dimension must be contiguous (stride[-1] == 1). + Supports non-contiguous tensors (stride[-1] must be 1). Uses an optimized + compact kernel for contiguous inputs and a general strided kernel otherwise. """ - shape = input.shape H = shape[-1] if len(shape) == 3: M = shape[0] * shape[1] - input_2d = input.view(M, H) - out_2d = out.view(M, H) + input_2d = input.reshape(M, H) + out_2d = out.reshape(M, H) else: M = shape[0] input_2d = input out_2d = out + is_contiguous = input_2d.is_contiguous() and out_2d.is_contiguous() kernel = _get_compiled_rmsnorm_kernel( - _torch_dtype_to_str(input.dtype), H, weight_bias, enable_pdl + _torch_dtype_to_str(input.dtype), + H, + weight_bias, + enable_pdl, + get_sm_version(input.device), + contiguous=is_contiguous, ) kernel(input_2d, weight, out_2d, M, eps) @@ -825,41 +1320,21 @@ def qk_rmsnorm_cute( ) -> None: """CuTe DSL QKRMSNorm for 3D tensors [batch, heads, head_dim]. - Supports arbitrary stride - no need to call contiguous(). - Each warp processes one (batch, head) pair independently using warp-only reduction. - - Args: - input: Input tensor of shape [batch_size, num_heads, head_dim]. - Last dimension must be contiguous (stride[-1] == 1). - weight: Weight tensor of shape [head_dim]. - output: Output tensor (same shape as input). - eps: Small constant for numerical stability. - weight_bias: Bias added to weight (0 for standard RMSNorm, 1 for Gemma). - enable_pdl: Enable Programmatic Dependent Launch for SM90+ GPUs. + Supports arbitrary stride. Uses multi-row blocks with async/sync copy + depending on head_dim alignment. Each block processes multiple (batch, head) + rows independently. """ shape = input.shape assert len(shape) == 3, "QKRMSNorm expects 3D input [batch, heads, head_dim]" batch_size, num_heads, head_dim = shape - M = batch_size * num_heads - - # Kernel configuration - num_warps = 4 - - # Calculate grid size based on SM count and estimated occupancy - num_sms = get_num_sm(input.device) - blocks_per_sm = 16 # Theoretical max for 128-thread blocks - max_blocks = num_sms * blocks_per_sm - needed_blocks = (M + num_warps - 1) // num_warps - num_blocks = min(max_blocks, needed_blocks) dtype_str = _torch_dtype_to_str(input.dtype) kernel = _get_compiled_qk_rmsnorm_kernel( - dtype_str, head_dim, weight_bias, num_warps, enable_pdl + dtype_str, head_dim, weight_bias, enable_pdl ) - # Pass 3D tensors directly - kernel handles arbitrary stride - kernel(input, weight, output, batch_size, num_heads, eps, num_blocks) + kernel(input, weight, output, batch_size, num_heads, eps) def rmsnorm_quant_cute( @@ -873,14 +1348,14 @@ def rmsnorm_quant_cute( ) -> None: """CuTe DSL RMSNorm + FP8 quantization implementation. - Supports arbitrary stride - no need to call contiguous(). - Last dimension must be contiguous (stride[-1] == 1). + Supports non-contiguous tensors (stride[-1] must be 1). Uses an optimized + compact kernel for contiguous inputs and a general strided kernel otherwise. """ - shape = input.shape H = shape[-1] M = shape[0] + is_contiguous = input.is_contiguous() and out.is_contiguous() dtype_str = _torch_dtype_to_str(input.dtype) out_dtype_str = _torch_dtype_to_str(out.dtype) kernel = _get_compiled_rmsnorm_quant_kernel( @@ -890,6 +1365,8 @@ def rmsnorm_quant_cute( weight_bias, enable_pdl, use_hw_fp8=has_hw_fp8_cvt(input.device), + sm_version=get_sm_version(input.device), + contiguous=is_contiguous, ) kernel(input, weight, out, M, scale, eps) diff --git a/flashinfer/norm/utils.py b/flashinfer/norm/utils.py index 058b7c2276..27a03cbe07 100644 --- a/flashinfer/norm/utils.py +++ b/flashinfer/norm/utils.py @@ -25,6 +25,7 @@ - Type conversion utilities """ +import functools import math import operator from typing import Callable @@ -183,6 +184,124 @@ def cvt_and_store_f32_to_e4m3_sw(val: Float32, addr: Int64, *, loc=None, ip=None ) +@dsl_user_op +def cvt_and_store_8xf32_to_e4m3_hw( + v0: Float32, + v1: Float32, + v2: Float32, + v3: Float32, + v4: Float32, + v5: Float32, + v6: Float32, + v7: Float32, + addr: Int64, + *, + loc=None, + ip=None, +): + """Convert 8 float32 values to E4M3 and store as one 64-bit global store (sm_89+). + + Uses cvt.rn.satfinite.e4m3x2.f32 to convert pairs, then packs into two b32 + words and issues a single st.global.v2.b32. ~4x fewer instructions and ~8x + fewer store transactions compared to 8 scalar st.global.b8 calls. + """ + llvm.inline_asm( + None, + [ + Float32(v0).ir_value(loc=loc, ip=ip), + Float32(v1).ir_value(loc=loc, ip=ip), + Float32(v2).ir_value(loc=loc, ip=ip), + Float32(v3).ir_value(loc=loc, ip=ip), + Float32(v4).ir_value(loc=loc, ip=ip), + Float32(v5).ir_value(loc=loc, ip=ip), + Float32(v6).ir_value(loc=loc, ip=ip), + Float32(v7).ir_value(loc=loc, ip=ip), + Int64(addr).ir_value(loc=loc, ip=ip), + ], + """ + { + .reg .b16 p01, p23, p45, p67; + .reg .b32 lo, hi; + cvt.rn.satfinite.e4m3x2.f32 p01, $1, $0; + cvt.rn.satfinite.e4m3x2.f32 p23, $3, $2; + cvt.rn.satfinite.e4m3x2.f32 p45, $5, $4; + cvt.rn.satfinite.e4m3x2.f32 p67, $7, $6; + mov.b32 lo, {p01, p23}; + mov.b32 hi, {p45, p67}; + st.global.v2.b32 [$8], {lo, hi}; + } + """, + "f,f,f,f,f,f,f,f,l", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +@dsl_user_op +def cvt_and_store_4xf32_to_e4m3_hw( + v0: Float32, + v1: Float32, + v2: Float32, + v3: Float32, + addr: Int64, + *, + loc=None, + ip=None, +): + """Convert 4 float32 values to E4M3 and store as one 32-bit global store (sm_89+).""" + llvm.inline_asm( + None, + [ + Float32(v0).ir_value(loc=loc, ip=ip), + Float32(v1).ir_value(loc=loc, ip=ip), + Float32(v2).ir_value(loc=loc, ip=ip), + Float32(v3).ir_value(loc=loc, ip=ip), + Int64(addr).ir_value(loc=loc, ip=ip), + ], + """ + { + .reg .b16 p01, p23; + .reg .b32 packed; + cvt.rn.satfinite.e4m3x2.f32 p01, $1, $0; + cvt.rn.satfinite.e4m3x2.f32 p23, $3, $2; + mov.b32 packed, {p01, p23}; + st.global.b32 [$4], packed; + } + """, + "f,f,f,f,l", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +@dsl_user_op +def cvt_and_store_2xf32_to_e4m3_hw( + v0: Float32, v1: Float32, addr: Int64, *, loc=None, ip=None +): + """Convert 2 float32 values to E4M3 and store as one 16-bit global store (sm_89+).""" + llvm.inline_asm( + None, + [ + Float32(v0).ir_value(loc=loc, ip=ip), + Float32(v1).ir_value(loc=loc, ip=ip), + Int64(addr).ir_value(loc=loc, ip=ip), + ], + """ + { + .reg .b16 packed; + cvt.rn.satfinite.e4m3x2.f32 packed, $1, $0; + st.global.b16 [$2], packed; + } + """, + "f,f,l", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + def has_hw_fp8_cvt(device: torch.device = None) -> bool: """Check if the device supports hardware FP8 conversion (sm_89+).""" if device is None: @@ -191,6 +310,17 @@ def has_hw_fp8_cvt(device: torch.device = None) -> bool: return major > 8 or (major == 8 and minor >= 9) +@functools.lru_cache(maxsize=16) +def get_sm_version(device=None) -> int: + """Get the SM version of a CUDA device (e.g., 100 for SM100).""" + if not torch.cuda.is_available(): + return 80 + if device is None: + device = torch.cuda.current_device() + props = torch.cuda.get_device_properties(device) + return props.major * 10 + props.minor + + @dsl_user_op def get_ptr_as_int64(tensor: cute.Tensor, offset: Int32, *, loc=None, ip=None) -> Int64: """Get the memory address of tensor[offset] as Int64.""" @@ -199,6 +329,64 @@ def get_ptr_as_int64(tensor: cute.Tensor, offset: Int32, *, loc=None, ip=None) - return Int64(ptr_int) +# ============================================================================= +# PTX Intrinsics - Cluster Operations (SM90+) +# ============================================================================= + + +@dsl_user_op +def set_block_rank( + smem_ptr: cute.Pointer, peer_cta_rank_in_cluster: Int32, *, loc=None, ip=None +) -> Int32: + """Map smem pointer to address at another CTA rank in the cluster.""" + smem_ptr_i32 = smem_ptr.toint(loc=loc, ip=ip).ir_value() + return Int32( + llvm.inline_asm( + T.i32(), + [smem_ptr_i32, peer_cta_rank_in_cluster.ir_value()], + "mapa.shared::cluster.u32 $0, $1, $2;", + "=r,r,r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +@dsl_user_op +def store_shared_remote( + val: Float32, + smem_ptr: cute.Pointer, + mbar_ptr: cute.Pointer, + peer_cta_rank_in_cluster: Int32, + *, + loc=None, + ip=None, +) -> None: + """Store Float32 value to shared memory on a remote CTA in the cluster.""" + remote_smem_ptr_i32 = set_block_rank( + smem_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip + ).ir_value() + remote_mbar_ptr_i32 = set_block_rank( + mbar_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip + ).ir_value() + llvm.inline_asm( + None, + [remote_smem_ptr_i32, val.ir_value(loc=loc, ip=ip), remote_mbar_ptr_i32], + "st.async.shared::cluster.mbarrier::complete_tx::bytes.f32 [$0], $1, [$2];", + "r,f,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +@dsl_user_op +def elem_pointer(x: cute.Tensor, coord, *, loc=None, ip=None) -> cute.Pointer: + """Get pointer to element at coordinate in tensor.""" + return x.iterator + cute.crd2idx(coord, x.layout, loc=loc, ip=ip) + + # ============================================================================= # Warp and Block Reduction Utilities # ============================================================================= @@ -263,6 +451,129 @@ def row_reduce_sum( return warp_val +@cute.jit +def block_reduce_multirow( + val: Float32, + op: Callable, + reduction_buffer: cute.Tensor, + init_val: Float32, +) -> Float32: + """Block reduction with 2D buffer (rows_per_block, warps_per_row). + + Each warp writes its partial sum to the row it belongs to, then + lane 0..warps_per_row-1 read back and do a final warp reduction. + """ + lane_idx = cute.arch.lane_idx() + warp_idx = cute.arch.warp_idx() + warps_per_row = cute.size(reduction_buffer.shape[1]) + row_idx = warp_idx // warps_per_row + col_idx = warp_idx % warps_per_row + + if lane_idx == 0: + reduction_buffer[row_idx, col_idx] = val + cute.arch.barrier() + + block_reduce_val = init_val + if lane_idx < warps_per_row: + block_reduce_val = reduction_buffer[row_idx, lane_idx] + return warp_reduce(block_reduce_val, op) + + +@cute.jit +def cluster_reduce_multirow( + val: Float32, + op: Callable, + reduction_buffer: cute.Tensor, + mbar_ptr, + cluster_n: cutlass.Constexpr[int], + init_val: Float32, +) -> Float32: + """Cluster reduction across multiple CTAs using mbarrier. + + reduction_buffer has shape (rows_per_block, (warps_per_row, cluster_n)). + Each warp sends its partial result to all CTAs in the cluster via + st.async.shared::cluster, then every CTA reduces the collected values. + """ + cta_rank_in_cluster = cute.arch.block_idx_in_cluster() + lane_idx = cute.arch.lane_idx() + warp_idx = cute.arch.warp_idx() + + rows_per_block = reduction_buffer.shape[0] + warps_per_row = reduction_buffer.shape[1][0] + + row_idx = warp_idx // warps_per_row + col_idx = warp_idx % warps_per_row + + if warp_idx == 0: + with cute.arch.elect_one(): + num_warps = rows_per_block * warps_per_row + expected_bytes = num_warps * cluster_n * 4 + cute.arch.mbarrier_arrive_and_expect_tx(mbar_ptr, expected_bytes) + + if lane_idx < cluster_n: + store_shared_remote( + val, + elem_pointer(reduction_buffer, (row_idx, (col_idx, cta_rank_in_cluster))), + mbar_ptr, + peer_cta_rank_in_cluster=lane_idx, + ) + + cute.arch.mbarrier_wait(mbar_ptr, phase=0) + + num_total = warps_per_row * cluster_n + num_iter = cute.ceil_div(num_total, 32) + + block_reduce_val = init_val + for i in cutlass.range_constexpr(num_iter): + idx = lane_idx + i * 32 + if idx < num_total: + block_reduce_val = op(block_reduce_val, reduction_buffer[row_idx, idx]) + + return warp_reduce(block_reduce_val, op) + + +@cute.jit +def row_reduce_sum_multirow( + x: cute.TensorSSA, + threads_per_row: cutlass.Constexpr[int], + reduction_buffer: cute.Tensor, + mbar_ptr, + cluster_n: cutlass.Constexpr[int], +) -> Float32: + """Row reduction for sum with optional cluster support. + + When cluster_n == 1, uses block-level reduction with 2D buffer + (rows_per_block, warps_per_row). When cluster_n > 1, uses cross-CTA + cluster reduction with hierarchical buffer + (rows_per_block, (warps_per_row, cluster_n)). + """ + local_val = x.reduce( + cute.ReductionOp.ADD, init_val=Float32(0.0), reduction_profile=0 + ) + + warp_width = min(threads_per_row, 32) + warp_val = warp_reduce(local_val, operator.add, width=warp_width) + + warps_per_row = max(threads_per_row // 32, 1) + + if cutlass.const_expr(warps_per_row > 1 or cluster_n > 1): + if cutlass.const_expr(cluster_n == 1): + return block_reduce_multirow( + warp_val, operator.add, reduction_buffer, Float32(0.0) + ) + else: + return cluster_reduce_multirow( + warp_val, + operator.add, + reduction_buffer, + mbar_ptr, + cluster_n, + Float32(0.0), + ) + else: + return warp_val + + # ============================================================================= # Predicate Utility # ============================================================================= @@ -414,12 +725,24 @@ def _torch_dtype_to_str(dtype: torch.dtype) -> str: "rcp_approx_ftz", "cvt_and_store_f32_to_e4m3_hw", "cvt_and_store_f32_to_e4m3_sw", + "cvt_and_store_8xf32_to_e4m3_hw", + "cvt_and_store_4xf32_to_e4m3_hw", + "cvt_and_store_2xf32_to_e4m3_hw", "has_hw_fp8_cvt", "get_ptr_as_int64", + # PTX intrinsics - Cluster operations + "set_block_rank", + "store_shared_remote", + "elem_pointer", + # Device utilities + "get_sm_version", # Reduction utilities "warp_reduce", "block_reduce", "row_reduce_sum", + "block_reduce_multirow", + "cluster_reduce_multirow", + "row_reduce_sum_multirow", # Predicate utilities "predicate_k", "predicate_k_3d",