Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/brevitas_examples/common/generative/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
from brevitas_examples.common.generative.quantizers import ShiftedUint8DynamicActPerGroupFloat
from brevitas_examples.common.generative.quantizers import ShiftedUint8DynamicActPerRowFloat
from brevitas_examples.common.generative.quantizers import ShiftedUint8DynamicActPerTensorFloat
from brevitas_examples.common.generative.quantizers import Int8WeightPerChannelFloatSparse

WEIGHT_QUANT_MAP = {
'int': {
Expand Down Expand Up @@ -107,6 +108,10 @@
'asym': ShiftedUint8WeightPerChannelFloatHQO},
'per_group': {
'asym': ShiftedUint8WeightPerGroupFloatHQO}},},
'float_sparse_scale': {
'stats': {
'per_channel': {
'sym': Int8WeightPerChannelFloatSparse}}},
'po2_scale': {
'stats': {
'per_tensor': {
Expand Down
16 changes: 16 additions & 0 deletions src/brevitas_examples/common/generative/quantizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,3 +213,19 @@ class FP8e4m3FNUZDynamicActPerRowFloat(Fp8e4m3FNUZActPerTensorFloat):
scaling_stats_op = 'min_max'
scaling_per_output_channel = True
proxy_class = DynamicActFloatQuantProxyFromInjector

def sparse_ste(x, sparse_ratio):
pass

class SparseRoundSte(torch.nn.Module):
def __init__(self, sparse_ratio):
self.sparse_ratio = sparse_ratio
super().__init__()

def forward(self, x):
# Sparse logic here
return round_ste(x)

class Int8WeightPerChannelFloatSparse(Int8WeightPerChannelFloat):
float_to_int_impl = SparseRoundSte
sparse_ratio = 0.5
2 changes: 1 addition & 1 deletion src/brevitas_examples/llm/llm_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def create_args_parser() -> ArgumentParser:
'--weight-scale-precision',
type=str,
default='float_scale',
choices=['float_scale', 'po2_scale'],
choices=['float_scale', 'float_sparse_scale', 'po2_scale'],
help='Whether scale is a float value or a po2. Default: po2.')
parser.add_argument(
'--weight-quant-rescaling-init',
Expand Down