diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index 6eb05929da..7018d4cc93 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -62,6 +62,7 @@ get_current_accelerator_device, is_sm_at_least_89, is_sm_at_least_90, + is_sm_at_least_100, unwrap_tensor_subclass, ) @@ -72,6 +73,10 @@ except ModuleNotFoundError: has_gemlite = False +from torchao.prototype.mx_formats.inference_workflow import ( + MXDynamicActivationMXWeightConfig, +) + def dynamic_quant(model, example_inputs): m = torch.export.export(model, example_inputs, strict=True).module() @@ -1050,6 +1055,8 @@ def test_fqn_to_config_non_weight_param(self): Float8WeightOnlyConfig(), Float8DynamicActivationFloat8WeightConfig(granularity=PerTensor()), ] + if is_sm_at_least_100(): + configs.append(MXDynamicActivationMXWeightConfig()) for config in configs: with self.subTest(config=type(config).__name__): model = torch.nn.Sequential( diff --git a/torchao/prototype/mx_formats/inference_workflow.py b/torchao/prototype/mx_formats/inference_workflow.py index ea41b19169..f2f3d573fa 100644 --- a/torchao/prototype/mx_formats/inference_workflow.py +++ b/torchao/prototype/mx_formats/inference_workflow.py @@ -6,6 +6,7 @@ import types from dataclasses import dataclass +from functools import partial from typing import Optional import torch @@ -27,7 +28,7 @@ QuantizeTensorToNVFP4Kwargs, per_tensor_amax_to_scale, ) -from torchao.quantization.quant_api import _quantization_type +from torchao.quantization.quant_api import _module_extra_repr, _quantization_type from torchao.quantization.quantize_.common.kernel_preference import KernelPreference from torchao.quantization.quantize_.common.quantization_step import QuantizationStep from torchao.quantization.transform_module import ( @@ -133,9 +134,12 @@ def _linear_extra_repr(self): @register_quantize_module_handler(MXDynamicActivationMXWeightConfig) def _mx_inference_linear_transform( - module: torch.nn.Module, config: MXDynamicActivationMXWeightConfig + module: torch.nn.Module, + config: MXDynamicActivationMXWeightConfig, + *, + parameter_name: str = "weight", ): - weight = module.weight + weight = getattr(module, parameter_name) assert weight.dtype == torch.bfloat16, ( f"Only supporting bf16 out dtype for now, got {weight.dtype}" @@ -159,8 +163,19 @@ def _mx_inference_linear_transform( scaling_mode=config.scaling_mode, ) - module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False) - module.extra_repr = types.MethodType(_linear_extra_repr, module) + setattr( + module, + parameter_name, + torch.nn.Parameter(quantized_weight, requires_grad=False), + ) + module.extra_repr = types.MethodType( + partial( + _module_extra_repr, + original_extra_repr=module.extra_repr, + parameter_name=parameter_name, + ), + module, + ) return module