Skip to content

[Float8DynamicActivationFloat8WeightConfig] Enable Dynamic Quantized models to be exportable using torch.export.export #3928

@asfiyab-nvidia

Description

@asfiyab-nvidia

I'm using the script below to quantize the activation and weights of a toy linear model and want to export it using torch.export.export. but I run into NotImplementedError: UserDefinedObjectVariable(QuantizeTensorToFloat8Kwargs). Based on my investigation, QuantizeTensorToFloat8Kwargs is a dataclass inheriting from QuantizeTensorKwargs(abc.ABC). Dynamo has no special handling for dataclass and it falls through to the generic UserDefinedObjectVariable, whose as_proxy() hits the base NotImplementedError. Is there a recommendation on how I can work around this issue?

# Based on https://github.com/pytorch/ao/blob/main/tutorials/developer_api_guide/my_dtype_tensor_subclass.py
import torch
import torchao
from torchao.quantization import quantize_, Float8DynamicActivationFloat8WeightConfig, PerRow
from torchao.quantization.quantize_.workflows import Float8Tensor

class Float8TensorNonDecomposed(Float8Tensor):
    """Float8Tensor with non-decomposed dequantize"""
    @classmethod
    def from_float8_tensor(cls, tensor: Float8Tensor) -> "Float8TensorNonDecomposed":
        """Convert a Float8Tensor to Float8TensorNonDecomposed by copying its internals."""
        tensor_data, tensor_attrs = tensor.__tensor_flatten__()
        tensor_data_dict = {name: getattr(tensor, name) for name in tensor_data}
        print(tensor_attrs)
        return cls.__tensor_unflatten__(
            tensor_data_dict, tensor_attrs, tensor.shape, tensor.stride()
        )

    def dequantize(self, output_dtype=None):
        if output_dtype is None:
            output_dtype = torch.float16
        return torch.ops.torchao.dequantize_affine_float8_non_decomposed.default(
            self.qdata, self.scale, output_dtype
        )

class linear_model(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.linear = torch.nn.Linear(16, 32)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.linear(x)

def convert_float8_to_float8_non_decomposed(model: torch.nn.Module) -> torch.nn.Module:
    """Convert all Float8Tensor parameters in the model to Float8TensorNonDecomposed."""
    for name, module in model.named_modules():
        for param_name, param in list(module.named_parameters(recurse=False)):
            if isinstance(param, Float8Tensor):
                print(f"[DEBUG] converting {name}.{param_name} to Float8TensorNonDecomposed")
                new_param = torch.nn.Parameter(
                    Float8TensorNonDecomposed.from_float8_tensor(param),
                    requires_grad=param.requires_grad
                )
                setattr(module, param_name, new_param)
    return torchao.utils.unwrap_tensor_subclass(model)

def main():
    # initialize the model
    model = linear_model().eval().to(torch.bfloat16).cuda()
    input = torch.randn(1, 16).to(torch.bfloat16).cuda()

    # quantize the model
    quant_config = Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())
    quantize_(model, quant_config)

    model = convert_float8_to_float8_non_decomposed(model)

    exp_program = torch.export.export(model, (input,), strict=True)
    exp_program.graph_module.print_readable()

    torch.cuda.empty_cache()

if __name__ == "__main__":
    main()

Full error stacktrace

[DEBUG] converting linear.weight to Float8TensorNonDecomposed
{'block_size': [1, 16], 'mm_config': Float8MMConfig(emulate=False, use_fast_accum=True, pad_inner_dim=False), 'act_quant_kwargs': QuantizeTensorToFloat8Kwargs(float8_dtype=torch.float8_e4m3fn, granularity=PerRow(dim=-1), mm_config=None, hp_value_lb=None, hp_value_ub=None, kernel_preference=<KernelPreference.AUTO: 'auto'>), 'kernel_preference': <KernelPreference.AUTO: 'auto'>, 'dtype': torch.bfloat16}
NotImplementedError: UserDefinedObjectVariable(QuantizeTensorToFloat8Kwargs)

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/ao/torch-tensorrt-quantization/fp8_dynamic_quantization/linear_fp8_dynq_repro.py", line 64, in <module>
    main()
  File "/home/ao/torch-tensorrt-quantization/fp8_dynamic_quantization/linear_fp8_dynq_repro.py", line 58, in main
    exp_program = torch.export.export(model, (input,), strict=True)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/export/__init__.py", line 311, in export
    raise e
  File "/usr/local/lib/python3.12/dist-packages/torch/export/__init__.py", line 277, in export
    return _export(
           ^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/export/_trace.py", line 1271, in wrapper
    raise e
  File "/usr/local/lib/python3.12/dist-packages/torch/export/_trace.py", line 1237, in wrapper
    ep = fn(*args, **kwargs)
         ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/export/exported_program.py", line 124, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/export/_trace.py", line 2380, in _export
    ep = _export_for_training(
         ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/export/_trace.py", line 1271, in wrapper
    raise e
  File "/usr/local/lib/python3.12/dist-packages/torch/export/_trace.py", line 1237, in wrapper
    ep = fn(*args, **kwargs)
         ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/export/exported_program.py", line 124, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/export/_trace.py", line 2188, in _export_for_training
    export_artifact = export_func(
                      ^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/export/_trace.py", line 1526, in _strict_export
    gm_torch_level = _export_to_torch_ir(
                     ^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/export/_trace.py", line 861, in _export_to_torch_ir
    gm_torch_level = dynamo_graph_capture(*args, **kwargs)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/functional_export.py", line 759, in inner
    out = fullgraph_capture(
          ^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/convert_frame.py", line 1208, in fullgraph_capture
    return _fullgraph_capture_frame(
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/convert_frame.py", line 1273, in _fullgraph_capture_frame
    raise e.with_traceback(None) from e.__cause__  # User compiler error
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
torch._dynamo.exc.Unsupported: Failed to convert args/kwargs to proxy
  Explanation: Missing `as_proxy()` implementation for some arg/kwarg.


  Developer debug context: call_function args: TensorVariable() TensorVariable() LazyVariableTracker(realized: ListVariable(length=2)) LazyVariableTracker(realized: NamedTupleVariable(length=3)) LazyVariableTracker(realized: UserDefinedObjectVariable(QuantizeTensorToFloat8Kwargs)) LazyVariableTracker(unrealized: <enum 'KernelPreference'>) LazyVariableTracker(unrealized: <class 'torch.dtype'>)

 For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0055.html

from user code:
   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/functional_export.py", line 216, in forward
    res = self._export_root(*args, **kwargs)
  File "/home/ao/torch-tensorrt-quantization/fp8_dynamic_quantization/linear_fp8_dynq_repro.py", line 32, in forward
    return self.linear(x)
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/linear.py", line 134, in forward
    return F.linear(input, self.weight, self.bias)
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/utils/parametrize.py", line 420, in get_parametrized
    return parametrization()
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/utils/parametrize.py", line 315, in forward
    x = self[0](*originals)
  File "/home/.local/lib/python3.12/site-packages/torchao/utils.py", line 295, in forward
    rebuilt = tp.__tensor_unflatten__(inner_tensors, meta, None, None)
  File "/home/.local/lib/python3.12/site-packages/torchao/utils.py", line 955, in __tensor_unflatten__
    return cls(

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions