-
Notifications
You must be signed in to change notification settings - Fork 400
Open
Description
How to export a fp8 quantized model?
I got the following error:
RuntimeError: We found a fake tensor in the exported program constant's list. This typically means our tracing system encountered an op that we can't trace through. For the potential source, you can refer to following model attribute: linear.lifted_tensor_0. Please file an issue on github.
code:
import torch
from torchao.quantization.quant_api import (
quantize_,
Float8DynamicActivationFloat8WeightConfig
)
class SimpleNetwork(torch.nn.Module):
def __init__(self):
super(SimpleNetwork, self).__init__()
self.linear = torch.nn.Linear(in_features=32, out_features=16, bias=False)
def forward(self, x):
return self.linear(x)
model= SimpleNetwork().eval().cuda()
input = torch.randn(2, 32).cuda()
config = Float8DynamicActivationFloat8WeightConfig()
quantize_(model, config)
ep = torch.export.export(model, (input,), strict=False)
The same code works on torchao v0.11.0 around June 2025
Metadata
Metadata
Assignees
Labels
No labels