Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How do I perform Int8 activation and int8 weight QAT and export to onnx? #975

Open
ben-da6 opened this issue Sep 30, 2024 · 3 comments
Open
Assignees

Comments

@ben-da6
Copy link

ben-da6 commented Sep 30, 2024

From the tutorials and recipes it looks like you can only do dynamic Int8 Int4? Also I cannot export the trained model to onnx?

import torch
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer


class MLP(torch.nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.other_fc = torch.nn.Linear(256, 4096)
        self.relu = torch.nn.ReLU()
        self.fc = torch.nn.Linear(4096, 10)

    def forward(self, x):
        x = self.other_fc(x)
        x = self.relu(x)
        x = self.fc(x)
        return x


model = MLP().cuda()
# Quantizer for int8 dynamic per token activations +
# int4 grouped per channel weights, only for linear layers
qat_quantizer = Int8DynActInt4WeightQATQuantizer()

# Insert "fake quantize" operations into linear layers.
# These operations simulate quantization numerics during
# training without performing any dtype casting
model = qat_quantizer.prepare(model)

# Standard training loop
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5)
loss_fn = torch.nn.MSELoss()
for i in range(10):
    example = torch.randn((2, 256)).cuda()
    target = torch.randn((2, 10)).cuda()
    output = model(example)
    loss = loss_fn(output, target)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

# Convert fake quantize to actual quantize operations
# The quantized model has the exact same structure as the
# quantized model produced in the corresponding PTQ flow
# through `Int8DynActInt4WeightQuantizer`
model = qat_quantizer.convert(model)

# inference or generate
torch.onnx.export(model, torch.randn(2, 256), "blah.onnx")

Errors with

Traceback (most recent call last):
  File "/persist/code/random/blah.py", line 48, in <module>
    torch.onnx.export(model, torch.randn(2, 256), "blah.onnx")
  File "/persist/envs/random/lib/python3.10/site-packages/torch/onnx/utils.py", line 551, in export
    _export(
  File "/persist/envs/random/lib/python3.10/site-packages/torch/onnx/utils.py", line 1648, in _export
    graph, params_dict, torch_out = _model_to_graph(
  File "/persist/envs/random/lib/python3.10/site-packages/torch/onnx/utils.py", line 1170, in _model_to_graph
    graph, params, torch_out, module = _create_jit_graph(model, args)
  File "/persist/envs/random/lib/python3.10/site-packages/torch/onnx/utils.py", line 1046, in _create_jit_graph
    graph, torch_out = _trace_and_get_graph_from_model(model, args)
  File "/persist/envs/random/lib/python3.10/site-packages/torch/onnx/utils.py", line 950, in _trace_and_get_graph_from_model
    trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph(
  File "/persist/envs/random/lib/python3.10/site-packages/torch/jit/_trace.py", line 1497, in _get_trace_graph
    outs = ONNXTracedModule(
  File "/persist/envs/random/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/persist/envs/random/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/persist/envs/random/lib/python3.10/site-packages/torch/jit/_trace.py", line 141, in forward
    graph, out = torch._C._create_graph_by_tracing(
  File "/persist/envs/random/lib/python3.10/site-packages/torch/jit/_trace.py", line 132, in wrapper
    outs.append(self.inner(*trace_inputs))
  File "/persist/envs/random/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/persist/envs/random/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/persist/envs/random/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1543, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/persist/code/random/blah.py", line 13, in forward
    x = self.other_fc(x)
  File "/persist/envs/random/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/persist/envs/random/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1603, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/persist/envs/random/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1543, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/persist/envs/random/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 117, in forward
    return F.linear(input, self.weight, self.bias)
  File "/persist/envs/random/lib/python3.10/site-packages/torchao/utils.py", line 372, in _dispatch__torch_function__
    return cls._ATEN_OP_OR_TORCH_FN_TABLE[func](func, types, args, kwargs)
  File "/persist/envs/random/lib/python3.10/site-packages/torchao/utils.py", line 355, in wrapper
    return func(f, types, args, kwargs)
  File "/persist/envs/random/lib/python3.10/site-packages/torchao/quantization/prototype/qat/affine_fake_quantized_tensor.py", line 237, in _
    input_tensor = input_tensor.get_value()
  File "/persist/envs/random/lib/python3.10/site-packages/torchao/quantization/prototype/qat/affine_fake_quantized_tensor.py", line 170, in get_value
    return self.apply_fake_quant_fn(self)
  File "/persist/envs/random/lib/python3.10/site-packages/torchao/quantization/prototype/qat/affine_fake_quantized_tensor.py", line 58, in apply_fake_quant_fn
    fq = _GenericFakeQuantize.apply(
  File "/persist/envs/random/lib/python3.10/site-packages/torch/autograd/function.py", line 574, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
RuntimeError: _Map_base::at
@ben-da6 ben-da6 changed the title How do I perform Int8 activation and int8 weight QAT with an mlp? How do I perform Int8 activation and int8 weight QAT and export to onnx? Sep 30, 2024
@andrewor14 andrewor14 self-assigned this Sep 30, 2024
@andrewor14
Copy link
Contributor

Hi @ben-da6, there's currently no API for 8-bit activations + 8-bit weight yet. We're working on making this API more flexible, but for now you can manually modify this quantizer to change n_bits=4 to n_bits=8 for the weights (this line, there may be others in the same file). I don't think we have support for exporting to ONNX in general, but maybe @jerryzh168 has more context.

@jerryzh168
Copy link
Contributor

we have some discussions for onnx export here: #986

@justinchuby
Copy link

Please test with torch.onnx.export(..., dynamo=True, report=True) using the latest torch-nightly. Attach the generated report if there is an error. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants