You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
From the tutorials and recipes it looks like you can only do dynamic Int8 Int4? Also I cannot export the trained model to onnx?
import torch
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer
class MLP(torch.nn.Module):
def __init__(self):
super(MLP, self).__init__()
self.other_fc = torch.nn.Linear(256, 4096)
self.relu = torch.nn.ReLU()
self.fc = torch.nn.Linear(4096, 10)
def forward(self, x):
x = self.other_fc(x)
x = self.relu(x)
x = self.fc(x)
return x
model = MLP().cuda()
# Quantizer for int8 dynamic per token activations +
# int4 grouped per channel weights, only for linear layers
qat_quantizer = Int8DynActInt4WeightQATQuantizer()
# Insert "fake quantize" operations into linear layers.
# These operations simulate quantization numerics during
# training without performing any dtype casting
model = qat_quantizer.prepare(model)
# Standard training loop
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5)
loss_fn = torch.nn.MSELoss()
for i in range(10):
example = torch.randn((2, 256)).cuda()
target = torch.randn((2, 10)).cuda()
output = model(example)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
optimizer.zero_grad()
# Convert fake quantize to actual quantize operations
# The quantized model has the exact same structure as the
# quantized model produced in the corresponding PTQ flow
# through `Int8DynActInt4WeightQuantizer`
model = qat_quantizer.convert(model)
# inference or generate
torch.onnx.export(model, torch.randn(2, 256), "blah.onnx")
Errors with
Traceback (most recent call last):
File "/persist/code/random/blah.py", line 48, in <module>
torch.onnx.export(model, torch.randn(2, 256), "blah.onnx")
File "/persist/envs/random/lib/python3.10/site-packages/torch/onnx/utils.py", line 551, in export
_export(
File "/persist/envs/random/lib/python3.10/site-packages/torch/onnx/utils.py", line 1648, in _export
graph, params_dict, torch_out = _model_to_graph(
File "/persist/envs/random/lib/python3.10/site-packages/torch/onnx/utils.py", line 1170, in _model_to_graph
graph, params, torch_out, module = _create_jit_graph(model, args)
File "/persist/envs/random/lib/python3.10/site-packages/torch/onnx/utils.py", line 1046, in _create_jit_graph
graph, torch_out = _trace_and_get_graph_from_model(model, args)
File "/persist/envs/random/lib/python3.10/site-packages/torch/onnx/utils.py", line 950, in _trace_and_get_graph_from_model
trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph(
File "/persist/envs/random/lib/python3.10/site-packages/torch/jit/_trace.py", line 1497, in _get_trace_graph
outs = ONNXTracedModule(
File "/persist/envs/random/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/persist/envs/random/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/persist/envs/random/lib/python3.10/site-packages/torch/jit/_trace.py", line 141, in forward
graph, out = torch._C._create_graph_by_tracing(
File "/persist/envs/random/lib/python3.10/site-packages/torch/jit/_trace.py", line 132, in wrapper
outs.append(self.inner(*trace_inputs))
File "/persist/envs/random/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/persist/envs/random/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/persist/envs/random/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1543, in _slow_forward
result = self.forward(*input, **kwargs)
File "/persist/code/random/blah.py", line 13, in forward
x = self.other_fc(x)
File "/persist/envs/random/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/persist/envs/random/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1603, in _call_impl
result = forward_call(*args, **kwargs)
File "/persist/envs/random/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1543, in _slow_forward
result = self.forward(*input, **kwargs)
File "/persist/envs/random/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 117, in forward
return F.linear(input, self.weight, self.bias)
File "/persist/envs/random/lib/python3.10/site-packages/torchao/utils.py", line 372, in _dispatch__torch_function__
return cls._ATEN_OP_OR_TORCH_FN_TABLE[func](func, types, args, kwargs)
File "/persist/envs/random/lib/python3.10/site-packages/torchao/utils.py", line 355, in wrapper
return func(f, types, args, kwargs)
File "/persist/envs/random/lib/python3.10/site-packages/torchao/quantization/prototype/qat/affine_fake_quantized_tensor.py", line 237, in _
input_tensor = input_tensor.get_value()
File "/persist/envs/random/lib/python3.10/site-packages/torchao/quantization/prototype/qat/affine_fake_quantized_tensor.py", line 170, in get_value
return self.apply_fake_quant_fn(self)
File "/persist/envs/random/lib/python3.10/site-packages/torchao/quantization/prototype/qat/affine_fake_quantized_tensor.py", line 58, in apply_fake_quant_fn
fq = _GenericFakeQuantize.apply(
File "/persist/envs/random/lib/python3.10/site-packages/torch/autograd/function.py", line 574, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
RuntimeError: _Map_base::at
The text was updated successfully, but these errors were encountered:
ben-da6
changed the title
How do I perform Int8 activation and int8 weight QAT with an mlp?
How do I perform Int8 activation and int8 weight QAT and export to onnx?
Sep 30, 2024
Hi @ben-da6, there's currently no API for 8-bit activations + 8-bit weight yet. We're working on making this API more flexible, but for now you can manually modify this quantizer to change n_bits=4 to n_bits=8 for the weights (this line, there may be others in the same file). I don't think we have support for exporting to ONNX in general, but maybe @jerryzh168 has more context.
Please test with torch.onnx.export(..., dynamo=True, report=True) using the latest torch-nightly. Attach the generated report if there is an error. Thanks!
From the tutorials and recipes it looks like you can only do dynamic Int8 Int4? Also I cannot export the trained model to onnx?
Errors with
The text was updated successfully, but these errors were encountered: