Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@
get_current_accelerator_device,
is_sm_at_least_89,
is_sm_at_least_90,
is_sm_at_least_100,
torch_version_at_least,
unwrap_tensor_subclass,
)

Expand All @@ -72,6 +74,11 @@
except ModuleNotFoundError:
has_gemlite = False

from torchao.prototype.mx_formats.inference_workflow import (
MXDynamicActivationMXWeightConfig,
NVFP4DynamicActivationNVFP4WeightConfig,
)


def dynamic_quant(model, example_inputs):
m = torch.export.export(model, example_inputs, strict=True).module()
Expand Down Expand Up @@ -1050,6 +1057,10 @@ def test_fqn_to_config_non_weight_param(self):
Float8WeightOnlyConfig(),
Float8DynamicActivationFloat8WeightConfig(granularity=PerTensor()),
]
if is_sm_at_least_100():
configs.append(MXDynamicActivationMXWeightConfig())
if is_sm_at_least_100() and torch_version_at_least("2.8.0"):
configs.append(NVFP4DynamicActivationNVFP4WeightConfig())
for config in configs:
with self.subTest(config=type(config).__name__):
model = torch.nn.Sequential(
Expand Down
49 changes: 40 additions & 9 deletions torchao/prototype/mx_formats/inference_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import types
from dataclasses import dataclass
from functools import partial
from typing import Optional

import torch
Expand All @@ -27,7 +28,7 @@
QuantizeTensorToNVFP4Kwargs,
per_tensor_amax_to_scale,
)
from torchao.quantization.quant_api import _quantization_type
from torchao.quantization.quant_api import _module_extra_repr, _quantization_type
from torchao.quantization.quantize_.common.kernel_preference import KernelPreference
from torchao.quantization.quantize_.common.quantization_step import QuantizationStep
from torchao.quantization.transform_module import (
Expand Down Expand Up @@ -133,9 +134,12 @@ def _linear_extra_repr(self):

@register_quantize_module_handler(MXDynamicActivationMXWeightConfig)
def _mx_inference_linear_transform(
module: torch.nn.Module, config: MXDynamicActivationMXWeightConfig
module: torch.nn.Module,
config: MXDynamicActivationMXWeightConfig,
*,
parameter_name: str = "weight",
):
weight = module.weight
weight = getattr(module, parameter_name)

assert weight.dtype == torch.bfloat16, (
f"Only supporting bf16 out dtype for now, got {weight.dtype}"
Expand All @@ -159,8 +163,19 @@ def _mx_inference_linear_transform(
scaling_mode=config.scaling_mode,
)

module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False)
module.extra_repr = types.MethodType(_linear_extra_repr, module)
setattr(
module,
parameter_name,
torch.nn.Parameter(quantized_weight, requires_grad=False),
)
module.extra_repr = types.MethodType(
partial(
_module_extra_repr,
original_extra_repr=module.extra_repr,
parameter_name=parameter_name,
),
module,
)
return module


Expand Down Expand Up @@ -231,7 +246,10 @@ def __post_init__(self):

@register_quantize_module_handler(NVFP4DynamicActivationNVFP4WeightConfig)
def _nvfp4_inference_linear_transform(
module: torch.nn.Linear, config: NVFP4DynamicActivationNVFP4WeightConfig
module: torch.nn.Linear,
config: NVFP4DynamicActivationNVFP4WeightConfig,
*,
parameter_name: str = "weight",
):
"""Quantization handler for NVFP4DynamicActivationNVFP4WeightConfig

Expand All @@ -240,7 +258,7 @@ def _nvfp4_inference_linear_transform(
- CONVERT: Extract amax from observer, compute static per_tensor_scale, quantize
- None (default): Original dynamic quantization behavior
"""
weight = module.weight
weight = getattr(module, parameter_name)
if weight.shape[-2] % 16 != 0 or weight.shape[-1] % 16 != 0:
raise RuntimeError(
f"NVFP4 only supports weight shape with last 2 dims divisible by 16, got {weight.shape}"
Expand Down Expand Up @@ -297,6 +315,8 @@ def _nvfp4_inference_linear_transform(
"NVFP4 DYNAMIC mode is only supported on sm100+ machines"
)

weight = getattr(module, parameter_name)

per_tensor_scale = None
if config.use_dynamic_per_tensor_scale:
tensor_amax = torch.max(torch.abs(weight))
Expand All @@ -316,8 +336,19 @@ def _nvfp4_inference_linear_transform(
act_quant_kwargs=act_quant_kwargs,
)
quantized_weight.use_triton_kernel = config.use_triton_kernel
module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False)
module.extra_repr = types.MethodType(_linear_extra_repr, module)
setattr(
module,
parameter_name,
torch.nn.Parameter(quantized_weight, requires_grad=False),
)
module.extra_repr = types.MethodType(
partial(
_module_extra_repr,
original_extra_repr=module.extra_repr,
parameter_name=parameter_name,
),
module,
)
return module

else:
Expand Down
Loading