-
Notifications
You must be signed in to change notification settings - Fork 665
rmsaddbias #5012
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
rmsaddbias #5012
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -20,6 +20,119 @@ | |||||||||
| import torch | ||||||||||
| from vllm.config import get_current_vllm_config | ||||||||||
| from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm | ||||||||||
| from vllm.triton_utils import tl, triton | ||||||||||
| from functools import cache | ||||||||||
|
|
||||||||||
|
|
||||||||||
| @triton.heuristics({ | ||||||||||
| "HAS_BIAS": lambda args: args["B"] is not None | ||||||||||
| }) | ||||||||||
| @triton.heuristics({"HAS_Z": lambda args: args["Z"] is not None}) | ||||||||||
| @triton.jit | ||||||||||
| def rms_norm_fwd_kernel( | ||||||||||
| X, # pointer to the input | ||||||||||
| Y, # pointer to the output | ||||||||||
| W, # pointer to the weights | ||||||||||
| B, # pointer to the biases | ||||||||||
| Z, # pointer to the residual | ||||||||||
| Z_Out, # pointer to the residual output | ||||||||||
| stride_x_row, # how much to increase the pointer when moving by 1 row | ||||||||||
| stride_y_row, | ||||||||||
| stride_z_row, | ||||||||||
| stride_z_out_row, | ||||||||||
| n_rows, # number of rows in X_base | ||||||||||
| n_cols, # number of columns in X_base | ||||||||||
| eps, # epsilon to avoid division by zero | ||||||||||
| BLOCK_N: tl.constexpr, | ||||||||||
| HAS_BIAS: tl.constexpr, | ||||||||||
| HAS_Z: tl.constexpr, | ||||||||||
| ): | ||||||||||
| # Map the program id to the row of X_base and Y_base it should compute. | ||||||||||
| # Each program computes a row of X_base and store to Y_base | ||||||||||
| row_idx = tl.program_id(0) | ||||||||||
| row_mask = row_idx < n_rows | ||||||||||
| offsets = tl.arange(0, BLOCK_N) | ||||||||||
| col_mask = offsets < n_cols | ||||||||||
| if HAS_BIAS: | ||||||||||
| bias = tl.load(B + offsets, mask=col_mask, other=0.0).to(tl.float32) | ||||||||||
| w = tl.load(W + offsets, mask=col_mask, other=0.0).to(tl.float32) | ||||||||||
| start_x = X + row_idx * stride_x_row | ||||||||||
| start_y = Y + row_idx * stride_y_row | ||||||||||
| x = tl.load(start_x + offsets, mask=col_mask, other=0.0).to(tl.float32) | ||||||||||
| if HAS_Z: | ||||||||||
| start_z = Z + row_idx * stride_z_row | ||||||||||
| start_z_out = Z_Out + row_idx * stride_z_out_row | ||||||||||
| z = tl.load(start_z + offsets, mask=col_mask, other=0.0).to(tl.float32) | ||||||||||
| x = x + z | ||||||||||
| tl.store(start_z_out + offsets, x, mask=col_mask) | ||||||||||
| var = tl.sum(x * x, axis=0) / n_cols | ||||||||||
| rtsd = 1 / tl.sqrt(var + eps) | ||||||||||
| x_hat = x * rtsd | ||||||||||
| y = x_hat * w | ||||||||||
| if HAS_BIAS: | ||||||||||
| y = y + bias | ||||||||||
| tl.store(start_y + offsets, y, mask=col_mask) | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def _rms_norm_fwd_triton( | ||||||||||
| x, | ||||||||||
| weight, | ||||||||||
| eps, | ||||||||||
| residual=None, | ||||||||||
| bias=None, | ||||||||||
| out=None, | ||||||||||
| residual_out=None, | ||||||||||
| ): | ||||||||||
| M, N = x.shape | ||||||||||
| assert x.stride(-1) == 1 | ||||||||||
| assert weight.shape == (N,) | ||||||||||
| assert weight.stride(-1) == 1 | ||||||||||
| # logger.info(f"bias is {bias}") | ||||||||||
| if bias is not None: | ||||||||||
| assert bias.stride(-1) == 1 | ||||||||||
| assert bias.shape == (N,) | ||||||||||
| if residual is not None: | ||||||||||
| assert residual.shape == x.shape | ||||||||||
| assert residual.stride(-1) == 1 | ||||||||||
| if residual_out is None: | ||||||||||
| residual_out = torch.empty_like(x) | ||||||||||
| # allocate output | ||||||||||
| if out is not None: | ||||||||||
| assert out.shape == x.shape | ||||||||||
| else: | ||||||||||
| out = torch.empty_like(x) | ||||||||||
| assert out.stride(-1) == 1 | ||||||||||
| # Less than 64KB per feature: enqueue fused kernel | ||||||||||
| MAX_FUSED_SIZE = 65536 // x.element_size() | ||||||||||
| BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) | ||||||||||
| if N > BLOCK_N: | ||||||||||
| raise RuntimeError( | ||||||||||
| "This rms norm doesn't support feature dim >= 64KB.") | ||||||||||
| # heuristics for number of warps | ||||||||||
| num_warps = min(max(BLOCK_N // 256, 1), 8) | ||||||||||
| # _, num_vectorcore = get_device_properties() | ||||||||||
| num_vectorcore = 48 | ||||||||||
| grid = (M if M < num_vectorcore else num_vectorcore,) | ||||||||||
| # with torch.npu.device(x.device.index): | ||||||||||
| rms_norm_fwd_kernel[grid]( | ||||||||||
| x, | ||||||||||
| out, | ||||||||||
| weight, | ||||||||||
| bias, | ||||||||||
| residual, | ||||||||||
| residual_out, | ||||||||||
| x.stride(0), | ||||||||||
| out.stride(0), | ||||||||||
| residual.stride(0) if residual is not None else None, | ||||||||||
| residual_out.stride(0) if residual is not None else None, | ||||||||||
|
||||||||||
| residual.stride(0) if residual is not None else None, | |
| residual_out.stride(0) if residual is not None else None, | |
| residual.stride(0) if residual is not None else 0, | |
| residual_out.stride(0) if residual is not None else 0, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The grid for launching the Triton kernel is incorrectly calculated. It's capped at a hardcoded
num_vectorcorevalue of 48. Therms_norm_fwd_kernelprocesses one row per program instance. With the current grid calculation, if the input tensor has more than 48 rows, only the first 48 rows will be processed, leading to incorrect output for the remaining rows. This is a critical correctness bug.To process all
Mrows, the grid size must be(M,). The hardcodednum_vectorcoreand the logic that uses it should be removed.