diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index 7018d4cc93..f3e2852a65 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -63,6 +63,7 @@ is_sm_at_least_89, is_sm_at_least_90, is_sm_at_least_100, + torch_version_at_least, unwrap_tensor_subclass, ) @@ -75,6 +76,7 @@ from torchao.prototype.mx_formats.inference_workflow import ( MXDynamicActivationMXWeightConfig, + NVFP4DynamicActivationNVFP4WeightConfig, ) @@ -1057,6 +1059,8 @@ def test_fqn_to_config_non_weight_param(self): ] if is_sm_at_least_100(): configs.append(MXDynamicActivationMXWeightConfig()) + if is_sm_at_least_100() and torch_version_at_least("2.8.0"): + configs.append(NVFP4DynamicActivationNVFP4WeightConfig()) for config in configs: with self.subTest(config=type(config).__name__): model = torch.nn.Sequential( diff --git a/torchao/prototype/mx_formats/inference_workflow.py b/torchao/prototype/mx_formats/inference_workflow.py index 0bd9039de6..ce2ecbaae0 100644 --- a/torchao/prototype/mx_formats/inference_workflow.py +++ b/torchao/prototype/mx_formats/inference_workflow.py @@ -240,7 +240,10 @@ def __post_init__(self): @register_quantize_module_handler(NVFP4DynamicActivationNVFP4WeightConfig) def _nvfp4_inference_linear_transform( - module: torch.nn.Linear, config: NVFP4DynamicActivationNVFP4WeightConfig + module: torch.nn.Linear, + config: NVFP4DynamicActivationNVFP4WeightConfig, + *, + parameter_name: str = "weight", ): """Quantization handler for NVFP4DynamicActivationNVFP4WeightConfig @@ -249,7 +252,7 @@ def _nvfp4_inference_linear_transform( - CONVERT: Extract amax from observer, compute static per_tensor_scale, quantize - None (default): Original dynamic quantization behavior """ - weight = module.weight + weight = getattr(module, parameter_name) if weight.shape[-2] % 16 != 0 or weight.shape[-1] % 16 != 0: raise RuntimeError( f"NVFP4 only supports weight shape with last 2 dims divisible by 16, got {weight.shape}" @@ -306,6 +309,8 @@ def _nvfp4_inference_linear_transform( "NVFP4 DYNAMIC mode is only supported on sm100+ machines" ) + weight = getattr(module, parameter_name) + per_tensor_scale = None if config.use_dynamic_per_tensor_scale: tensor_amax = torch.max(torch.abs(weight)) @@ -325,8 +330,19 @@ def _nvfp4_inference_linear_transform( act_quant_kwargs=act_quant_kwargs, ) quantized_weight.use_triton_kernel = config.use_triton_kernel - module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False) - module.extra_repr = types.MethodType(_linear_extra_repr, module) + setattr( + module, + parameter_name, + torch.nn.Parameter(quantized_weight, requires_grad=False), + ) + module.extra_repr = types.MethodType( + partial( + _module_extra_repr, + original_extra_repr=module.extra_repr, + parameter_name=parameter_name, + ), + module, + ) return module else: