Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions vllm_rbln/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def register_ops():
import vllm_rbln.attention.layer # noqa
import vllm_rbln.model_executor.layers.fused_moe.layer # noqa
import vllm_rbln.model_executor.layers.logits_processor # noqa
import vllm_rbln.model_executor.layers.quantization.kernels.mixed_precision # noqa
import vllm_rbln.model_executor.layers.rotary_embedding.base # noqa
import vllm_rbln.model_executor.layers.rotary_embedding.deepseek_scaling_rope # noqa
import vllm_rbln.model_executor.layers.vocab_parallel_embedding # noqa
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright 2025 Rebellions Inc. All rights reserved.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at:

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# NOTE: This module should be imported before creating model config,
# because the function we're patching is imported at that moment.

from typing import Optional

import vllm.envs as envs
import vllm.model_executor.layers.quantization.kernels.mixed_precision as mp
from vllm.model_executor.layers.quantization.kernels.mixed_precision.MPLinearKernel import ( # noqa: E501
MPLinearKernel, MPLinearLayerConfig)

from vllm_rbln.model_executor.layers.quantization.kernels.mixed_precision.unpacked import ( # noqa: E501
RBLNInt8UnpackedLinearKernel)

choose_mp_linear_kernel_original = mp.choose_mp_linear_kernel

_POSSIBLE_KERNELS: list[type[MPLinearKernel]] = [
RBLNInt8UnpackedLinearKernel,
]


def choose_mp_linear_kernel_rbln(
config: MPLinearLayerConfig,
compute_capability: Optional[int] = None,
) -> type[MPLinearKernel]:
from vllm.platforms import current_platform

if "rbln" not in current_platform.get_device_name().lower():
return choose_mp_linear_kernel_original(config, compute_capability)

failure_reasons = []
for kernel in _POSSIBLE_KERNELS:
if kernel.__name__ in envs.VLLM_DISABLED_KERNELS:
failure_reasons.append(
f' {kernel.__name__} disabled by environment variable')
continue

can_implement, failure_reason = kernel.can_implement(config)
if can_implement:
return kernel
else:
failure_reasons.append(
f' {kernel.__name__} cannot implement due to: {failure_reason}'
)

raise ValueError(
"Failed to find a kernel that can implement the "\
"WNA16 linear layer. Reasons: \n"
+ '\n'.join(failure_reasons))


mp.choose_mp_linear_kernel = choose_mp_linear_kernel_rbln
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# Copyright 2025 Rebellions Inc. All rights reserved.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at:

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from typing import Optional

import torch
from compressed_tensors.compressors.quantized_compressors import (
unpack_from_int32)
from vllm.model_executor.layers.quantization.kernels.mixed_precision.MPLinearKernel import ( # noqa: E501
MPLinearKernel, MPLinearLayerConfig)
from vllm.model_executor.parameter import BasevLLMParameter
from vllm.scalar_type import scalar_types


class RBLNInt8UnpackedLinearKernel(MPLinearKernel):
"""
Torch native implementation of mixed precision Linear, based on
compressed_tensors' dequantize() function. rebel_compiler detects this
pattern and maps it to the actual kernel.
"""

@classmethod
def get_min_capability(cls) -> int:
raise NotImplementedError

@classmethod
def can_implement(cls,
c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
if c.weight_type not in (scalar_types.uint4b8, scalar_types.uint8b128):
return False, f"Weight type {c.weight_type} not supported"
if c.zero_points:
return False, "Asymmetric quantization not supported"
if c.group_size not in (-1, 64, 128):
return False, f"Group size {c.group_size} not supported"
if c.has_g_idx:
return False, "Group/dynamic activation ordering not supported"
return True, None

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
c = self.config
bits = c.weight_type.mantissa

os.environ['RBLN_QUANT_BITS'] = str(bits)

if c.has_g_idx:
w_gidx = getattr(layer, self.w_gidx_name)
layer.perm = torch.argsort(w_gidx)

def transform_w_q(x: BasevLLMParameter):
in_features, out_features = c.full_weight_shape
x.data = unpack_from_int32(x.data, bits,
torch.Size((out_features, in_features)))
if c.has_g_idx:
x.data = x.data[:, layer.perm]
return x

def transform_w_s(x: BasevLLMParameter):
if c.group_size == 128:
# Currently we only support group size 64 natively. So
# duplicate scale to break a group into two groups of size 64.
x.data = x.data.repeat_interleave(2, dim=-1)
return x

self._transform_param(layer, self.w_q_name, transform_w_q)
self._transform_param(layer, self.w_s_name, transform_w_s)

def apply_weights(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
in_features, out_features = self.config.full_weight_shape

if self.config.has_g_idx:
x = x[..., layer.perm]

w_q, w_s, _, _ = self._get_weight_params(layer)
if self.config.group_size > 0:
w_q = w_q.view(out_features, in_features // 64,
64) # see transform_w_s
w_fp = w_q.type(x.dtype) * w_s.unsqueeze(-1)
w_fp = w_fp.view(out_features, in_features)
else:
w_fp = w_q.type(x.dtype) * w_s

return torch.nn.functional.linear(x, w_fp, bias)