Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 140 additions & 32 deletions vllm_ascend/ops/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,120 @@
# This file is a part of the vllm-ascend project.
#

from typing import Optional, Tuple, Union, cast
from typing import Optional, Tuple, Union, cast, Dict, Any

import torch
from vllm.config import get_current_vllm_config
from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm
from vllm.triton_utils import tl, triton
from functools import cache


def get_device_properties():
return None, 40


@triton.jit
def add_rmsnorm_bias_kernel(
input_ptr,
residual_ptr,
norm_weight_ptr,
norm_bias_ptr,
quant_scale_ptr,
quant_offset_ptr,
output_ptr,
output2_ptr,
batch_size,
hidden_size: tl.constexpr,
eps: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
COL_BLOCK_SIZE: tl.constexpr,
):
row_start = tl.program_id(0)
row_step = tl.num_programs(0)
cols = tl.arange(0, BLOCK_SIZE)
valid_mask = cols < hidden_size
norm_weight_values = tl.load(norm_weight_ptr + cols, mask=valid_mask, other=0.0)
input_offsets = row_start * hidden_size + cols
for i in tl.range(row_start, batch_size, row_step):
# add
buffered_values = tl.load(input_ptr + input_offsets, mask=valid_mask, other=0.0)
buffered_values += tl.load(
residual_ptr + input_offsets, mask=valid_mask, other=0.0
)
tl.store(output2_ptr + input_offsets, buffered_values, mask=valid_mask)
buffered_values = buffered_values.to(tl.float32)
# rmsnorm
squares = buffered_values * buffered_values
variance = tl.sum(squares) / hidden_size
reciprocal_std = 1 / tl.sqrt(variance + eps)
buffered_values = buffered_values * reciprocal_std
buffered_values = buffered_values * norm_weight_values
# add bias
norm_bias_values = tl.load(norm_bias_ptr + cols, mask=valid_mask, other=0.0)
buffered_values = buffered_values + norm_bias_values
tl.store(output_ptr + input_offsets, buffered_values, mask=valid_mask)

input_offsets += row_step * hidden_size


kernels = {}


def add_rmsnorm_bias(
input: torch.Tensor,
residual: torch.Tensor,

Check failure on line 80 in vllm_ascend/ops/layernorm.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Need type annotation for "kernels" (hint: "kernels: dict[<type>, <type>] = ...") [var-annotated]

Check failure on line 80 in vllm_ascend/ops/layernorm.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Need type annotation for "kernels" (hint: "kernels: dict[<type>, <type>] = ...") [var-annotated]

Check failure on line 80 in vllm_ascend/ops/layernorm.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Need type annotation for "kernels" (hint: "kernels: dict[<type>, <type>] = ...") [var-annotated]

Check failure on line 80 in vllm_ascend/ops/layernorm.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Need type annotation for "kernels" (hint: "kernels: dict[<type>, <type>] = ...") [var-annotated]

Check failure on line 80 in vllm_ascend/ops/layernorm.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Need type annotation for "kernels" (hint: "kernels: dict[<type>, <type>] = ...") [var-annotated]

Check failure on line 80 in vllm_ascend/ops/layernorm.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Need type annotation for "kernels" (hint: "kernels: dict[<type>, <type>] = ...") [var-annotated]

Check failure on line 80 in vllm_ascend/ops/layernorm.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Need type annotation for "kernels" (hint: "kernels: dict[<type>, <type>] = ...") [var-annotated]

Check failure on line 80 in vllm_ascend/ops/layernorm.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Need type annotation for "kernels" (hint: "kernels: dict[<type>, <type>] = ...") [var-annotated]

Check failure on line 80 in vllm_ascend/ops/layernorm.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Need type annotation for "kernels" (hint: "kernels: dict[<type>, <type>] = ...") [var-annotated]

Check failure on line 80 in vllm_ascend/ops/layernorm.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Need type annotation for "kernels" (hint: "kernels: dict[<type>, <type>] = ...") [var-annotated]
norm_weight: torch.Tensor,
norm_bias: Optional[torch.Tensor],
eps: float,
quant_scale: Optional[torch.Tensor] = None,
quant_offset: Optional[torch.Tensor] = None,
):
input = input.contiguous()
residual = residual.contiguous()
norm_weight = norm_weight.contiguous()
norm_bias = norm_bias.contiguous() if norm_bias is not None else torch.zeros_like(norm_weight).contiguous()

num_vectorcore = 40
batch_size = input.shape[0]
hidden_size = input.shape[1]
BLOCK_SIZE = triton.next_power_of_2(hidden_size)
COL_BLOCK_SIZE = 2048
n_rows = min(batch_size, num_vectorcore)
output = torch.empty(
batch_size, hidden_size, device=input.device, dtype=input.dtype
)
output2 = torch.empty(
batch_size, hidden_size, device=input.device, dtype=input.dtype
)

add_rmsnorm_bias_kernel[(n_rows, 1, 1)](
input,
residual,
norm_weight,
norm_bias,
quant_scale,
quant_offset,
output,
output2,
batch_size,
hidden_size,
eps,
BLOCK_SIZE,
COL_BLOCK_SIZE,
)
return output, output2


class AscendRMSNorm(RMSNorm):

def __init__(
self,
hidden_size: int,
eps: float = 1e-6,
var_hidden_size: Optional[int] = None,
has_weight: bool = True,
dtype: Optional[torch.dtype] = None,
self,
hidden_size: int,
eps: float = 1e-6,
var_hidden_size: Optional[int] = None,
has_weight: bool = True,
dtype: Optional[torch.dtype] = None,
) -> None:
super().__init__(hidden_size, eps, var_hidden_size, has_weight, dtype)
vllm_config = get_current_vllm_config()
Expand All @@ -42,9 +140,9 @@
requires_grad=False)

def forward_oot(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
import torch_npu

Expand All @@ -57,37 +155,47 @@
x, _ = torch_npu.npu_rms_norm(x, self.weight,
self.variance_epsilon)
else:
x, _, residual = torch_npu.npu_add_rms_norm(
x, residual, self.weight, self.variance_epsilon)
if self.bias is not None:
x.add_(self.bias)
x, residual = add_rmsnorm_bias(
input=x,
residual=residual,
norm_weight=self.weight,
norm_bias=self.bias,
eps=self.variance_epsilon,
quant_scale=None,
quant_offset=None
)
return x, residual

x, residual = torch_npu.npu_rms_norm(x, self.weight,
self.variance_epsilon)
if self.bias is not None:
x.add_(self.bias)
residual = torch.zeros_like(x, device=x.device, dtype=x.dtype)
x, _ = add_rmsnorm_bias(
input=x,
residual=residual,
norm_weight=self.weight,
norm_bias=self.bias,
eps=self.variance_epsilon,
quant_scale=None,
quant_offset=None
)
return x


class AscendQuantRMSNorm(AscendRMSNorm):

def __init__(
self,
hidden_size: int,
eps: float = 1e-6,
var_hidden_size: Optional[int] = None,
has_weight: bool = True,
dtype: Optional[torch.dtype] = None,
self,
hidden_size: int,
eps: float = 1e-6,
var_hidden_size: Optional[int] = None,
has_weight: bool = True,
dtype: Optional[torch.dtype] = None,
) -> None:
super().__init__(hidden_size, eps, var_hidden_size, has_weight, dtype)
self.bias = torch.nn.Parameter(torch.zeros(hidden_size),
requires_grad=False)

def forward_oot(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if residual is not None:
x, residual = super().forward_oot(x, residual)
Expand All @@ -98,9 +206,9 @@
class AscendGemmaRMSNorm(GemmaRMSNorm):

def forward_oot(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
import torch_npu

Expand All @@ -119,4 +227,4 @@

x, _ = torch_npu.npu_rms_norm(x, 1.0 + self.weight,
self.variance_epsilon)
return x
return x
Loading