Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/optimum/rbln/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@
"RBLNPixtralVisionModelConfig",
"RBLNPhiModel",
"RBLNPhiModelConfig",
"RBLNQuantizationConfig",
"RBLNQwen2ForCausalLM",
"RBLNQwen2ForCausalLMConfig",
"RBLNQwen2_5_VisionTransformerPretrainedModel",
Expand Down Expand Up @@ -434,6 +435,7 @@
RBLNPhiModelConfig,
RBLNPixtralVisionModel,
RBLNPixtralVisionModelConfig,
RBLNQuantizationConfig,
RBLNQwen2_5_VisionTransformerPretrainedModel,
RBLNQwen2_5_VisionTransformerPretrainedModelConfig,
RBLNQwen2_5_VLForConditionalGeneration,
Expand Down
2 changes: 2 additions & 0 deletions src/optimum/rbln/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@
"RBLNXLMRobertaModel",
"RBLNXLMRobertaModelConfig",
],
"utils": ["RBLNQuantizationConfig"],
}

if TYPE_CHECKING:
Expand Down Expand Up @@ -329,6 +330,7 @@
RBLNXLMRobertaModel,
RBLNXLMRobertaModelConfig,
)
from .utils import RBLNQuantizationConfig
else:
import sys

Expand Down
1 change: 1 addition & 0 deletions src/optimum/rbln/transformers/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .rbln_quantization import RBLNQuantizationConfig
73 changes: 73 additions & 0 deletions src/optimum/rbln/transformers/utils/qlinear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F


class QLinear(nn.Module):
def __init__(
self,
weight: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
weight_scale: Optional[torch.Tensor] = None,
input_scale: Optional[torch.Tensor] = None,
# FIXME(jongho): Make it only holds k_scale or v_scale
k_scale: Optional[torch.Tensor] = None,
v_scale: Optional[torch.Tensor] = None,
dynamic: bool = False,
):
super().__init__()

self.weight = weight
self.bias = bias
self.weight_scale = weight_scale
self.input_scale = input_scale
self.k_scale = k_scale
self.v_scale = v_scale
self.dynamic = dynamic

if weight_scale is None:
raise ValueError("weight_scale is required")

def dtype(self) -> torch.dtype:
return self.weight.dtype

def forward(self, x: torch.Tensor) -> torch.Tensor:
raise NotImplementedError


class QIntLinear(QLinear):
def forward(self, x: torch.Tensor) -> torch.Tensor:
iinfo = torch.iinfo(self.dtype())
finfo = torch.finfo(x.dtype)
if self.dynamic:
if self.input_scale:
raise NotImplementedError("Dynamic quantization with input_scale is not supported.")
x_max = x.abs().max(dim=-1, keepdim=True).values
x_scale = x_max / iinfo.max
x_scale = torch.clamp(x_scale, min=finfo.eps)

x = (x / x_scale).clamp(min=iinfo.min, max=iinfo.max)
else:
if self.input_scale:
x = (x / self.input_scale).clamp(min=iinfo.min, max=iinfo.max)

weight = self.weight * self.weight_scale
qact = F.linear(x, weight, self.bias)

if self.dynamic: # Dequantize
qact = qact * x_scale

return qact


class QFloatLinear(QLinear):
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.input_scale:
finfo = torch.finfo(self.dtype())
x = (x / self.input_scale).clamp(min=finfo.min, max=finfo.max)

weight = self.weight.to(self.weight_scale.dtype) * self.weight_scale

return F.linear(x, weight, self.bias)
155 changes: 70 additions & 85 deletions src/optimum/rbln/transformers/utils/rbln_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@
from huggingface_hub import hf_hub_download, list_repo_files
from safetensors.torch import load_file
from torch.nn import Linear, Parameter
from torch.nn import functional as F
from transformers import AutoConfig
from transformers.modeling_utils import get_state_dict_dtype, no_init_weights

from ...configuration_utils import RBLNSerializableConfigProtocol
from ...utils.logging import get_logger
from .qlinear import QFloatLinear, QIntLinear


if TYPE_CHECKING:
Expand Down Expand Up @@ -68,6 +68,7 @@ def __init__(
weights: Optional[str] = None,
activations: Optional[str] = None,
kv_caches: Optional[str] = None,
dynamic: Optional[bool] = None,
*,
precision: Optional[str] = None,
):
Expand All @@ -89,6 +90,8 @@ def __init__(
self.weights = weights or "fp16"
self.activations = activations or "fp16"
self.kv_caches = kv_caches or "fp16"
self.dynamic = dynamic if dynamic is not None else False

self._validate()

def _validate(self):
Expand Down Expand Up @@ -128,19 +131,19 @@ class QuantizedLayerFactory:
def __init__(self, quantization_config: RBLNQuantizationConfig):
self.quantization_config = quantization_config

def create_linear(self, layer: Linear) -> Linear:
def create_linear(self, layer: Linear, scale_dtype: torch.dtype) -> Linear:
if self.quantization_config.weights in ["int4", "int8"]:
return self.create_qlinear(layer)
return self.convert_to_qint_linear(layer, scale_dtype)
elif self.quantization_config.weights == "fp8":
return self.create_fp8linear(layer)
return self.convert_to_qfloat_linear(layer, scale_dtype)
else:
raise ValueError(f"Invalid quantization weights: {self.quantization_config.weights}")

def create_qlinear(self, layer: Linear) -> Linear:
return create_qlinear(layer, self.quantization_config)
def convert_to_qint_linear(self, layer: Linear, scale_dtype: torch.dtype) -> Linear:
return convert_to_qint_linear(layer, self.quantization_config, scale_dtype)

def create_fp8linear(self, layer: Linear) -> Linear:
return create_fp8linear(layer, self.quantization_config)
def convert_to_qfloat_linear(self, layer: Linear, scale_dtype: torch.dtype) -> Linear:
return convert_to_qfloat_linear(layer, self.quantization_config, scale_dtype)


def get_quantized_model(
Expand Down Expand Up @@ -183,6 +186,14 @@ def get_quantized_model(
# get the dtype of the model from the first safetensor file
torch_dtype = get_state_dict_dtype(safetensors[0])

# remove n_layer_keys from kwargs if they are None.
# otherwise AutoConfig.from_pretrained will raise an error.
n_layer_keys = ["num_hidden_layers", "n_layers"]
for n_layer_key in n_layer_keys:
if n_layer_key in kwargs:
if kwargs[n_layer_key] is None:
kwargs.pop(n_layer_key)

config = AutoConfig.from_pretrained(
model_id,
use_auth_token=use_auth_token,
Expand All @@ -197,7 +208,7 @@ def get_quantized_model(
model = hf_auto_model_class.from_config(config, torch_dtype=torch_dtype)

# Quantize the model
update_layers_to_quantize(model, rbln_quantization)
update_layers_to_quantize(model, model.dtype, rbln_quantization)

# Load weights into the model
load_weights_from_files(model, safetensors, rbln_quantization)
Expand Down Expand Up @@ -251,6 +262,7 @@ def load_weight_files(

def update_layers_to_quantize(
module: torch.nn.Module,
scale_dtype: torch.dtype,
rbln_quantization: Optional[RBLNQuantizationConfig] = None,
) -> None:
"""
Expand All @@ -263,7 +275,7 @@ def update_layers_to_quantize(
for name, layer in module.named_modules():
if is_target_for_qlinear_replacement(name, layer):
parent_module, layer_name = get_parent_and_child(module, name)
setattr(parent_module, layer_name, quantized_layer_factory.create_linear(layer))
setattr(parent_module, layer_name, quantized_layer_factory.create_linear(layer, scale_dtype))
processed_layers.append(name)

if processed_layers:
Expand Down Expand Up @@ -291,6 +303,11 @@ def _reduce_to_scalar(t: torch.Tensor) -> torch.Tensor:
return t.reshape(-1).amax()


def _scalar_value_as_1d(scale: torch.Tensor) -> torch.Tensor:
v = _reduce_to_scalar(scale)
return v.reshape(1).contiguous()


def _coerce_per_out_channel_scale(scale: torch.Tensor, out_features: int) -> torch.Tensor:
s = scale
if s.ndim == 0:
Expand Down Expand Up @@ -362,34 +379,31 @@ def canonicalize_checkpoint_items(
if len(wshape) == 2:
out_features = int(wshape[0])

if rbln_quantization.weights in ["int4", "int8"] and out_features is not None:
t = _coerce_per_out_channel_scale(t.to(torch.float32), out_features)
elif rbln_quantization.weights == "fp8":
# Use a conservative scalar scale to ensure broadcastability
t = _reduce_to_scalar(t.to(torch.float32))
if out_features is not None:
t = _coerce_per_out_channel_scale(t, out_features)
else:
t = t.to(torch.float32)
t = _scalar_value_as_1d(t)

results.append((target_key, t))
continue

# Normalize input/activation scale variants
if _matches_any_alias(key, "input_scale"):
target_key = _replace_last_with(key, "input_scale")
t = _reduce_to_scalar(t.to(torch.float32))
t = _scalar_value_as_1d(t)
results.append((target_key, t))
continue

# KV scale handling
if _matches_any_alias(key, "kv_scale"):
# For quark-like formats, expand to k/v
kv_items = _kv_split_items(key, t.to(torch.float32))
kv_items = _kv_split_items(key, t)
for k2, v2 in kv_items:
results.append((k2, v2))
continue

if _matches_any_alias(key, "k_scale") or _matches_any_alias(key, "v_scale"):
results.append((key, t.to(torch.float32)))
results.append((key, t))
continue

# Default: passthrough
Expand Down Expand Up @@ -497,84 +511,55 @@ def access_attribute(obj: Any, attributes: list[str]) -> Any:
return obj


def create_qlinear(layer: Linear, rbln_quantization: RBLNQuantizationConfig) -> Linear:
def convert_to_qint_linear(
layer: Linear, rbln_quantization: RBLNQuantizationConfig, scale_dtype: torch.dtype
) -> Linear:
"""
Converts a standard linear layer to a quantized linear (qlinear) layer with a custom forward pass.
"""

def qlinear_forward(self, inputs: torch.Tensor) -> torch.Tensor:
weight_scale = self.weight_scale
if inputs.dtype != weight_scale.dtype:
raise TypeError(f"Expected input dtype {weight_scale.dtype}, but got {inputs.dtype}")

w_fp = self.weight.type(inputs.dtype)
w_fp *= weight_scale.view(-1, 1)
return F.linear(inputs, w_fp, self.bias)

# Convert weight to int8 and add scale parameter
layer.weight = Parameter(layer.weight.to(torch.int8), requires_grad=False)
layer.weight_scale = Parameter(torch.ones(layer.out_features, 1, dtype=torch.float32), requires_grad=False)
layer.forward = lambda inputs: qlinear_forward(layer, inputs)

return layer
weight_scale = Parameter(torch.ones(layer.out_features, 1, dtype=scale_dtype), requires_grad=False)
input_scale = None

if rbln_quantization.activations == "int8" and not rbln_quantization.dynamic:
# Keep non-scalar shape for consistency with fp path
input_scale = Parameter(torch.ones(1, dtype=scale_dtype), requires_grad=False)

return QIntLinear(
weight=layer.weight,
bias=layer.bias,
weight_scale=weight_scale,
input_scale=input_scale,
dynamic=rbln_quantization.dynamic,
)


def create_fp8linear(layer: Linear, rbln_quantization: RBLNQuantizationConfig) -> Linear:
def convert_to_qfloat_linear(
layer: Linear, rbln_quantization: RBLNQuantizationConfig, scale_dtype: torch.dtype
) -> Linear:
"""
Converts a standard linear layer to a fp8 linear layer with a custom forward pass.
"""

def static_per_tensor_quantize(tensor: torch.Tensor, inv_scale: float) -> torch.Tensor:
finfo = torch.finfo(torch.float8_e4m3fn)
qweight = (tensor / inv_scale).clamp(min=finfo.min, max=finfo.max)
return qweight

def fp8_gemm(A: torch.Tensor, A_scale, B: torch.Tensor, B_scale, bias, out_dtype: torch.dtype):
A = A.type(out_dtype)
B = B.type(out_dtype)

if A_scale is not None:
A *= A_scale
if B_scale is not None:
B *= B_scale.to(out_dtype)

output = torch.nn.functional.linear(A, B, bias=bias)
return output

def fp8linear_forward(self, x: torch.Tensor) -> torch.Tensor:
if self.input_scale:
input = static_per_tensor_quantize(x, self.input_scale)
else:
input = x

if self.weight_scale:
# broadcast weight_scale to vector
weight_scale = self.weight_scale.broadcast_to(self.weight.shape[-1:])
else:
weight_scale = None
output = fp8_gemm(
A=input,
A_scale=self.input_scale,
B=self.weight,
B_scale=weight_scale,
bias=self.bias,
out_dtype=x.dtype,
)

return output

# assign here to free weight from the original layer
layer.weight = Parameter(layer.weight.to(torch.float8_e4m3fn), requires_grad=False)
layer.weight_scale = Parameter(torch.tensor(1, dtype=torch.float32), requires_grad=False)
weight_scale = Parameter(torch.ones(layer.out_features, 1, dtype=scale_dtype), requires_grad=False)
input_scale = None

if rbln_quantization.activations == "fp8":
layer.input_scale = Parameter(torch.tensor(1, dtype=torch.float32), requires_grad=False)
else:
layer.input_scale = None
# Keep a non-scalar shape for input scale as well ([1]) for consistency
input_scale = Parameter(torch.ones(1, dtype=scale_dtype), requires_grad=False)

k_scale, v_scale = None, None
if rbln_quantization.kv_caches == "fp8":
layer.k_scale = Parameter(torch.tensor(1, dtype=torch.float32), requires_grad=False)
layer.v_scale = Parameter(torch.tensor(1, dtype=torch.float32), requires_grad=False)

layer.forward = lambda inputs: fp8linear_forward(layer, inputs)

return layer
k_scale = Parameter(torch.tensor(1, dtype=scale_dtype), requires_grad=False)
v_scale = Parameter(torch.tensor(1, dtype=scale_dtype), requires_grad=False)

return QFloatLinear(
weight=layer.weight,
bias=layer.bias,
weight_scale=weight_scale,
input_scale=input_scale,
k_scale=k_scale,
v_scale=v_scale,
)
Loading