-
Notifications
You must be signed in to change notification settings - Fork 463
Open
Description
I'm using the script below to quantize the activation and weights of a toy linear model and want to export it using torch.export.export. but I run into NotImplementedError: UserDefinedObjectVariable(QuantizeTensorToFloat8Kwargs). Based on my investigation, QuantizeTensorToFloat8Kwargs is a dataclass inheriting from QuantizeTensorKwargs(abc.ABC). Dynamo has no special handling for dataclass and it falls through to the generic UserDefinedObjectVariable, whose as_proxy() hits the base NotImplementedError. Is there a recommendation on how I can work around this issue?
# Based on https://github.com/pytorch/ao/blob/main/tutorials/developer_api_guide/my_dtype_tensor_subclass.py
import torch
import torchao
from torchao.quantization import quantize_, Float8DynamicActivationFloat8WeightConfig, PerRow
from torchao.quantization.quantize_.workflows import Float8Tensor
class Float8TensorNonDecomposed(Float8Tensor):
"""Float8Tensor with non-decomposed dequantize"""
@classmethod
def from_float8_tensor(cls, tensor: Float8Tensor) -> "Float8TensorNonDecomposed":
"""Convert a Float8Tensor to Float8TensorNonDecomposed by copying its internals."""
tensor_data, tensor_attrs = tensor.__tensor_flatten__()
tensor_data_dict = {name: getattr(tensor, name) for name in tensor_data}
print(tensor_attrs)
return cls.__tensor_unflatten__(
tensor_data_dict, tensor_attrs, tensor.shape, tensor.stride()
)
def dequantize(self, output_dtype=None):
if output_dtype is None:
output_dtype = torch.float16
return torch.ops.torchao.dequantize_affine_float8_non_decomposed.default(
self.qdata, self.scale, output_dtype
)
class linear_model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear = torch.nn.Linear(16, 32)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.linear(x)
def convert_float8_to_float8_non_decomposed(model: torch.nn.Module) -> torch.nn.Module:
"""Convert all Float8Tensor parameters in the model to Float8TensorNonDecomposed."""
for name, module in model.named_modules():
for param_name, param in list(module.named_parameters(recurse=False)):
if isinstance(param, Float8Tensor):
print(f"[DEBUG] converting {name}.{param_name} to Float8TensorNonDecomposed")
new_param = torch.nn.Parameter(
Float8TensorNonDecomposed.from_float8_tensor(param),
requires_grad=param.requires_grad
)
setattr(module, param_name, new_param)
return torchao.utils.unwrap_tensor_subclass(model)
def main():
# initialize the model
model = linear_model().eval().to(torch.bfloat16).cuda()
input = torch.randn(1, 16).to(torch.bfloat16).cuda()
# quantize the model
quant_config = Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())
quantize_(model, quant_config)
model = convert_float8_to_float8_non_decomposed(model)
exp_program = torch.export.export(model, (input,), strict=True)
exp_program.graph_module.print_readable()
torch.cuda.empty_cache()
if __name__ == "__main__":
main()
Full error stacktrace
[DEBUG] converting linear.weight to Float8TensorNonDecomposed
{'block_size': [1, 16], 'mm_config': Float8MMConfig(emulate=False, use_fast_accum=True, pad_inner_dim=False), 'act_quant_kwargs': QuantizeTensorToFloat8Kwargs(float8_dtype=torch.float8_e4m3fn, granularity=PerRow(dim=-1), mm_config=None, hp_value_lb=None, hp_value_ub=None, kernel_preference=<KernelPreference.AUTO: 'auto'>), 'kernel_preference': <KernelPreference.AUTO: 'auto'>, 'dtype': torch.bfloat16}
NotImplementedError: UserDefinedObjectVariable(QuantizeTensorToFloat8Kwargs)
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/ao/torch-tensorrt-quantization/fp8_dynamic_quantization/linear_fp8_dynq_repro.py", line 64, in <module>
main()
File "/home/ao/torch-tensorrt-quantization/fp8_dynamic_quantization/linear_fp8_dynq_repro.py", line 58, in main
exp_program = torch.export.export(model, (input,), strict=True)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/export/__init__.py", line 311, in export
raise e
File "/usr/local/lib/python3.12/dist-packages/torch/export/__init__.py", line 277, in export
return _export(
^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/export/_trace.py", line 1271, in wrapper
raise e
File "/usr/local/lib/python3.12/dist-packages/torch/export/_trace.py", line 1237, in wrapper
ep = fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/export/exported_program.py", line 124, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/export/_trace.py", line 2380, in _export
ep = _export_for_training(
^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/export/_trace.py", line 1271, in wrapper
raise e
File "/usr/local/lib/python3.12/dist-packages/torch/export/_trace.py", line 1237, in wrapper
ep = fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/export/exported_program.py", line 124, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/export/_trace.py", line 2188, in _export_for_training
export_artifact = export_func(
^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/export/_trace.py", line 1526, in _strict_export
gm_torch_level = _export_to_torch_ir(
^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/export/_trace.py", line 861, in _export_to_torch_ir
gm_torch_level = dynamo_graph_capture(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/functional_export.py", line 759, in inner
out = fullgraph_capture(
^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/convert_frame.py", line 1208, in fullgraph_capture
return _fullgraph_capture_frame(
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/convert_frame.py", line 1273, in _fullgraph_capture_frame
raise e.with_traceback(None) from e.__cause__ # User compiler error
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
torch._dynamo.exc.Unsupported: Failed to convert args/kwargs to proxy
Explanation: Missing `as_proxy()` implementation for some arg/kwarg.
Developer debug context: call_function args: TensorVariable() TensorVariable() LazyVariableTracker(realized: ListVariable(length=2)) LazyVariableTracker(realized: NamedTupleVariable(length=3)) LazyVariableTracker(realized: UserDefinedObjectVariable(QuantizeTensorToFloat8Kwargs)) LazyVariableTracker(unrealized: <enum 'KernelPreference'>) LazyVariableTracker(unrealized: <class 'torch.dtype'>)
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0055.html
from user code:
File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/functional_export.py", line 216, in forward
res = self._export_root(*args, **kwargs)
File "/home/ao/torch-tensorrt-quantization/fp8_dynamic_quantization/linear_fp8_dynq_repro.py", line 32, in forward
return self.linear(x)
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/linear.py", line 134, in forward
return F.linear(input, self.weight, self.bias)
File "/usr/local/lib/python3.12/dist-packages/torch/nn/utils/parametrize.py", line 420, in get_parametrized
return parametrization()
File "/usr/local/lib/python3.12/dist-packages/torch/nn/utils/parametrize.py", line 315, in forward
x = self[0](*originals)
File "/home/.local/lib/python3.12/site-packages/torchao/utils.py", line 295, in forward
rebuilt = tp.__tensor_unflatten__(inner_tensors, meta, None, None)
File "/home/.local/lib/python3.12/site-packages/torchao/utils.py", line 955, in __tensor_unflatten__
return cls(
Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels