Skip to content

How do I perform Int8 activation and int8 weight QAT and export to onnx? #975

Open
@ben-da6

Description

@ben-da6

From the tutorials and recipes it looks like you can only do dynamic Int8 Int4? Also I cannot export the trained model to onnx?

import torch
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer


class MLP(torch.nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.other_fc = torch.nn.Linear(256, 4096)
        self.relu = torch.nn.ReLU()
        self.fc = torch.nn.Linear(4096, 10)

    def forward(self, x):
        x = self.other_fc(x)
        x = self.relu(x)
        x = self.fc(x)
        return x


model = MLP().cuda()
# Quantizer for int8 dynamic per token activations +
# int4 grouped per channel weights, only for linear layers
qat_quantizer = Int8DynActInt4WeightQATQuantizer()

# Insert "fake quantize" operations into linear layers.
# These operations simulate quantization numerics during
# training without performing any dtype casting
model = qat_quantizer.prepare(model)

# Standard training loop
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5)
loss_fn = torch.nn.MSELoss()
for i in range(10):
    example = torch.randn((2, 256)).cuda()
    target = torch.randn((2, 10)).cuda()
    output = model(example)
    loss = loss_fn(output, target)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

# Convert fake quantize to actual quantize operations
# The quantized model has the exact same structure as the
# quantized model produced in the corresponding PTQ flow
# through `Int8DynActInt4WeightQuantizer`
model = qat_quantizer.convert(model)

# inference or generate
torch.onnx.export(model, torch.randn(2, 256), "blah.onnx")

Errors with

Traceback (most recent call last):
  File "/persist/code/random/blah.py", line 48, in <module>
    torch.onnx.export(model, torch.randn(2, 256), "blah.onnx")
  File "/persist/envs/random/lib/python3.10/site-packages/torch/onnx/utils.py", line 551, in export
    _export(
  File "/persist/envs/random/lib/python3.10/site-packages/torch/onnx/utils.py", line 1648, in _export
    graph, params_dict, torch_out = _model_to_graph(
  File "/persist/envs/random/lib/python3.10/site-packages/torch/onnx/utils.py", line 1170, in _model_to_graph
    graph, params, torch_out, module = _create_jit_graph(model, args)
  File "/persist/envs/random/lib/python3.10/site-packages/torch/onnx/utils.py", line 1046, in _create_jit_graph
    graph, torch_out = _trace_and_get_graph_from_model(model, args)
  File "/persist/envs/random/lib/python3.10/site-packages/torch/onnx/utils.py", line 950, in _trace_and_get_graph_from_model
    trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph(
  File "/persist/envs/random/lib/python3.10/site-packages/torch/jit/_trace.py", line 1497, in _get_trace_graph
    outs = ONNXTracedModule(
  File "/persist/envs/random/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/persist/envs/random/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/persist/envs/random/lib/python3.10/site-packages/torch/jit/_trace.py", line 141, in forward
    graph, out = torch._C._create_graph_by_tracing(
  File "/persist/envs/random/lib/python3.10/site-packages/torch/jit/_trace.py", line 132, in wrapper
    outs.append(self.inner(*trace_inputs))
  File "/persist/envs/random/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/persist/envs/random/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/persist/envs/random/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1543, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/persist/code/random/blah.py", line 13, in forward
    x = self.other_fc(x)
  File "/persist/envs/random/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/persist/envs/random/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1603, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/persist/envs/random/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1543, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/persist/envs/random/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 117, in forward
    return F.linear(input, self.weight, self.bias)
  File "/persist/envs/random/lib/python3.10/site-packages/torchao/utils.py", line 372, in _dispatch__torch_function__
    return cls._ATEN_OP_OR_TORCH_FN_TABLE[func](func, types, args, kwargs)
  File "/persist/envs/random/lib/python3.10/site-packages/torchao/utils.py", line 355, in wrapper
    return func(f, types, args, kwargs)
  File "/persist/envs/random/lib/python3.10/site-packages/torchao/quantization/prototype/qat/affine_fake_quantized_tensor.py", line 237, in _
    input_tensor = input_tensor.get_value()
  File "/persist/envs/random/lib/python3.10/site-packages/torchao/quantization/prototype/qat/affine_fake_quantized_tensor.py", line 170, in get_value
    return self.apply_fake_quant_fn(self)
  File "/persist/envs/random/lib/python3.10/site-packages/torchao/quantization/prototype/qat/affine_fake_quantized_tensor.py", line 58, in apply_fake_quant_fn
    fq = _GenericFakeQuantize.apply(
  File "/persist/envs/random/lib/python3.10/site-packages/torch/autograd/function.py", line 574, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
RuntimeError: _Map_base::at

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions