Open
Description
From the tutorials and recipes it looks like you can only do dynamic Int8 Int4? Also I cannot export the trained model to onnx?
import torch
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer
class MLP(torch.nn.Module):
def __init__(self):
super(MLP, self).__init__()
self.other_fc = torch.nn.Linear(256, 4096)
self.relu = torch.nn.ReLU()
self.fc = torch.nn.Linear(4096, 10)
def forward(self, x):
x = self.other_fc(x)
x = self.relu(x)
x = self.fc(x)
return x
model = MLP().cuda()
# Quantizer for int8 dynamic per token activations +
# int4 grouped per channel weights, only for linear layers
qat_quantizer = Int8DynActInt4WeightQATQuantizer()
# Insert "fake quantize" operations into linear layers.
# These operations simulate quantization numerics during
# training without performing any dtype casting
model = qat_quantizer.prepare(model)
# Standard training loop
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5)
loss_fn = torch.nn.MSELoss()
for i in range(10):
example = torch.randn((2, 256)).cuda()
target = torch.randn((2, 10)).cuda()
output = model(example)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
optimizer.zero_grad()
# Convert fake quantize to actual quantize operations
# The quantized model has the exact same structure as the
# quantized model produced in the corresponding PTQ flow
# through `Int8DynActInt4WeightQuantizer`
model = qat_quantizer.convert(model)
# inference or generate
torch.onnx.export(model, torch.randn(2, 256), "blah.onnx")
Errors with
Traceback (most recent call last):
File "/persist/code/random/blah.py", line 48, in <module>
torch.onnx.export(model, torch.randn(2, 256), "blah.onnx")
File "/persist/envs/random/lib/python3.10/site-packages/torch/onnx/utils.py", line 551, in export
_export(
File "/persist/envs/random/lib/python3.10/site-packages/torch/onnx/utils.py", line 1648, in _export
graph, params_dict, torch_out = _model_to_graph(
File "/persist/envs/random/lib/python3.10/site-packages/torch/onnx/utils.py", line 1170, in _model_to_graph
graph, params, torch_out, module = _create_jit_graph(model, args)
File "/persist/envs/random/lib/python3.10/site-packages/torch/onnx/utils.py", line 1046, in _create_jit_graph
graph, torch_out = _trace_and_get_graph_from_model(model, args)
File "/persist/envs/random/lib/python3.10/site-packages/torch/onnx/utils.py", line 950, in _trace_and_get_graph_from_model
trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph(
File "/persist/envs/random/lib/python3.10/site-packages/torch/jit/_trace.py", line 1497, in _get_trace_graph
outs = ONNXTracedModule(
File "/persist/envs/random/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/persist/envs/random/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/persist/envs/random/lib/python3.10/site-packages/torch/jit/_trace.py", line 141, in forward
graph, out = torch._C._create_graph_by_tracing(
File "/persist/envs/random/lib/python3.10/site-packages/torch/jit/_trace.py", line 132, in wrapper
outs.append(self.inner(*trace_inputs))
File "/persist/envs/random/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/persist/envs/random/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/persist/envs/random/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1543, in _slow_forward
result = self.forward(*input, **kwargs)
File "/persist/code/random/blah.py", line 13, in forward
x = self.other_fc(x)
File "/persist/envs/random/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/persist/envs/random/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1603, in _call_impl
result = forward_call(*args, **kwargs)
File "/persist/envs/random/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1543, in _slow_forward
result = self.forward(*input, **kwargs)
File "/persist/envs/random/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 117, in forward
return F.linear(input, self.weight, self.bias)
File "/persist/envs/random/lib/python3.10/site-packages/torchao/utils.py", line 372, in _dispatch__torch_function__
return cls._ATEN_OP_OR_TORCH_FN_TABLE[func](func, types, args, kwargs)
File "/persist/envs/random/lib/python3.10/site-packages/torchao/utils.py", line 355, in wrapper
return func(f, types, args, kwargs)
File "/persist/envs/random/lib/python3.10/site-packages/torchao/quantization/prototype/qat/affine_fake_quantized_tensor.py", line 237, in _
input_tensor = input_tensor.get_value()
File "/persist/envs/random/lib/python3.10/site-packages/torchao/quantization/prototype/qat/affine_fake_quantized_tensor.py", line 170, in get_value
return self.apply_fake_quant_fn(self)
File "/persist/envs/random/lib/python3.10/site-packages/torchao/quantization/prototype/qat/affine_fake_quantized_tensor.py", line 58, in apply_fake_quant_fn
fq = _GenericFakeQuantize.apply(
File "/persist/envs/random/lib/python3.10/site-packages/torch/autograd/function.py", line 574, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
RuntimeError: _Map_base::at
Metadata
Metadata
Assignees
Labels
No labels