Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
from torchao.prototype.mx_formats.inference_workflow import (
MXDynamicActivationMXWeightConfig,
NVFP4DynamicActivationNVFP4WeightConfig,
NVFP4WeightOnlyConfig,
)


Expand Down Expand Up @@ -1061,6 +1062,7 @@ def test_fqn_to_config_non_weight_param(self):
configs.append(MXDynamicActivationMXWeightConfig())
if is_sm_at_least_100() and torch_version_at_least("2.8.0"):
configs.append(NVFP4DynamicActivationNVFP4WeightConfig())
configs.append(NVFP4WeightOnlyConfig())
for config in configs:
with self.subTest(config=type(config).__name__):
model = torch.nn.Sequential(
Expand Down
22 changes: 18 additions & 4 deletions torchao/prototype/mx_formats/inference_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,10 +377,13 @@ def __post_init__(self):

@register_quantize_module_handler(NVFP4WeightOnlyConfig)
def _nvfp4_weight_only_linear_transform(
module: torch.nn.Linear, config: NVFP4WeightOnlyConfig
module: torch.nn.Linear,
config: NVFP4WeightOnlyConfig,
*,
parameter_name: str = "weight",
):
"""Quantization handler for NVFP4WeightOnlyConfig"""
weight = module.weight
weight = getattr(module, parameter_name)

if weight.shape[-2] % 16 != 0 or weight.shape[-1] % 16 != 0:
raise RuntimeError(
Expand All @@ -399,8 +402,19 @@ def _nvfp4_weight_only_linear_transform(
act_quant_kwargs=None,
)
# Set triton preference after construction
module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False)
module.extra_repr = types.MethodType(_linear_extra_repr, module)
setattr(
module,
parameter_name,
torch.nn.Parameter(quantized_weight, requires_grad=False),
)
module.extra_repr = types.MethodType(
partial(
_module_extra_repr,
original_extra_repr=module.extra_repr,
parameter_name=parameter_name,
),
module,
)
return module


Expand Down
Loading