From d5a16380dc904b00e209a7576e2819f8f36a85ce Mon Sep 17 00:00:00 2001 From: rebel-jongho Date: Wed, 13 Aug 2025 17:04:04 +0900 Subject: [PATCH 1/7] initial run --- .../rbln/transformers/utils/qlinear.py | 55 ++++++++ .../transformers/utils/rbln_quantization.py | 120 ++++++------------ 2 files changed, 97 insertions(+), 78 deletions(-) create mode 100644 src/optimum/rbln/transformers/utils/qlinear.py diff --git a/src/optimum/rbln/transformers/utils/qlinear.py b/src/optimum/rbln/transformers/utils/qlinear.py new file mode 100644 index 000000000..f46f756f1 --- /dev/null +++ b/src/optimum/rbln/transformers/utils/qlinear.py @@ -0,0 +1,55 @@ +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class QLinear(nn.Module): + def __init__( + self, + weight: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + weight_scale: Optional[torch.Tensor] = None, + input_scale: Optional[torch.Tensor] = None, + # FIXME(jongho): Make it only holds k_scale or v_scale + k_scale: Optional[torch.Tensor] = None, + v_scale: Optional[torch.Tensor] = None, + ): + super().__init__() + + self.weight = weight + self.bias = bias + self.weight_scale = weight_scale + self.input_scale = input_scale + self.k_scale = k_scale + self.v_scale = v_scale + + if weight_scale is None: + raise ValueError("weight_scale is required") + + def dtype(self) -> torch.dtype: + return self.weight.dtype + + def forward(self, x: torch.Tensor) -> torch.Tensor: + raise NotImplementedError + + +class QIntLinear(QLinear): + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.input_scale: + raise NotImplementedError("Input scale is currently not supported for int quantization") + + weight = self.weight * self.weight_scale + return F.linear(x, weight, self.bias) + + +class QFloatLinear(QLinear): + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.input_scale: + finfo = torch.finfo(self.dtype()) + x = (x / self.input_scale).clamp(min=finfo.min, max=finfo.max) + + weight = self.weight.to(self.weight_scale.dtype) * self.weight_scale + + return F.linear(x, weight, self.bias) diff --git a/src/optimum/rbln/transformers/utils/rbln_quantization.py b/src/optimum/rbln/transformers/utils/rbln_quantization.py index 600ad3e58..806caf703 100644 --- a/src/optimum/rbln/transformers/utils/rbln_quantization.py +++ b/src/optimum/rbln/transformers/utils/rbln_quantization.py @@ -20,10 +20,10 @@ from huggingface_hub import hf_hub_download, list_repo_files from safetensors.torch import load_file from torch.nn import Linear, Parameter -from torch.nn import functional as F from ...configuration_utils import RBLNSerializableConfigProtocol from ...utils.logging import get_logger +from .qlinear import QFloatLinear, QIntLinear logger = get_logger() @@ -125,17 +125,17 @@ def __init__(self, quantization_config: RBLNQuantizationConfig): def create_linear(self, layer: Linear) -> Linear: if self.quantization_config.weights in ["int4", "int8"]: - return self.create_qlinear(layer) + return self.convert_to_qint_linear(layer) elif self.quantization_config.weights == "fp8": - return self.create_fp8linear(layer) + return self.convert_to_qfloat_linear(layer) else: raise ValueError(f"Invalid quantization weights: {self.quantization_config.weights}") - def create_qlinear(self, layer: Linear) -> Linear: - return create_qlinear(layer, self.quantization_config) + def convert_to_qint_linear(self, layer: Linear) -> Linear: + return convert_to_qint_linear(layer, self.quantization_config) - def create_fp8linear(self, layer: Linear) -> Linear: - return create_fp8linear(layer, self.quantization_config) + def convert_to_qfloat_linear(self, layer: Linear) -> Linear: + return convert_to_qfloat_linear(layer, self.quantization_config) def prepare_model_for_quantization( @@ -263,6 +263,11 @@ def _reduce_to_scalar(t: torch.Tensor) -> torch.Tensor: return t.reshape(-1).amax() +def _scalar_value_as_1d(scale: torch.Tensor) -> torch.Tensor: + v = _reduce_to_scalar(scale) + return v.reshape(1).contiguous() + + def _coerce_per_out_channel_scale(scale: torch.Tensor, out_features: int) -> torch.Tensor: s = scale if s.ndim == 0: @@ -334,13 +339,10 @@ def canonicalize_checkpoint_items( if len(wshape) == 2: out_features = int(wshape[0]) - if rbln_quantization.weights in ["int4", "int8"] and out_features is not None: + if out_features is not None: t = _coerce_per_out_channel_scale(t.to(torch.float32), out_features) - elif rbln_quantization.weights == "fp8": - # Use a conservative scalar scale to ensure broadcastability - t = _reduce_to_scalar(t.to(torch.float32)) else: - t = t.to(torch.float32) + t = _scalar_value_as_1d(t.to(torch.float32)) results.append((target_key, t)) continue @@ -348,7 +350,7 @@ def canonicalize_checkpoint_items( # Normalize input/activation scale variants if _matches_any_alias(key, "input_scale"): target_key = _replace_last_with(key, "input_scale") - t = _reduce_to_scalar(t.to(torch.float32)) + t = _scalar_value_as_1d(t.to(torch.float32)) results.append((target_key, t)) continue @@ -481,84 +483,46 @@ def access_attribute(obj: Any, attributes: list[str]) -> Any: return obj -def create_qlinear(layer: Linear, rbln_quantization: RBLNQuantizationConfig) -> Linear: +def convert_to_qint_linear(layer: Linear, rbln_quantization: RBLNQuantizationConfig) -> Linear: """ Converts a standard linear layer to a quantized linear (qlinear) layer with a custom forward pass. """ - def qlinear_forward(self, inputs: torch.Tensor) -> torch.Tensor: - weight_scale = self.weight_scale - if inputs.dtype != weight_scale.dtype: - raise TypeError(f"Expected input dtype {weight_scale.dtype}, but got {inputs.dtype}") - - w_fp = self.weight.type(inputs.dtype) - w_fp *= weight_scale.view(-1, 1) - return F.linear(inputs, w_fp, self.bias) - - # Convert weight to int8 and add scale parameter + # assign here to free weight from the original layer layer.weight = Parameter(layer.weight.to(torch.int8), requires_grad=False) - layer.weight_scale = Parameter(torch.ones(layer.out_features, 1, dtype=torch.float32), requires_grad=False) - layer.forward = lambda inputs: qlinear_forward(layer, inputs) + weight_scale = Parameter(torch.ones(layer.out_features, 1, dtype=torch.float32), requires_grad=False) + input_scale = None + + if rbln_quantization.activations == "int8": + # Keep non-scalar shape for consistency with fp path + input_scale = Parameter(torch.ones(1, dtype=torch.float32), requires_grad=False) - return layer + return QIntLinear(weight=layer.weight, bias=layer.bias, weight_scale=weight_scale, input_scale=input_scale) -def create_fp8linear(layer: Linear, rbln_quantization: RBLNQuantizationConfig) -> Linear: +def convert_to_qfloat_linear(layer: Linear, rbln_quantization: RBLNQuantizationConfig) -> Linear: """ Converts a standard linear layer to a fp8 linear layer with a custom forward pass. """ - - def static_per_tensor_quantize(tensor: torch.Tensor, inv_scale: float) -> torch.Tensor: - finfo = torch.finfo(torch.float8_e4m3fn) - qweight = (tensor / inv_scale).clamp(min=finfo.min, max=finfo.max) - return qweight - - def fp8_gemm(A: torch.Tensor, A_scale, B: torch.Tensor, B_scale, bias, out_dtype: torch.dtype): - A = A.type(out_dtype) - B = B.type(out_dtype) - - if A_scale is not None: - A *= A_scale - if B_scale is not None: - B *= B_scale.to(out_dtype) - - output = torch.nn.functional.linear(A, B, bias=bias) - return output - - def fp8linear_forward(self, x: torch.Tensor) -> torch.Tensor: - if self.input_scale: - input = static_per_tensor_quantize(x, self.input_scale) - else: - input = x - - if self.weight_scale: - # broadcast weight_scale to vector - weight_scale = self.weight_scale.broadcast_to(self.weight.shape[-1:]) - else: - weight_scale = None - output = fp8_gemm( - A=input, - A_scale=self.input_scale, - B=self.weight, - B_scale=weight_scale, - bias=self.bias, - out_dtype=x.dtype, - ) - - return output - + # assign here to free weight from the original layer layer.weight = Parameter(layer.weight.to(torch.float8_e4m3fn), requires_grad=False) - layer.weight_scale = Parameter(torch.tensor(1, dtype=torch.float32), requires_grad=False) + weight_scale = Parameter(torch.ones(layer.out_features, 1, dtype=torch.float32), requires_grad=False) + input_scale = None if rbln_quantization.activations == "fp8": - layer.input_scale = Parameter(torch.tensor(1, dtype=torch.float32), requires_grad=False) - else: - layer.input_scale = None + # Keep a non-scalar shape for input scale as well ([1]) for consistency + input_scale = Parameter(torch.ones(1, dtype=torch.float32), requires_grad=False) + k_scale, v_scale = None, None if rbln_quantization.kv_caches == "fp8": - layer.k_scale = Parameter(torch.tensor(1, dtype=torch.float32), requires_grad=False) - layer.v_scale = Parameter(torch.tensor(1, dtype=torch.float32), requires_grad=False) - - layer.forward = lambda inputs: fp8linear_forward(layer, inputs) - - return layer + k_scale = Parameter(torch.tensor(1, dtype=torch.float32), requires_grad=False) + v_scale = Parameter(torch.tensor(1, dtype=torch.float32), requires_grad=False) + + return QFloatLinear( + weight=layer.weight, + bias=layer.bias, + weight_scale=weight_scale, + input_scale=input_scale, + k_scale=k_scale, + v_scale=v_scale, + ) From c90710134b6be514eba205380541400c9628e9e1 Mon Sep 17 00:00:00 2001 From: rebel-jongho Date: Wed, 13 Aug 2025 17:38:17 +0900 Subject: [PATCH 2/7] a8 --- src/optimum/rbln/transformers/utils/qlinear.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/optimum/rbln/transformers/utils/qlinear.py b/src/optimum/rbln/transformers/utils/qlinear.py index f46f756f1..b961b04e0 100644 --- a/src/optimum/rbln/transformers/utils/qlinear.py +++ b/src/optimum/rbln/transformers/utils/qlinear.py @@ -38,7 +38,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class QIntLinear(QLinear): def forward(self, x: torch.Tensor) -> torch.Tensor: if self.input_scale: - raise NotImplementedError("Input scale is currently not supported for int quantization") + x = (x / self.input_scale).clamp(min=-128, max=127) weight = self.weight * self.weight_scale return F.linear(x, weight, self.bias) From afe498666f12eabb70b3e79dc30e7707515df23e Mon Sep 17 00:00:00 2001 From: rebel-jongho Date: Mon, 1 Sep 2025 15:14:30 +0900 Subject: [PATCH 3/7] fix dtype --- .../transformers/utils/rbln_quantization.py | 49 ++++++++++--------- 1 file changed, 27 insertions(+), 22 deletions(-) diff --git a/src/optimum/rbln/transformers/utils/rbln_quantization.py b/src/optimum/rbln/transformers/utils/rbln_quantization.py index 34facb4ca..a7cc513ab 100644 --- a/src/optimum/rbln/transformers/utils/rbln_quantization.py +++ b/src/optimum/rbln/transformers/utils/rbln_quantization.py @@ -129,19 +129,19 @@ class QuantizedLayerFactory: def __init__(self, quantization_config: RBLNQuantizationConfig): self.quantization_config = quantization_config - def create_linear(self, layer: Linear) -> Linear: + def create_linear(self, layer: Linear, scale_dtype: torch.dtype) -> Linear: if self.quantization_config.weights in ["int4", "int8"]: - return self.convert_to_qint_linear(layer) + return self.convert_to_qint_linear(layer, scale_dtype) elif self.quantization_config.weights == "fp8": - return self.convert_to_qfloat_linear(layer) + return self.convert_to_qfloat_linear(layer, scale_dtype) else: raise ValueError(f"Invalid quantization weights: {self.quantization_config.weights}") - def convert_to_qint_linear(self, layer: Linear) -> Linear: - return convert_to_qint_linear(layer, self.quantization_config) + def convert_to_qint_linear(self, layer: Linear, scale_dtype: torch.dtype) -> Linear: + return convert_to_qint_linear(layer, self.quantization_config, scale_dtype) - def convert_to_qfloat_linear(self, layer: Linear) -> Linear: - return convert_to_qfloat_linear(layer, self.quantization_config) + def convert_to_qfloat_linear(self, layer: Linear, scale_dtype: torch.dtype) -> Linear: + return convert_to_qfloat_linear(layer, self.quantization_config, scale_dtype) def get_quantized_model( @@ -198,7 +198,7 @@ def get_quantized_model( model = hf_auto_model_class.from_config(config, torch_dtype=torch_dtype) # Quantize the model - update_layers_to_quantize(model, rbln_quantization) + update_layers_to_quantize(model, model.dtype, rbln_quantization) # Load weights into the model load_weights_from_files(model, safetensors, rbln_quantization) @@ -252,6 +252,7 @@ def load_weight_files( def update_layers_to_quantize( module: torch.nn.Module, + scale_dtype: torch.dtype, rbln_quantization: Optional[RBLNQuantizationConfig] = None, ) -> None: """ @@ -264,7 +265,7 @@ def update_layers_to_quantize( for name, layer in module.named_modules(): if is_target_for_qlinear_replacement(name, layer): parent_module, layer_name = get_parent_and_child(module, name) - setattr(parent_module, layer_name, quantized_layer_factory.create_linear(layer)) + setattr(parent_module, layer_name, quantized_layer_factory.create_linear(layer, scale_dtype)) processed_layers.append(name) if processed_layers: @@ -369,9 +370,9 @@ def canonicalize_checkpoint_items( out_features = int(wshape[0]) if out_features is not None: - t = _coerce_per_out_channel_scale(t.to(torch.float32), out_features) + t = _coerce_per_out_channel_scale(t, out_features) else: - t = _scalar_value_as_1d(t.to(torch.float32)) + t = _scalar_value_as_1d(t) results.append((target_key, t)) continue @@ -379,20 +380,20 @@ def canonicalize_checkpoint_items( # Normalize input/activation scale variants if _matches_any_alias(key, "input_scale"): target_key = _replace_last_with(key, "input_scale") - t = _scalar_value_as_1d(t.to(torch.float32)) + t = _scalar_value_as_1d(t) results.append((target_key, t)) continue # KV scale handling if _matches_any_alias(key, "kv_scale"): # For quark-like formats, expand to k/v - kv_items = _kv_split_items(key, t.to(torch.float32)) + kv_items = _kv_split_items(key, t) for k2, v2 in kv_items: results.append((k2, v2)) continue if _matches_any_alias(key, "k_scale") or _matches_any_alias(key, "v_scale"): - results.append((key, t.to(torch.float32))) + results.append((key, t)) continue # Default: passthrough @@ -500,40 +501,44 @@ def access_attribute(obj: Any, attributes: list[str]) -> Any: return obj -def convert_to_qint_linear(layer: Linear, rbln_quantization: RBLNQuantizationConfig) -> Linear: +def convert_to_qint_linear( + layer: Linear, rbln_quantization: RBLNQuantizationConfig, scale_dtype: torch.dtype +) -> Linear: """ Converts a standard linear layer to a quantized linear (qlinear) layer with a custom forward pass. """ # assign here to free weight from the original layer layer.weight = Parameter(layer.weight.to(torch.int8), requires_grad=False) - weight_scale = Parameter(torch.ones(layer.out_features, 1, dtype=torch.float32), requires_grad=False) + weight_scale = Parameter(torch.ones(layer.out_features, 1, dtype=scale_dtype), requires_grad=False) input_scale = None if rbln_quantization.activations == "int8": # Keep non-scalar shape for consistency with fp path - input_scale = Parameter(torch.ones(1, dtype=torch.float32), requires_grad=False) + input_scale = Parameter(torch.ones(1, dtype=scale_dtype), requires_grad=False) return QIntLinear(weight=layer.weight, bias=layer.bias, weight_scale=weight_scale, input_scale=input_scale) -def convert_to_qfloat_linear(layer: Linear, rbln_quantization: RBLNQuantizationConfig) -> Linear: +def convert_to_qfloat_linear( + layer: Linear, rbln_quantization: RBLNQuantizationConfig, scale_dtype: torch.dtype +) -> Linear: """ Converts a standard linear layer to a fp8 linear layer with a custom forward pass. """ # assign here to free weight from the original layer layer.weight = Parameter(layer.weight.to(torch.float8_e4m3fn), requires_grad=False) - weight_scale = Parameter(torch.ones(layer.out_features, 1, dtype=torch.float32), requires_grad=False) + weight_scale = Parameter(torch.ones(layer.out_features, 1, dtype=scale_dtype), requires_grad=False) input_scale = None if rbln_quantization.activations == "fp8": # Keep a non-scalar shape for input scale as well ([1]) for consistency - input_scale = Parameter(torch.ones(1, dtype=torch.float32), requires_grad=False) + input_scale = Parameter(torch.ones(1, dtype=scale_dtype), requires_grad=False) k_scale, v_scale = None, None if rbln_quantization.kv_caches == "fp8": - k_scale = Parameter(torch.tensor(1, dtype=torch.float32), requires_grad=False) - v_scale = Parameter(torch.tensor(1, dtype=torch.float32), requires_grad=False) + k_scale = Parameter(torch.tensor(1, dtype=scale_dtype), requires_grad=False) + v_scale = Parameter(torch.tensor(1, dtype=scale_dtype), requires_grad=False) return QFloatLinear( weight=layer.weight, From 084aa938d6bf2d9e88f4d2a5465aa5c2999068c9 Mon Sep 17 00:00:00 2001 From: rebel-jongho Date: Mon, 1 Sep 2025 15:44:36 +0900 Subject: [PATCH 4/7] remove n_layer if none --- src/optimum/rbln/transformers/utils/rbln_quantization.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/optimum/rbln/transformers/utils/rbln_quantization.py b/src/optimum/rbln/transformers/utils/rbln_quantization.py index a7cc513ab..84c4dc5f3 100644 --- a/src/optimum/rbln/transformers/utils/rbln_quantization.py +++ b/src/optimum/rbln/transformers/utils/rbln_quantization.py @@ -184,6 +184,14 @@ def get_quantized_model( # get the dtype of the model from the first safetensor file torch_dtype = get_state_dict_dtype(safetensors[0]) + # remove n_layer_keys from kwargs if they are None. + # otherwise AutoConfig.from_pretrained will raise an error. + n_layer_keys = ["num_hidden_layers", "n_layers"] + for n_layer_key in n_layer_keys: + if n_layer_key in kwargs: + if kwargs[n_layer_key] is None: + kwargs.pop(n_layer_key) + config = AutoConfig.from_pretrained( model_id, use_auth_token=use_auth_token, From 1784ef63f02636794bcc51f570c1566c30b66ac8 Mon Sep 17 00:00:00 2001 From: rebel-jongho Date: Mon, 1 Sep 2025 17:02:52 +0900 Subject: [PATCH 5/7] ruff --- src/optimum/rbln/transformers/utils/rbln_quantization.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/optimum/rbln/transformers/utils/rbln_quantization.py b/src/optimum/rbln/transformers/utils/rbln_quantization.py index 84c4dc5f3..c0f6ccadf 100644 --- a/src/optimum/rbln/transformers/utils/rbln_quantization.py +++ b/src/optimum/rbln/transformers/utils/rbln_quantization.py @@ -20,7 +20,6 @@ from huggingface_hub import hf_hub_download, list_repo_files from safetensors.torch import load_file from torch.nn import Linear, Parameter -from torch.nn import functional as F from transformers import AutoConfig from transformers.modeling_utils import get_state_dict_dtype, no_init_weights From 582785ef99257d2aa672258ac85869eb62b76bcd Mon Sep 17 00:00:00 2001 From: rebel-jongho Date: Wed, 29 Oct 2025 16:37:28 +0900 Subject: [PATCH 6/7] add dynamic --- src/optimum/rbln/__init__.py | 2 ++ src/optimum/rbln/transformers/__init__.py | 2 ++ .../rbln/transformers/utils/__init__.py | 1 + .../rbln/transformers/utils/qlinear.py | 21 ++++++++++++++++--- .../transformers/utils/rbln_quantization.py | 14 ++++++++++--- 5 files changed, 34 insertions(+), 6 deletions(-) diff --git a/src/optimum/rbln/__init__.py b/src/optimum/rbln/__init__.py index 4922f0a9e..c3ee35784 100644 --- a/src/optimum/rbln/__init__.py +++ b/src/optimum/rbln/__init__.py @@ -140,6 +140,7 @@ "RBLNPixtralVisionModelConfig", "RBLNPhiModel", "RBLNPhiModelConfig", + "RBLNQuantizationConfig", "RBLNQwen2ForCausalLM", "RBLNQwen2ForCausalLMConfig", "RBLNQwen2_5_VisionTransformerPretrainedModel", @@ -434,6 +435,7 @@ RBLNPhiModelConfig, RBLNPixtralVisionModel, RBLNPixtralVisionModelConfig, + RBLNQuantizationConfig, RBLNQwen2_5_VisionTransformerPretrainedModel, RBLNQwen2_5_VisionTransformerPretrainedModelConfig, RBLNQwen2_5_VLForConditionalGeneration, diff --git a/src/optimum/rbln/transformers/__init__.py b/src/optimum/rbln/transformers/__init__.py index 9fca4e122..4a7e86ef6 100644 --- a/src/optimum/rbln/transformers/__init__.py +++ b/src/optimum/rbln/transformers/__init__.py @@ -173,6 +173,7 @@ "RBLNXLMRobertaModel", "RBLNXLMRobertaModelConfig", ], + "utils": ["RBLNQuantizationConfig"], } if TYPE_CHECKING: @@ -329,6 +330,7 @@ RBLNXLMRobertaModel, RBLNXLMRobertaModelConfig, ) + from .utils import RBLNQuantizationConfig else: import sys diff --git a/src/optimum/rbln/transformers/utils/__init__.py b/src/optimum/rbln/transformers/utils/__init__.py index e69de29bb..56eaf86a2 100644 --- a/src/optimum/rbln/transformers/utils/__init__.py +++ b/src/optimum/rbln/transformers/utils/__init__.py @@ -0,0 +1 @@ +from .rbln_quantization import RBLNQuantizationConfig diff --git a/src/optimum/rbln/transformers/utils/qlinear.py b/src/optimum/rbln/transformers/utils/qlinear.py index b961b04e0..80bd5a084 100644 --- a/src/optimum/rbln/transformers/utils/qlinear.py +++ b/src/optimum/rbln/transformers/utils/qlinear.py @@ -15,6 +15,7 @@ def __init__( # FIXME(jongho): Make it only holds k_scale or v_scale k_scale: Optional[torch.Tensor] = None, v_scale: Optional[torch.Tensor] = None, + dynamic: bool = False, ): super().__init__() @@ -24,6 +25,7 @@ def __init__( self.input_scale = input_scale self.k_scale = k_scale self.v_scale = v_scale + self.dynamic = dynamic if weight_scale is None: raise ValueError("weight_scale is required") @@ -37,11 +39,24 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class QIntLinear(QLinear): def forward(self, x: torch.Tensor) -> torch.Tensor: - if self.input_scale: - x = (x / self.input_scale).clamp(min=-128, max=127) + iinfo = torch.iinfo(self.dtype()) + if self.dynamic: + if self.input_scale: + raise NotImplementedError("Dynamic quantization with input_scale is not supported.") + x_max = x.abs().max(dim=-1, keepdim=True).values + x_scale = x_max / iinfo.max + x = (x / x_scale).clamp(min=iinfo.min, max=iinfo.max) + else: + if self.input_scale: + x = (x / self.input_scale).clamp(min=iinfo.min, max=iinfo.max) weight = self.weight * self.weight_scale - return F.linear(x, weight, self.bias) + qact = F.linear(x, weight, self.bias) + + if self.dynamic: # Dequantize + qact = qact * x_scale + + return qact class QFloatLinear(QLinear): diff --git a/src/optimum/rbln/transformers/utils/rbln_quantization.py b/src/optimum/rbln/transformers/utils/rbln_quantization.py index c0f6ccadf..ccf777cc7 100644 --- a/src/optimum/rbln/transformers/utils/rbln_quantization.py +++ b/src/optimum/rbln/transformers/utils/rbln_quantization.py @@ -68,6 +68,7 @@ def __init__( weights: Optional[str] = None, activations: Optional[str] = None, kv_caches: Optional[str] = None, + dynamic: Optional[bool] = None, *, precision: Optional[str] = None, ): @@ -89,6 +90,8 @@ def __init__( self.weights = weights or "fp16" self.activations = activations or "fp16" self.kv_caches = kv_caches or "fp16" + self.dynamic = dynamic if dynamic is not None else False + self._validate() def _validate(self): @@ -515,16 +518,21 @@ def convert_to_qint_linear( Converts a standard linear layer to a quantized linear (qlinear) layer with a custom forward pass. """ - # assign here to free weight from the original layer layer.weight = Parameter(layer.weight.to(torch.int8), requires_grad=False) weight_scale = Parameter(torch.ones(layer.out_features, 1, dtype=scale_dtype), requires_grad=False) input_scale = None - if rbln_quantization.activations == "int8": + if rbln_quantization.activations == "int8" and not rbln_quantization.dynamic: # Keep non-scalar shape for consistency with fp path input_scale = Parameter(torch.ones(1, dtype=scale_dtype), requires_grad=False) - return QIntLinear(weight=layer.weight, bias=layer.bias, weight_scale=weight_scale, input_scale=input_scale) + return QIntLinear( + weight=layer.weight, + bias=layer.bias, + weight_scale=weight_scale, + input_scale=input_scale, + dynamic=rbln_quantization.dynamic, + ) def convert_to_qfloat_linear( From 0fc9f443026c818b1235fb38bdb5493418c25f46 Mon Sep 17 00:00:00 2001 From: rebel-jongho Date: Thu, 30 Oct 2025 20:29:14 +0900 Subject: [PATCH 7/7] add clamp for numerical stability --- src/optimum/rbln/transformers/utils/qlinear.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/optimum/rbln/transformers/utils/qlinear.py b/src/optimum/rbln/transformers/utils/qlinear.py index 80bd5a084..34948aa61 100644 --- a/src/optimum/rbln/transformers/utils/qlinear.py +++ b/src/optimum/rbln/transformers/utils/qlinear.py @@ -40,11 +40,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class QIntLinear(QLinear): def forward(self, x: torch.Tensor) -> torch.Tensor: iinfo = torch.iinfo(self.dtype()) + finfo = torch.finfo(x.dtype) if self.dynamic: if self.input_scale: raise NotImplementedError("Dynamic quantization with input_scale is not supported.") x_max = x.abs().max(dim=-1, keepdim=True).values x_scale = x_max / iinfo.max + x_scale = torch.clamp(x_scale, min=finfo.eps) + x = (x / x_scale).clamp(min=iinfo.min, max=iinfo.max) else: if self.input_scale: