Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/brevitas/core/scaling/standalone.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,6 @@ def training_forward(self, stats_input: Tensor, threshold: Tensor) -> Tensor:
return clamped_stats / threshold
elif self.counter == self.collect_stats_steps:
self.init_scale()
value = self.clamp_scaling(self.restrict_scaling(self.value))
threshold = self.restrict_threshold(self.restrict_threshold_pre(threshold))
value = self.restrict_scale_threshold(value / threshold)
return value
Expand Down
34 changes: 34 additions & 0 deletions src/brevitas_examples/common/generative/quant_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,17 @@

from typing import Callable

from dependencies import this
from dependencies import value
from torch import Tensor
import torch.nn as nn

from brevitas.core.restrict_val import _RestrictClampValue
from brevitas.core.stats.stats_wrapper import SCALAR_SHAPE
from brevitas.core.zero_point import _ScaleShiftZeroPoint
from brevitas.function.ops_ste import abs_binary_sign_grad
from brevitas.inject import ExtendedInjector
from brevitas.inject.enum import ScalingPerOutputType


# TODO: restore JIT compatibility
Expand Down Expand Up @@ -74,3 +79,32 @@ def forward(self, x, scale, bit_width) -> Tensor:
x = abs_binary_sign_grad(x)
x = self.scale_shift_zero_point(x, scale, bit_width)
return x


class QuantScaleScaleShapeMixin(ExtendedInjector):

@value
def scaling_shape(
scaling_per_output,
scaling_per_output_channel_shape,
expanded_groupwise_shape,
group_dim,
upstream_scaling):
if scaling_per_output == ScalingPerOutputType.TENSOR:
scaling = SCALAR_SHAPE
elif scaling_per_output == ScalingPerOutputType.CHANNEL:
scaling = scaling_per_output_channel_shape
elif scaling_per_output == ScalingPerOutputType.GROUP:
# Scaling shape is like expanded_groupwise_shape but has 1 in position group_dim + 1
assert expanded_groupwise_shape is not None, "Per Group scaling not correctly configured"
assert group_dim is not None, "Per Group scaling not correctly configured"
size = list(expanded_groupwise_shape)
size[group_dim + 1] = 1
scaling = tuple(size)

# When quantizing scale of groupwise, there will be one extra dim compared to the normal case
if upstream_scaling == ScalingPerOutputType.GROUP:
scaling = list(scaling)
scaling.insert(-1, 1)
scaling = tuple(scaling)
return scaling
46 changes: 42 additions & 4 deletions src/brevitas_examples/common/generative/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerTensorFloatMSE
from brevitas_examples.common.generative.nn import LoRACompatibleQuantConv2d
from brevitas_examples.common.generative.nn import LoRACompatibleQuantLinear
from brevitas_examples.common.generative.quantizers import DynamicQuantScaleMXFloat8e4m3Act
from brevitas_examples.common.generative.quantizers import Fp8e4m3DynamicActPerGroupFloat
from brevitas_examples.common.generative.quantizers import FP8e4m3FNUZDynamicActPerRowFloat
from brevitas_examples.common.generative.quantizers import Fp8e4m3FNUZDynamicActPerTensorFloat
Expand All @@ -76,6 +77,8 @@
from brevitas_examples.common.generative.quantizers import Int8DynamicActPerRowFloat
from brevitas_examples.common.generative.quantizers import Int8DynamicActPerTensorFloat
from brevitas_examples.common.generative.quantizers import IntWeightSymmetricGroupQuant
from brevitas_examples.common.generative.quantizers import QuantScaleMXFloat8e4m3Weight
from brevitas_examples.common.generative.quantizers import QuantScaleMXFloat8e4m3WeightMSE
from brevitas_examples.common.generative.quantizers import RuntimeDynamicStatsZeroPoint
from brevitas_examples.common.generative.quantizers import ShiftedUint8DynamicActPerGroupFloat
from brevitas_examples.common.generative.quantizers import ShiftedUint8DynamicActPerRowFloat
Expand Down Expand Up @@ -144,6 +147,13 @@
'mse': {
'per_channel': {
'sym': Fp8e4m3OCPWeightPerChannelFloatMSE}}},
'float_quant_scale': {
'stats': {
'per_group': {
'sym': QuantScaleMXFloat8e4m3Weight}},
'mse': {
'per_group': {
'sym': QuantScaleMXFloat8e4m3WeightMSE}}},
'po2_scale': {
'stats': {
'per_group': {
Expand Down Expand Up @@ -223,6 +233,10 @@
'sym': FP8e4m3OCPDynamicActPerRowFloat},
'per_group': {
'sym': Fp8e4m3OCPDynamicActPerGroupFloat}}},
'float_quant_scale': {
'stats': {
'per_group': {
'sym': DynamicQuantScaleMXFloat8e4m3Act}}},
'po2_scale': {
'stats': {
'per_row': {
Expand Down Expand Up @@ -284,19 +298,25 @@ def generate_quantizers(
"""

# Retrive base quantizer, match against custom float format, or return as-is
def quant_format_from_string(quant_format):
quant_format_re = re.compile(r'e[1-8]m[1-8]')
def quant_format_from_string(quant_format, scale=False):
quant_format_re = re.compile(r'e[1-8]m[0-8]')
if quant_format_re.findall(quant_format):
float_type = quant_format_re.findall(quant_format)[0]
quant_format = quant_format.replace('_' + float_type, '')
if scale:
quant_format = "float_quant_scale"
else:
quant_format = quant_format.replace('_' + float_type, '')
float_format = {
'exponent_bit_width': int(float_type[1]), 'mantissa_bit_width': int(float_type[3])}

else:
float_format = {}
return quant_format, float_format

weight_quant_format, weight_float_format = quant_format_from_string(weight_quant_format)
input_quant_format, input_float_format = quant_format_from_string(input_quant_format)
weight_scale_precision, weight_scale_precision_format = quant_format_from_string(weight_scale_precision, scale=True)
input_scale_precision, input_scale_precision_format = quant_format_from_string(input_scale_precision, scale=True)

weight_quant = WEIGHT_QUANT_MAP[weight_quant_format][weight_scale_precision][
weight_param_method][weight_quant_granularity][weight_quant_type]
Expand Down Expand Up @@ -326,7 +346,7 @@ def quant_format_from_string(quant_format):

attn_quant_format, attn_float_format = quant_format_from_string(attn_quant_format) if attn_quant_format is not None else (input_quant_format, input_float_format)
attn_scale_type = attn_scale_type if attn_scale_type is not None else input_scale_type
attn_scale_precision = attn_scale_precision if attn_scale_precision is not None else input_scale_precision
attn_scale_precision, attn_scale_precision_format = quant_format_from_string(attn_scale_precision) if attn_scale_precision is not None else (input_scale_precision, input_scale_precision_format)
attn_param_method = attn_param_method if attn_param_method is not None else input_param_method
attn_quant_granularity = attn_quant_granularity if attn_quant_granularity is not None else input_quant_granularity
attn_quant_type = attn_quant_type if attn_quant_type is not None else input_quant_type
Expand All @@ -345,6 +365,18 @@ def quant_format_from_string(quant_format):

input_quant = input_quant.let(**input_kwargs)
linear_input_quant = linear_input_quant.let(**input_kwargs)
if input_scale_precision == "float_quant_scale":
# Set the format of the input's quantized scale
input_quant = input_quant.let(
scaling_float_quant=input_quant.scaling_float_quant.let(
**input_scale_precision_format))
linear_input_quant = linear_input_quant.let(
scaling_float_quant=linear_input_quant.scaling_float_quant.let(
**input_scale_precision_format))
if attn_scale_precision == "float_quant_scale":
k_transposed_quant = k_transposed_quant.let(
scaling_float_quant=k_transposed_quant.scaling_float_quant.let(
**attn_scale_precision_format))
k_transposed_quant = k_transposed_quant.let(
**input_kwargs
) # later we define v_quant=k_transposed_quant, so don't instantiate it here
Expand Down Expand Up @@ -398,6 +430,12 @@ def quant_format_from_string(quant_format):
if weight_quant_type == 'asym' and weight_scaling_impl_type == 'parameter_from_stats':
weight_quant = weight_quant.let(zero_point_impl=ParameterFromStatsFromParameterZeroPoint)

# Set the format of the weight's quantized scale
if weight_scale_precision == "float_quant_scale":
weight_quant = weight_quant.let(
scaling_float_quant=weight_quant.scaling_float_quant.let(
**weight_scale_precision_format))

if quant_attn_mode == 'sdpa':
kv_permute_dims = (0, 1, 3, 2)
kv_broadcastable_shape_lambda = lambda x, shape: x.view(shape[0], shape[1], 1, shape[-1])
Expand Down
97 changes: 97 additions & 0 deletions src/brevitas_examples/common/generative/quantizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
from brevitas.core.function_wrapper.ops_ste import FloorSte
from brevitas.core.function_wrapper.shape import OverOutputFeaturesView
from brevitas.core.function_wrapper.shape import OverTensorView
from brevitas.core.quant.float import FloatQuant
from brevitas.core.restrict_val import FloatRestrictValue
from brevitas.core.restrict_val import QuantRestrictValue
from brevitas.core.scaling.runtime import RuntimeDynamicGroupStatsScaling
from brevitas.core.stats import AbsMinMax
from brevitas.core.stats import NegativeMinOrZero
Expand Down Expand Up @@ -35,6 +38,9 @@
from brevitas.quant.experimental.float_quant_fnuz import Fp8e4m3FNUZActPerTensorFloat
from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPActPerTensorFloat
from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPWeightPerChannelFloat
from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPWeightPerTensorFloat
from brevitas.quant.experimental.mx_quant_ocp import MXFloat8e4m3Act
from brevitas.quant.experimental.mx_quant_ocp import MXFloat8e4m3Weight
from brevitas.quant.scaled_int import Int8ActPerTensorFloat
from brevitas.quant.scaled_int import Int8WeightPerChannelFloat
from brevitas.quant.scaled_int import Int8WeightPerChannelFloatHQO
Expand Down Expand Up @@ -213,3 +219,94 @@ class FP8e4m3FNUZDynamicActPerRowFloat(Fp8e4m3FNUZActPerTensorFloat):
scaling_stats_op = 'min_max'
scaling_per_output_channel = True
proxy_class = DynamicActFloatQuantProxyFromInjector


class ConstActQuantScalingFloat(QuantScaleScaleShapeMixin, Fp8e4m3OCPActPerTensorFloat):
module = (this << 1).module
upstream_scaling = (this << 1).scaling_per_output_type
scaling_impl_type = "const"
scaling_init = 448.0


class DynamicActQuantScalingFloat(QuantScaleScaleShapeMixin,
DynamicActProxyMixin,
Fp8e4m3OCPActPerTensorFloat):
module = (this << 1).module
upstream_scaling = (this << 1).scaling_per_output_type
scaling_impl = RuntimeDynamicStatsScaling
scaling_stats_input_view_shape_impl = OverTensorView
scaling_stats_op = 'min_max'
dynamic_scaling_broadcastable_fn = lambda x, shape: x.view(SCALAR_SHAPE)


class StaticActQuantScalingFloat(QuantScaleScaleShapeMixin, Fp8e4m3OCPActPerTensorFloat):
module = (this << 1).module
upstream_scaling = (this << 1).scaling_per_output_type
scaling_stats_op = 'min_max'
scaling_shape = SCALAR_SHAPE


class DynamicQuantScaleMXFloat8e4m3Act(MXFloat8e4m3Act):
scaling_float_quant = StaticActQuantScalingFloat
restrict_scaling_impl = QuantRestrictValue
restrict_threshold_impl = FloatRestrictValue

@value
def restrict_value_float_to_int_impl():
return this.scaling_float_quant.tensor_quant

@value
def scale_dequantized_shape(scaling_per_output_type, scaling_shape):
if scaling_per_output_type == ScalingPerOutputType.TENSOR or scaling_per_output_type == ScalingPerOutputType.CHANNEL:
return None
elif scaling_per_output_type == ScalingPerOutputType.GROUP:
return scaling_shape


class ConstQuantWeightScalingFloat(QuantScaleScaleShapeMixin, Fp8e4m3OCPWeightPerTensorFloat):
module = (this << 1).module
tracked_parameter_list = (this << 1).tracked_parameter_list
upstream_scaling = (this << 1).scaling_per_output_type
scaling_impl_type = "const"
scaling_init = 1.0


class QuantWeightScalingFloat(QuantScaleScaleShapeMixin, Fp8e4m3OCPWeightPerTensorFloat):
module = (this << 1).module
tracked_parameter_list = (this << 1).tracked_parameter_list
upstream_scaling = (this << 1).scaling_per_output_type
float_quant = FloatQuant


class QuantScaleMXFloat8e4m3Weight(MXFloat8e4m3Weight):
scaling_float_quant = QuantWeightScalingFloat
restrict_scaling_impl = QuantRestrictValue
restrict_threshold_impl = FloatRestrictValue

@value
def restrict_value_float_to_int_impl():
return this.scaling_float_quant.tensor_quant

@value
def scale_dequantized_shape(scaling_per_output_type, scaling_shape):
if scaling_per_output_type == ScalingPerOutputType.TENSOR or scaling_per_output_type == ScalingPerOutputType.CHANNEL:
return None
elif scaling_per_output_type == ScalingPerOutputType.GROUP:
return scaling_shape


class QuantScaleMXFloat8e4m3WeightMSE(MSESymmetricScale, MXFloat8e4m3Weight):
scaling_float_quant = QuantWeightScalingFloat
restrict_scaling_impl = QuantRestrictValue
restrict_threshold_impl = FloatRestrictValue

@value
def restrict_value_float_to_int_impl():
return this.scaling_float_quant.tensor_quant

@value
def scale_dequantized_shape(scaling_per_output_type, scaling_shape):
if scaling_per_output_type == ScalingPerOutputType.TENSOR or scaling_per_output_type == ScalingPerOutputType.CHANNEL:
return None
elif scaling_per_output_type == ScalingPerOutputType.GROUP:
return scaling_shape
43 changes: 43 additions & 0 deletions src/brevitas_examples/llm/benchmark/test_scale_format.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import torch

import brevitas.nn as qnn
from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPActPerTensorFloat


class Fp8e4m3OCPActPerTensorFloatConst(Fp8e4m3OCPActPerTensorFloat):
scaling_impl_type = "const"
scaling_init = 448.0


class Fp8e5m2OCPActPerTensorFloatConst(Fp8e4m3OCPActPerTensorFloatConst):
exponent_bit_width = 5
mantissa_bit_width = 2


def test_scale_quant(model):
e4m3 = qnn.QuantIdentity(act_quant=Fp8e4m3OCPActPerTensorFloatConst)
e5m2 = qnn.QuantIdentity(act_quant=Fp8e5m2OCPActPerTensorFloatConst)
x = torch.rand((100, 100))
layers_tested = 0
layers_passed = 0
layers_failed = 0
for name, module in model.named_modules():
if isinstance(module, qnn.QuantLinear):
try:
weight_scale = module.quant_weight().scale
e4m3.to(device=weight_scale.device)
e5m2.to(device=weight_scale.device)
x = x.to(device=weight_scale.device)
assert (weight_scale == e4m3(weight_scale)).all()
assert not (weight_scale == e5m2(weight_scale)).all()
module.input_quant.return_quant_tensor = True
act_scale = module.input_quant(x).scale
assert (act_scale == e4m3(act_scale)).all()
assert not (act_scale == e5m2(act_scale)).all()
layers_passed += 1
except:
layers_failed += 1
layers_tested += 1
print(
f"Layers passed: {layers_passed}, Layers failed: {layers_failed}, Layers tested: {layers_tested}"
)
17 changes: 16 additions & 1 deletion src/brevitas_examples/llm/llm_quant/calibrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,23 @@
from brevitas.graph.calibrate import calibration_mode


class groupwise_calibration_mode(calibration_mode):

def __init__(self, model):
super(calibration_mode, self).__init__(
model=model,
call_act_quantizer_impl=True,
disable_act_quant=False,
disable_weight_quant=True,
disable_bias_quant=True,
is_training=True)
self.enabled = True


@torch.no_grad()
def apply_calibration(model, dataloader):
with calibration_mode(model):
model.train()
with groupwise_calibration_mode(model):
for inps in tqdm(dataloader):
model(**inps)
model.eval()
1 change: 1 addition & 0 deletions src/brevitas_examples/llm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from brevitas_examples.common.generative.quantize import generate_quantizers
from brevitas_examples.common.parse_utils import override_defaults
from brevitas_examples.common.parse_utils import parse_args
from brevitas_examples.llm.benchmark.test_scale_format import test_scale_quant
from brevitas_examples.llm.gguf_export.export import save_quantized_as_gguf
from brevitas_examples.llm.llm_args import create_args_parser
from brevitas_examples.llm.llm_args import validate
Expand Down
Loading