diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index f3e2852a65..fce977a971 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -77,6 +77,7 @@ from torchao.prototype.mx_formats.inference_workflow import ( MXDynamicActivationMXWeightConfig, NVFP4DynamicActivationNVFP4WeightConfig, + NVFP4WeightOnlyConfig, ) @@ -1061,6 +1062,7 @@ def test_fqn_to_config_non_weight_param(self): configs.append(MXDynamicActivationMXWeightConfig()) if is_sm_at_least_100() and torch_version_at_least("2.8.0"): configs.append(NVFP4DynamicActivationNVFP4WeightConfig()) + configs.append(NVFP4WeightOnlyConfig()) for config in configs: with self.subTest(config=type(config).__name__): model = torch.nn.Sequential( diff --git a/torchao/prototype/mx_formats/inference_workflow.py b/torchao/prototype/mx_formats/inference_workflow.py index ce2ecbaae0..447e5229e5 100644 --- a/torchao/prototype/mx_formats/inference_workflow.py +++ b/torchao/prototype/mx_formats/inference_workflow.py @@ -377,10 +377,13 @@ def __post_init__(self): @register_quantize_module_handler(NVFP4WeightOnlyConfig) def _nvfp4_weight_only_linear_transform( - module: torch.nn.Linear, config: NVFP4WeightOnlyConfig + module: torch.nn.Linear, + config: NVFP4WeightOnlyConfig, + *, + parameter_name: str = "weight", ): """Quantization handler for NVFP4WeightOnlyConfig""" - weight = module.weight + weight = getattr(module, parameter_name) if weight.shape[-2] % 16 != 0 or weight.shape[-1] % 16 != 0: raise RuntimeError( @@ -399,8 +402,19 @@ def _nvfp4_weight_only_linear_transform( act_quant_kwargs=None, ) # Set triton preference after construction - module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False) - module.extra_repr = types.MethodType(_linear_extra_repr, module) + setattr( + module, + parameter_name, + torch.nn.Parameter(quantized_weight, requires_grad=False), + ) + module.extra_repr = types.MethodType( + partial( + _module_extra_repr, + original_extra_repr=module.extra_repr, + parameter_name=parameter_name, + ), + module, + ) return module