Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 115 additions & 8 deletions vllm_ascend/ops/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,119 @@
import torch
from vllm.config import get_current_vllm_config
from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm
from vllm.triton_utils import tl, triton
from functools import cache


@triton.heuristics({
"HAS_BIAS": lambda args: args["B"] is not None
})
@triton.heuristics({"HAS_Z": lambda args: args["Z"] is not None})
@triton.jit
def rms_norm_fwd_kernel(
X, # pointer to the input
Y, # pointer to the output
W, # pointer to the weights
B, # pointer to the biases
Z, # pointer to the residual
Z_Out, # pointer to the residual output
stride_x_row, # how much to increase the pointer when moving by 1 row
stride_y_row,
stride_z_row,
stride_z_out_row,
n_rows, # number of rows in X_base
n_cols, # number of columns in X_base
eps, # epsilon to avoid division by zero
BLOCK_N: tl.constexpr,
HAS_BIAS: tl.constexpr,
HAS_Z: tl.constexpr,
):
# Map the program id to the row of X_base and Y_base it should compute.

Check failure on line 50 in vllm_ascend/ops/layernorm.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Ruff (F841)

vllm_ascend/ops/layernorm.py:50:5: F841 Local variable `row_mask` is assigned to but never used
# Each program computes a row of X_base and store to Y_base
row_idx = tl.program_id(0)
row_mask = row_idx < n_rows
offsets = tl.arange(0, BLOCK_N)
col_mask = offsets < n_cols
if HAS_BIAS:
bias = tl.load(B + offsets, mask=col_mask, other=0.0).to(tl.float32)
w = tl.load(W + offsets, mask=col_mask, other=0.0).to(tl.float32)
start_x = X + row_idx * stride_x_row
start_y = Y + row_idx * stride_y_row
x = tl.load(start_x + offsets, mask=col_mask, other=0.0).to(tl.float32)
if HAS_Z:
start_z = Z + row_idx * stride_z_row
start_z_out = Z_Out + row_idx * stride_z_out_row
z = tl.load(start_z + offsets, mask=col_mask, other=0.0).to(tl.float32)
x = x + z
tl.store(start_z_out + offsets, x, mask=col_mask)
var = tl.sum(x * x, axis=0) / n_cols
rtsd = 1 / tl.sqrt(var + eps)
x_hat = x * rtsd
y = x_hat * w
if HAS_BIAS:
y = y + bias
tl.store(start_y + offsets, y, mask=col_mask)


def _rms_norm_fwd_triton(
x,
weight,
eps,
residual=None,
bias=None,
out=None,
residual_out=None,
):
M, N = x.shape
assert x.stride(-1) == 1
assert weight.shape == (N,)
assert weight.stride(-1) == 1
# logger.info(f"bias is {bias}")
if bias is not None:
assert bias.stride(-1) == 1
assert bias.shape == (N,)
if residual is not None:
assert residual.shape == x.shape
assert residual.stride(-1) == 1
if residual_out is None:
residual_out = torch.empty_like(x)
# allocate output
if out is not None:
assert out.shape == x.shape
else:
out = torch.empty_like(x)
assert out.stride(-1) == 1
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // x.element_size()
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
if N > BLOCK_N:
raise RuntimeError(
"This rms norm doesn't support feature dim >= 64KB.")
# heuristics for number of warps
num_warps = min(max(BLOCK_N // 256, 1), 8)
# _, num_vectorcore = get_device_properties()
num_vectorcore = 48
grid = (M if M < num_vectorcore else num_vectorcore,)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The grid for launching the Triton kernel is incorrectly calculated. It's capped at a hardcoded num_vectorcore value of 48. The rms_norm_fwd_kernel processes one row per program instance. With the current grid calculation, if the input tensor has more than 48 rows, only the first 48 rows will be processed, leading to incorrect output for the remaining rows. This is a critical correctness bug.

To process all M rows, the grid size must be (M,). The hardcoded num_vectorcore and the logic that uses it should be removed.

Suggested change
# _, num_vectorcore = get_device_properties()
num_vectorcore = 48
grid = (M if M < num_vectorcore else num_vectorcore,)
grid = (M,)

# with torch.npu.device(x.device.index):
rms_norm_fwd_kernel[grid](
x,
out,
weight,
bias,
residual,
residual_out,
x.stride(0),
out.stride(0),
residual.stride(0) if residual is not None else None,
residual_out.stride(0) if residual is not None else None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

When residual is None, None is passed to the stride_z_row and stride_z_out_row arguments of the Triton kernel. Triton's JIT compiler expects numerical values for non-pointer arguments and will likely raise a TypeError when it receives None. You should pass a dummy integer value like 0 instead. These arguments are not used when residual is None because of the HAS_Z conditional, so 0 is a safe value.

Suggested change
residual.stride(0) if residual is not None else None,
residual_out.stride(0) if residual is not None else None,
residual.stride(0) if residual is not None else 0,
residual_out.stride(0) if residual is not None else 0,

M,
N,
eps,
BLOCK_N=BLOCK_N,
num_warps=num_warps,
# multibuffer=True,
)
return out, residual_out


class AscendRMSNorm(RMSNorm):
Expand Down Expand Up @@ -57,16 +170,10 @@
x, _ = torch_npu.npu_rms_norm(x, self.weight,
self.variance_epsilon)
else:
x, _, residual = torch_npu.npu_add_rms_norm(
x, residual, self.weight, self.variance_epsilon)
if self.bias is not None:
x.add_(self.bias)
x, residual = _rms_norm_fwd_triton(x, self.weight, self.variance_epsilon, residual, self.bias)
return x, residual

x, residual = torch_npu.npu_rms_norm(x, self.weight,
self.variance_epsilon)
if self.bias is not None:
x.add_(self.bias)
x, _ = _rms_norm_fwd_triton(x, self.weight, self.variance_epsilon, residual, self.bias)
return x


Expand Down
Loading