Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] Long Term QAT Flow #987

Open
andrewor14 opened this issue Oct 1, 2024 · 7 comments
Open

[RFC] Long Term QAT Flow #987

andrewor14 opened this issue Oct 1, 2024 · 7 comments
Labels

Comments

@andrewor14
Copy link
Contributor

andrewor14 commented Oct 1, 2024

Currently torchao QAT has two APIs, tensor subclasses and module swap. The original plan was to deprecate and eventually remove the old module swap API in favor of the tensor subclass API. However, users are starting to rely on the module API for production uses due to gaps in the tensor subclass API. In this RFC, we discuss the few long term plans for these two APIs in torchao.

API Today

We use a quantizer API today to abstract the implementation details from the user. Currently we support both tensor subclass and module swap APIs using different quantizers:

from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer 
from torchao.quantization.prototype.qat._module_swap_api import Int8DynActInt4WeightQATQuantizerModuleSwap

# tensor subclass version
qat_quantizer = Int8DynActInt4WeightQATQuantizer()
# module swap version
qat_quantizer = Int8DynActInt4WeightQATQuantizerModuleSwap()

# prepare: inserts fake quantizes but keeps everything in bf16
fake_quantized_model = qat_quantizer.prepare(model)

# train or fine-tune as before
train(fake_quantized_model)

# convert: actually quantizes the model to lower bit-widths
quantized_model = qat_quantizer.convert(fake_quantized_model)

Module Swap vs Tensor Subclass

Although tensor subclasses are generally adopted in torchao, the main gap today are (1) the lack of general distributed support, and (2) steep learning curve. For these two reasons, some users prefer the module swap flow, and have begun implementing new features in this flow, such as embedding quantization and static QAT.

To summarize the pros and cons of both approaches:

  Tensor subclass Module swap
Consistency ✓ The rest of torchao uses tensor subclasses, including PTQ, quantized training, float8 inference, and sparsity. These all use the same quantize_ API. ✖ Diverges from PTQ flow (not necessarily a con)
Composability ✓ Better composability with other tensor subclasses such as DTensor and NestedJaggedTensor ✖ Pure module swap misses out on potential composability benefits with other tensor subclasses (no clear benefits for QAT today)
Distributed support ✖ Currently only supports FSDP2. Internal implementation of each distributed strategy is exposed to the subclass. There are problems with how tensor subclasses interact with FSDP1 and DDP. Fixing these is not a priority for the distributed team. ✓ Works with any distribution strategy, including non-PyTorch ones like FAIR FSDP
Developer experience ✖ Steep learning curve, difficult for new users to extend, confusing error messages ✓ Easy to understand and extend. Supports module-level features like range learning

We can separate tensor subclass usage into two categories:

  • Injection. This refers to how we insert fake quantization logic into the model. For example, tensor subclass injection means we look for nn.Linear modules and swap out the weight tensor, while module swap injection means we look for nn.Linear modules and swap out the whole module with our custom QATLinear. Today, the tensor subclass flow in torchao uses the former, while the module swap flow uses the latter.
  • Fake quantization implementation. This refers to how we represent fake quantization during training. We can use our custom AffineFakeQuantizedTensor to encode the desired fake quantization configurations, or we can use plain torch.Tensor.
  • We can combine these two in the same flow. For example, use module swap for injection and tensor subclass for data representation.

Long Term Flow

We propose to use module swap for injection and tensor subclass for implementing fake quantization in the long term. This has the following pros and cons compared to the alternatives:

  • ✓ Single QAT flow in torchao
  • ✓ Lower bar of entry; new users can continue to contribute features quickly
  • ✓ Consistent with float8 training in torchao
  • ✓ Composes well with other tensor subclasses like DTensor (e.g. cast to int8 before all-gather)
  • ✖ Need additional work to support all distributed strategies

Note: In the short term, we will continue to use plain torch.Tensors for fake quantization due to the lack of general distributed support for tensor subclasses. The distributed strategies we should support before migrating to the long term flow include DDP and FSDP1. Additionally, we should migrate only if tensor subclass composability provides meaningful performance benefits, such as faster fake quantization through efficient int8 kernels.

@jerryzh168
Copy link
Contributor

jerryzh168 commented Oct 1, 2024

@andrewor14 by Data Representation I think you meant how fake quantization is implemented right (fake quantized tensor v.s. using modules? not the final quantized tensor right, might be good to clarify (renaming to something else might be better I think)

andrewor14 added a commit that referenced this issue Oct 1, 2024
Summary: Following #987, this
commit makes module swap the main QAT flow today. We remove all
tensor subclass fake quantize injection logic since this is not
needed in both the long term and the short term plans for QAT.
In the short term, we will continue to use a full module swap
flow, and only migrate to the long term flow once there is
general distributed support for tensor subclasses and when
tensor subclass composability provides meaningful benefits.

Test Plan:
python test/quantization/test_qat.py
@andrewor14
Copy link
Contributor Author

by Data Representation I think you meant how fake quantization is implemented right (fake quantized tensor v.s. using modules? not the final quantized tensor right, might be good to clarify (renaming to something else might be better I think)

Yeah, this is referring to how the data is represented during fake quantization in the training phase (after prepare but before convert), not the final quantized data (after convert). I'm open to renaming suggestions if you have any

andrewor14 added a commit that referenced this issue Oct 1, 2024
Summary: Following #987, this
commit makes module swap the main QAT flow today. We remove all
tensor subclass fake quantize injection logic since this is not
needed in both the long term and the short term plans for QAT.
In the short term, we will continue to use a full module swap
flow, and only migrate to the long term flow once there is
general distributed support for tensor subclasses and when
tensor subclass composability provides meaningful benefits.

Test Plan:
python test/quantization/test_qat.py
@jerryzh168
Copy link
Contributor

Yeah, this is referring to how the data is represented during fake quantization in the training phase (after prepare but before convert), not the final quantized data (after convert). I'm open to renaming suggestions if you have any

maybe just "Fake Quantization Implementation"?

@gau-nernst
Copy link
Collaborator

Just curious. What are the problems that you observe with tensor subclass + DDP? I have used this combination in my other projects and it seems to work as expected (i.e. no errors, correct results)

@vkuzo
Copy link
Contributor

vkuzo commented Oct 2, 2024

Note: In the short term, we will continue to use module swap for data representation due to the lack of general distributed support for tensor subclasses.

To clarify, not sure that "module swap for data representation makes sense". Should this say "use plain torch.Tensor for data representation"?

For FSDP1 composability, from what I understand as long as you don't have model parameter wrappers, you can still use tensor subclass for data representation, and thus get the benefits of integrating with other distributed paradigms (TP/SP) and benefits of easily using low precision gemms.

@andrewor14
Copy link
Contributor Author

Just curious. What are the problems that you observe with tensor subclass + DDP? I have used this combination in my other projects and it seems to work as expected (i.e. no errors, correct results)

That's great to know. Can you share the links?

To clarify, not sure that "module swap for data representation makes sense". Should this say "use plain torch.Tensor for data representation"?

Sounds good. For FSDP1, the issue I ran into was moving the model to a different device moves only the outer tensor but not the inner tensor, and this is fundamental to how FSDP1 assigns model.data during model initialization (this line). I think @awgu mentioned this is something they'd like to fix eventually, but it's not the top priority at the moment.

@andrewor14 andrewor14 added the rfc label Oct 2, 2024
@gau-nernst
Copy link
Collaborator

@andrewor14 Train script https://github.com/gau-nernst/quantized-training/blob/main/llm_pretrain.py. DDP stuff is pretty standard, no changes. The subclasses are swapped by quantize_model(). This code is where I run some of the experiments for quantized training before opening PRs in torchao.

andrewor14 added a commit that referenced this issue Oct 4, 2024
Summary: Following #987, this
commit makes module swap the main QAT flow today. We remove all
tensor subclass fake quantize injection logic since this is not
needed in both the long term and the short term plans for QAT.
In the short term, we will continue to use a full module swap
flow, and only migrate to the long term flow once there is
general distributed support for tensor subclasses and when
tensor subclass composability provides meaningful benefits.

Test Plan:
python test/quantization/test_qat.py

[ghstack-poisoned]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

4 participants